In [1]:
import numpy as np
import nibabel as nib
import torchvision.transforms as transforms
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.autograd import Variable as V
from torch.utils.data import Dataset
import torch.distributions as dist
from pathlib import Path
from sklearn.model_selection import train_test_split

import math
from torch.autograd import grad

from glob import glob
from IPython.display import display
from PIL import Image
import matplotlib
import cv2
from random import *
import sklearn.metrics as sk
import matplotlib.pyplot as plt

import abc
from warnings import warn

import skimage.util as skutil
import skimage.transform as skt
import sklearn.preprocessing as skp

from torch.utils.data import random_split
from torch.utils.data import DataLoader 
from sklearn.preprocessing import MinMaxScaler

import os
import pandas as pd
from tqdm import tqdm
from copy import deepcopy





In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
def kl_divergence(mu, logsigma):
        """Compute KL divergence KL(q_i(z)||p(z)) for each q_i in the batch.
        
        Args:
            mu: Means of the q_i distributions, shape [batch_size, latent_dim]
            logsigma: Logarithm of standard deviations of the q_i distributions,
                      shape [batch_size, latent_dim]
        
        Returns:
            kl: KL divergence for each of the q_i distributions, shape [batch_size]
        """
        ##########################################################
        # YOUR CODE HERE
        sigma = torch.exp(logsigma)
        
        kl = 0.5*(torch.sum(sigma**2 + mu**2 - torch.log(sigma**2) - 1))
        
        return kl

In [4]:
class brain_dataset(Dataset):
                          #[64,64],[80,80],[128,128]
    def __init__(self,
                 csv_file,
                 root_dir,
                 transform=None):
        
        self.df = csv_file
        self.root_dir = root_dir
        self.transform = transform
        


    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        #print('ciao',index)
        
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        
       

        img_name = os.path.join(self.root_dir,
                                self.df.iloc[idx, 1],   #Healty if train
                                #da togliere quando passo a altro dataset
                                str(self.df.iloc[idx, 3]),   #ID_Patient_Folder
                                str(self.df.iloc[idx, 2]),   #type of image 0 1 2 3
                                self.df.iloc[idx, 0])   #name of image 000.png


        image = Image.open(img_name)
        
        if self.transform:
            image = self.transform(image)
    
        res = {'image': image}

        return res

In [5]:
class brain_dataset_final(Dataset):
                          #[64,64],[80,80],[128,128]
    def __init__(self,
                 csv_file,
                 root_dir,
                 transform=None):
        
        self.df = csv_file
        self.root_dir = root_dir
        self.transform = transform
        


    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        #print('ciao',index)
        
        if torch.is_tensor(idx):
            idx = idx.tolist()
            

        
        img_name = os.path.join(self.root_dir,
                                str(self.df.iloc[idx, 3]), #ID_Patient_Folder
                                'brainmetshare/metshare/train',
                                self.df.iloc[idx, 1] + '/id', #Healty if train
                                str(self.df.iloc[idx, 2]), #type of image 0 1 2 3
                                self.df.iloc[idx, 0]) #name of image 000.png
                                


        image = Image.open(img_name)
        
        if self.transform:
            image = self.transform(image)
    
        res = {'image': image}

        return res

In [6]:
def create_csv(path):
        
        data = []
        status = next(os.walk(path))[1][-1]

        if status == 'disease':
            path1 = path + '/' + status
            for patient_id_folder in sorted(next(os.walk(path1))[1], key=lambda x: int(x.split("_")[-1]) 
                   if (type(x) == str)
                   else (x) ):
                for MRI_type in next(os.walk(os.path.join(path1, patient_id_folder)))[1]:
                    if MRI_type == '2':
                        path2 = next(os.walk(os.path.join(path1, patient_id_folder)))[0]
                        for img in sorted(next(os.walk(os.path.join(path2, MRI_type)))[2], key=lambda x: int(x.split(".")[0])):
                            #obj = {'Name': img, 'status': status}
                            #dizionario
                            obj = {'Name': img, 'status': status, 'MRI_type': str(MRI_type),
                                   'patient_id_folder': patient_id_folder}
                            data.append(obj)
                            #print(os.path.join(path2, MRI_type, img, patient_id_folder))

            #df.to_csv(path + '/data_disease.csv')
            
            path1 = path + '/' + status
            for patient_id_folder in sorted(next(os.walk(path1))[1], key=lambda x: int(x.split("_")[-1]) 
                   if (type(x) == str)
                   else (x) ):
                for MRI_type in next(os.walk(os.path.join(path1, patient_id_folder)))[1]:
                    if MRI_type == 'seg':
                        path2 = next(os.walk(os.path.join(path1, patient_id_folder)))[0]
                        for img in sorted(next(os.walk(os.path.join(path2, MRI_type)))[2], key=lambda x: int(x.split(".")[0])):
                            # obj = {'Name': img, 'status': status}
                            # dizionario
                            obj = {'Name': img, 'status': status, 'MRI_type': str(MRI_type),
                                   'patient_id_folder': patient_id_folder}
                            data.append(obj)
                            #print(os.path.join(path2, MRI_type, img, patient_id_folder))
            
            status = next(os.walk(path))[1][0]





        path1 = path + '/' + status
        for patient_id_folder in sorted(next(os.walk(path1))[1], key=lambda x: int(x.split("_")[-1]) 
                   if (type(x) == str)
                   else (x) ):
            for MRI_type in next(os.walk(os.path.join(path1, patient_id_folder)))[1]:
                if MRI_type == '2':
                    path2 =  next(os.walk(os.path.join(path1, patient_id_folder)))[0]
                    for img in sorted(next(os.walk(os.path.join(path2, MRI_type)))[2], key=lambda x: int(x.split(".")[0])):
                        
                        obj = {'Name': img, 'status':status, 'MRI_type' : str(MRI_type),
                               'patient_id_folder':patient_id_folder }
                        data.append(obj)
                        #print(os.path.join(path2, MRI_type, img, patient_id_folder))

        return pd.DataFrame(data)


In [7]:
def create_csv_final(path = '/kaggle/input/dataset-finale-finale/Dataset_Finale_finale'):
        data = []
        
        for patient_id_folder in sorted(next(os.walk(path))[1], key=lambda x: int(x.split("_")[-1])):
            path1 = path + '/' + patient_id_folder + '/brainmetshare/metshare/train/healthy/id/2'
            if 'Mets_052' != patient_id_folder:

                for img in sorted(next(os.walk(path1))[2], key=lambda x: int(x.split(".")[0])):
                
                    obj = {'Name': img, 'status': 'healthy', 'MRI_type': str(2),
                                   'patient_id_folder': patient_id_folder}
                    data.append(obj)

        return pd.DataFrame(data)

In [8]:
def create_dataloaders(path,batch_size,trasform,var):
    
    train_csv = create_csv(path + '/train')
    print('Length of training set: ', len(train_csv))

    val_csv = create_csv(path + '/val')
    print('Length of validation set: ', len(val_csv))

    
    
    if var:
        
        test_csv = create_csv(path + '/test')
        print('Length of test set: ', len(test_csv))

        brainTrain = brain_dataset(csv_file= train_csv,
                                   root_dir=path + '/train',
                                   transform=trasform)


        brainVal = brain_dataset(csv_file=val_csv,
                                root_dir=path + '/val',
                                transform=trasform)


        brainTest = brain_dataset(csv_file=test_csv,
                                  root_dir= path + '/test',
                                  transform=trasform)


        trainloader = DataLoader(brainTrain, batch_size)
        validationloader = DataLoader(brainVal, batch_size)
        testloader = DataLoader(brainTest, batch_size=batch_size)
        return trainloader,validationloader,testloader,brainTrain,brainVal,brainTest
    
    else:
        
        brainTrain = brain_dataset(csv_file= train_csv,
                                   root_dir=path + '/train',
                                   transform=trasform)


        brainVal = brain_dataset(csv_file=val_csv,
                                root_dir=path + '/val',
                                transform=trasform)


        trainloader = DataLoader(brainTrain, batch_size)
        validationloader = DataLoader(brainVal, batch_size)
        
        return trainloader,validationloader,validationloader,brainTrain,brainVal,brainVal

In [9]:
#Auto encoder convoluzionale

class Encoder(nn.Module):
    def __init__(self, latent_dims):
        super(Encoder, self).__init__()
        self.linear1 = nn.Linear(256*256, 512)
        self.linear2 = nn.Linear(512, latent_dims)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.linear1(x))
        return self.linear2(x)
    
class Decoder(nn.Module):
    def __init__(self, latent_dims):
        super(Decoder, self).__init__()
        self.linear1 = nn.Linear(latent_dims, 512)
        self.linear2 = nn.Linear(512, 256*256)

    def forward(self, z):
        z = F.relu(self.linear1(z))
        z = torch.sigmoid(self.linear2(z))
        return z.reshape((-1, 1, 256, 256))
    
