In [83]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torchvision import models
import torchinfo
from torchvision import transforms


In [62]:
class BuildingBlock(nn.Module):
    def __init__(self, in_feats: int, out_feats: int, first_stride=1):
        super().__init__()
        self.left = nn.Sequential(
            nn.Conv2d(in_feats, out_feats, kernel_size=3, stride=first_stride, padding=1, bias=False),
            nn.BatchNorm2d(out_feats),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_feats, out_feats, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_feats)
        )

        if first_stride > 1:
            self.right = nn.Sequential(
                nn.Conv2d(in_feats, out_feats, kernel_size=1, stride=first_stride, padding=0, bias=False),
                nn.BatchNorm2d(out_feats)
            )
        else:
            self.right = nn.Sequential( # Redundant, but for consistency
                nn.Identity()
            )
        

    def forward(self, x):
        left_out = self.left(x)
        right_out = self.right(x)
        print(left_out.shape, right_out.shape)
        return F.relu(left_out + right_out)

In [63]:
class Group(nn.Module):
    def __init__(self, n_blocks: int, in_feats: int, out_feats: int, first_stride=1):
        super().__init__()
        self.blocks = nn.Sequential(
            BuildingBlock(in_feats, out_feats, first_stride),
            *[BuildingBlock(out_feats, out_feats) for _ in range(n_blocks - 1)]
        )

    def forward(self, x):
        return self.blocks(x)

In [64]:
class ResNet34(nn.Module):
    def __init__(self, groupSizes=[3, 4, 6, 3], groupFeats=[64, 128, 256, 512], groupFirstStrides=[1, 2, 2, 2], num_classes=1000):
        super().__init__()
        inGroupFeats = [groupFeats[0]] + groupFeats[:-1]
        print(inGroupFeats)

        self.net = nn.Sequential(
            nn.Conv2d(3, groupFeats[0], 7, 2, 3, bias=False),
            nn.BatchNorm2d(groupFeats[0]),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, 1),
            *[Group(groupSizes[i], inGroupFeats[i], groupFeats[i], groupFirstStrides[i]) for i in range(len(groupSizes))],
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(groupFeats[-1], num_classes)
        )

    def forward(self, x):
        return self.net(x)

In [65]:
temp = ResNet34()
temp(torch.randn(1, 3, 224, 224)).shape

[64, 64, 128, 256]
torch.Size([1, 64, 56, 56]) torch.Size([1, 64, 56, 56])
torch.Size([1, 64, 56, 56]) torch.Size([1, 64, 56, 56])
torch.Size([1, 64, 56, 56]) torch.Size([1, 64, 56, 56])
torch.Size([1, 128, 28, 28]) torch.Size([1, 128, 28, 28])
torch.Size([1, 128, 28, 28]) torch.Size([1, 128, 28, 28])
torch.Size([1, 128, 28, 28]) torch.Size([1, 128, 28, 28])
torch.Size([1, 128, 28, 28]) torch.Size([1, 128, 28, 28])
torch.Size([1, 256, 14, 14]) torch.Size([1, 256, 14, 14])
torch.Size([1, 256, 14, 14]) torch.Size([1, 256, 14, 14])
torch.Size([1, 256, 14, 14]) torch.Size([1, 256, 14, 14])
torch.Size([1, 256, 14, 14]) torch.Size([1, 256, 14, 14])
torch.Size([1, 256, 14, 14]) torch.Size([1, 256, 14, 14])
torch.Size([1, 256, 14, 14]) torch.Size([1, 256, 14, 14])
torch.Size([1, 512, 7, 7]) torch.Size([1, 512, 7, 7])
torch.Size([1, 512, 7, 7]) torch.Size([1, 512, 7, 7])
torch.Size([1, 512, 7, 7]) torch.Size([1, 512, 7, 7])


torch.Size([1, 1000])

In [66]:
resnet = models.resnet34()
print(torchinfo.summary(resnet, input_size=(1, 3, 64, 64)))

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [1, 1000]                 --
├─Conv2d: 1-1                            [1, 64, 32, 32]           9,408
├─BatchNorm2d: 1-2                       [1, 64, 32, 32]           128
├─ReLU: 1-3                              [1, 64, 32, 32]           --
├─MaxPool2d: 1-4                         [1, 64, 16, 16]           --
├─Sequential: 1-5                        [1, 64, 16, 16]           --
│    └─BasicBlock: 2-1                   [1, 64, 16, 16]           --
│    │    └─Conv2d: 3-1                  [1, 64, 16, 16]           36,864
│    │    └─BatchNorm2d: 3-2             [1, 64, 16, 16]           128
│    │    └─ReLU: 3-3                    [1, 64, 16, 16]           --
│    │    └─Conv2d: 3-4                  [1, 64, 16, 16]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 16, 16]           128
│    │    └─ReLU: 3-6                    [1, 64, 16, 16]           --
│

In [67]:
def copy_weights(myresnet, pretrained_resnet):
    mydict = myresnet.state_dict()
    pretraineddict = pretrained_resnet.state_dict()
    state_dict_to_load = {
        mykey: pretrainedvalue
        for (mykey, _), (_, pretrainedvalue) in zip(mydict.items(), pretraineddict.items())
    }

    myresnet.load_state_dict(state_dict_to_load)

    return myresnet


pretrained_resnet = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
myresnet = ResNet34()
myresnet = copy_weights(myresnet, pretrained_resnet)

[64, 64, 128, 256]


In [68]:
save_path = "model.pth"
torch.save(myresnet.state_dict(), save_path)

In [80]:
from PIL import Image

In [90]:
astro = Image.open('astro.jpg')
IMAGE_SIZE = 224
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

IMAGENET_TRANSFORM = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

prepared_image = IMAGENET_TRANSFORM(astro).unsqueeze(0)




In [101]:
import requests
pretrained_resnet.eval()
LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
labels = requests.get(LABELS_URL).json()

with torch.no_grad():
    output = pretrained_resnet(prepared_image)
    print(output.argmax().item(), labels[output.argmax().item()])

570 gas mask
