In [3]:
import torch 
import torch.nn as nn 
class DS_Block(nn.Module):
    def __init__(self, in_filters, out_filters, kernel_size = 3):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_filters)
        self.conv1 = nn.Conv2d(in_filters, out_filters, kernel_size)
        self.bn2 = nn.BatchNorm2d(out_filters)
        self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size)
        self.maxpool = nn.MaxPool2d(2)
                
    def forward(self, x):
        x = self.bn1(x)
        x = self.conv1(x)
        x = self.bn2(x)
        x = self.conv2(x)
        x_pool = self.maxpool(x)
        x_skip = x
        return x_skip, x_pool

class Bottom_Block(nn.Module):
    def __init__(self, in_filters, out_filters, kernel_size = 3):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_filters)
        self.conv1 = nn.Conv2d(in_filters, out_filters, kernel_size)
        self.bn2 = nn.BatchNorm2d(out_filters)
        self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size)
                
    def forward(self, x):
        x = self.bn1(x)
        x = self.conv1(x)
        x = self.bn2(x)
        x = self.conv2(x)
        return x
    
class MMFF_Block(nn.Module):
    '''Multi-Modality Feature Fusion block'''
    def __init__(self, in_filters, out_filters):
        super().__init__()
       
        self.bn1_1 = nn.BatchNorm2d(in_filters)        
        self.bn1_2 = nn.BatchNorm2d(in_filters)
        self.bn1_3 = nn.BatchNorm2d(in_filters)
        # 1x1 conv
        self.conv1_1 = nn.Conv2d(in_filters, out_filters, 1)
        self.conv1_2 = nn.Conv2d(in_filters, out_filters, 1)
        self.conv1_3 = nn.Conv2d(in_filters, out_filters, 1)
        
        # batch norm
        self.bn2_1 = nn.BatchNorm2d(out_filters)
        self.bn2_2 = nn.BatchNorm2d(out_filters)
        self.bn2_3 = nn.BatchNorm2d(out_filters)
        
        # 3x3 conv
        self.conv2_1 = nn.Conv2d(out_filters, out_filters, 3, padding = 1)
        self.conv2_2 = nn.Conv2d(out_filters, out_filters, 3, padding = 1)
        self.conv2_3 = nn.Conv2d(out_filters, out_filters, 3, padding = 1)
        
    def forward(self, x_skip_1, x_skip_2, x_skip_3):
        
        x_1 = self.bn1_1(x_skip_1)
        x_1 = self.conv1_1(x_1)
        x_1 = self.bn2_1(x_1)
        x_1 = self.conv2_1(x_1)
        
        x_2 = self.bn1_2(x_skip_2)
        x_2 = self.conv1_2(x_2)
        x_2 = self.bn2_2(x_2)
        x_2 = self.conv2_2(x_2)
        
        x_3 = self.bn1_3(x_skip_3)
        x_3 = self.conv1_3(x_3)
        x_3 = self.bn2_3(x_3)
        x_3 = self.conv2_3(x_3)
        
        # concat all x_1, x_2, x_3
        x = torch.cat((x_1, x_2, x_3), dim = 1)
        return x

counter = 0 
class MSFU_Block(nn.Module):
    '''
    Multi-Scale Feature Up-sampling block
    '''
    def __init__(self, in_filters, out_filters):
        super().__init__()
        in_filters = in_filters * 3
        out_filters = out_filters * 3
        print('[INFO] MSFU-Block:------------ ')
        print('------ in_filters: ', in_filters)
        print('------ out_filters: ', out_filters)
        
        self.bn1 = nn.BatchNorm2d(in_filters)
        self.conv1 = nn.Conv2d(in_filters, out_filters, kernel_size = 1, padding = 2)
        # 2x2 upsample
        self.upsample = nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)
        
        # CONCATENATE
        # after concate
        self.bn2 = nn.BatchNorm2d(out_filters*2)        
        self.conv2 = nn.Conv2d(out_filters*2, out_filters, 1)
        self.bn3 = nn.BatchNorm2d(out_filters)
        self.conv3 = nn.Conv2d(out_filters, out_filters, 3)

    def forward(self, x_low, x_high):
        ''' 
        low-resolution input to the MSFU-Block came from the bottom- or another MSFU-Block
        high-resolution came from MMFF-Block
        '''
        print('[SHAPE] x_low before: ', x_low.shape)
        global counter
        counter += 1
        print('[COUNTER]  counter  = ', counter)
        x_low = self.bn1(x_low)
        x_low = self.conv1(x_low)
        x_low = self.upsample(x_low)
        
        print('[SHAPE] x_low after: ', x_low.shape)
        print('[SHAPE] x_high: ', x_high.shape)
        x = torch.cat((x_low, x_high), dim = 1)
        x = self.bn2(x)
        x = self.conv2(x)
        x = self.bn3(x)
        x = self.conv3(x)
        return x