class Autoencoder(nn.Module):
    def __init__(self, latent_dims):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder(latent_dims)
        self.decoder = Decoder(latent_dims)

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)
    
    
def train(autoencoder, trainloader, trainset, validationloader, valset, epochs):
    
    
    best_val_loss = 99999999
    best_epoch = 0
    counter_early = 0
    counter = 0
    patience_early_stopping = 150
    
    patience_plateu = 100
    
    opt = torch.optim.Adam(autoencoder.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min',
        factor=0.1, patience=patience_plateu, threshold=0.0001, threshold_mode='abs', verbose = True)
    
    for epoch in range(epochs):
        autoencoder.train()
        running_loss = 0
        flag = 0
        for i, x in tqdm(enumerate(trainloader), total=int(len(trainset) / trainloader.batch_size)):
            #train
            x = x['image']
            x = x.to(device) # GPU
            opt.zero_grad()
            x_hat = autoencoder(x)
            loss = ((x - x_hat)**2).sum()
            loss.backward()
            opt.step()
            running_loss += loss.item() #gradiente
            counter += 1
            
            
        train_loss = running_loss/counter
        #val
        autoencoder.train()
        running_loss_val = 0.0
        counter_val = 0
        with torch.no_grad():
            for i, x in tqdm(enumerate(validationloader), total=int(len(valset) / validationloader.batch_size)):
                
                x = x['image']
                x = x.to(device) # GPU
                x_hat = autoencoder(x)
                loss = ((x - x_hat)**2).sum()
                running_loss_val += loss.item()
                counter_val += 1
                
                if flag == 0 and epoch % 30  == 0:
                    
                    print('Real input')  
                    a = np.squeeze(x[0].cpu().detach()) * 255                
                    a[a<0]=0
                    display(Image.fromarray(np.uint8(a)))
                    
                    print('Reconstructed input')  
                    a = np.squeeze(x_hat[0].cpu().detach()) * 255                
                    a[a<0]=0
                    display(Image.fromarray(np.uint8(a)))
                    
                    
                    print('Difference')  
                    orig = np.array(np.squeeze(x[0].cpu()))
                    recon = np.array(np.squeeze(x_hat[0].cpu()))
                    diff = np.absolute(orig - recon)*255
                    diff[diff<0] = 0
                    display(Image.fromarray(np.uint8(diff)))
                    flag = 1

        val_loss = running_loss_val / counter_val
        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_weight_par = autoencoder.state_dict()
            best_model = deepcopy(autoencoder)
            best_model.load_state_dict(best_weight_par)
            best_val_loss = val_loss
            best_epoch = i
            counter_early = 0
            best_orig = orig * 255
            best_recon = recon * 255
            best_diff = diff
        else:
            counter_early += 1
            if counter_early > patience_early_stopping: break
    
            
        print('LR:',opt.state_dict()['param_groups'][0]['lr'])
        print('Best val loss:',best_val_loss,' Epoch:',best_epoch,' Counter:',counter_early )
        print('epoch:{} \t'.format(i+1),'trainloss:{}'.format(train_loss),'\t','valloss:{}'.format(val_loss))
        
    print('Real input')  
    display(Image.fromarray(np.uint8(best_orig)))

    print('Reconstructed input')  
    display(Image.fromarray(np.uint8(best_recon)))

    print('Difference')  
    display(Image.fromarray(np.uint8(best_diff)))

    return best_model

'''latent_dims = 256
autoencoder = Autoencoder(latent_dims).to(device) # GPU

#autoencoder = train(autoencoder, trainloader, trainset, validationloader, valset, 1000)

model_path = '/kaggle/working/Model'
cartellaDaVerificare= Path(model_path)
if not cartellaDaVerificare.is_dir():
    os.mkdir(model_path)
torch.save(autoencoder.state_dict(), model_path + '/' + 'best_model_AE')'''

"latent_dims = 256\nautoencoder = Autoencoder(latent_dims).to(device) # GPU\n\n#autoencoder = train(autoencoder, trainloader, trainset, validationloader, valset, 1000)\n\nmodel_path = '/kaggle/working/Model'\ncartellaDaVerificare= Path(model_path)\nif not cartellaDaVerificare.is_dir():\n    os.mkdir(model_path)\ntorch.save(autoencoder.state_dict(), model_path + '/' + 'best_model_AE')"

In [10]:
#VAE Auto encoder convoluzionale

def weights_init(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.kaiming_uniform_(m.weight)


class Decoder(nn.Module):
    def __init__(self, latent_dims):
        super(Decoder, self).__init__()
        self.linear1 = nn.Linear(latent_dims, 512)
        self.linear2 = nn.Linear(512, 256*256)


    def forward(self, z):
        z = F.relu(self.linear1(z))
        z = torch.sigmoid(self.linear2(z))
        return z.reshape((-1, 1, 256, 256))

class VariationalEncoder(nn.Module):
    def __init__(self, latent_dims):
        super(VariationalEncoder, self).__init__()
        self.linear1 = nn.Linear(256*256, 512)

        self.linear2 = nn.Linear(512, latent_dims)

        self.linear3 = nn.Linear(512, latent_dims)
        self.linear3.apply(weights_init)

        self.N = torch.distributions.Normal(0, 1)
        self.N.loc = self.N.loc.cuda() # hack to get sampling on the GPU
        self.N.scale = self.N.scale.cuda()
        self.kl = 0

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.linear1(x))
        mu =  self.linear2(x)
        sigma = torch.exp(self.linear3(x))
        z = mu + sigma*self.N.sample(mu.shape)
        #originale
        #print(sigma,mu,(sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum())
        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
        #paper
        
        return z
    
    
class VariationalAutoencoder(nn.Module):
    def __init__(self, latent_dims):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = VariationalEncoder(latent_dims)
        self.decoder = Decoder(latent_dims)

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)
    
    
def train(autoencoder, trainloader, trainset, validationloader, valset, epochs):
    
    best_val_loss = 99999999
    best_epoch = 0
    counter_early = 0
    counter = 0
    patience_early_stopping = 20
    
    patience_plateu = 10
    
    opt = torch.optim.Adam(autoencoder.parameters(), lr=0.0001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min',
        factor=0.1, patience=patience_plateu, threshold=0.0001, threshold_mode='abs', verbose = True)
    
    for epoch in range(epochs):
        autoencoder.train()
        running_loss = 0
        flag = 0
        for i, x in tqdm(enumerate(trainloader), total=int(len(trainset) / trainloader.batch_size)):
            #train
            x = x['image']
            x = x.to(device) # GPU
            opt.zero_grad()
            x_hat = autoencoder(x)
            loss = ((x - x_hat)**2).sum() + autoencoder.encoder.kl
            loss.backward()
            opt.step()
            running_loss += loss.item() #gradiente
            counter += 1
            
            
        train_loss = running_loss/counter
        #val
        autoencoder.train()
        running_loss_val = 0.0
        counter_val = 0
        with torch.no_grad():
            for i, x in tqdm(enumerate(validationloader), total=int(len(valset) / validationloader.batch_size)):
                
                x = x['image']
                x = x.to(device) # GPU
                x_hat = autoencoder(x)
                loss = ((x - x_hat)**2).sum() + autoencoder.encoder.kl
                running_loss_val += loss.item()
                counter_val += 1
                
                if flag == 0 and epoch % 5  == 0:
                    
                    print('Real input')  
                    a = np.squeeze(x[50].cpu().detach()) * 255                
                    a[a<0]=0
                    display(Image.fromarray(np.uint8(a)))
                    
                    print('Reconstructed input')  
                    a = np.squeeze(x_hat[50].cpu().detach()) * 255                
                    a[a<0]=0
                    display(Image.fromarray(np.uint8(a)))
                    
                    
                    print('Difference')  
                    orig = np.array(np.squeeze(x[50].cpu()))
                    recon = np.array(np.squeeze(x_hat[50].cpu()))
                    diff = np.absolute(orig - recon)*255
                    diff[diff<0] = 0
                    display(Image.fromarray(np.uint8(diff)))
                    flag = 1

        val_loss = running_loss_val / counter_val
        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_weight_par = autoencoder.state_dict()
            best_model = deepcopy(autoencoder)
            best_model.load_state_dict(best_weight_par)
            best_val_loss = val_loss
            best_epoch = epoch
            counter_early = 0
            best_orig = orig * 255
            best_recon = recon * 255
            best_diff = diff
        else:
            counter_early += 1
            if counter_early > patience_early_stopping: break
            
        print('LR:',opt.state_dict()['param_groups'][0]['lr'])
        print('Best val loss:',best_val_loss,' Epoch:',best_epoch,' Counter:',counter_early )
        print('epoch:{} \t'.format(epoch+1),'trainloss:{}'.format(train_loss),'\t','valloss:{}'.format(val_loss))
        
    print('Real input')  
    display(Image.fromarray(np.uint8(best_orig)))

    print('Reconstructed input')  
    display(Image.fromarray(np.uint8(best_recon)))

    print('Difference')  
    display(Image.fromarray(np.uint8(best_diff)))

    return best_model

'''latent_dims = 256
autoencoder = VariationalAutoencoder(latent_dims).to(device) # GPU
autoencoder.apply(weights_init)

#autoencoder = train(autoencoder, trainloader, trainset, validationloader, valset, 100)


model_path = '/kaggle/working/Model'

cartellaDaVerificare= Path(model_path)
if not cartellaDaVerificare.is_dir():
    os.mkdir(model_path)
    
torch.save(autoencoder.state_dict(), model_path + '/' + 'best_model_VAE')'''

"latent_dims = 256\nautoencoder = VariationalAutoencoder(latent_dims).to(device) # GPU\nautoencoder.apply(weights_init)\n\n#autoencoder = train(autoencoder, trainloader, trainset, validationloader, valset, 100)\n\n\nmodel_path = '/kaggle/working/Model'\n\ncartellaDaVerificare= Path(model_path)\nif not cartellaDaVerificare.is_dir():\n    os.mkdir(model_path)\n    \ntorch.save(autoencoder.state_dict(), model_path + '/' + 'best_model_VAE')"

In [11]:
#Modello paper 2
#MODEL
class NoOp(nn.Module):
    def __init__(self, *args, **kwargs):
        """NoOp Pytorch Module.
        Forwards the given input as is.
        """
        super(NoOp, self).__init__()

    def forward(self, x, *args, **kwargs):
        return x


