In [1]:
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
#Conv(1,1) : Only used to increase the number of features to the desired initial filter size 
# Input are either Batch x Channel x Height x Width or Channel x Height x Width
conv1_1 = nn.Conv2d(in_channels= 3, out_channels= 32, kernel_size=(1,1), stride=(1,1), padding='same') 
x = torch.randn((10,3,256,256))
y = conv1_1(x)
y.shape

torch.Size([10, 32, 256, 256])

In [12]:
# UpSample 2
# Input has to be Batch x channel x Height x Width
upsample2 = nn.Upsample(scale_factor= 2, mode = 'nearest')
x = torch.randn((3,323,128,128))
y = upsample2(x)
y.shape

torch.Size([3, 323, 256, 256])

In [13]:
## BatchNorm2D
# Applies over a 4D input (a mini-batch of 2D inputs with additional channel dimension)
norm = nn.BatchNorm2d(num_features=3)
x = torch.randn((3,3,128,128))
y = norm(x)
y.shape

torch.Size([3, 3, 128, 128])

In [85]:
# Max Pooling 2D
maxpool = nn.MaxPool2d(kernel_size=2)
x = torch.randn((3,3,256,256))
y = maxpool(x)
y.shape

torch.Size([3, 3, 128, 128])

In [28]:
# Concatenate, seems to be done on the channel axis
x = torch.randn((3,3,128,128))
y = torch.cat((x,x), dim = 1)
print(y.shape)

torch.Size([3, 6, 128, 128])


In [36]:

class ResUNeta_block(nn.Module):
    def __init__(self,input_channels,output_channels,d,kernel_size = 3,stride = 1, padding = 'same' ):
        
        self.block = []
        self.n_block = len(d) # have to reuse it 
        for dilatation_rate in d: # for every dilation rate we create a new sequential block
            self.block.append(
                nn.Sequential(
                    nn.BatchNorm2d(input_channels),
                    nn.ReLU(),
                    nn.Conv2d(in_channels=input_channels,
                              out_channels=output_channels, 
                              kernel_size=kernel_size, 
                              stride=stride, 
                              padding= padding,
                              dilation= dilatation_rate),
                    nn.BatchNorm2d(output_channels),
                    nn.ReLU(),
                    nn.Conv2d(in_channels=output_channels,
                              out_channels=output_channels, 
                              kernel_size=kernel_size, 
                              stride=stride, 
                              padding= padding,
                              dilation= dilatation_rate)
                )
            )
    def forward(self,x):
        result = []
        for block in self.block:
            rate = block(x)
            result.append(rate)
        return torch.stack(result, dim=0).sum(dim=0) # this return the sum of all the differents results
        
block_test = ResUNeta_block(32,32,d = [1,3,15,31])
x = torch.randn((64,32,128,128))
y = block_test(x)
y.shape


torch.Size([64, 32, 128, 128])

In [84]:
class Combine(nn.Module):
    def __init__(self, in_channels,
                out_channels,
                kernel_size=1,
                stride=1,
                padding='same',
                dilation=1) -> None:
        super(Combine,self).__init__()
        self.act = nn.ReLU()
        self.conv2DN = nn.Sequential(
            nn.Conv2d(
                in_channels=2*in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation),
                nn.BatchNorm2d(out_channels)
        )

       
    def forward(self,x1,x2):
        x1 = self.act(x1)
        concat = torch.concat([x1,x2], dim = 1)
        print(concat.shape)
        return self.conv2DN(concat)

combine_test = Combine(32,32)
x = torch.randn((10,32,256,256))
y = combine_test(x,x)
print(y.shape)


torch.Size([10, 64, 256, 256])
torch.Size([10, 32, 256, 256])
