<a href="https://colab.research.google.com/github/Squirrelcoding/mini-projects/blob/main/ResNet_performance_visualizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
from torch import nn

In [9]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms

mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

print(f"Train size: {len(mnist_train)}")
print(f"Test size: {len(mnist_test)}")


Train size: 60000
Test size: 10000


In [58]:
class BasicBlock(nn.Module):
  def __init__(self, channels: int, out_channels: int, downsampling=False) -> None:
    super().__init__()

    stride = 2 if downsampling else 1

    self.channels = channels
    # First convolutional layer. If downsampling, we set the stride to 2.
    self.conv1 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=stride, padding=1)
    # batch normalization
    self.bn1 = nn.BatchNorm2d(num_features=out_channels)
    # ReLU. Not really necessary but helps with keeping track of stuff.
    self.relu1 = nn.ReLU()
    # second convolutional layer that increases channels if downsampling
    self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
    # batch normalization
    self.bn2 = nn.BatchNorm2d(num_features=out_channels)
    self.relu2 = nn.ReLU()

    self.downsampling = None
    if downsampling:
      # double the channels to keep time complexity of layers while halving the input
      self.downsampling = nn.Sequential(nn.Conv2d(in_channels=self.channels,
                                    out_channels=self.channels * 2,
                                    kernel_size=1,
                                    padding=0,
                                    stride=2), nn.BatchNorm2d(num_features=out_channels))

  def forward(self, x):
    x_copy = x
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu1(x)
    x = self.conv2(x)
    x = self.bn2(x)

    # downsample stuff to match dimensions of x and x_copy
    if self.downsampling:
      x_copy = self.downsampling(x_copy)

    x = x + x_copy
    x = self.relu2(x)
    return x

In [63]:
import torch

class ResNet18(nn.Module):
  def __init__(self, in_channels: int) -> None:
    super().__init__()

    # Deepnet part
    self.conv1 = nn.Conv2d(in_channels, out_channels=64, stride=2, kernel_size=7, padding=1)
    self.conv10 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

    self.conv2_1 = BasicBlock(64, 64)
    self.conv2_2 = BasicBlock(64, 64)
    self.conv3_1 = BasicBlock(64, 128, downsampling=True)
    self.conv3_2 = BasicBlock(128, 128)
    self.conv4_1 = BasicBlock(128,256, downsampling=True)
    self.conv4_2 = BasicBlock(256, 256)
    self.conv5_1 = BasicBlock(256, 512, downsampling=True)
    self.conv5_2 = BasicBlock(512, 512)

    # Classification head

    # Average pooling
    self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))

    self.fc_layer = nn.Linear(512, 1000)

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv10(x)
    x = self.conv2_1(x)
    x = self.conv2_2(x)
    x = self.conv3_1(x)
    x = self.conv3_2(x)
    x = self.conv4_1(x)
    x = self.conv4_2(x)
    x = self.conv5_1(x)
    x = self.conv5_2(x)

    x = self.avg_pool(x)

    x = torch.flatten(x, 1)

    x = self.fc_layer(x)
    return x

In [64]:
from PIL import Image

img = Image.open("goldfish.JPEG")

model = ResNet18(in_channels=3)

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], # values per colour channel [red, green, blue]
                                 std=[0.229, 0.224, 0.225]) # values per colour channel [red, green, blue]

# Compose transforms into a pipeline
simple_transform = transforms.Compose([
    transforms.Resize((224, 224)), # 1. Resize the images
    transforms.ToTensor(), # 2. Turn the images into tensors with values between 0 & 1
    normalize # 3. Normalize the images so their distributions match the ImageNet dataset
])

model(simple_transform(img).unsqueeze(0))

tensor([[-4.8363e-01,  1.1057e-01, -2.4305e-01, -5.2720e-01,  6.9542e-01,
          8.2866e-01, -2.9087e-01,  4.4073e-02, -5.8548e-01,  3.9606e-01,
          1.2809e-02, -2.7932e-01, -7.9662e-01, -4.3640e-01, -5.8902e-01,
         -6.3348e-01,  3.2920e-01,  1.2048e-01, -6.0553e-02,  6.2218e-01,
         -6.0868e-01, -6.6579e-02, -1.5499e-01, -4.5170e-01, -1.0995e+00,
         -4.2036e-01, -7.6332e-01,  3.1332e-01, -7.3152e-02, -1.0582e+00,
          2.8788e-01, -7.5309e-01,  1.3131e-01, -7.4459e-01, -5.0049e-01,
          3.6227e-01,  5.8785e-02,  1.6576e-01, -6.2292e-01,  1.6695e-01,
          3.6272e-01,  5.9019e-01, -5.0217e-02,  5.2307e-01,  4.1649e-02,
         -4.8069e-01, -5.6264e-01,  2.4871e-01, -6.9793e-02, -4.8371e-01,
          4.8500e-01, -3.5336e-02, -6.0810e-01, -8.8588e-03,  4.0147e-02,
         -8.9444e-03,  1.1247e-01, -8.4132e-02,  2.9706e-01, -6.7474e-01,
          1.3389e-01, -1.6477e-01, -7.6186e-01,  9.9340e-01, -1.6738e-01,
          2.1913e-01, -1.2276e-01, -3.