class ConvModule(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        conv_op=nn.Conv2d,
        conv_params=None,
        normalization_op=None,
        normalization_params=None,
        activation_op=nn.LeakyReLU,
        activation_params=None,
    ):
        """Basic Conv Pytorch Conv Module
        Has can have a Conv Op, a Normlization Op and a Non Linearity:
        x = conv(x)
        x = some_norm(x)
        x = nonlin(x)
        Args:
            in_channels ([int]): [Number on input channels/ feature maps]
            out_channels ([int]): [Number of ouput channels/ feature maps]
            conv_op ([torch.nn.Module], optional): [Conv operation]. Defaults to nn.Conv2d.
            conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to None.
            normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...)]. Defaults to None.
            normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None.
            activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...)]. Defaults to nn.LeakyReLU.
            activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None.
        """

        super(ConvModule, self).__init__()

        self.conv_params = conv_params
        if self.conv_params is None:
            self.conv_params = {}
        self.activation_params = activation_params
        if self.activation_params is None:
            self.activation_params = {}
        self.normalization_params = normalization_params
        if self.normalization_params is None:
            self.normalization_params = {}

        self.conv = None
        if conv_op is not None and not isinstance(conv_op, str):
            self.conv = conv_op(in_channels, out_channels, **self.conv_params)

        self.normalization = None
        if normalization_op is not None and not isinstance(normalization_op, str):
            self.normalization = normalization_op(out_channels, **self.normalization_params)

        self.activation = None
        if activation_op is not None and not isinstance(activation_op, str):
            self.activation = activation_op(**self.activation_params)

    def forward(self, input, conv_add_input=None, normalization_add_input=None, activation_add_input=None):

        x = input

        if self.conv is not None:
            if conv_add_input is None:
                x = self.conv(x)
            else:
                x = self.conv(x, **conv_add_input)

        if self.normalization is not None:
            if normalization_add_input is None:
                x = self.normalization(x)
            else:
                x = self.normalization(x, **normalization_add_input)

        if self.activation is not None:
            if activation_add_input is None:
                x = self.activation(x)
            else:
                x = self.activation(x, **activation_add_input)

        # nn.functional.dropout(x, p=0.95, training=True)

        return x


class ConvBlock(nn.Module):
    def __init__(
        self,
        n_convs: int,
        n_featmaps: int,
        conv_op=nn.Conv2d,
        conv_params=None,
        normalization_op=nn.BatchNorm2d,
        normalization_params=None,
        activation_op=nn.LeakyReLU,
        activation_params=None,
    ):
        """Basic Conv block with repeated conv, build up from repeated @ConvModules (with same/fixed feature map size)
        Args:
            n_convs ([type]): [Number of convolutions]
            n_featmaps ([type]): [Feature map size of the conv]
            conv_op ([torch.nn.Module], optional): [Convulioton operation -> see ConvModule ]. Defaults to nn.Conv2d.
            conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to None.
            normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d.
            normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None.
            activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU.
            activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None.
        """

        super(ConvBlock, self).__init__()

        self.n_featmaps = n_featmaps
        self.n_convs = n_convs
        self.conv_params = conv_params
        if self.conv_params is None:
            self.conv_params = {}

        self.conv_list = nn.ModuleList()

        for i in range(self.n_convs):
            conv_layer = ConvModule(
                n_featmaps,
                n_featmaps,
                conv_op=conv_op,
                conv_params=conv_params,
                normalization_op=normalization_op,
                normalization_params=normalization_params,
                activation_op=activation_op,
                activation_params=activation_params,
            )
            self.conv_list.append(conv_layer)

    def forward(self, input, **frwd_params):
        x = input
        for conv_layer in self.conv_list:
            x = conv_layer(x)

        return x


class ResBlock(nn.Module):
    def __init__(
        self,
        n_convs,
        n_featmaps,
        conv_op=nn.Conv2d,
        conv_params=None,
        normalization_op=nn.BatchNorm2d,
        normalization_params=None,
        activation_op=nn.LeakyReLU,
        activation_params=None,
    ):
        """Basic Conv block with repeated conv, build up from repeated @ConvModules (with same/fixed feature map size) and a skip/ residual connection:
        x = input
        x = conv_block(x)
        out = x + input
        Args:
            n_convs ([type]): [Number of convolutions in the conv block]
            n_featmaps ([type]): [Feature map size of the conv block]
            conv_op ([torch.nn.Module], optional): [Convulioton operation -> see ConvModule ]. Defaults to nn.Conv2d.
            conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to None.
            normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d.
            normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None.
            activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU.
            activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None.
        """
        super(ResBlock, self).__init__()

        self.n_featmaps = n_featmaps
        self.n_convs = n_convs
        self.conv_params = conv_params
        if self.conv_params is None:
            self.conv_params = {}

        self.conv_block = ConvBlock(
            n_featmaps,
            n_convs,
            conv_op=conv_op,
            conv_params=conv_params,
            normalization_op=normalization_op,
            normalization_params=normalization_params,
            activation_op=activation_op,
            activation_params=activation_params,
        )

    def forward(self, input, **frwd_params):
        x = input
        x = self.conv_block(x)

        out = x + input

        return out


# Basic Generator
class BasicGenerator(nn.Module):
    def __init__(
        self,
        input_size,
        z_dim=256,
        fmap_sizes=(256, 128, 64),
        upsample_op=nn.ConvTranspose2d,
        conv_params=None,
        normalization_op=NoOp,
        normalization_params=None,
        activation_op=nn.LeakyReLU,
        activation_params=None,
        block_op=NoOp,
        block_params=None,
        to_1x1=True,
    ):
        """Basic configureable Generator/ Decoder.
        Allows for mutilple "feature-map" levels defined by the feature map size, where for each feature map size a conv operation + optional conv block is used.
        Args:
            input_size ((int, int, int): Size of the input in format CxHxW): 
            z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim).
            fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each 
                                            int defines the number of feature maps in the layer]. Defaults to (256, 128, 64).
            upsample_op ([torch.nn.Module], optional): [Upsampling operation used, to upsample to a new level/ featuremap size]. Defaults to nn.ConvTranspose2d.
            conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False).
            normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d.
            normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None.
            activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU.
            activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None.
            block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp.
            block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None.
            to_1x1 (bool, optional): [If Latent dimesion is a z_dim x 1 x 1 vector (True) or if allows spatial resolution not to be 1x1 (z_dim x H x W) (False) ]. Defaults to True.
        """

        super(BasicGenerator, self).__init__()

        if conv_params is None:
            conv_params = dict(kernel_size=4, stride=2, padding=1, bias=False)
        if block_op is None:
            block_op = NoOp
        if block_params is None:
            block_params = {}

        n_channels = input_size[0]
        input_size_ = np.array(input_size[1:])

        if not isinstance(fmap_sizes, list) and not isinstance(fmap_sizes, tuple):
            raise AttributeError("fmap_sizes has to be either a list or tuple or an int")
        elif len(fmap_sizes) < 2:
            raise AttributeError("fmap_sizes has to contain at least three elements")
        else:
            h_size_bot = fmap_sizes[0]

        # We need to know how many layers we will use at the beginning
        input_size_new = input_size_ // (2 ** len(fmap_sizes))
        if np.min(input_size_new) < 2 and z_dim is not None:
            raise AttributeError("fmap_sizes to long, one image dimension has already perished")

        ### Start block
        start_block = []

        if not to_1x1:
            kernel_size_start = [min(conv_params["kernel_size"], i) for i in input_size_new]
        else:
            kernel_size_start = input_size_new.tolist()

        if z_dim is not None:
            self.start = ConvModule(
                z_dim,
                h_size_bot,
                conv_op=upsample_op,
                conv_params=dict(kernel_size=kernel_size_start, stride=1, padding=0, bias=False),
                normalization_op=normalization_op,
                normalization_params=normalization_params,
                activation_op=activation_op,
                activation_params=activation_params,
            )

            input_size_new = input_size_new * 2
        else:
            self.start = NoOp()

        ### Middle block (Done until we reach ? x input_size/2 x input_size/2)
        self.middle_blocks = nn.ModuleList()

        for h_size_top in fmap_sizes[1:]:

            self.middle_blocks.append(block_op(h_size_bot, **block_params))

            self.middle_blocks.append(
                ConvModule(
                    h_size_bot,
                    h_size_top,
                    conv_op=upsample_op,
                    conv_params=conv_params,
                    normalization_op=normalization_op,
                    normalization_params={},
                    activation_op=activation_op,
                    activation_params=activation_params,
                )
            )

            h_size_bot = h_size_top
            input_size_new = input_size_new * 2

        ### End block
        self.end = ConvModule(
            h_size_bot,
            n_channels,
            conv_op=upsample_op,
            conv_params=conv_params,
            normalization_op=None,
            activation_op=None,
        )

    def forward(self, inpt, **kwargs):
        output = self.start(inpt, **kwargs)
        for middle in self.middle_blocks:
            output = middle(output, **kwargs)
        output = self.end(output, **kwargs)
        return output