class MultiBranch2DNet(nn.Module):
    def __init__(self, filter_list_in = [1, 32, 64, 128, 256, 512], filter_list_out = [1, 32, 64, 128, 256, 512]):
        # filter_list_in = [3, 32, 64, 128, 256, 512]
        # filter_list_out = [1, 32, 64, 128, 256, 512]
        # len(filter_list_in) == len(filter_list_out)
        super().__init__()
        self.DS_list_1 = [DS_Block(filter_list_in[i], filter_list_in[i+1]) for i in range(len(filter_list_in)-2)]
        self.DS_list_2 = [DS_Block(filter_list_in[i], filter_list_in[i+1]) for i in range(len(filter_list_in)-2)]
        self.DS_list_3 = [DS_Block(filter_list_in[i], filter_list_in[i+1]) for i in range(len(filter_list_in)-2)]
        
        self.BTM_1 = Bottom_Block(filter_list_in[-2], filter_list_in[-1])
        self.BTM_2 = Bottom_Block(filter_list_in[-2], filter_list_in[-1])
        self.BTM_3 = Bottom_Block(filter_list_in[-2], filter_list_in[-1])
        
        self.MMFF_list = [MMFF_Block(filter_list_out[i+1], filter_list_out[i]) for i in range(len(filter_list_out)-1)]
        self.MSFU_list = [MSFU_Block(filter_list_out[i+1], filter_list_out[i]) for i in range(len(filter_list_out)-2)]
        self.len_ds_list = len(self.DS_list_1 )
    def forward(self, x_1, x_2, x_3):
        # use setattr to create new variables
        # init input
        print('[CHECK] x_1.shape', x_1.shape)
        setattr(self, 'x_ds_1_0_skip', x_1)
        setattr(self, 'x_ds_2_0_skip', x_2)
        setattr(self, 'x_ds_3_0_skip', x_3)
        
        # pass through DS blocks                
        for i in [1, 2, 3]:
            for j in range(self.len_ds_list):
                output_ds = self.DS_list_1[j](getattr(self, f'x_ds_{i}_{j}_skip'))
                setattr(self, f'x_ds_{i}_{j+1}_skip', output_ds[0])
                setattr(self, f'x_ds_{i}_{j+1}_pool', output_ds[1])
        print('[DONE] DS blocks') 

        x_btm_1 = self.BTM_1(getattr(self, 'x_ds_1_'+str(self.len_ds_list) + '_pool'))
        x_btm_2 = self.BTM_2(getattr(self, 'x_ds_2_'+str(self.len_ds_list) + '_pool'))
        x_btm_3 = self.BTM_3(getattr(self, 'x_ds_3_'+str(self.len_ds_list) + '_pool'))
        print('[DONE] BTM blocks')        
        
        setattr(self, 'x_mmff_5', self.MMFF_list[self.len_ds_list](x_btm_1, x_btm_2, x_btm_3))
        for i in reversed(range(self.len_ds_list)):
            setattr(self, 'x_mmff_' + str(i+1), self.MMFF_list[i](
                                    getattr(self, 'x_ds_1_'+str(i+1) + '_skip'), \
                                    getattr(self, 'x_ds_2_'+str(i+1) + '_skip'), \
                                    getattr(self, 'x_ds_3_'+str(i+1) + '_skip')))
        print('[DONE] MMFF blocks')
        
        print('len_ds_list = ', self.len_ds_list)
        print('x_mmff_5.shape = ', getattr(self, 'x_mmff_5').shape)
        print('x_mmff_4.shape = ', getattr(self, 'x_mmff_4').shape)
        print('x_mmff_3.shape = ', getattr(self, 'x_mmff_3').shape)
        print('x_mmff_2.shape = ', getattr(self, 'x_mmff_2').shape)
        print('x_mmff_1.shape = ', getattr(self, 'x_mmff_1').shape)
        # print('x_mmff_0.shape = ', getattr(self, 'x_mmff_0').shape)
        
        setattr(self, 'x_msfu_4' ,  self.MSFU_list[self.len_ds_list - 1](getattr(self, 'x_mmff_5'), getattr(self, 'x_mmff_4')))
        print('[SHAPE 2222222222222] x_msfu_4.shape = ', getattr(self, 'x_msfu_4').shape)
        
        for i in reversed(range(self.len_ds_list - 1)):
            print('[i] = ', i)
            setattr(self, 'x_msfu_' + str(i+1), self.MSFU_list[i](getattr(self, 'x_mmff_' + str(i+1)), getattr(self, 'x_msfu_' + str(i+2))))
        print('[DONE] MSFU blocks')
        
        x = getattr(self, 'x_msfu_1')
        return x 
    
