""
Residual Network (ResNet). We train a ResNet (7) with 32 convolutional layers. The ResNet-32
has a sequence of 15 residual blocks: the first 5 blocks have an output of shape 32 × 32 × 16, the
following 5 blocks have an output of shape 16×16×32 and the last 5 blocks have an output of shape
8×8×64. On top of these blocks, there is a 2×2 average pooling layer with stride of 2, followed by
a output layer of size 10 with softmax non-linearity. The ResNet-32 has ≈467k trainable parameters
in total.
""

In [36]:
import torch
from torch import nn

class ResBlock(nn.Module):
  """
  Residual block of 2 conv layers:
  Conv -> Norm -> Act -> Conv -> Norm -> Act
     |__[Optional: 1x1 Conv -> Norm]__|
  """
  def __init__(self, in_channels, mid_channels, out_channels, downsample=None):
    super().__init__()
    self.downsample = isinstance(downsample, int)
    self.conv1 = nn.Conv2d(in_channels, mid_channels, 3, padding=1, stride=(downsample or 1))
    self.norm1 = nn.BatchNorm2d(mid_channels)
    self.act1 = nn.ReLU()
    self.conv2 = nn.Conv2d(mid_channels, out_channels, 3, padding=1)
    self.norm2 = nn.BatchNorm2d(out_channels)
    self.act2 = nn.ReLU()
    if self.downsample:
      self.convp = nn.Conv2d(in_channels, out_channels, 1, padding=0, stride=downsample)
      self.normp = nn.BatchNorm2d(out_channels)
    
  def forward(self, x):
    x_ = self.act1(self.norm1(self.conv1(x)))
    x_ = self.norm2(self.conv2(x_))
    if self.downsample:
      x = self.normp(self.convp(x))
    x = x + x_
    return self.act2(x)

    
class ResNet(nn.Module):
  def __init__(self, channels=[16,32,64],
               num_classes=10):
    super().__init__()
    self.conv1 = nn.Conv2d(3, channels[0], 3, padding=1)
    self.block1 = nn.Sequential(
                                *[
                                ResBlock(channels[0], channels[0], channels[0])
                                for i in range(5)
                                ])
    self.block2 = nn.Sequential(ResBlock(channels[0], channels[1], channels[1], downsample=2),
                                *[
                                ResBlock(channels[1], channels[1], channels[1])
                                for i in range(4)
                                ])
    self.block3 = nn.Sequential(ResBlock(channels[1], channels[2], channels[2], downsample=2),
                                *[
                                ResBlock(channels[2], channels[2], channels[2])
                                for i in range(4)
                                ])
    self.pool = nn.AvgPool2d(8)
    self.flat_channels = channels[2]
    self.fc = nn.Linear(channels[2], num_classes)
    self.prob = nn.Softmax(dim=1)

  def forward(self, x):
    B = x.shape[0]
    x = self.conv1(x)
    x = self.block1(x)
    x = self.block2(x)
    x = self.block3(x)
    x = self.pool(x)
    x = torch.flatten(x,1)
    x = self.fc(x)
    return x

In [37]:
#Quick Test
model = ResNet()
out = model(torch.rand(128,3,32,32))
print(out)

tensor([[-0.5709,  0.2133,  0.2073,  ...,  0.8972, -0.6640, -0.6786],
        [-0.4053,  0.4282,  0.0434,  ...,  0.8353, -0.5160, -0.6282],
        [-0.5241,  0.1957,  0.2826,  ...,  0.9182, -0.8445, -0.6121],
        ...,
        [-0.3689,  0.3178,  0.1000,  ...,  0.8236, -0.5086, -0.7364],
        [-0.5768,  0.4614,  0.1615,  ...,  0.8969, -0.4816, -0.7008],
        [-0.4936,  0.4332,  0.0858,  ...,  0.8797, -0.5195, -0.5641]],
       grad_fn=<AddmmBackward0>)