# Basic Encoder
class BasicEncoder(nn.Module):
    def __init__(
        self,
        input_size,
        z_dim=256,
        fmap_sizes=(64, 128, 256),
        conv_op=nn.Conv2d,
        conv_params=None,
        normalization_op=NoOp,
        normalization_params=None,
        activation_op=nn.LeakyReLU,
        activation_params=None,
        block_op=NoOp,
        block_params=None,
        to_1x1=True,
    ):
        """Basic configureable Encoder.
        Allows for mutilple "feature-map" levels defined by the feature map size, where for each feature map size a conv operation + optional conv block is used. 
        Args:
            z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim).
            fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each 
                                            int defines the number of feature maps in the layer]. Defaults to (64, 128, 256).
            conv_op ([torch.nn.Module], optional): [Convolutioon operation used to downsample to a new level/ featuremap size]. Defaults to nn.Conv2d.
            conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False).
            normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d.
            normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None.
            activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU.
            activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None.
            block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp.
            block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None.
            to_1x1 (bool, optional): [If True, then the last conv layer goes to a latent dimesion is a z_dim x 1 x 1 vector (similar to fully connected) or if False allows spatial resolution not to be 1x1 (z_dim x H x W, uses the in the conv_params given conv-kernel-size) ]. Defaults to True.
        """
        super(BasicEncoder, self).__init__()

        if conv_params is None:
            conv_params = dict(kernel_size=3, stride=2, padding=1, bias=False)
        if block_op is None:
            block_op = NoOp
        if block_params is None:
            block_params = {}

        n_channels = input_size[0]
        input_size_new = np.array(input_size[1:])

        if not isinstance(fmap_sizes, list) and not isinstance(fmap_sizes, tuple):
            raise AttributeError("fmap_sizes has to be either a list or tuple or an int")
        # elif len(fmap_sizes) < 2:
        #     raise AttributeError("fmap_sizes has to contain at least three elements")
        else:
            h_size_bot = fmap_sizes[0]

        ### Start block
        self.start = ConvModule(
            n_channels,
            h_size_bot,
            conv_op=conv_op,
            conv_params=conv_params,
            normalization_op=normalization_op,
            normalization_params={},
            activation_op=activation_op,
            activation_params=activation_params,
        )
        input_size_new = input_size_new // 2

        ### Middle block (Done until we reach ? x 4 x 4)
        self.middle_blocks = nn.ModuleList()

        for h_size_top in fmap_sizes[1:]:

            self.middle_blocks.append(block_op(h_size_bot, **block_params))

            self.middle_blocks.append(
                ConvModule(
                    h_size_bot,
                    h_size_top,
                    conv_op=conv_op,
                    conv_params=conv_params,
                    normalization_op=normalization_op,
                    normalization_params={},
                    activation_op=activation_op,
                    activation_params=activation_params,
                )
            )

            h_size_bot = h_size_top
            input_size_new = input_size_new // 2

            if np.min(input_size_new) < 2 and z_dim is not None:
                raise ("fmap_sizes to long, one image dimension has already perished")

        ### End block
        if not to_1x1:
            kernel_size_end = [min(conv_params["kernel_size"], i) for i in input_size_new]
        else:
            kernel_size_end = input_size_new.tolist()

        if z_dim is not None:
            self.end = ConvModule(
                h_size_bot,
                z_dim,
                conv_op=conv_op,
                conv_params=dict(kernel_size=kernel_size_end, stride=1, padding=0, bias=False),
                normalization_op=None,
                activation_op=None,
            )

            if to_1x1:
                self.output_size = (z_dim, 1, 1)
            else:
                self.output_size = (z_dim, *[i - (j - 1) for i, j in zip(input_size_new, kernel_size_end)])
        else:
            self.end = NoOp()
            self.output_size = input_size_new

    def forward(self, inpt, **kwargs):
        output = self.start(inpt, **kwargs)
        for middle in self.middle_blocks:
            output = middle(output, **kwargs)
        output = self.end(output, **kwargs)
        return output

class VAE(torch.nn.Module):
    def __init__(
        self,
        input_size,
        z_dim=256,
        fmap_sizes=(16, 64, 256, 1024),
        to_1x1=True,
        conv_op=torch.nn.Conv2d,
        conv_params=None,
        tconv_op=torch.nn.ConvTranspose2d,
        tconv_params=None,
        normalization_op=None,
        normalization_params=None,
        activation_op=torch.nn.LeakyReLU,
        activation_params=None,
        block_op=None,
        block_params=None,
        *args,
        **kwargs
    ):
        super(VAE, self).__init__()

        input_size_enc = list(input_size)
        input_size_dec = list(input_size)

        self.enc = BasicEncoder(
            input_size=input_size_enc,
            fmap_sizes=fmap_sizes,
            z_dim=z_dim * 2,
            conv_op=conv_op,
            conv_params=conv_params,
            normalization_op=normalization_op,
            normalization_params=normalization_params,
            activation_op=activation_op,
            activation_params=activation_params,
            block_op=block_op,
            block_params=block_params,
            to_1x1=to_1x1,
        )
        self.dec = BasicGenerator(
            input_size=input_size_dec,
            fmap_sizes=fmap_sizes[::-1],
            z_dim=z_dim,
            upsample_op=tconv_op,
            conv_params=tconv_params,
            normalization_op=normalization_op,
            normalization_params=normalization_params,
            activation_op=activation_op,
            activation_params=activation_params,
            block_op=block_op,
            block_params=block_params,
            to_1x1=to_1x1,
        )

        self.hidden_size = self.enc.output_size

    def forward(self, inpt, sample=True, no_dist=False, **kwargs):
        y1 = self.enc(inpt, **kwargs)

        mu, log_std = torch.chunk(y1, 2, dim=1)
        std = torch.exp(log_std)
        z_dist = dist.Normal(mu, std)
        if sample:
            z_sample = z_dist.rsample()
        else:
            z_sample = mu

        x_rec = self.dec(z_sample)

        if no_dist:
            return x_rec
        else:
            return x_rec, z_dist

    def encode(self, inpt, **kwargs):
        enc = self.enc(inpt, **kwargs)
        mu, log_std = torch.chunk(enc, 2, dim=1)
        std = torch.exp(log_std)
        return mu, std

    def decode(self, inpt, **kwargs):
        x_rec = self.dec(inpt, **kwargs)
        return x_rec

class AE(torch.nn.Module):
    def __init__(
        self,
        input_size,
        z_dim=1024,
        fmap_sizes=(16, 64, 256, 1024),
        to_1x1=True,
        conv_op=torch.nn.Conv2d,
        conv_params=None,
        tconv_op=torch.nn.ConvTranspose2d,
        tconv_params=None,
        normalization_op=None,
        normalization_params=None,
        activation_op=torch.nn.LeakyReLU,
        activation_params=None,
        block_op=None,
        block_params=None,
        *args,
        **kwargs
    ):
        super(AE, self).__init__()

        input_size_enc = list(input_size)
        input_size_dec = list(input_size)

        self.enc = BasicEncoder(
            input_size=input_size_enc,
            fmap_sizes=fmap_sizes,
            z_dim=z_dim,
            conv_op=conv_op,
            conv_params=conv_params,
            normalization_op=normalization_op,
            normalization_params=normalization_params,
            activation_op=activation_op,
            activation_params=activation_params,
            block_op=block_op,
            block_params=block_params,
            to_1x1=to_1x1,
        )
        self.dec = BasicGenerator(
            input_size=input_size_dec,
            fmap_sizes=fmap_sizes[::-1],
            z_dim=z_dim,
            upsample_op=tconv_op,
            conv_params=tconv_params,
            normalization_op=normalization_op,
            normalization_params=normalization_params,
            activation_op=activation_op,
            activation_params=activation_params,
            block_op=block_op,
            block_params=block_params,
            to_1x1=to_1x1,
        )

        self.hidden_size = self.enc.output_size

    def forward(self, inpt, **kwargs):

        y1 = self.enc(inpt, **kwargs)

        x_rec = self.dec(y1)

        return x_rec

    def encode(self, inpt, **kwargs):
        enc = self.enc(inpt, **kwargs)
        return enc

    def decode(self, inpt, **kwargs):
        rec = self.dec(inpt, **kwargs)
        return rec
    
    
    
    
def kl_loss_fn(z_post, sum_samples=True, correct=False, sumdim=(1,2,3)):
    z_prior = dist.Normal(0, 1.0)
    kl_div = dist.kl_divergence(z_post, z_prior)
    if correct:
        kl_div = torch.sum(kl_div, dim=sumdim)
    else:
        kl_div = torch.mean(kl_div, dim=sumdim)
    if sum_samples:
        return torch.mean(kl_div)
    else:
        return kl_div

def rec_loss_fn(recon_x, x, sum_samples=True, correct=False, sumdim=(1,2,3)):
    if correct:
        x_dist = dist.Laplace(recon_x, 1.0)
        log_p_x_z = x_dist.log_prob(x)
        log_p_x_z = torch.sum(log_p_x_z, dim=sumdim)
    else:
        log_p_x_z = -torch.abs(recon_x - x)
        log_p_x_z = torch.mean(log_p_x_z, dim=sumdim)
    if sum_samples:
        return -torch.mean(log_p_x_z)
    else:
        return -log_p_x_z

def geco_beta_update(beta, error_ema, goal, step_size, min_clamp=1e-10, max_clamp=1e4, speedup=None):
    constraint = (error_ema - goal).detach()
    if speedup is not None and constraint > 0.0:
        beta = beta * torch.exp(speedup * step_size * constraint)
    else:
        beta = beta * torch.exp(step_size * constraint)
    if min_clamp is not None:
        beta = np.max((beta.item(), min_clamp))
    if max_clamp is not None:
        beta = np.min((beta.item(), max_clamp))
    return beta

def get_ema(new, old, alpha):
    if old is None:
        return new
    return (1.0 - alpha) * new + alpha * old

import random
def get_range_val(value, rnd_type="uniform"):
    if isinstance(value, (list, tuple, np.ndarray)):
        if len(value) == 2:
            if value[0] == value[1]:
                n_val = value[0]
            else:
                orig_type = type(value[0])
                if rnd_type == "uniform":
                    n_val = random.uniform(value[0], value[1])
                elif rnd_type == "normal":
                    n_val = random.normalvariate(value[0], value[1])
                n_val = orig_type(n_val)
        elif len(value) == 1:
            n_val = value[0]
        else:
            raise RuntimeError("value must be either a single vlaue or a list/tuple of len 2")
        return n_val
    else:
        return value
    
