""
Residual Network (ResNet). We train a ResNet (7) with 32 convolutional layers. The ResNet-32
has a sequence of 15 residual blocks: the first 5 blocks have an output of shape 32 × 32 × 16, the
following 5 blocks have an output of shape 16×16×32 and the last 5 blocks have an output of shape
8×8×64. On top of these blocks, there is a 2×2 average pooling layer with stride of 2, followed by
a output layer of size 10 with softmax non-linearity. The ResNet-32 has ≈467k trainable parameters
in total.
""

In [6]:
import torch
from torch import nn

class ResBlock(nn.Module):
  """
  Residual block of 2 conv layers:
  Conv -> Norm -> Act -> Conv -> Norm -> Act
     |__[Optional: 1x1 Conv -> Norm]__|
  """
  def __init__(self, in_channels, mid_channels, out_channels, downsample=None):
    super().__init__()
    self.downsample = isinstance(downsample, int)
    self.conv1 = nn.Conv2d(in_channels, mid_channels, 3, padding=1, stride=(downsample or 1))
    self.norm1 = nn.BatchNorm2d(mid_channels)
    self.act1 = nn.ReLU()
    self.conv2 = nn.Conv2d(mid_channels, out_channels, 3, padding=1)
    self.norm2 = nn.BatchNorm2d(out_channels)
    self.act2 = nn.ReLU()
    if self.downsample:
      self.convp = nn.Conv2d(in_channels, out_channels, 1, padding=0, stride=downsample)
      self.normp = nn.BatchNorm2d(out_channels)
    
  def forward(self, x):
    x_ = self.act1(self.norm1(self.conv1(x)))
    x_ = self.norm2(self.conv2(x_))
    if self.downsample:
      x = self.normp(self.convp(x))
    x = x + x_
    return self.act2(x)

    
class ResNet(nn.Module):
  def __init__(self, channels=[16,32,64],
               num_classes=10):
    super().__init__()
    self.conv1 = nn.Conv2d(3, channels[0], 3, padding=1)
    self.block1 = nn.Sequential(
                                *[
                                ResBlock(channels[0], channels[0], channels[0])
                                for i in range(5)
                                ])
    self.block2 = nn.Sequential(ResBlock(channels[0], channels[1], channels[1], downsample=2),
                                *[
                                ResBlock(channels[1], channels[1], channels[1])
                                for i in range(4)
                                ])
    self.block3 = nn.Sequential(ResBlock(channels[1], channels[2], channels[2], downsample=2),
                                *[
                                ResBlock(channels[2], channels[2], channels[2])
                                for i in range(4)
                                ])
    self.pool = nn.AvgPool2d(8)
    self.flat_channels = channels[2]
    self.fc = nn.Linear(channels[2], num_classes)
    self.prob = nn.Softmax(dim=1)

  def forward(self, x):
    B = x.shape[0]
    x = self.conv1(x)
    x = self.block1(x)
    x = self.block2(x)
    x = self.block3(x)
    x = self.pool(x)
    x = torch.flatten(x,1)
    x = self.fc(x)
    return x

In [7]:
#Quick Test
model = ResNet()
out = model(torch.rand(128,3,32,32))
print(out)

tensor([[ 0.0141, -1.2120, -1.8141,  ..., -0.6015, -0.2684, -0.0842],
        [-0.2129, -1.0757, -1.7816,  ..., -0.5973, -0.1430,  0.0569],
        [-0.0033, -1.2660, -1.7629,  ..., -0.6755, -0.2155, -0.0640],
        ...,
        [-0.0920, -1.2395, -1.7326,  ..., -0.7915, -0.2694, -0.0891],
        [-0.0588, -1.1318, -1.4636,  ..., -0.4433, -0.3099,  0.1310],
        [-0.0652, -1.1346, -1.6320,  ..., -0.5530, -0.0829, -0.0390]],
       grad_fn=<AddmmBackward0>)


In [8]:
#"Load CIFAR10 Data"
import torchvision
import torchvision.transforms as transforms
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
     transforms.RandomVerticalFlip(),
     transforms.RandomHorizontalFlip()]
    )

batch_size = 128

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [9]:
import torch.optim as optim

net = ResNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [10]:
for epoch in range(2):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
#         print(loss)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')

KeyboardInterrupt: 