In [1]:
import torch 
import torch.nn as nn
# https://medium.com/@karuneshu21/resnet-paper-walkthrough-b7f3bdba55f0
# https://arxiv.org/pdf/1512.03385.pdf
# https://medium.com/@karuneshu21/how-to-resnet-in-pytorch-9acb01f36cf5

In [41]:
class ResNet(nn.Module):
    def __init__(self,in_channels,num_classes) -> None:
        super().__init__()

        self.conv1=nn.Conv2d(in_channels,64,kernel_size=7,stride=2,padding=3)
        self.bn1=nn.BatchNorm2d(64)
        self.relu=nn.ReLU()
        self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)

        self.avgpool=nn.AdaptiveAvgPool2d((1,1))
        self.fc=nn.Linear(512,num_classes)

        self.in_channels=64
        self.out_channels=64

        # self.conv2=[]
        # for i in range(layers[0]):
        #     self.conv2.append(nn.Conv2d(self.in_channels,self.out_channels,kernel_size=3,padding=2,stride=1))
        # self.bn2=nn.BatchNorm2d(self.out_channels)

        self.conv2_1=nn.Conv2d(self.in_channels,self.out_channels,kernel_size=3,padding=1,stride=1)
        self.bn2_1=nn.BatchNorm2d(self.out_channels)

        self.out_channels=self.out_channels*2
        self.conv3_1=nn.Conv2d(self.in_channels,self.out_channels,kernel_size=3,padding=1,stride=2)
        self.bn3_1=nn.BatchNorm2d(self.out_channels)

        self.in_channels=self.out_channels
        self.out_channels=self.out_channels*2
        self.conv4_1=nn.Conv2d(self.in_channels,self.out_channels,kernel_size=3,padding=1,stride=2)
        self.bn4_1=nn.BatchNorm2d(self.out_channels)


        self.in_channels=self.out_channels
        self.out_channels=self.out_channels*2
        self.conv5_1=nn.Conv2d(self.in_channels,self.out_channels,kernel_size=3,padding=1,stride=2)
        self.bn5_1=nn.BatchNorm2d(self.out_channels)

    def forward(self,x):
       

        x=self.conv1(x)
        print(f"after conv1 channels: 64 kernel_size: 7 stride: 2 padding: 3 {x.shape}")
        x=self.bn1(x)
        x=self.relu(x)
        x=self.maxpool(x)
        print(f"after maxpool kernel_size: 3 stride: 2 padding: 1 {x.shape}")
        identity=x
        x=self.conv2_1(x)
        x=self.bn2_1(x)
        x=self.relu(x)
        x+=identity
        print(f"identity shape {identity.shape}")
        print(f"after conv2 channels: 64 kernel_size: 3 stride: 1 padding: 1 {x.shape}")


        identity=x
        x=self.conv3_1(x)
        x=self.bn3_1(x)
        x=self.relu(x)
        identity_downsample=nn.Sequential(
            nn.Conv2d(in_channels=64,out_channels=128,kernel_size=1,stride=2,padding=0),
            nn.BatchNorm2d(num_features=128)
        )
        x+=identity_downsample(identity)
        print(f"identity shape {identity.shape}")
        print(f"after conv3 channels: 128 kernel_size: 3 stride: 2 padding: 1 {x.shape}")

        
        
        identity=x
        x=self.conv4_1(x)
        x=self.bn4_1(x)
        x=self.relu(x)
        identity_downsample=nn.Sequential(
            nn.Conv2d(in_channels=128,out_channels=256,kernel_size=1,stride=2,padding=0),
            nn.BatchNorm2d(num_features=256)
        )
        x+=identity_downsample(identity)
        print(f"identity shape {identity.shape}")
        print(f"after conv4 channels: 256 kernel_size: 3 stride: 2 padding: 1 {x.shape}")
        

        identity=x
        x=self.conv5_1(x)
        x=self.bn5_1(x)
        x=self.relu(x)
        identity_downsample=nn.Sequential(
            nn.Conv2d(in_channels=256,out_channels=512,kernel_size=1,stride=2,padding=0),
            nn.BatchNorm2d(num_features=512)
        )
        x+=identity_downsample(identity)
        print(f"identity shape {identity.shape}")
        print(f"after conv4 channels: 256 kernel_size: 3 stride: 2 padding: 1 {x.shape}")
        
        x=self.avgpool(x)
        print(f"after avgpool {x.shape}")
        x=x.reshape(x.shape[0],-1)
        print(f"after reshape {x.shape}")

        x=self.fc(x)
        
        return x

    # def __make_layers(self,num_residual_block,out_channels,stride):
        
    #     layers=[]
    #     for i in range(num_residual_block):
    #         layers.append(nn.Conv2d(self.in_channels,out_channels,kernel_size=3,padding=2,stride=stride))
        
    #     return(nn.Sequential(*layers))

In [43]:
net=ResNet(3,4)
x=torch.randn(2,3,224,224)
y=net(x)
print(f"out shape {y.shape}")


after conv1 channels: 64 kernel_size: 7 stride: 2 padding: 3 torch.Size([2, 64, 112, 112])
after maxpool kernel_size: 3 stride: 2 padding: 1 torch.Size([2, 64, 56, 56])
identity shape torch.Size([2, 64, 56, 56])
after conv2 channels: 64 kernel_size: 3 stride: 1 padding: 1 torch.Size([2, 64, 56, 56])
identity shape torch.Size([2, 64, 56, 56])
after conv3 channels: 128 kernel_size: 3 stride: 2 padding: 1 torch.Size([2, 128, 28, 28])
identity shape torch.Size([2, 128, 28, 28])
after conv4 channels: 256 kernel_size: 3 stride: 2 padding: 1 torch.Size([2, 256, 14, 14])
identity shape torch.Size([2, 256, 14, 14])
after conv4 channels: 256 kernel_size: 3 stride: 2 padding: 1 torch.Size([2, 512, 7, 7])
after avgpool torch.Size([2, 512, 1, 1])
after reshape torch.Size([2, 512])
out shape torch.Size([2, 4])