def get_square_mask(data_shape, square_size, n_squares, noise_val=(0, 0), channel_wise_n_val=False, square_pos=None):
    """Returns a 'mask' with the same size as the data, where random squares are != 0
    Args:
        data_shape ([tensor]): [data_shape to determine the shape of the returned tensor]
        square_size ([tuple]): [int/ int tuple (min_size, max_size), determining the min and max squear size]
        n_squares ([type]): [int/ int tuple (min_number, max_number), determining the min and max number of squares]
        noise_val (tuple, optional): [int/ int tuple (min_val, max_val), determining the min and max value given in the 
                                        squares, which habe the value != 0 ]. Defaults to (0, 0).
        channel_wise_n_val (bool, optional): [Use a different value for each channel]. Defaults to False.
        square_pos ([type], optional): [Square position]. Defaults to None.
    """

    def mask_random_square(img_shape, square_size, n_val, channel_wise_n_val=False, square_pos=None):
        """Masks (sets = 0) a random square in an image"""
        img_h = img_shape[-2]
        img_w = img_shape[-1]

        img = np.zeros(img_shape)
        

        if square_pos is None:
            w_start = np.random.randint(0, img_w - square_size)
            h_start = np.random.randint(0, img_h - square_size)
        else:
            pos_wh = square_pos[np.random.randint(0, len(square_pos))]
            w_start = pos_wh[0]
            h_start = pos_wh[1]

        if img.ndim == 2:
            rnd_n_val = get_range_val(n_val)
            img[h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val
        elif img.ndim == 3:
            if channel_wise_n_val:
                for i in range(img.shape[0]):
                    rnd_n_val = get_range_val(n_val)
                    img[i, h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val
            else:
                rnd_n_val = get_range_val(n_val)
                img[:, h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val
        elif img.ndim == 4:
            if channel_wise_n_val:
                for i in range(img.shape[0]):
                    rnd_n_val = get_range_val(n_val)
                    img[:, i, h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val
            else:
                rnd_n_val = get_range_val(n_val)
                img[:, :, h_start : (h_start + square_size), w_start : (w_start + square_size)] = rnd_n_val

        return img

    def mask_random_squares(img_shape, square_size, n_squares, n_val, channel_wise_n_val=False, square_pos=None):
        """Masks a given number of squares in an image"""
        img = np.zeros(img_shape)
        for i in range(n_squares):
            img = mask_random_square(
                img_shape, square_size, n_val, channel_wise_n_val=channel_wise_n_val, square_pos=square_pos
            )
        return img

    ret_data = np.zeros(data_shape)
   
    for sample_idx in range(data_shape[0]):
        # rnd_n_val = get_range_val(noise_val)
        
        rnd_square_size = get_range_val(square_size)
        rnd_n_squares = get_range_val(n_squares)
        
        ret_data[sample_idx] = mask_random_squares(
            data_shape[1:],
            square_size=rnd_square_size,
            n_squares=rnd_n_squares,
            n_val=noise_val,
            channel_wise_n_val=channel_wise_n_val,
            square_pos=square_pos,
        )

    return ret_data

In [12]:
class VQVAEModel(nn.Module):
    def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens, 
                 num_embeddings, embedding_dim, commitment_cost, decay=0):
        super(VQVAEModel, self).__init__()

        #first number is the number of channel

        self._encoder = Encoder(1 , num_hiddens,
                                num_residual_layers, 
                                num_residual_hiddens)

        self._pre_vq_conv = nn.Conv2d(in_channels=num_hiddens, 
                                      out_channels=embedding_dim,
                                      kernel_size=1, 
                                      stride=1)

        self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
                                           commitment_cost)

        self._decoder = Decoder(embedding_dim,
                                num_hiddens, #dim input decoder
                                num_residual_layers, 
                                num_residual_hiddens)
        self.initialize_weights()

    def forward(self, x):

        z = self._encoder(x)
        z = self._pre_vq_conv(z)
        loss, quantized, perplexity, _ = self._vq_vae(z)
        x_recon = self._decoder(quantized)

        return x_recon, loss, perplexity

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight)

                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)

class Decoder(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(Decoder, self).__init__()
        
        self._conv_1 = nn.Conv2d(in_channels=in_channels,
                                 out_channels=num_hiddens,
                                 kernel_size=3, 
                                 stride=1, padding=1)
        
        self._residual_stack = ResidualStack(in_channels=num_hiddens,
                                             num_hiddens=num_hiddens,
                                             num_residual_layers=num_residual_layers,
                                             num_residual_hiddens=num_residual_hiddens)
        
        self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens, 
                                                out_channels=num_hiddens//2,
                                                kernel_size=4, 
                                                stride=2, padding=1)
        
        self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens//2,
                                                #occhio
                                                out_channels=1,
                                                kernel_size=4, 
                                                stride=2, padding=1)

    def forward(self, inputs):
        x = self._conv_1(inputs)
        x = self._residual_stack(x)
        x = nn.Dropout(0)(x)
        x = self._conv_trans_1(x)
        x = F.relu(x)
        return self._conv_trans_2(x)

class Encoder(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(Encoder, self).__init__()

        #in_channels = 1

        self._conv_1 = nn.Conv2d(in_channels=in_channels,
                                 out_channels=num_hiddens//2, #è una divisione senza parte frazionaria
                                 kernel_size=4,
                                 stride=2, padding=1)

        #shape [batch,
        self._conv_2 = nn.Conv2d(in_channels=num_hiddens//2,
                                 out_channels=num_hiddens,
                                 kernel_size=4,
                                 stride=2, padding=1)
        self._conv_3 = nn.Conv2d(in_channels=num_hiddens,
                                 out_channels=num_hiddens,
                                 kernel_size=3,
                                 stride=1, padding=1)
        self._residual_stack = ResidualStack(in_channels=num_hiddens,
                                             num_hiddens=num_hiddens,
                                             num_residual_layers=num_residual_layers,
                                             num_residual_hiddens=num_residual_hiddens)


    def forward(self, inputs):
        x = self._conv_1(inputs)
        x = F.relu(x)
        x = self._conv_2(x)
        x = F.relu(x)
        x = nn.Dropout(0)(x)
        x = self._conv_3(x)
        return self._residual_stack(x)

class ResidualStack(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(ResidualStack, self).__init__()
        self._num_residual_layers = num_residual_layers
        self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
                             for _ in range(self._num_residual_layers)])

    def forward(self, x):
        for i in range(self._num_residual_layers):
            x = self._layers[i](x)
        return F.relu(x)

class Residual(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
        super(Residual, self).__init__()
        self._block = nn.Sequential(
            nn.ReLU(True),
            nn.Conv2d(in_channels=in_channels,
                      out_channels=num_residual_hiddens,
                      kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(in_channels=num_residual_hiddens,
                      out_channels=num_hiddens,
                      kernel_size=1, stride=1, bias=False)
        )
    
    def forward(self, x):
        return x + self._block(x)

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        
        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
        self._commitment_cost = commitment_cost

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape
        
        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)
        
        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
            
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
        
        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss
        
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings

In [13]:
#Training function
def train_VQVAE(model, train_loader, validation_loader, num_epochs, optimizer, scaler, device,beta,theta, path_checkpoint,
         image_path, displ,train_loader_patch,ce_factor):
    
    best_val = 9999999999
    best_epoch = 0
    use_geco = False
    vae_loss_ema = 1
    flag = 0 
    
    if ce_factor > 0:
            
        for epoch in tqdm(range(num_epochs)):
            counter = 0
            loss_tot = 0
            model.train()

            #print('Epoch ' + str(epoch) + ': Train')
            for item1, item2 in zip(enumerate(train_loader),enumerate(train_loader_patch)):

                i, data = item1
                u, data_patch = item2
                
                img = data['image']
                img = img.to(device)

                img_patch = data_patch['image']
                
                optimizer.zero_grad()

                ### VAE Part
                with autocast():
                    loss_vqae = 0
                    if ce_factor < 1:
                        reconstruction, loss, perplexity = model(img)
                        recon_error = F.mse_loss(reconstruction, img)
                        loss = recon_error + loss
                        
                        loss_vqae = loss

                ### CE Part
                loss_ce = 0
                inpt_noisy = img_patch.to(device)

                with autocast():
                    reconstruction, _, _ = model(inpt_noisy)
                    #rec_loss_ce = criterion(x_rec_ce,img)
                    rec_loss_ce = F.mse_loss(reconstruction, img)
                    loss_ce = rec_loss_ce
                    loss = (1.0 - ce_factor) * loss_vqae + ce_factor * loss_ce

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                loss_tot += loss.item()
                counter += 1


            loss = round(loss_tot/(len(train_loader)/counter),4) * 1000
            model.eval()
            mood_val_loss = 0
            counter = 0
            with torch.no_grad():
                #print('Epoch '+ str(epoch)+ ': Val')
                for i, data in enumerate(validation_loader):

                    img = data['image']
                    img = img.to(device)
                    
                    reconstruction, loss, perplexity = model(img)
                    recon_error = F.mse_loss(reconstruction, img)
                    loss = recon_error + loss
                    mood_val_loss += loss.item()
                    counter += 1

            mood_val_loss = round(mood_val_loss/(len(validation_loader)/counter),4) * 1000
            if mood_val_loss < best_val:
                best_val = mood_val_loss
                best_epoch = epoch
                best_weight_par = model.state_dict()

            '''checkpoint = {
                'state_dict': model.state_dict()
                #'optimizer': optimizer.state_dict()


            }'''
            '''if epoch %5 == 0: 
                #torch.save(checkpoint, os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))
                torch.save(model.state_dict(), os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))'''
            if epoch % 10 == 0: 
                
                if flag == 0:
                    orig = np.squeeze(img[0].cpu().detach()) * 255                
                    orig = Image.fromarray(np.uint8(orig))
                    orig.save(image_path + '/Ori_epoch_' + str(epoch) + '.PNG')
                    flag = 1

                rec = np.squeeze(reconstruction[0].cpu().detach()) * 255                
                rec[rec<0]=0
                rec = Image.fromarray(np.uint8(rec))
                rec.save(image_path + '/Rec_epoch_' + str(epoch) + '_loss_' + str(mood_val_loss) + '.PNG')
            if epoch % 10 == 0: torch.save(best_weight_par, os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))


            print("Epoch: ", epoch, ", Training loss: ", loss ,"Validation loss: " + str(mood_val_loss) )      
        rec.save(image_path + '/Rec_epoch_' + str(epoch) + '_loss_' + str(mood_val_loss) + '.PNG')
        torch.save(best_weight_par, os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))

        return model
    else:
        for epoch in tqdm(range(num_epochs)):
            
            counter = 0
            loss_tot = 0
            model.train()

            #print('Epoch ' + str(epoch) + ': Train')
            for i, data in enumerate(train_loader):
                
                img = data['image']
                img = img.to(device)

                optimizer.zero_grad()

                ### VAE Part
                with autocast():
                    
                    reconstruction, loss, perplexity = model(img)
                    recon_error = F.mse_loss(reconstruction, img)
                    loss = recon_error + loss
                        
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                loss_tot += loss.item()
                
                counter += 1


            loss = round(loss_tot/(len(train_loader)/counter),4) * 1000
            model.eval()
            mood_val_loss = 0
            counter = 0
            with torch.no_grad():
                #print('Epoch '+ str(epoch)+ ': Val')
                for i, data in enumerate(validation_loader):

                    img = data['image']
                    img = img.to(device)
                    
                    reconstruction, loss, perplexity = model(img)
                    recon_error = F.mse_loss(reconstruction, img)
                    loss = recon_error + loss
                    mood_val_loss += loss.item()
                    counter += 1

            mood_val_loss = round(mood_val_loss/(len(validation_loader)/counter),4) * 1000
            if mood_val_loss < best_val:
                best_val = mood_val_loss
                best_epoch = epoch
                best_weight_par = model.state_dict()

            '''checkpoint = {
                'state_dict': model.state_dict()
                #'optimizer': optimizer.state_dict()


            }'''
            '''if epoch %5 == 0: 
                #torch.save(checkpoint, os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))
                torch.save(model.state_dict(), os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))'''
            
            if epoch % 10 == 0: 
                
                if flag == 0:
                    orig = np.squeeze(img[0].cpu().detach()) * 255                
                    orig = Image.fromarray(np.uint8(orig))
                    orig.save(image_path + '/Ori_epoch_' + str(epoch) + '.PNG')
                    flag = 1

                rec = np.squeeze(reconstruction[0].cpu().detach()) * 255                
                rec[rec<0]=0
                rec = Image.fromarray(np.uint8(rec))
                rec.save(image_path + '/Rec_epoch_' + str(epoch) + '_loss_' + str(mood_val_loss) + '.PNG')
            display(orig)
            display(rec)
            if epoch % 10 == 0: torch.save(best_weight_par, os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))
                

            print("Epoch: ", epoch, ", Training loss: ", loss ,"Validation loss: " + str(mood_val_loss) )      
        rec.save(image_path + '/Rec_epoch_' + str(epoch) + '_loss_' + str(mood_val_loss) + '.PNG')
        torch.save(best_weight_par, os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))

        return model

