/
resnet.py
156 lines (118 loc) · 4.73 KB
/
resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
import torch.nn as nn
from torch.nn.functional import relu
from torch.nn import init
import math
class DownSample(nn.Module):
def __init__(self, stride):
super(DownSample, self).__init__()
assert stride == 2
self.avg = nn.AvgPool2d(kernel_size=1, stride=stride)
def forward(self, x):
x = self.avg(x)
return torch.cat((x, x.mul(0)), 1)
class ResNetBasicblock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, down_sample=None):
super(ResNetBasicblock, self).__init__()
self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn_a = nn.BatchNorm2d(planes)
self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn_b = nn.BatchNorm2d(planes)
self.down_sample = down_sample
def forward(self, x):
residual = x
basicblock = self.conv_a(x)
basicblock = self.bn_a(basicblock)
basicblock = relu(basicblock, inplace=True)
basicblock = self.conv_b(basicblock)
basicblock = self.bn_b(basicblock)
if self.down_sample is not None:
residual = self.down_sample(x)
return relu(residual + basicblock, inplace=True)
class CifarResNet(nn.Module):
def __init__(self, block, depth, num_classes, cfg=[16, 32, 64]):
""" Constructor
Args:
depth: number of layers.
num_classes: number of classes
base_width: base width
"""
super(CifarResNet, self).__init__()
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
assert depth in [20, 32, 44, 56, 110]
layer_blocks = (depth - 2) // 6
print('CifarResNet : Depth : {} , Layers for each block : {}'.format(depth, layer_blocks))
self.num_classes = num_classes
self.conv_1_3x3 = nn.Conv2d(3, cfg[0], kernel_size=3, stride=1, padding=1, bias=False)
self.bn_1 = nn.BatchNorm2d(cfg[0])
self.inplanes = cfg[0]
self.stage_1 = self._make_layer(block, cfg[0], layer_blocks, 1)
self.stage_2 = self._make_layer(block, cfg[1], layer_blocks, 2)
self.stage_3 = self._make_layer(block, cfg[2], layer_blocks, 2)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(cfg[2] * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
# m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
down_sample = None
if stride != 1 or self.inplanes != planes * block.expansion:
down_sample = DownSample(stride)
layers = [block(self.inplanes, planes, stride, down_sample)]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv_1_3x3(x)
x = relu(self.bn_1(x), inplace=True)
x = self.stage_1(x)
x = self.stage_2(x)
x = self.stage_3(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return self.classifier(x)
def resnet20(num_classes=10):
"""Constructs a ResNet-20 model for CIFAR-10 (by default)
Args:
num_classes (uint): number of classes
"""
model = CifarResNet(ResNetBasicblock, 20, num_classes)
return model
def resnet32(num_classes=10, cfg=[16, 32, 64]):
"""Constructs a ResNet-32 model for CIFAR-10 (by default)
Args:
num_classes (uint): number of classes
"""
model = CifarResNet(ResNetBasicblock, 32, num_classes, cfg)
return model
def resnet44(num_classes=10):
"""Constructs a ResNet-44 model for CIFAR-10 (by default)
Args:
num_classes (uint): number of classes
"""
model = CifarResNet(ResNetBasicblock, 44, num_classes)
return model
def resnet56(num_classes=10):
"""Constructs a ResNet-56 model for CIFAR-10 (by default)
Args:
num_classes (uint): number of classes
"""
model = CifarResNet(ResNetBasicblock, 56, num_classes)
return model
def resnet110(num_classes=10):
"""Constructs a ResNet-110 model for CIFAR-10 (by default)
Args:
num_classes (uint): number of classes
"""
model = CifarResNet(ResNetBasicblock, 110, num_classes)
return model