In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def init_bn(bn):
    bn.bias.data.fill_(0.0)
    bn.weight.data.fill_(1.0)

def init_layer(layer):
    nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, "bias"):
        if layer.bias is not None:
            layer.bias.data.fill_(0.0)

In [None]:
class FilmModule(nn.Module):
    def __init__(self,input_size,output_size):
        super(FilmModule, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.linear = nn.Sequential(
            nn.Linear(input_size, output_size * 2),
            nn.ReLU(inplace=True),
            nn.Linear(output_size * 2, output_size),
            nn.ReLU(inplace=True)
        )
    
    def forward(self,data,embedding_vector):
        
        x = self.linear(embedding_vector)
        x = data + x[...,None,None,None]

        return x


Film1_1 = FilmModule(512,32)

random_embedding = torch.rand(1,512)
random_value = torch.rand(32,2,513,501)

print(Film1_1(random_value,random_embedding))



In [None]:
class EncoderBlock(nn.Module):
    def __init__(self,input_channels, output_channels, embedding_size, momentum,downsample):
        super(EncoderBlock, self).__init__()
        self.downsample = downsample
        self.Film1 = FilmModule(embedding_size,input_channels)
        self.Film2 = FilmModule(embedding_size,output_channels)

         
        self.bn1 = nn.BatchNorm2d(input_channels,momentum=momentum)

        self.conv1 = nn.Conv2d(
            in_channels=input_channels,
            out_channels=output_channels,
            kernel_size=(3,3),
            stride=(1,1),
            dilation=(1,1),
            padding=(1,1),
            bias=False
            )
        
        self.bn2 = nn.BatchNorm2d(output_channels,momentum=momentum)

        self.conv2 = nn.Conv2d(
            in_channels=output_channels,
            out_channels=output_channels,
            kernel_size=(3,3),
            stride=(1,1),
            dilation=(1,1),
            padding=(1,1),
            bias=False
        )

        if input_channels != output_channels:
            self.residual_convolution = nn.Conv2d(
                in_channels=input_channels,
                out_channels=output_channels,
                kernel_size=(1,1),
                stride=(1,1),
                padding=(0,0),
            )
            self.has_residual_connection = True
        else:
            self.has_residual_connection = False
        
        self.init_weights()
        
    
    def init_weights(self):
        init_bn(self.bn1)
        init_bn(self.bn2)
        init_layer(self.conv1)
        init_layer(self.conv2)

        if self.has_residual_connection:
            init_layer(self.residual_convolution)
    
        

    def forward(self,input_tensor,embedding_vector):

        x = self.bn1(input_tensor)
        x = self.Film1(x,embedding_vector)
        x = F.leaky_relu(x,negative_slope=0.01)
        x = self.conv1(x)
        x = self.bn2(x)
        x = self.Film2(x,embedding_vector)
        x = F.leaky_relu(x,negative_slope=0.01)
        x = self.conv2(x)

        if self.has_residual_connection:
            y = self.residual_convolution(input_tensor)
            x = x + y

        x_pool = F.avg_pool2d(x,self.downsample)

        return x, x_pool    

In [None]:
class DecoderBlock(nn.Module):
    
    def __init__(self,input_size, output_size,embedding_size,momentum,upsample):
        super(DecoderBlock, self).__init__()
        self.upsample = upsample
        
        self.conv1 = torch.nn.ConvTranspose2d(
            in_channels=input_size,
            out_channels=output_size,
            kernel_size=self.upsample,
            stride=self.upsample,
            padding=(0,0),
            bias=False,
            dilation=(1,1)
            
        )
        
        self.bn1 = nn.BatchNorm2d(input_size,momentum=momentum)
        
        #self.conv_block2 = ConvBlockRes(
        #    out_channels * 2, out_channels, kernel_size, momentum, has_film,
        
        self.Film1 = FilmModule(embedding_size,input_size)
        self.Film2 = FilmModule(embedding_size,output_size*2)
        self.Film3 = FilmModule(embedding_size,output_size)

        self.bn2 = nn.BatchNorm2d(output_size*2,momentum=momentum)
        self.bn3 = nn.BatchNorm2d(output_size,momentum=momentum)

        self.conv2 = nn.Conv2d(
            in_channels=output_size*2,
            out_channels=output_size,
            kernel_size=(3,3),
            stride=(1,1),
            dilation=(1,1),
            padding=(1,1),
            bias=False
        )

        self.conv3 = nn.conv2d(
            in_channels=output_size,
            out_channels=output_size,
            kernel_size=(3,3),
            stride=(1,1),
            dilation=(1,1),
            padding=(1,1),
            bias=False
        )

        if input_size != output_size:
            self.residual_convolution = nn.Conv2d(
                in_channels=input_size,
                out_channels=output_size,
                kernel_size=(1,1),
                stride=(1,1),
                padding=(0,0),
            )
            self.has_residual_connection = True
        else:
            self.has_residual_connection = False
        
        self.bn4 = nn.BatchNorm2d(input_size,momentum=momentum)

        
        

        self.init_weights()
    
    def init_weights(self):
        init_bn(self.bn1)
        init_bn(self.bn2)
        init_bn(self.bn3)
        
        init_layer(self.conv1)
        init_layer(self.conv2)
        init_layer(self.conv3)

        if self.has_residual_connection:
            init_layer(self.residual_convolution)

    def forward(self,input_tensor,concat_tensor,embedding_vector):
        x = self.bn1(input_tensor)
        x = self.Film1(x,embedding_vector)
        x = F.leaky_relu(x)

        x = self.conv1(x)

        x = torch.cat((x,concat_tensor), dim=1)

        x = self.bn2(x)
        x = self.Film2(x,embedding_vector)
        x = F.leaky_relu(x,negative_slope=0.01)
        x = self.conv2(x)
        x = self.bn3(x)
        x = self.Film3(x,embedding_vector)
        x = F.leaky_relu(x,negative_slope=0.01)
        x = self.conv3(x)

        if self.has_residual_connection:
            y = self.residual_convolution(input_tensor)
            x = x + y
        
        return x
        


        

In [None]:
class ResUnet(nn.Module):

    def __init__(self, input_size, output_size):
        super(ResUnet, self).__init__()

        self.input_size = input_size;
        self.output_size = output_size;

        self.momentum = 0.01 


        # instanziare la preconv che è una conv2d

        # definire la classe degli encoder block
        # definire la classe dei decoder block

        self.batch_norm0 = nn.BatchNorm2d(513,momentum=self.momentum)


        self.preconvolution = nn.Conv2d(
            input_channels=input_size,
            kernel_size=(1,1),
            stride=(1,1),
            padding=(0,0),
            bias=True
        )

        
        self.after_conv = nn.Conv2d(
            in_channels=32,
            out_channels=output_size * 3,
            kernel_size=(1, 1),
            stride=(1, 1),
            padding=(0, 0),
            bias=True,
        )


    def forward(self,input):