In [14]:
#Training function
def train_ae(model, train_loader, validation_loader, num_epochs, optimizer, scaler, device,beta,theta, path_checkpoint,
         image_path, displ,train_loader_patch,batch_size,ce_factor):
    
    criterion = nn.MSELoss(reduction='mean')
    best_val = 9999999999
    best_epoch = 0
    use_geco = False
    vae_loss_ema = 1
    flag = 0 
    
    if ce_factor > 0:
            
        for epoch in tqdm(range(num_epochs)):
            loss_tot = 0
            model.train()

            #print('Epoch ' + str(epoch) + ': Train')
            for item1, item2 in zip(enumerate(train_loader),enumerate(train_loader_patch)):

                i, data = item1
                u, data_patch = item2
                img = data['image']
                img = img.to(device)

                img_patch = data_patch['image']
                optimizer.zero_grad()



                ### VAE Part
                with autocast():
                    loss_vae = 0
                    if ce_factor < 1:
                        x_r = model(img)
                        loss = criterion(img,x_r)
                        loss_vae = loss

                ### CE Part
                loss_ce = 0
                if ce_factor > 0:

                    inpt_noisy = img_patch.to(device)

                    with autocast():
                        x_rec_ce = model(inpt_noisy)
                        rec_loss_ce = criterion(x_rec_ce,img)
                        loss_ce = rec_loss_ce
                        loss = (1.0 - ce_factor) * loss_vae + ce_factor * loss_ce

                else:
                    loss = loss_vae


                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                loss_tot += loss.item()


            loss = round(loss_tot/(len(train_loader)/batch_size),4) * 1000
            model.eval()
            mood_val_loss = 0
            with torch.no_grad():
                #print('Epoch '+ str(epoch)+ ': Val')
                for i, data in enumerate(validation_loader):

                    img = data['image']
                    img = img.to(device)


                    x_r = model(img)
                    loss = criterion(img,x_r)
                    mood_val_loss += loss.item()

            mood_val_loss = round(mood_val_loss/(len(validation_loader)/batch_size),4) * 1000
            if mood_val_loss < best_val:
                best_val = mood_val_loss
                best_epoch = epoch
                best_weight_par = model.state_dict()

            '''checkpoint = {
                'state_dict': model.state_dict()
                #'optimizer': optimizer.state_dict()


            }'''
            '''if epoch %5 == 0: 
                #torch.save(checkpoint, os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))
                torch.save(model.state_dict(), os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))'''
            if epoch % 10 == 0: 
                
                if flag == 0:
                    orig = np.squeeze(img[0].cpu().detach()) * 255                
                    orig = Image.fromarray(np.uint8(orig))
                    orig.save(image_path + '/Ori_epoch_' + str(epoch) + '.PNG')
                    flag = 1

                rec = np.squeeze(x_r[0].cpu().detach()) * 255                
                rec[rec<0]=0
                rec = Image.fromarray(np.uint8(rec))
                rec.save(image_path + '/Rec_epoch_' + str(epoch) + '_loss_' + str(mood_val_loss) + '.PNG')
            if epoch % 10 == 0: torch.save(best_weight_par, os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))


            print("Epoch: ", epoch, ", Training loss: ", loss ,"Validation loss: " + str(mood_val_loss) )      
        rec.save(image_path + '/Rec_epoch_' + str(epoch) + '_loss_' + str(mood_val_loss) + '.PNG')
        torch.save(best_weight_par, os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))

        return model
    else:
        for epoch in tqdm(range(num_epochs)):
            
            loss_tot = 0
            model.train()

            #print('Epoch ' + str(epoch) + ': Train')
            for i, data in enumerate(train_loader):
                
                img = data['image']
                img = img.to(device)

                optimizer.zero_grad()

                ### VAE Part
                with autocast():
                    loss_vae = 0
                    if ce_factor < 1:
                        x_r = model(img)
                        loss = criterion(img,x_r)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                loss_tot += loss.item()


            loss = round(loss_tot/(len(train_loader)/batch_size),4) * 1000
            model.eval()
            mood_val_loss = 0
            with torch.no_grad():
                #print('Epoch '+ str(epoch)+ ': Val')
                for i, data in enumerate(validation_loader):

                    img = data['image']
                    img = img.to(device)


                    x_r = model(img)
                    loss = criterion(img,x_r)
                    mood_val_loss += loss.item()

            mood_val_loss = round(mood_val_loss/(len(validation_loader)/batch_size),4) * 1000
            if mood_val_loss < best_val:
                best_val = mood_val_loss
                best_epoch = epoch
                best_weight_par = model.state_dict()

            '''checkpoint = {
                'state_dict': model.state_dict()
                #'optimizer': optimizer.state_dict()


            }'''
            '''if epoch %5 == 0: 
                #torch.save(checkpoint, os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))
                torch.save(model.state_dict(), os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))'''
            
            if epoch % 10 == 0: 
                
                if flag == 0:
                    orig = np.squeeze(img[0].cpu().detach()) * 255                
                    orig = Image.fromarray(np.uint8(orig))
                    orig.save(image_path + '/Ori_epoch_' + str(epoch) + '.PNG')
                    flag = 1

                rec = np.squeeze(x_r[0].cpu().detach()) * 255                
                rec[rec<0]=0
                rec = Image.fromarray(np.uint8(rec))
                rec.save(image_path + '/Rec_epoch_' + str(epoch) + '_loss_' + str(mood_val_loss) + '.PNG')
                
            if epoch % 10 == 0: torch.save(best_weight_par, os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))


            print("Epoch: ", epoch, ", Training loss: ", loss ,"Validation loss: " + str(mood_val_loss) )      
        rec.save(image_path + '/Rec_epoch_' + str(epoch) + '_loss_' + str(mood_val_loss) + '.PNG')
        torch.save(best_weight_par, os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))

        return model
        