x_1 = torch.randn(12, 1, 256, 256)
x_2 = torch.randn(12, 1, 256, 256)
x_3 = torch.randn(12, 1, 256, 256)
multi_branch = MultiBranch2DNet()
multi_branch(x_1, x_2, x_3)





[INFO] MSFU-Block:------------ 
------ in_filters:  96
------ out_filters:  3
[INFO] MSFU-Block:------------ 
------ in_filters:  192
------ out_filters:  96
[INFO] MSFU-Block:------------ 
------ in_filters:  384
------ out_filters:  192
[INFO] MSFU-Block:------------ 
------ in_filters:  768
------ out_filters:  384
[CHECK] x_1.shape torch.Size([12, 1, 256, 256])
[DONE] DS blocks
[DONE] BTM blocks
[DONE] MMFF blocks
len_ds_list =  4
x_mmff_5.shape =  torch.Size([12, 768, 116, 116])
x_mmff_4.shape =  torch.Size([12, 384, 240, 240])
x_mmff_3.shape =  torch.Size([12, 192, 244, 244])
x_mmff_2.shape =  torch.Size([12, 96, 248, 248])
x_mmff_1.shape =  torch.Size([12, 3, 252, 252])
[SHAPE] x_low before:  torch.Size([12, 768, 116, 116])
[COUNTER]  counter  =  1
[SHAPE] x_low after:  torch.Size([12, 384, 240, 240])
[SHAPE] x_high:  torch.Size([12, 384, 240, 240])
[SHAPE 2222222222222] x_msfu_4.shape =  torch.Size([12, 384, 238, 238])
[i] =  2
[SHAPE] x_low before:  torch.Size([12, 192, 244, 2

RuntimeError: running_mean should contain 192 elements not 384

In [None]:
import torch 
x_msfu_4 = torch.rand([12, 384, 238, 238])
x = torch.rand([12, 384, 240, 240])
y = torch.rand([12, 384, 240, 240])
# x_mmff_5.shape =  torch.Size([12, 768, 116, 116])
# x_mmff_4.shape =  torch.Size([12, 384, 240, 240])
# x_mmff_3.shape =  torch.Size([12, 192, 244, 244])
# x_mmff_2.shape =  torch.Size([12, 96, 248, 248])
# x_mmff_1.shape =  torch.Size([12, 3, 252, 252])
msfu_ok   = MSFU_Block(256, 128)
msfu_ok(x, x)

In [10]:
# import F from torch
from torch.nn import functional as F
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        ''' 
        low-resolution input to the MSFU-Block came from the bottom- or another MSFU-Block
        high-resolution came from MMFF-Block
        '''
        super(UpSample, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
    def forward(self, x):
        x = self.bn1(x)
        x = F.relu(x)
        x = self.conv1(x)
        x = self.upsample(x)
        return x
x_mmff_5 =  torch.rand([12, 768, 116, 116])
x_mmff_4 =  torch.rand([12, 384, 240, 240])
x_mmff_3 =  torch.rand([12, 192, 244, 244])
x_mmff_2 =  torch.rand([12, 96, 248, 248])
x_mmff_1 =  torch.rand([12, 3, 252, 252])

up_sample = UpSample(128*3, 64*3)
up_sample(x_mmff_4)

tensor([[[[-8.2347e-03, -4.9215e-02, -9.0194e-02,  ..., -4.4169e-01,
           -3.1098e-01, -1.8027e-01],
          [-6.5949e-02, -1.1866e-01, -1.7137e-01,  ..., -5.0108e-01,
           -5.3456e-01, -5.6805e-01],
          [-1.2366e-01, -1.8811e-01, -2.5255e-01,  ..., -5.6046e-01,
           -7.5815e-01, -9.5583e-01],
          ...,
          [-2.5923e-01, -6.6383e-01, -1.0684e+00,  ..., -3.2061e-01,
           -3.3419e-01, -3.4778e-01],
          [-1.4079e-02, -4.6083e-01, -9.0759e-01,  ..., -4.4952e-01,
           -4.3964e-01, -4.2975e-01],
          [ 2.3108e-01, -2.5783e-01, -7.4674e-01,  ..., -5.7843e-01,
           -5.4508e-01, -5.1173e-01]],

         [[-3.0458e-02, -3.3570e-02, -3.6681e-02,  ..., -6.1116e-01,
           -3.8773e-01, -1.6430e-01],
          [ 1.6179e-01,  8.9186e-02,  1.6585e-02,  ..., -4.0656e-01,
           -2.3250e-01, -5.8451e-02],
          [ 3.5403e-01,  2.1194e-01,  6.9850e-02,  ..., -2.0195e-01,
           -7.7276e-02,  4.7403e-02],
          ...,
     

: 

In [11]:
import torch.nn as nn 
import torch
class model_ok(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
    
    def forward(self, x1, x2):
        # stack the input
        print('x_1 shape = ', x1.shape)
        print('x_2 shape = ', x2.shape)
        
        x = torch.stack([x1, x2], dim = 0) 
        print('x shape = ', x.shape)
        x = self.conv1(x)
        return x
    
x_1 = torch.randn( 3, 256, 256)
x_2 = torch.randn( 3, 256, 256)

model = model_ok()
y_truth = torch.randn(2, 64, 256, 256)

for i in range(10):
    y_pred = model(x_1, x_2)
    loss = torch.mean((y_truth - y_pred)**2)
    print(loss)

x_1 shape =  torch.Size([3, 256, 256])
x_2 shape =  torch.Size([3, 256, 256])
x shape =  torch.Size([2, 3, 256, 256])
tensor(1.3481, grad_fn=<MeanBackward0>)
x_1 shape =  torch.Size([3, 256, 256])
x_2 shape =  torch.Size([3, 256, 256])
x shape =  torch.Size([2, 3, 256, 256])
tensor(1.3481, grad_fn=<MeanBackward0>)
x_1 shape =  torch.Size([3, 256, 256])
x_2 shape =  torch.Size([3, 256, 256])
x shape =  torch.Size([2, 3, 256, 256])
tensor(1.3481, grad_fn=<MeanBackward0>)
x_1 shape =  torch.Size([3, 256, 256])
x_2 shape =  torch.Size([3, 256, 256])
x shape =  torch.Size([2, 3, 256, 256])
tensor(1.3481, grad_fn=<MeanBackward0>)
x_1 shape =  torch.Size([3, 256, 256])
x_2 shape =  torch.Size([3, 256, 256])
x shape =  torch.Size([2, 3, 256, 256])
tensor(1.3481, grad_fn=<MeanBackward0>)
x_1 shape =  torch.Size([3, 256, 256])
x_2 shape =  torch.Size([3, 256, 256])
x shape =  torch.Size([2, 3, 256, 256])
tensor(1.3481, grad_fn=<MeanBackward0>)
x_1 shape =  torch.Size([3, 256, 256])
x_2 shape =  