In [7]:
#"Load CIFAR10 Data"
import torchvision
import torchvision.transforms as transforms
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 128

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [38]:
import torch.optim as optim

net = ResNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [39]:
for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
#         print(loss)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')

tensor(2.8266, grad_fn=<NllLossBackward0>)
tensor(2.8425, grad_fn=<NllLossBackward0>)
tensor(2.7072, grad_fn=<NllLossBackward0>)
tensor(2.7832, grad_fn=<NllLossBackward0>)
tensor(2.7293, grad_fn=<NllLossBackward0>)
tensor(2.6233, grad_fn=<NllLossBackward0>)
tensor(2.6254, grad_fn=<NllLossBackward0>)
tensor(2.5808, grad_fn=<NllLossBackward0>)
tensor(2.5486, grad_fn=<NllLossBackward0>)
tensor(2.4508, grad_fn=<NllLossBackward0>)
tensor(2.4381, grad_fn=<NllLossBackward0>)
tensor(2.4479, grad_fn=<NllLossBackward0>)
tensor(2.4040, grad_fn=<NllLossBackward0>)
tensor(2.3363, grad_fn=<NllLossBackward0>)
tensor(2.3485, grad_fn=<NllLossBackward0>)
tensor(2.2748, grad_fn=<NllLossBackward0>)
tensor(2.3287, grad_fn=<NllLossBackward0>)
tensor(2.3511, grad_fn=<NllLossBackward0>)
tensor(2.3645, grad_fn=<NllLossBackward0>)
tensor(2.2906, grad_fn=<NllLossBackward0>)
tensor(2.3636, grad_fn=<NllLossBackward0>)
tensor(2.2634, grad_fn=<NllLossBackward0>)
tensor(2.2878, grad_fn=<NllLossBackward0>)
tensor(2.36

tensor(1.7752, grad_fn=<NllLossBackward0>)
tensor(1.8375, grad_fn=<NllLossBackward0>)
tensor(1.6683, grad_fn=<NllLossBackward0>)
tensor(1.7615, grad_fn=<NllLossBackward0>)
tensor(1.7918, grad_fn=<NllLossBackward0>)
tensor(1.8391, grad_fn=<NllLossBackward0>)
tensor(1.8736, grad_fn=<NllLossBackward0>)
tensor(1.8114, grad_fn=<NllLossBackward0>)
tensor(1.7053, grad_fn=<NllLossBackward0>)
tensor(1.6751, grad_fn=<NllLossBackward0>)
tensor(1.7012, grad_fn=<NllLossBackward0>)
tensor(1.7180, grad_fn=<NllLossBackward0>)
tensor(1.6665, grad_fn=<NllLossBackward0>)
tensor(1.6867, grad_fn=<NllLossBackward0>)
tensor(1.8155, grad_fn=<NllLossBackward0>)
tensor(1.5892, grad_fn=<NllLossBackward0>)
tensor(1.7539, grad_fn=<NllLossBackward0>)
tensor(1.7270, grad_fn=<NllLossBackward0>)
tensor(1.6993, grad_fn=<NllLossBackward0>)
tensor(1.6918, grad_fn=<NllLossBackward0>)
tensor(1.6667, grad_fn=<NllLossBackward0>)
tensor(1.7987, grad_fn=<NllLossBackward0>)
tensor(1.7323, grad_fn=<NllLossBackward0>)
tensor(1.60

tensor(1.6102, grad_fn=<NllLossBackward0>)
tensor(1.4700, grad_fn=<NllLossBackward0>)
tensor(1.5568, grad_fn=<NllLossBackward0>)
tensor(1.4058, grad_fn=<NllLossBackward0>)
tensor(1.3776, grad_fn=<NllLossBackward0>)
tensor(1.4904, grad_fn=<NllLossBackward0>)
tensor(1.5133, grad_fn=<NllLossBackward0>)
tensor(1.4415, grad_fn=<NllLossBackward0>)
tensor(1.5028, grad_fn=<NllLossBackward0>)
tensor(1.5534, grad_fn=<NllLossBackward0>)
tensor(1.3640, grad_fn=<NllLossBackward0>)
tensor(1.5604, grad_fn=<NllLossBackward0>)
tensor(1.5057, grad_fn=<NllLossBackward0>)
tensor(1.4157, grad_fn=<NllLossBackward0>)
tensor(1.5104, grad_fn=<NllLossBackward0>)
tensor(1.5762, grad_fn=<NllLossBackward0>)
tensor(1.5505, grad_fn=<NllLossBackward0>)
tensor(1.3368, grad_fn=<NllLossBackward0>)
tensor(1.6751, grad_fn=<NllLossBackward0>)
tensor(1.4886, grad_fn=<NllLossBackward0>)
tensor(1.6300, grad_fn=<NllLossBackward0>)
tensor(1.4378, grad_fn=<NllLossBackward0>)
tensor(1.6376, grad_fn=<NllLossBackward0>)
tensor(1.50

tensor(1.2943, grad_fn=<NllLossBackward0>)
tensor(1.3767, grad_fn=<NllLossBackward0>)
tensor(1.2973, grad_fn=<NllLossBackward0>)
tensor(1.4491, grad_fn=<NllLossBackward0>)
tensor(1.2355, grad_fn=<NllLossBackward0>)
tensor(1.2821, grad_fn=<NllLossBackward0>)
tensor(1.3228, grad_fn=<NllLossBackward0>)
tensor(1.2354, grad_fn=<NllLossBackward0>)
tensor(1.1499, grad_fn=<NllLossBackward0>)
tensor(1.4293, grad_fn=<NllLossBackward0>)
tensor(1.4289, grad_fn=<NllLossBackward0>)
tensor(1.3438, grad_fn=<NllLossBackward0>)
tensor(1.2490, grad_fn=<NllLossBackward0>)
tensor(1.3879, grad_fn=<NllLossBackward0>)
tensor(1.3611, grad_fn=<NllLossBackward0>)
tensor(1.4800, grad_fn=<NllLossBackward0>)
tensor(1.3949, grad_fn=<NllLossBackward0>)
tensor(1.3527, grad_fn=<NllLossBackward0>)
tensor(1.1377, grad_fn=<NllLossBackward0>)
tensor(1.2918, grad_fn=<NllLossBackward0>)
tensor(1.5727, grad_fn=<NllLossBackward0>)
tensor(1.5298, grad_fn=<NllLossBackward0>)
tensor(1.3267, grad_fn=<NllLossBackward0>)
tensor(1.34

tensor(1.3283, grad_fn=<NllLossBackward0>)
tensor(1.1713, grad_fn=<NllLossBackward0>)
tensor(1.1706, grad_fn=<NllLossBackward0>)
tensor(1.3185, grad_fn=<NllLossBackward0>)
tensor(1.2264, grad_fn=<NllLossBackward0>)
tensor(1.2334, grad_fn=<NllLossBackward0>)
tensor(1.3051, grad_fn=<NllLossBackward0>)
tensor(1.2245, grad_fn=<NllLossBackward0>)
tensor(1.1727, grad_fn=<NllLossBackward0>)
tensor(1.2028, grad_fn=<NllLossBackward0>)
tensor(1.0945, grad_fn=<NllLossBackward0>)
tensor(1.3852, grad_fn=<NllLossBackward0>)
tensor(1.2027, grad_fn=<NllLossBackward0>)
tensor(1.3019, grad_fn=<NllLossBackward0>)
tensor(1.2338, grad_fn=<NllLossBackward0>)
tensor(1.2114, grad_fn=<NllLossBackward0>)
tensor(1.2739, grad_fn=<NllLossBackward0>)
tensor(1.1069, grad_fn=<NllLossBackward0>)
Finished Training