In [15]:
#Training function
def train(model, train_loader, validation_loader, num_epochs, optimizer, scaler, device,beta,theta, path_checkpoint,
         image_path, displ,train_loader_patch,ce_factor):
    
    best_val = 9999999999
    best_epoch = 0
    use_geco = False
    vae_loss_ema = 1
    flag = 0 
    
    
    if ce_factor > 0:
        for epoch in tqdm(range(num_epochs)):
            counter = 0
            model.train()
            loss_tot = 0
            for item1, item2 in zip(enumerate(train_loader),enumerate(train_loader_patch)):

                i, data = item1
                u, data_patch = item2
                img = data['image']
                img = img.to(device)

                img_patch = data_patch['image']

                optimizer.zero_grad()
                
                ### VAE Part
                with autocast():
                    loss_vae = 0
                    if ce_factor < 1:

                        x_r, z_dist = model(img)
                        kl_loss = 0
                        kl_loss = kl_loss_fn(z_dist, sumdim=(1,2,3)) * beta
                        rec_loss_vae = rec_loss_fn(x_r, img, sumdim=(1,2,3))
                        loss_vae = kl_loss + rec_loss_vae * theta

                ### CE Part
                loss_ce = 0

                inpt_noisy = img_patch.to(device)

                with autocast():
                    x_rec_ce, _ = model(inpt_noisy)
                    rec_loss_ce = rec_loss_fn(x_rec_ce, img, sumdim=(1,2,3))
                    loss_ce = rec_loss_ce
                    loss = (1.0 - ce_factor) * loss_vae + ce_factor * loss_ce
                    loss_tot += loss.item()

                if use_geco and ce_factor < 1:
                    g_goal = 0.1
                    g_lr = 1e-4
                    vae_loss_ema = (1.0 - 0.9) * rec_loss_vae + 0.9 * vae_loss_ema
                    theta = geco_beta_update(theta, vae_loss_ema, g_goal, g_lr, speedup=2)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                counter += 1
                

            loss = round(loss_tot/(len(train_loader)/counter),4) * 1000
            model.eval()
            mood_val_loss = 0
            counter = 0
            with torch.no_grad():

                for i, data in enumerate(validation_loader):

                    img = data['image']
                    img = img.to(device)
                    x_r, z_dist = model(img)
                    kl_loss = 0
                    kl_loss = kl_loss_fn(z_dist, sumdim=(1,2,3)) * beta
                    rec_loss_vae = rec_loss_fn(x_r, img, sumdim=(1,2,3))
                    loss_vae = kl_loss + rec_loss_vae * theta
                    mood_val_loss += loss_vae.item()
                    counter += 1

            mood_val_loss = round(mood_val_loss/(len(validation_loader)/counter),4) * 1000

            if mood_val_loss < best_val:
                best_val = mood_val_loss
                best_epoch = epoch
                best_weight_par = model.state_dict()

            '''checkpoint = {
                'state_dict': model.state_dict()
                #'optimizer': optimizer.state_dict()


            }'''
            '''if epoch %5 == 0: 
                #torch.save(checkpoint, os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))
                torch.save(model.state_dict(), os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))'''
            if epoch % 10 == 0: 
                
                if flag == 0:
                    orig = np.squeeze(img[0].cpu().detach()) * 255                
                    orig = Image.fromarray(np.uint8(orig))
                    orig.save(image_path + '/Ori_epoch_' + str(epoch) + '.PNG')
                    flag = 1

                rec = np.squeeze(x_r[0].cpu().detach()) * 255                
                rec[rec<0]=0
                rec = Image.fromarray(np.uint8(rec))
                rec.save(image_path + '/Rec_epoch_' + str(epoch) + '_loss_' + str(mood_val_loss) + '.PNG')
                
            if epoch % 10 == 0: torch.save(best_weight_par, os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))


            print("Epoch: ", epoch, ", Training loss: ", str(loss) ,"Validation loss: " + str(mood_val_loss))  
            
        rec.save(image_path + '/Rec_epoch_' + str(epoch) + '_loss_' + str(mood_val_loss) + '.PNG')
        torch.save(best_weight_par, os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))

        return model
    else:
        for epoch in tqdm(range(num_epochs)):
            counter = 0
            model.train()
            loss_tot = 0
            #print('Epoch ' + str(epoch) + ': Train')
            for i, data in enumerate(train_loader):
                
                img = data['image']
                img = img.to(device)

                optimizer.zero_grad()
                
                ### VAE Part
                with autocast():
                    loss_vae = 0
                    if ce_factor < 1:

                        x_r, z_dist = model(img)
                        kl_loss = 0
                        kl_loss = kl_loss_fn(z_dist, sumdim=(1,2,3)) * beta
                        rec_loss_vae = rec_loss_fn(x_r, img, sumdim=(1,2,3))
                        loss_vae = kl_loss + rec_loss_vae * theta
                        
                loss = loss_vae

                if use_geco and ce_factor < 1:
                    g_goal = 0.1
                    g_lr = 1e-4
                    vae_loss_ema = (1.0 - 0.9) * rec_loss_vae + 0.9 * vae_loss_ema
                    theta = geco_beta_update(theta, vae_loss_ema, g_goal, g_lr, speedup=2)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                loss_tot += loss.item()
                counter += 1

            loss = round(loss_tot/(len(train_loader)/counter),4) * 1000
            model.eval()
            mood_val_loss = 0
            counter = 0
            with torch.no_grad():

                for i, data in enumerate(validation_loader):

                    img = data['image']
                    img = img.to(device)
                    x_r, z_dist = model(img)
                    kl_loss = 0
                    kl_loss = kl_loss_fn(z_dist, sumdim=(1,2,3)) * beta
                    rec_loss_vae = rec_loss_fn(x_r, img, sumdim=(1,2,3))
                    loss_vae = kl_loss + rec_loss_vae * theta
                    mood_val_loss += loss_vae.item()
                    counter += 1
                        

            mood_val_loss = round(mood_val_loss/(len(validation_loader)/counter),4) * 1000

            if mood_val_loss < best_val:
                best_val = mood_val_loss
                best_epoch = epoch
                best_weight_par = model.state_dict()

            '''checkpoint = {
                'state_dict': model.state_dict()
                #'optimizer': optimizer.state_dict()


            }'''
            '''if epoch %5 == 0: 
                #torch.save(checkpoint, os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))
                torch.save(model.state_dict(), os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))'''
            if epoch % 10 == 0: 
                
                if flag == 0:
                    orig = np.squeeze(img[0].cpu().detach()) * 255                
                    orig = Image.fromarray(np.uint8(orig))
                    orig.save(image_path + '/Ori_epoch_' + str(epoch) + '.PNG')
                    flag = 1

                rec = np.squeeze(x_r[0].cpu().detach()) * 255                
                rec[rec<0]=0
                rec = Image.fromarray(np.uint8(rec))
                rec.save(image_path + '/Rec_epoch_' + str(epoch) + '_loss_' + str(mood_val_loss) + '.PNG')
                
                
                
            if epoch % 10 == 0: torch.save(best_weight_par, os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))


            print("Epoch: ", epoch, ", Training loss: ", str(loss) ,"Validation loss: " + str(mood_val_loss) )  
        rec.save(image_path + '/Rec_epoch_' + str(epoch) + '_loss_' + str(mood_val_loss) + '.PNG')
        torch.save(best_weight_par, os.path.join(path_checkpoint + '/Epoch_' + str(epoch) + ".pt"))

In [22]:
##### Set model    
from torch.optim import Adam
from torch.cuda.amp import autocast, GradScaler
from torch.autograd import Variable
    
input_dim = (192,192) # Slice shapes
input_size = (1,192,192)
z_dim = 1024
model_feature_map_sizes=(16, 64, 256, 1024) # Compact vae

#Serve al codice che dice che ho a che fare con imm 2d
conv = nn.Conv2d
convt = nn.ConvTranspose2d
d = 2


model = AE(input_size=input_size, z_dim=z_dim, fmap_sizes=model_feature_map_sizes,
           conv_op=conv,
           tconv_op=convt,
           activation_op=torch.nn.PReLU)

model.d = d

model.to(device)

param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))


model size: 1178.662MB


In [17]:
import torch
torch.cuda.empty_cache()

In [18]:
#path = '/kaggle/input/168-image/Train100ImagesOnly/brainmetshare/metshare'
#path = '/kaggle/input/brainmetshare/DataSplittedHealthyDiseased/brainmetshare/metshare'
path = '/kaggle/input/dataset-finale/Dataset_progetto_advanced_deep_learning'
path_patch = '/kaggle/input/dataset-finale-patch/Dataset_progetto_advanced_deep_learning_patch'
batch_size = 64
input_image = [192,192]
transform = transforms.Compose([
    transforms.Resize(input_image),
    transforms.ToTensor()
])

In [19]:
input_channel = 1
#num_training_updates = 15000

num_hiddens = 256     #larghezza latent space 512/1024
num_residual_hiddens = 4
num_residual_layers = 1

embedding_dim = 256     #provo prima questi 3   aumenta prima dimensione
num_embeddings = 512

commitment_cost = 0.25    #più alto + vicino al valore del codebook

decay = 0.99

learning_rate = 1e-4

'''# initialize the model
model = VQVAEModel(num_hiddens, num_residual_layers, num_residual_hiddens,
    num_embeddings, embedding_dim, commitment_cost, decay).to(device)'''

param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))

model size: 1754.662MB


In [None]:
lr = 1e-4
optimizer = Adam(model.parameters(), lr=lr)
scaler = GradScaler()

path_checkpoint = '/kaggle/working/Checkpoint'
save_image_path = '/kaggle/working/Images'

    
cartellaDaVerificare= Path(path_checkpoint)
if not cartellaDaVerificare.is_dir():
    os.mkdir(path_checkpoint)

cartellaDaVerificare= Path(save_image_path)
if not cartellaDaVerificare.is_dir():
    os.mkdir(save_image_path)
    
    
df = create_csv_final()
index = int((6813 * 0.8)-1)
train_data = df.iloc[:index,:]
val_data = df.iloc[index + 1:,:]


brainTrain = brain_dataset_final(csv_file=train_data,
                                  root_dir= '/kaggle/input/dataset-finale-finale/Dataset_Finale_finale',
                                  transform=transform)

brainVal = brain_dataset_final(csv_file=val_data,
                                  root_dir= '/kaggle/input/dataset-finale-finale/Dataset_Finale_finale',
                                  transform=transform)

trainloader = DataLoader(brainTrain, batch_size)
validationloader = DataLoader(brainVal, batch_size)

#train
beta = 0.01
theta = 1

'''#per CE VAE-AE
#print('CE')
trainloader,validationloader,testloader,trainset,valset,testset = create_dataloaders(path,batch_size,transform,False)
trainloader_patch,validationloader_patch,testloader_patch,trainset_patch,valset_patch,testset_patch = create_dataloaders(path_patch,batch_size,transform,False)
#train(model, trainloader, validationloader, 200, optimizer, scaler, device, beta, theta, path_checkpoint,save_image_path,False,trainloader_patch,batch_size,0.5)
#train_ae(model, trainloader, validationloader, 200, optimizer, scaler, device, beta, theta, path_checkpoint,save_image_path,False,trainloader_patch,batch_size,0.5)
train_VQVAE(model, trainloader, validationloader, 200, optimizer, scaler, device,beta,theta, path_checkpoint, save_image_path, False,trainloader,0.5)
'''

#SENZA CE VAE-AE
print('NO CE')
#trainloader,validationloader,testloader,trainset,valset,testset = create_dataloaders(path,batch_size,transform,False)
#train(model, trainloader, validationloader, 100, optimizer, scaler, device, beta, theta, path_checkpoint,save_image_path,False,trainloader,0)
train_ae(model, trainloader, validationloader, 100, optimizer, scaler, device, beta, theta, path_checkpoint,save_image_path,False,trainloader,batch_size,0)
#train_VQVAE(model, trainloader,validationloader, 200, optimizer, scaler, device,beta,theta, path_checkpoint, save_image_path, False,trainloader,0)

