In [4]:
import torch
import torchvision
from torchvision import transforms, datasets
import torch.nn as nn
import torch.nn.functional as F

In [7]:
class MonodepthModel(nn.Module):
    """
    Model Architecture
    Encoder:
        Layer | Kernal size | stride | input-output | downsample input/output | input
        conv1       7            2       3/32               1/2                 left image
        conv1b      7            1       32/32              2/2                 conv1
        conv2       5            2       32/64              2/4                 conv1b
        conv2b      5            1       64/64              4/4                 conv2
        conv3       3            2       64/128             4/8                 conv2b
        conv3b      3            1       128/128            8/8                 conv3
        conv4       3            2       128/256            8/16                conv3b
        conv4b      3            1       256/256            16/16               conv4
        conv5       3            2       256/512            16/32               conv4b
        conv5b      3            1       512/512            32/32               conv5
        conv6       3            2       512/512            32/64               conv5b
        conv6b      3            1       512/512            64/64               conv6
        conv7       3            2       512/512            64/128              conv6b
        conv7b      3            1       512/512            128/128             conv7
    Deoder:
        upconv7     3            2       512/512            128/64              conv7b
        iconv7      3            1       1024/512           64/64               upconv7+conv6b
        upconv6     3            2       512/512            64/32               iconv7
        iconv6      3            1       1024/512           32/32               upconv6+conv5b
        upconv5     3            2       512/256            32/16               iconv6
        iconv5      3            1       512/256            16/16               upconv5+conv4b
        upconv4     3            2       256/128            16/8                iconv5
        iconv4      3            1       256/128            8/8                 upconv4+conv3b
        disp4       3            1       128/2              8/8                 iconv4
        upconv3     3            2       128/64             8/4                 iconv4
        iconv3      3            1       130/64             4/4                 upconv3+conv2b+disp4*
        disp3       3            1       64/2               4/4                 iconv3
        upconv2     3            2       64/32              4/2                 iconv3
        iconv2      3            1       66/32              2/2                 upconv2+conv1b+disp3*
        disp2       3            1       32/2               2/2                 iconv2
        upconv1     3            2       32/16              2/1                 iconv2
        iconv1      3            1       18/16              1/1                 upconv1+disp2*
        disp1       3            1       16/2               1/1                 iconv1
        
    ∗ is a 2× upsampling of the layer
    
    torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode:='zeros')
    torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    Padding size calculation [image\featuremap width\height = W, kernal size = K, strides = S, padding = P] 
    output_layer height/width = [(W-K+2P)/S]+1
    
    """

    def __init__(self):    
        super(MonodepthModel,self).__init__()
        
        #Encoder                                                   # input 512*256*3 
        self.conv1 = nn.Conv2d(3,32,7,stride=2,padding=3)          # 256*128*32
        self.conv1b = nn.BatchNorm2d(32, affine = False)           # 256*128*32
         
        self.conv2 = nn.Conv2d(32,64,5,stride=2,padding=2)         # 128*64*64
        self.conv2b = nn.BatchNorm2d(64, affine = False)           # 128*64*64

        self.conv3 = nn.Conv2d(64,128,3,stride=2,padding=1)        # 64*32*128
        self.conv3b = nn.BatchNorm2d(128, affine = False)          # 64*32*128
        
        self.conv4 = nn.Conv2d(128,256,3,stride=2,padding=1)       # 32*16*256
        self.conv4b = nn.BatchNorm2d(256, affine = False)          # 32*16*256

        self.conv5 = nn.Conv2d(256,512,3,stride=2,padding=1)       # 16*8*512
        self.conv5b = nn.BatchNorm2d(512, affine = False)          # 16*8*512
        
        self.conv6 = nn.Conv2d(512,512,3,stride=2,padding=1)       # 8*4*512
        self.conv6b = nn.BatchNorm2d(512, affine = False)          # 8*4*512
        
        self.conv7 = nn.Conv2d(512,512,3,stride=2,padding=1)       # 4*2*512
        self.conv7b = nn.BatchNorm2d(512, affine = False)          # 4*2*512
        
        #Decoder
               
        self.upconv7 = nn.Conv2d(512,512,3,stride=1,padding=1)    # 8*4*512 | input = 8*4*512 -> upsampling F.interpolate on conv7b
        self.iconv7 = nn.Conv2d(512+512,512,3,stride=1,padding=1) # 8*4*512 -> aditional input layers because of skip connection torch.cat
    
        self.upconv6 = nn.Conv2d(512,512,3,stride=1,padding=1)    # 16*8*512 -> upsampling F.interpolate
        self.iconv6 = nn.Conv2d(512+512,512,3,stride=1,padding=1) # 16*8*512
                
        self.upconv5 = nn.Conv2d(512,256,3,stride=1,padding=1)    # 32*16*256 
        self.iconv5 = nn.Conv2d(256+256,256,3,stride=1,padding=1) # 32*16*256
        
        self.upconv4 = nn.Conv2d(256,128,3,stride=1,padding=1)    # 64*32*128 
        self.iconv4 = nn.Conv2d(128+128,128,3,stride=1,padding=1) # 64*32*128
        self.disp4 = nn.Conv2d(128,2,3,stride=1,padding=1)        # 64*32*2
        
        self.upconv3 = nn.Conv2d(128,64,3,stride=1,padding=1)    # 128*64*64 
        self.iconv3 = nn.Conv2d(130,64,3,stride=1,padding=1)     # 128*64*64   concat [upconv3+conv2b+disp4*]
        self.disp3 = nn.Conv2d(64,2,3,stride=1,padding=1)        # 128*64*2

        self.upconv2 = nn.Conv2d(64,32,3,stride=1,padding=1)    # 256*128*32 
        self.iconv2 = nn.Conv2d(66,32,3,stride=1,padding=1)     # 256*128*32  
        self.disp2 = nn.Conv2d(32,2,3,stride=1,padding=1)       # 256*128*2
               
        self.upconv1 = nn.Conv2d(32,16,3,stride=1,padding=1)    # 512*256*16 
        self.iconv1 = nn.Conv2d(18,16,3,stride=1,padding=1)     # 512*256*16   
        self.disp1 = nn.Conv2d(16,2,3,stride=1,padding=1)       # 512*256*2
        
    def forward(self,x):
        
        # Encoder
        x_01 = F.elu(self.conv1b(self.conv1(x)))
        x_12 = F.elu(self.conv2b(self.conv2(x_01)))
        x_23 = F.elu(self.conv3b(self.conv3(x_12)))
        x_34 = F.elu(self.conv4b(self.conv4(x_23)))
        x_45 = F.elu(self.conv5b(self.conv5(x_34)))
        x_56 = F.elu(self.conv6b(self.conv6(x_45)))
        x_67 = F.elu(self.conv7b(self.conv7(x_56)))
        
        # Decoder
        X_77 = F.elu(self.upconv7(F.interpolate(x_67, scale_factor=2, mode='nearest')))
        x_76 = F.elu(self.iconv7(torch.cat((x_77,x_56),1)))
        
        X_66 = F.elu(self.upconv6(F.interpolate(x_76, scale_factor=2, mode='nearest')))
        x_65 = F.elu(self.iconv6(torch.cat((x_66,x_45),1)))
        
        X_55 = F.elu(self.upconv5(F.interpolate(x_65, scale_factor=2, mode='nearest')))
        x_54 = F.elu(self.iconv5(torch.cat((x_55,x_34),1)))
        
        x_44 = F.elu(self.upconv4(F.interpolate(x_54, scale_factor=2, mode='nearest')))
        x_43 = F.elu(self.iconv4(torch.cat((x_44,x_23),1)))
        x_43_d = F.elu(self.disp4(x_43))
        
        x_33 = F.elu(self.upconv3(F.interpolate(x_43, scale_factor=2, mode='nearest')))
        x_32 = F.elu(self.iconv3(torch.cat((x_33,x_12,F.interpolate(x_43_d, scale_factor=2, mode='nearest')),1)))
        x_32_d = F.elu(self.disp3(x_32))
        
        x_22 = F.elu(self.upconv2(F.interpolate(x_32, scale_factor=2, mode='nearest')))
        x_21 = F.elu(self.iconv2(torch.cat((x_22,x_01,F.interpolate(x_32_d, scale_factor=2, mode='nearest')),1)))      
        x_21_d = F.elu(self.disp2(x_21))
                
        x_11 = F.elu(self.upconv1(F.interpolate(x_21, scale_factor=2, mode='nearest')))
        x_10 = F.elu(self.iconv1(torch.cat((x_11,F.interpolate(x_21_d, scale_factor=2, mode='nearest')),1)))
        x_10_d = F.elu(self.disp1(x_10))
        
        return [x_10_d, x_21_d, x_32_d, x_43_d]  