-
Notifications
You must be signed in to change notification settings - Fork 179
/
resnet.py
235 lines (215 loc) · 8.63 KB
/
resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
from typing import List
from torch import Tensor
import torch.nn as nn
import brevitas.nn as qnn
from brevitas.quant import Int8WeightPerChannelFloat
from brevitas.quant import Int8WeightPerTensorFloat
from brevitas.quant import IntBias
from brevitas.quant import TruncTo8bit
from brevitas.quant_tensor import QuantTensor
def make_quant_conv2d(
in_channels,
out_channels,
kernel_size,
weight_bit_width,
weight_quant,
stride=1,
padding=0,
bias=False):
return qnn.QuantConv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
weight_quant=weight_quant,
weight_bit_width=weight_bit_width)
class QuantBasicBlock(nn.Module):
"""
Quantized BasicBlock implementation with extra relu activations to respect FINN constraints on the sign of residual
adds. Ok to train from scratch, but doesn't lend itself to e.g. retrain from torchvision.
"""
expansion = 1
def __init__(
self,
in_planes,
planes,
stride=1,
bias=False,
shared_quant_act=None,
act_bit_width=8,
weight_bit_width=8,
weight_quant=Int8WeightPerChannelFloat):
super(QuantBasicBlock, self).__init__()
self.conv1 = make_quant_conv2d(
in_planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=bias,
weight_bit_width=weight_bit_width,
weight_quant=weight_quant)
self.bn1 = nn.BatchNorm2d(planes)
self.relu1 = qnn.QuantReLU(bit_width=act_bit_width, return_quant_tensor=True)
self.conv2 = make_quant_conv2d(
planes,
planes,
kernel_size=3,
stride=1,
padding=1,
bias=bias,
weight_bit_width=weight_bit_width,
weight_quant=weight_quant)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.downsample = nn.Sequential(
make_quant_conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
padding=0,
bias=bias,
weight_bit_width=weight_bit_width,
weight_quant=weight_quant),
nn.BatchNorm2d(self.expansion * planes),
# We add a ReLU activation here because FINN requires the same sign along residual adds
qnn.QuantReLU(bit_width=act_bit_width, return_quant_tensor=True))
# Redefine shared_quant_act whenever shortcut is performing downsampling
shared_quant_act = self.downsample[-1]
if shared_quant_act is None:
shared_quant_act = qnn.QuantReLU(bit_width=act_bit_width, return_quant_tensor=True)
# We add a ReLU activation here because FINN requires the same sign along residual adds
self.relu2 = shared_quant_act
self.relu_out = qnn.QuantReLU(return_quant_tensor=True, bit_width=act_bit_width)
def forward(self, x):
out = self.relu1(self.bn1(self.conv1(x)))
out = self.relu2(self.bn2(self.conv2(out)))
if len(self.downsample):
x = self.downsample(x)
# Check that the addition is made explicitly among QuantTensor structures
assert isinstance(out, QuantTensor), "Perform add among QuantTensors"
assert isinstance(x, QuantTensor), "Perform add among QuantTensors"
out = out + x
out = self.relu_out(out)
return out
class QuantResNet(nn.Module):
def __init__(
self,
block_impl,
num_blocks: List[int],
first_maxpool=False,
zero_init_residual=False,
num_classes=10,
act_bit_width=8,
weight_bit_width=8,
round_average_pool=False,
last_layer_bias_quant=IntBias,
weight_quant=Int8WeightPerChannelFloat,
first_layer_weight_quant=Int8WeightPerChannelFloat,
last_layer_weight_quant=Int8WeightPerTensorFloat):
super(QuantResNet, self).__init__()
self.in_planes = 64
self.conv1 = make_quant_conv2d(
3,
64,
kernel_size=3,
stride=1,
padding=1,
weight_bit_width=8,
weight_quant=first_layer_weight_quant)
self.bn1 = nn.BatchNorm2d(64)
shared_quant_act = qnn.QuantReLU(bit_width=act_bit_width, return_quant_tensor=True)
self.relu = shared_quant_act
# MaxPool is typically present for ImageNet but not for CIFAR10
if first_maxpool:
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
else:
self.maxpool = nn.Identity()
self.layer1, shared_quant_act = self._make_layer(
block_impl, 64, num_blocks[0], 1, shared_quant_act, weight_bit_width, act_bit_width, weight_quant)
self.layer2, shared_quant_act = self._make_layer(
block_impl, 128, num_blocks[1], 2, shared_quant_act, weight_bit_width, act_bit_width, weight_quant)
self.layer3, shared_quant_act = self._make_layer(
block_impl, 256, num_blocks[2], 2, shared_quant_act, weight_bit_width, act_bit_width, weight_quant)
self.layer4, _ = self._make_layer(
block_impl, 512, num_blocks[3], 2, shared_quant_act, weight_bit_width, act_bit_width, weight_quant)
# Performs truncation to 8b (without rounding), which is supported in FINN
avgpool_float_to_int_impl_type = 'ROUND' if round_average_pool else 'FLOOR'
self.final_pool = qnn.TruncAvgPool2d(
kernel_size=4,
trunc_quant=TruncTo8bit,
float_to_int_impl_type=avgpool_float_to_int_impl_type)
# Keep last layer at 8b
self.linear = qnn.QuantLinear(
512 * block_impl.expansion,
num_classes,
weight_bit_width=8,
bias=True,
bias_quant=last_layer_bias_quant,
weight_quant=last_layer_weight_quant)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch
if zero_init_residual:
for m in self.modules():
if isinstance(m, QuantBasicBlock) and m.bn2.weight is not None:
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(
self,
block_impl,
planes,
num_blocks,
stride,
shared_quant_act,
weight_bit_width,
act_bit_width,
weight_quant):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
block = block_impl(
in_planes=self.in_planes,
planes=planes,
stride=stride,
bias=False,
shared_quant_act=shared_quant_act,
act_bit_width=act_bit_width,
weight_bit_width=weight_bit_width,
weight_quant=weight_quant)
layers.append(block)
shared_quant_act = layers[-1].relu_out
self.in_planes = planes * block_impl.expansion
return nn.Sequential(*layers), shared_quant_act
def forward(self, x: Tensor):
# There is no input quantizer, we assume the input is already 8b RGB
out = self.relu(self.bn1(self.conv1(x)))
out = self.maxpool(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.final_pool(out)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
def quant_resnet18(cfg) -> QuantResNet:
weight_bit_width = cfg.getint('QUANT', 'WEIGHT_BIT_WIDTH')
act_bit_width = cfg.getint('QUANT', 'ACT_BIT_WIDTH')
num_classes = cfg.getint('MODEL', 'NUM_CLASSES')
model = QuantResNet(
block_impl=QuantBasicBlock,
num_blocks=[2, 2, 2, 2],
num_classes=num_classes,
weight_bit_width=weight_bit_width,
act_bit_width=act_bit_width)
return model