NO CE


  1%|          | 1/100 [02:03<3:23:12, 123.15s/it]

In [None]:
def train_step(model,trainloader,trainset,optimizer,beta):

    model.train()
    running_loss = 0.0
    counter = 0

    for i, data in tqdm(enumerate(trainloader), total=int(len(trainset) / trainloader.batch_size)):
        counter += 1
        data = data['image'] 
        data = data.to(device)
        
        optimizer.zero_grad()  

        rec_vae,mu,std = model(data)
        kl_loss = kl_divergence(mu,std)
        rec_loss_vae = rec_loss_fn(rec_vae,data)
        loss = rec_loss_vae + kl_loss*beta
        loss.backward()
        running_loss += loss.item() #gradiente
        optimizer.step()

    train_loss = running_loss / (counter * trainloader.batch_size)
    return train_loss


def validation_step(model,validationloader,valset,beta):
    
    model.eval()
    running_loss = 0.0
    counter = 0
    with torch.no_grad():
        for i, data in tqdm(enumerate(validationloader), total=int(len(valset) / validationloader.batch_size)):
            
            counter += 1
            data = data['image']
            data = data.to(device)

            rec_vae,mu, std = model(data)
            kl_loss = kl_divergence(mu,std)
            rec_loss_vae = rec_loss_fn(rec_vae,data)
            loss = rec_loss_vae + kl_loss*beta

            running_loss += loss.item()

    val_loss = running_loss / counter
    return val_loss,data.cpu().detach(),rec_vae.cpu().detach()





def kl_divergence(mu, logsigma):
        """Compute KL divergence KL(q_i(z)||p(z)) for each q_i in the batch.
        
        Args:
            mu: Means of the q_i distributions, shape [batch_size, latent_dim]
            logsigma: Logarithm of standard deviations of the q_i distributions,
                      shape [batch_size, latent_dim]
        
        Returns:
            kl: KL divergence for each of the q_i distributions, shape [batch_size]
        """
        ##########################################################
        # YOUR CODE HERE
        sigma = torch.exp(logsigma)
        
        kl = 0.5*(torch.sum(sigma**2 + mu**2 - torch.log(sigma**2) - 1))
        
        return kl
    
def rec_loss_fn (recon_x,x):

    """
    The function checks the reconstruction loss of image in VAE
    """

    loss_fn = nn.MSELoss()
    loss = loss_fn(x,recon_x)

    return loss



In [None]:
def training_vae_papers(trainloader,trainset,validationloader,valset,epochs,beta,patience_early_stopping,
                       patience_plateu,learning_rate):
    
    weight_decay= 0

    z_dim=256
    h_size=(256, 128, 64)
    input_size = [1,256,256]

    model = VAE(input_size, h_size, z_dim)
    model = model.to(device)
    # model.apply(utils.weights_init)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
    factor=0.1, patience=patience_plateu, threshold=0.0001, threshold_mode='abs', verbose = True)
    
    best_val_loss = 99999999
    best_epoch = 0
    counter  = 0
    
    for i in range(epochs):
    
        train_epoch_loss = train_step(model,trainloader,trainset,optimizer,beta)
        val_epoch_loss,data,rec_vae = validation_step(model,validationloader,valset,beta)
        scheduler.step(val_epoch_loss)

        if i % 5 == 0:  
            print('Real input')  
            a = np.squeeze(data[0].cpu()) * 255                
            a[a<0]=0
            display(Image.fromarray(np.uint8(a)))


            print('Rec_vae')  
            a = np.array(np.squeeze(rec_vae[0].cpu())) * 255
            a[a<0]=0
            display(Image.fromarray(np.uint8(a)))
            count = 0

            print('Difference')  
            orig = np.array(np.squeeze(data[0].cpu()))
            recon = np.array(np.squeeze(rec_vae[0].cpu()))
            diff = np.absolute(orig - recon)*255
            diff[diff<0] = 0
            display(Image.fromarray(np.uint8(diff)))
            count = 0

        if val_epoch_loss < best_val_loss:
            best_weight_par = model.state_dict()
            best_model = deepcopy(model)
            best_model.load_state_dict(best_weight_par)
            best_val_loss = val_epoch_loss
            best_epoch = i
            counter = 0
        else:
            counter += 1
            if counter > patience_early_stopping: break
        if counter == patience_plateu:
            print('Loading best model at epoch: ',best_epoch, 'with validation loss of:' , best_val_loss* 1000)
            model = deepcopy(best_model)
            model.load_state_dict(best_weight_par)
            counter = 0
        print('LR:',optimizer.state_dict()['param_groups'][0]['lr'])
        print('Best val loss:',best_val_loss * 1000,' Epoch:',best_epoch,' Counter:',counter )
        print('epoch:{} \t'.format(i+1),'trainloss:{}'.format(train_epoch_loss*1000),'\t','valloss:{}'.format(val_epoch_loss*1000))

        
    print('Loading best model at epoch ',best_epoch,' with validation loss of :',best_val_loss * 1000)
    
    return model

In [None]:
'''model_vae_papers = training_vae_papers(trainloader = trainloader ,trainset = trainset ,validationloader = validationloader,
                            valset = valset,epochs = 100 ,beta = 1 ,patience_early_stopping = 20,
                            patience_plateu = 10,learning_rate = 0.0001)


model_path = '/kaggle/working/Model'

cartellaDaVerificare= Path(model_path)
if not cartellaDaVerificare.is_dir():
    os.mkdir(model_path)
    
torch.save(best_model.state_dict(), model_path + '/' + 'best_model_CEVAE_standard')'''

In [None]:
transform = transforms.Compose([
    transforms.Resize([192,192]),
    transforms.ToTensor()
])

test_csv  = create_csv('/kaggle/input/dataset-progetto-advanced-test-set/Dataset_progetto_advanced_TEST_SET')

brainTest = brain_dataset(csv_file=test_csv,
                                  root_dir= '/kaggle/input/dataset-progetto-advanced-test-set/Dataset_progetto_advanced_TEST_SET',
                                  transform=transform)

testloader = DataLoader(brainTest, 1)


In [None]:
from skimage import filters
from skimage import morphology
from scipy import ndimage

input_dim = (192,192) # Slice shapes
input_size = (1,192,192)
z_dim = 512
model_feature_map_sizes=(16, 64, 256, 512) # Compact vae

#Serve al codice che dice che ho a che fare con imm 2d
conv = nn.Conv2d
convt = nn.ConvTranspose2d

model = AE(input_size=input_size, z_dim=z_dim, fmap_sizes=model_feature_map_sizes,
           conv_op=conv,
           tconv_op=convt,
           activation_op=torch.nn.PReLU)

model.d = d
model.load_state_dict(torch.load('/kaggle/input/vaep2-short-epoch-199/VAEP2_short_Epoch_199.pt'))
model.eval()


for i,data in enumerate(testloader):
    img = data['image']
    
    with autocast():
            x_r= model(img)

    # Difference of reconstruction and input
    print(np.shape(x_r))
    x_r = x_r.float()
    x_r[0][0][x_r[0][0]< 0] = 0 
    
    or_mask = np.copy(img[0][0].detach().numpy())
    rec_mask = np.copy(x_r[0][0].detach().numpy())
    
    or_mask[or_mask>0] = 1
    rec_mask[rec_mask>0] = 1
    
    intersection_mask = or_mask * rec_mask
    
    
    diff_mask = (img.detach().cpu().numpy()- x_r.detach().cpu().numpy()) * intersection_mask
    
    display(Image.fromarray(np.uint8(np.squeeze(img.detach().cpu().numpy()) * 255)))
    display(Image.fromarray(np.uint8(np.squeeze(x_r.detach().cpu().numpy())*255)))

    # Manual Thresholding
    m_diff_mask = diff_mask.copy()
    m_diff_mask[m_diff_mask <= 0.25] = 0
    m_diff_mask[m_diff_mask > 0.25] = 1
    
    display(Image.fromarray(np.uint8(np.squeeze(m_diff_mask) * 255)))

    # Otsu Thresholding
    val = filters.threshold_otsu(m_diff_mask)
    thr = m_diff_mask > val
    thr[thr < 0] = 0
    
    #display(Image.fromarray(np.uint8(np.squeeze(thr) * 255)))

    # Morphological Opening
    final = np.zeros_like(thr)
    for i in range(thr.shape[0]):
        final[i,0] = torch.tensor(morphology.area_opening(thr[i,0], area_threshold= 20 ))
    final[img.cpu() == 0] = 0
    
    display(Image.fromarray(np.uint8(np.squeeze(final) * 255)))

    '''s_index = 90
    fig = plt.figure(figsize=(15, 15))
    ax1 = fig.add_subplot(1,6,1)
    #ax1.set_title('Input image', fontsize=12)
    rotated_img = ndimage.rotate(orig_out[s_index, 0], -90)
    ax1.imshow(rotated_img, cmap='gray')
    ax1.tick_params(axis='both', which='major', labelsize=4)
    plt.axis('off')

    ax1 = fig.add_subplot(1,6,1)
    #ax1.set_title('Input image', fontsize=12)
    rotated_img = ndimage.rotate(orig_out.cpu().detach().numpy()[s_index, 0], -90)
    ax1.imshow(rotated_img, cmap='gray')
    ax1.tick_params(axis='both', which='major', labelsize=4)
    plt.axis('off')

    ax1 = fig.add_subplot(1,6,2)
    #ax1.set_title('Input image', fontsize=12)
    rotated_img = ndimage.rotate(img.cpu().detach().numpy()[s_index, 0], -90)
    ax1.imshow(rotated_img, cmap='gray')
    ax1.tick_params(axis='both', which='major', labelsize=4)
    plt.axis('off')'''