In [1]:
import torch
import torchvision
import torch.nn as nn

In [18]:
class Block(nn.Module):
    
    def __init__(self, in_channels, out_channels, identity_downsample=None, stride=1):
        super(Block, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.identity_downsample = identity_downsample
        
    def forward(self, x):
        identity = x
        print(identity.shape)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        if self.identity_downsample is not None:
            identity = self.identity_downsample(identity)
        print(identity.shape)
        x += identity
        x = self.relu(x)
        return x

In [20]:
class ResNet_18(nn.Module):
    
    def __init__(self, image_channels, num_classes):
        
        super(ResNet_18, self).__init__()
        self.in_channels = 4
        self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        #resnet layers
        self.layer1 = self.__make_layer(4, 64, stride=2)
    
        
    def __make_layer(self, in_channels, out_channels, stride):
        
        identity_downsample = None
        if stride != 1:
            identity_downsample = self.identity_downsample(in_channels, out_channels)
            
        return nn.Sequential(
            Block(in_channels, out_channels, identity_downsample=identity_downsample, stride=stride), 
#             Block(out_channels, out_channels)
        )
        
    def forward(self, x):    
        x = self.layer1(x)
      
        return x 
    
    def identity_downsample(self, in_channels, out_channels):
        
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1), 
            nn.BatchNorm2d(out_channels)
        )
model = ResNet_18(1, 10)
zers = torch.zeros(size=(1,4,33,33))
model.forward(zers)

torch.Size([1, 4, 33, 33])
torch.Size([1, 64, 17, 17])


tensor([[[[2.7313e-05, 6.1476e-05, 6.1476e-05,  ..., 6.1476e-05,
           6.1476e-05, 1.2509e-04],
          [2.1423e-05, 6.1082e-06, 6.1082e-06,  ..., 6.1082e-06,
           6.1082e-06, 2.7313e-05],
          [2.1423e-05, 6.1082e-06, 6.1082e-06,  ..., 6.1082e-06,
           6.1082e-06, 2.7313e-05],
          ...,
          [2.1423e-05, 6.1082e-06, 6.1082e-06,  ..., 6.1082e-06,
           6.1082e-06, 2.7313e-05],
          [2.1423e-05, 6.1082e-06, 6.1082e-06,  ..., 6.1082e-06,
           6.1082e-06, 2.7313e-05],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 2.9213e-05],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 1.1015e-04],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 1.1015e-04],
          ...,
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000