In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as nnf
from torchsummary import summary


In [42]:
class resenet_block(nn.Module):
    def __init__(self,channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels,channels,3,padding='same')
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels,channels,3,padding='same')
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self,x):
        y = self.conv1(x)
        y = nnf.relu(self.bn1(y))
        y = self.conv2(y)
        y = nnf.relu(self.bn2(x+y))
        return y
        


class Resnet(nn.Module):
    def __init__(self, in_channels, output_dimension=1,downscale_architecture='resnet34'):
        super().__init__()
        self.downscale_architecture = downscale_architecture
        self.model_layers = []
        if self.downscale_architecture == 'resnet34':
            self.downscale_architecture = [3,3,5,3]
            
        self.model_layers.append(nn.Conv2d(in_channels,64,7,padding='same'))
        self.model_layers.append(nn.BatchNorm2d(64))
        self.model_layers.append(nn.ReLU())
        #add relu
        self.model_layers.append(nn.MaxPool2d(3,stride=2,padding=1))
        in_channels = 64



        for count,resnet_section_length in enumerate(self.downscale_architecture):
            for i in range(resnet_section_length):
                self.model_layers.append(resenet_block(in_channels))
            if count < (len(self.downscale_architecture)-1):
                self.model_layers.append(nn.Conv2d(in_channels,in_channels*2,kernel_size=3,stride=2,padding=1))
                in_channels = in_channels*2
                self.model_layers.append(nn.BatchNorm2d(in_channels))
                self.model_layers.append(nn.ReLU())
                self.model_layers.append(nn.Conv2d(in_channels,in_channels,kernel_size=3,stride=1,padding='same'))
                self.model_layers.append(nn.BatchNorm2d(in_channels))
                self.model_layers.append(nn.ReLU())

        self.model_layers.append(nn.AdaptiveAvgPool2d((1,1)))
        #add flatten
        self.model_layers.append(nn.Flatten())
        self.model_layers.append(nn.Linear(in_channels,1000))
        self.model_layers.append(nn.ReLU())
        self.model_layers.append(nn.Linear(1000,output_dimension))

        self.model = nn.Sequential(*self.model_layers)

    def forward(self,x):
        return self.model(x)


    # def model_train(self,dataloader,epochs):
    #     self.model.train()
    #     for 




In [44]:
net  = Resnet(3,32,[3,3,3,3,3])  
summary(net,(3,64,64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 64, 64]           9,472
       BatchNorm2d-2           [-1, 64, 64, 64]             128
              ReLU-3           [-1, 64, 64, 64]               0
         MaxPool2d-4           [-1, 64, 32, 32]               0
            Conv2d-5           [-1, 64, 32, 32]          36,928
       BatchNorm2d-6           [-1, 64, 32, 32]             128
            Conv2d-7           [-1, 64, 32, 32]          36,928
       BatchNorm2d-8           [-1, 64, 32, 32]             128
     resenet_block-9           [-1, 64, 32, 32]               0
           Conv2d-10           [-1, 64, 32, 32]          36,928
      BatchNorm2d-11           [-1, 64, 32, 32]             128
           Conv2d-12           [-1, 64, 32, 32]          36,928
      BatchNorm2d-13           [-1, 64, 32, 32]             128
    resenet_block-14           [-1, 64,

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1
