In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchlibrosa.stft import STFT, ISTFT, magphase
import numpy as np

stft = STFT(n_fft=1024,
            hop_length=320,
            win_length=1024,
            window='hann',
            center=True,
            pad_mode='reflect',
            freeze_parameters=True)

def init_bn(bn):
    bn.bias.data.fill_(0.0)
    bn.weight.data.fill_(1.0)

def init_layer(layer):
    nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, "bias"):
        if layer.bias is not None:
            layer.bias.data.fill_(0.0)

def from_audio_to_spectogram(audios):
    magnitudes = []
    cosines = []
    sines = []

    for i in range(audios.shape[1]):

        (real,img) = stft(audios[:,i,:])
        mag = torch.clamp(real ** 2 + img ** 2, 1e-10, np.inf) ** 0.5
        cos = real / mag
        sin = img / mag
        magnitudes.append(real)
        cosines.append(cos)
        sines.append(sin)
    mags = torch.cat(magnitudes, dim=1)
    coss = torch.cat(cosines, dim=1)
    sins = torch.cat(sines, dim=1)
    
    return mags,coss,sins

In [5]:
class FilmModule(nn.Module):
    def __init__(self,input_size,output_size):
        super(FilmModule, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.linear = nn.Sequential(
            nn.Linear(input_size, output_size * 2),
            nn.ReLU(inplace=True),
            nn.Linear(output_size * 2, output_size),
            nn.ReLU(inplace=True)
        )
    
    def forward(self,data,embedding_vector):
        
        x = self.linear(embedding_vector)
        x = data + x[...,None,None,None]

        return x


Film1_1 = FilmModule(512,32)

random_embedding = torch.rand(1,512)
random_value = torch.rand(32,2,513,501)

print(Film1_1(random_value,random_embedding))



tensor([[[[[0.3891, 1.1309, 0.8346,  ..., 0.6606, 0.5293, 0.7677],
           [0.3518, 0.5991, 0.8349,  ..., 0.2228, 0.6167, 0.7026],
           [0.5854, 0.3541, 0.7800,  ..., 0.9493, 0.7858, 0.6160],
           ...,
           [0.4010, 0.5812, 0.5038,  ..., 1.0483, 0.3706, 0.4178],
           [0.4399, 0.3279, 0.2465,  ..., 0.2841, 0.8857, 0.2242],
           [0.3554, 0.2343, 0.2337,  ..., 0.5775, 0.8904, 0.4043]],

          [[0.5623, 0.7671, 0.3071,  ..., 0.9968, 0.3726, 0.8539],
           [0.1955, 1.0498, 1.1248,  ..., 0.8044, 0.6932, 0.2013],
           [1.1633, 0.2046, 0.3101,  ..., 1.0578, 0.7477, 0.3615],
           ...,
           [0.2425, 0.2156, 1.0004,  ..., 0.6145, 0.3223, 1.1251],
           [0.5598, 1.0637, 0.8213,  ..., 1.0693, 0.5720, 0.8641],
           [0.3681, 0.8517, 0.9446,  ..., 0.2835, 0.4556, 1.0525]]],


         [[[0.8250, 1.1154, 1.1134,  ..., 0.4161, 0.7373, 0.9631],
           [0.3882, 1.0031, 0.8650,  ..., 0.9688, 0.6531, 0.9990],
           [0.7946, 0.82

In [7]:
class EncoderBlock(nn.Module):
    def __init__(self,input_channels, output_channels, embedding_size, momentum,downsample):
        super(EncoderBlock, self).__init__()
        self.downsample = downsample
        self.Film1 = FilmModule(embedding_size,input_channels)
        self.Film2 = FilmModule(embedding_size,output_channels)

         
        self.bn1 = nn.BatchNorm2d(input_channels,momentum=momentum)

        self.conv1 = nn.Conv2d(
            in_channels=input_channels,
            out_channels=output_channels,
            kernel_size=(3,3),
            stride=(1,1),
            dilation=(1,1),
            padding=(1,1),
            bias=False
            )
        
        self.bn2 = nn.BatchNorm2d(output_channels,momentum=momentum)

        self.conv2 = nn.Conv2d(
            in_channels=output_channels,
            out_channels=output_channels,
            kernel_size=(3,3),
            stride=(1,1),
            dilation=(1,1),
            padding=(1,1),
            bias=False
        )

        if input_channels != output_channels:
            self.residual_convolution = nn.Conv2d(
                in_channels=input_channels,
                out_channels=output_channels,
                kernel_size=(1,1),
                stride=(1,1),
                padding=(0,0),
            )
            self.has_residual_connection = True
        else:
            self.has_residual_connection = False
        
        self.init_weights()
        
    
    def init_weights(self):
        init_bn(self.bn1)
        init_bn(self.bn2)
        init_layer(self.conv1)
        init_layer(self.conv2)

        if self.has_residual_connection:
            init_layer(self.residual_convolution)
    
        

    def forward(self,input_tensor,embedding_vector):

        x = self.bn1(input_tensor)
        x = self.Film1(x,embedding_vector)
        x = F.leaky_relu(x,negative_slope=0.01)
        x = self.conv1(x)
        x = self.bn2(x)
        x = self.Film2(x,embedding_vector)
        x = F.leaky_relu(x,negative_slope=0.01)
        x = self.conv2(x)

        if self.has_residual_connection:
            y = self.residual_convolution(input_tensor)
            x = x + y

        x_pool = F.avg_pool2d(x,self.downsample)

        return x, x_pool    

In [8]:
class DecoderBlock(nn.Module):
    
    def __init__(self,input_size, output_size,embedding_size,momentum,upsample):
        super(DecoderBlock, self).__init__()
        self.upsample = upsample
        
        self.conv1 = torch.nn.ConvTranspose2d(
            in_channels=input_size,
            out_channels=output_size,
            kernel_size=self.upsample,
            stride=self.upsample,
            padding=(0,0),
            bias=False,
            dilation=(1,1)
            
        )
        
        self.bn1 = nn.BatchNorm2d(input_size,momentum=momentum)
        
        #self.conv_block2 = ConvBlockRes(
        #    out_channels * 2, out_channels, kernel_size, momentum, has_film,
        
        self.Film1 = FilmModule(embedding_size,input_size)
        self.Film2 = FilmModule(embedding_size,output_size*2)
        self.Film3 = FilmModule(embedding_size,output_size)

        self.bn2 = nn.BatchNorm2d(output_size*2,momentum=momentum)
        self.bn3 = nn.BatchNorm2d(output_size,momentum=momentum)

        self.conv2 = nn.Conv2d(
            in_channels=output_size*2,
            out_channels=output_size,
            kernel_size=(3,3),
            stride=(1,1),
            dilation=(1,1),
            padding=(1,1),
            bias=False
        )

        self.conv3 = nn.conv2d(
            in_channels=output_size,
            out_channels=output_size,
            kernel_size=(3,3),
            stride=(1,1),
            dilation=(1,1),
            padding=(1,1),
            bias=False
        )

        if input_size != output_size:
            self.residual_convolution = nn.Conv2d(
                in_channels=input_size,
                out_channels=output_size,
                kernel_size=(1,1),
                stride=(1,1),
                padding=(0,0),
            )
            self.has_residual_connection = True
        else:
            self.has_residual_connection = False
        
        self.bn4 = nn.BatchNorm2d(input_size,momentum=momentum)

        
        

        self.init_weights()
    
    def init_weights(self):
        init_bn(self.bn1)
        init_bn(self.bn2)
        init_bn(self.bn3)
        
        init_layer(self.conv1)
        init_layer(self.conv2)
        init_layer(self.conv3)

        if self.has_residual_connection:
            init_layer(self.residual_convolution)

    def forward(self,input_tensor,concat_tensor,embedding_vector):
        x = self.bn1(input_tensor)
        x = self.Film1(x,embedding_vector)
        x = F.leaky_relu(x)

        x = self.conv1(x)

        x = torch.cat((x,concat_tensor), dim=1)

        x = self.bn2(x)
        x = self.Film2(x,embedding_vector)
        x = F.leaky_relu(x,negative_slope=0.01)
        x = self.conv2(x)
        x = self.bn3(x)
        x = self.Film3(x,embedding_vector)
        x = F.leaky_relu(x,negative_slope=0.01)
        x = self.conv3(x)

        if self.has_residual_connection:
            y = self.residual_convolution(input_tensor)
            x = x + y
        
        return x
        


        

In [9]:
class ResUnet(nn.Module):

    def __init__(self, input_size, output_size):
        super(ResUnet, self).__init__()

        self.input_size = input_size;
        self.output_size = output_size;

        self.momentum = 0.01 


        # instanziare la preconv che è una conv2d

        # definire la classe degli encoder block
        # definire la classe dei decoder block

        self.batch_norm0 = nn.BatchNorm2d(513,momentum=self.momentum)

        self.preconvolution = nn.Conv2d(
            input_channels=input_size,
            out_channels=32,
            kernel_size=(1,1),
            stride=(1,1),
            padding=(0,0),
            bias=True
        )

        self.EncoderBlock1 = EncoderBlock(
            input_channels=32,
            output_channels=32,
            downsample=(2,2),
            embedding_size=512,
            momentum=0.01
        )

        self.EncoderBlock2 = EncoderBlock(
            input_channels=32,
            output_channels=64,
            downsample=(2,2),
            embedding_size=512,
            momentum=0.01
        )

        self.EncoderBlock3 = EncoderBlock(
            input_channels=64,
            output_channels=128,
            downsample=(2,2),
            embedding_size=512,
            momentum=0.01
        )

        self.EncoderBlock4 = EncoderBlock(
            input_channels=128,
            output_channels=256,
            downsample=(2,2),
            embedding_size=512,
            momentum=0.01
        )

        self.EncoderBlock5 = EncoderBlock(
            input_channels=256,
            output_channels=384,
            downsample=(2,2),
            embedding_size=512,
            momentum=0.01
        )

        self.EncoderBlock6 = EncoderBlock(
            input_channels=384,
            output_channels=384,
            downsample=(1,2),
            embedding_size=512,
            momentum=0.01
        )

        self.EncoderBlock7 = EncoderBlock(
            input_channels=384,
            output_channels=384,
            downsample=(1,1),
            momentum=0.01
        )

        self.DecoderBlock1 = DecoderBlock(
            input_size=384,
            output_size= 384,
            embedding_size= 512,
            momentum=0.01,
            upsample=(1,2)
            )
        
        self.DecoderBlock2 = DecoderBlock(
            input_size=384,
            output_size= 256,
            embedding_size= 512,
            momentum=0.01,
            upsample=(2,2)
            )
        
        self.DecoderBlock3 = DecoderBlock(
            input_size=256,
            output_size= 128,
            embedding_size= 512,
            momentum=0.01,
            upsample=(2,2)
            )
        
        self.DecoderBlock4 = DecoderBlock(
            input_size=128,
            output_size= 64,
            embedding_size= 512,
            momentum=0.01,
            upsample=(2,2)
            )
        
        self.DecoderBlock5 = DecoderBlock(
            input_size=64,
            output_size= 32,
            embedding_size= 512,
            momentum=0.01,
            upsample=(2,2)
            )
        
        self.DecoderBlock6 = DecoderBlock(
            input_size=32,
            output_size= 32,
            embedding_size= 512,
            momentum=0.01,
            upsample=(2,2)
            )

        
        self.after_conv = nn.Conv2d(
            in_channels=32,
            out_channels=output_size * 3,
            kernel_size=(1, 1),
            stride=(1, 1),
            padding=(0, 0),
            bias=True,
        )

        self.init_weights()
    def init_weights(self):
        init_bn(self.batch_norm0)
        init_layer(self.preconvolution)
        init_layer(self.after_conv)

    def forward(self,input,embedding_vector):

        x = self.batch_norm0(input)
        #  eb1(x) = x1,x1_pool
        #  eb2(x1) = x2,x2_pool
        #  eb3(x2) = x3,x3_pool
        #  eb4(x3) = x4,x4_pool
        #  eb5(x4) = x5,x5_pool
        #  eb6(x5) = x6,x6_pool
        #  eb7(x6) = x7
        #  db1(x7,x6_pool) = x8
        #  db2(x8,x5_pool) = x9
        #  db3(x9,x4_pool) = x10
        #  db4(x10,x3_pool) = x11
        #  db5(x11,x2_pool) = x12
        #  db6(x12,x1_pool) = x13
        #  x = afterconv(x13) 




In [None]:
rete = ResUnet(2,2)