In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import yaml
import pickle

In [None]:
with open("../config.yaml", 'r') as f:
    config = yaml.safe_load(f)

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = config['dataset']['batch_size']
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Grayscale(),
                                transforms.Normalize((0), (0.3))])

config['model']["dataset"] = data 

train_dataset = getattr(datasets, data)(root='./data', train=True, download=True, transform= transform)
test_dataset = getattr(datasets, data)(root='./data', train=False, download=True, transform= transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
input_feat = len(train_dataset[0][0].flatten(0))

criterium = nn.L1Loss(reduction = "mean")

In [None]:
def plot(loss_train: list, 
         loss_val: list, 
         model, 
         ds, 
         config: yaml, 
         transform: None, 
         pieces_of_loss_train:dict = None,
         pieces_of_loss_val:dict = None):
    if type(ds.targets) == list:
        targets = torch.tensor(ds.targets).unique()
    else:
        targets = ds.targets.unique()
    n_images = np.min([config['plot']['n_images'], len(targets)])
    targets = targets[:n_images]
    dataset = config['model']['dataset']             
    show_pieces = config['plot']['show_pieces']
    show_rec = config['plot']['show_rec']
    show_training = config['plot']['show_training']
    model_type = {
        1: "auto",
        2: "vae",
        3: "vae_recurrent"
    }
    model_type = model_type[config['model']['model']]
        
    l = {'epoch' :range(1, len(loss_train)+1),
         'training':loss_train, 
         'validation':loss_val}
    fig = px.line(l, 
                  x ='epoch', 
                  y=['training','validation'],
                  title = "Loss of the training",
                  width = 700, 
                  height = 600)
    if show_training:
        fig.show()
        

    y = list(pieces_of_loss_train.keys())
    y.remove('epoch')
    fig = px.line(pieces_of_loss_train, 
                  x ='epoch', 
                  y= y,
                  title = "pieces of the training loss",
                  width = 800, 
                  height = 700)
    
    fig.write_html(os.path.join(config['paths']['images'],model_type, f'piece_train_{dataset}.html'))
    if show_pieces:
        fig.show()

    fig = px.line(pieces_of_loss_val, 
                  x = 'epoch', 
                  y= y,
                  title = "pieces of the validation loss",
                  width = 800, 
                  height = 700)
    
    fig.write_html(os.path.join(config['paths']['images'],model_type, f'piece_val_{dataset}.html'))
    if show_pieces:
        fig.show()
             
    #reconstruction part

    fig, axes = plt.subplots(nrows = n_images, 
                             ncols = 2, 
                             figsize = (6, n_images*3),
                             constrained_layout=True)
    title = {
        1: "Autoencoder",
        2: "Variational autoencoder",
        3: "Vae Recurrent"
    }
    title = title[config['model']['model']]

    fig.suptitle(f"{title} {model.hidden_dim}D for {dataset}")
    with torch.no_grad():
        for i, target in enumerate(targets):
            if type(ds.targets) == list:
                data = ds.data[[True if x == i else False for x in ds.targets]][0]
            else:
                data = ds.data[ds.targets==target][0].numpy()

            data = transform(data)[0]
            data = data.unsqueeze(0) if len(data.shape) == 2 else data

            if config['model']['model'] in [2,3]: 
                recon, _, _, _,_ = model(data.float().to(model.device))
            else:
                recon = model.cpu()(data.float().flatten(1))
            axes[i,0].imshow(data[0], cmap = 'gray')
            axes[i,1].imshow(recon.detach().cpu().numpy().reshape(data.shape[1:]), cmap = 'gray')
        
            axes[i,0].set_title("real")
            axes[i,1].set_title("reconstructed")
    plt.savefig(os.path.join(config['paths']['images'],model_type, f'rec_{model.hidden_dim}D_{dataset}.png'))    
    if show_rec:
        plt.show()
    else:
        plt.close()


# Model

In [None]:
class dist_fun_rec(nn.Module):
    def __init__(self,
                 inverse: bool, 
                 input_feat:int, 
                 hidden_dim:int = 16):
        super(dist_fun_rec, self).__init__()
        self.hidden_dim = hidden_dim
        self.inverse = inverse
        self.input_feat = input_feat

        # Encoder layers        
        fc1 = [nn.Linear(input_feat, hidden_dim), 
               nn.Sigmoid(), 
               nn.Linear(hidden_dim, 1)]

        self.fc1 = nn.Sequential(*fc1)
    
    def forward(self, x):
        return self.fc1(x)
        
    def derivative(self, x):
        derivative = vmap(jacrev(self.forward))(x)
        return derivative

    def functional_loss(self, conditional, var = None):
        batch = conditional.shape[0]
        device = conditional.device.type
        
        if self.inverse:
            # La funzione inversa deve produrre campioni di media zero e varianza 1
            # Genero 500 rv in [0,1]
            u = torch.rand(batch, 500, 1, requires_grad = True).float().to(device)
            u = torch.cat((u, conditional.unsqueeze(1).repeat(1,500,1)),-1)          # B x 500 x (c+1)
            X = self.forward(u)
    
            ### Voglio che mu = 0 e std = 1
            mean = torch.mean(X)
            std = torch.mean(X**2)-mean**2

            zero = torch.mean(torch.abs(mean))
            one = torch.mean(torch.abs(std - torch.ones_like(std)))
            l = zero + one
            return l
        else:    
            #### proprietà densità
            # 1) lim_{x --> -infty} F(x)=0
            # 2) lim_{x --> infty} F(x)=1
            x = -40 * torch.ones(batch, 1 , requires_grad = True).float().to(device)
            lw = torch.cat((conditional, x), -1 )
            up = torch.cat((conditional, -x), -1)
            lower = self.forward(lw)
            upper = self.forward(up)
            
            zero = torch.mean(torch.abs(lower))
            one = torch.mean(torch.abs(upper-torch.ones_like(upper)))
            
            # 3) F è crescente ==> controllo in un dominio [a,b]
            a = -30
            b = 30
            domain = torch.rand(batch, 500, 1, device = device)*(b-a) + a
            input_pos = torch.cat((domain, conditional.unsqueeze(1).repeat(1,500,1)),-1).requires_grad_()
            density = torch.cat([self.derivative(input_pos[i])[:,:,0].view(1,-1) for i in range(x.shape[0])])
            positivity = torch.sum(F.relu(-density))

            # Poichè la derivata è una densità allora il suo integrale deve essere 1
            prob = torch.sum(density, -1)
            normality = torch.mean(torch.abs(prob-torch.ones_like(prob)))   
            l = zero + one + positivity + normality
        return l
        


class VAE_recurrent(nn.Module):
    def __init__(self,
                 input_feat: int,
                 criterium,
                 device, 
                 hidden_dim: int):
        super(VAE_recurrent, self).__init__()

        # Encoder layers
        self.criterium = criterium
        self.hidden_dim = hidden_dim
        self.device = device
        self.input_feat = input_feat
        self.upper_bound_var = torch.tensor([5.]*hidden_dim, device = device, requires_grad = True).float()
        self.fc1 = nn.Sequential(nn.Flatten(1),
                                 nn.Linear(input_feat, 512),
                                 nn.Tanh(),
                                 nn.Linear(512, 256))
        
        self.fc_mu = nn.Sequential(nn.Linear(256, 128),
                                   nn.Tanh(),
                                   nn.Linear(128, hidden_dim))
        
        self.fc_logvar = nn.Sequential(nn.Linear(256, 128),
                                       nn.Tanh(),
                                       nn.Linear(128, hidden_dim))

        # Decoder layers
        
        self.fc2 = nn.Sequential(nn.Tanh(), 
                                 nn.Linear(hidden_dim, 128),
                                 nn.Tanh(),
                                 nn.Linear(128, 256),
                                 nn.Tanh(),
                                 nn.Linear(256, 512),
                                 nn.Tanh(),
                                 nn.Linear(512, input_feat))

        #### F^{-1}(u) ####
        self.F_inv = nn.ModuleList([dist_fun_rec(inverse = True, input_feat = i+1) for i in range(hidden_dim)])

        #### F(F^{-1}(u)) ####
        self.F = nn.ModuleList([dist_fun_rec(inverse = False, input_feat = i+1) for i in range(hidden_dim)])
        
    def encode(self, x):
        x = F.relu(self.fc1(x))
        mu = self.fc_mu(x)
        log_var = self.fc_logvar(x)
        log_var = torch.max(torch.min(log_var,torch.ones_like(log_var)*4),torch.ones_like(log_var)*(-4)) 
        var = torch.exp(log_var)
        return mu, var#.view(-1,self.hidden_dim, self.hidden_dim)

    def decode(self, z):
        return self.fc2(z)

        
    def reparameterize(self, mu, var):

        #### Generating the random distribution #####
        b,_ = mu.shape
        eps = torch.tensor([]).to(self.device)
        input = []
        decode = []
        for i in range(self.hidden_dim):
            u = torch.rand(b,1, requires_grad = True).float().to(self.device)
            input.append(u)
            x = self.F_inv[i](torch.cat((u, eps),-1))
            eps = torch.cat((x, eps), -1)
            decode.append(self.F[i](eps))
            
        # tutti gli input che sono stati generati dalla distribuzione uniforme
        u = torch.cat(input, -1)
        
        # tutte le ricostruzioni generate dalla rete
        u_hat = torch.cat(decode, -1)
        ### Perturbing the embedding 
        z = mu + var*eps
        return z, u, u_hat, eps
    
    def forward(self, x):
        mu, var = self.encode(x.view(-1, self.input_feat))
        z, u, u_hat, eps, = self.reparameterize(mu, var)
        x_reconstructed = self.decode(z)
        
        return x_reconstructed, u, eps, u_hat, var
    
     
    def loss_density(self, u, eps, var):
        loss_density_F = 0
        loss_density_F_inv = 0
        loss_derivative = 0

        ### Kullenback Leiberg divergence        
        kl = torch.tensor([0]).float().to(self.device)
                
        for i in range(self.hidden_dim):
            dFinv_du = self.F[i].derivative(torch.cat((u[:,i:i+1],eps[:,:i]), -1))[:, 0, 0]
            dF_dy = self.F[i].derivative(eps[:,:i+1])[:, 0, 0]
            
            loss_derivative +=  self.criterium(dFinv_du, dF_dy)
            loss_density_F += self.F[i].functional_loss(eps[:,:i])
            loss_density_F_inv += self.F_inv[i].functional_loss(eps[:,:i])
        
            if len(dF_dy[dF_dy>0])>0:
                kl += torch.mean(torch.log(var[:,i][dF_dy>0]) - torch.log(var[:,i][dF_dy>0]))

        l = loss_density_F + loss_density_F_inv + loss_derivative + kl
        return l, (loss_derivative.item(), loss_density_F.item(), loss_density_F_inv.item(), kl.item())
    
    def loss_functional(self, img, img_rec, u, eps, u_hat, var):
        
        
        ### reconstruction loss for distribution
        reconstruction1 =  self.criterium(u, u_hat)
        ### reconstruction loss for image
        reconstruction2 = self.criterium(img, img_rec)
        ### Anti annullamento varianza
        var_emb = torch.mean(torch.prod(var, 1))
    
        l = reconstruction1 + 500*reconstruction2 + 1/var_emb
        
        return l, (reconstruction1.item(), reconstruction2.item())




In [None]:
def step(model, 
         dataloader,
         optimizer,
         pieces_of_loss: dict, 
         training: bool = False):
    loss_epoch = 0.0
    len_load = len(dataloader)
    if training:
        model.train()
    else:
        model.eval()
    for data, _ in tqdm(iter(dataloader)):
        # blocking the gradient summation 
        optimizer.zero_grad()
    
        # forward step
        x_reconstructed, u, x, u_hat, var  = model(data.to(model.device).float())
        
        # computing the loss
        l1, dens = model.loss_density(u, x, var)
        l2, func = model.loss_functional(data.to(model.device).flatten(1).float(), 
                                         x_reconstructed, u, x, u_hat, var)
        loss = l1 + l2 
        
        pieces = dens + func
        if torch.any(torch.isnan(loss)).item():
            print(dens)
            print(func)
        # Backward and optimize
        if training:
            loss.backward()
            optimizer.step()
        
        loss_epoch += loss.item()
        for i, key in enumerate(pieces_of_loss.keys()):
            if 'std_emb' == key:
                pieces_of_loss[key][-1] += torch.mean(torch.prod(var,1)).item()/len_load
            else:
                pieces_of_loss[key][-1] += pieces[i]/len_load
    return loss_epoch/len_load

def train_recurrent(model, 
             train_loader, 
             val_loader, 
             num_epochs,
             optimizer):
    loss_train = []
    loss_val = []
    be = np.inf
    bm = model
    
    pieces_of_loss_train = {'loss_derivative':[], 'loss_density_F':[], 'loss_density_F_inv':[], 'kl_loss':[],
                            'reconstruction1':[], 'reconstruction2':[], 'std_emb':[]}    
    pieces_of_loss_val = {'loss_derivative':[], 'loss_density_F':[], 'loss_density_F_inv':[], 'kl_loss':[],
                            'reconstruction1':[], 'reconstruction2':[],  'std_emb':[]} 
    model.train()

    for epoch in range(num_epochs):    
        for key in pieces_of_loss_train.keys():
            pieces_of_loss_train[key].append(0)
            pieces_of_loss_val[key].append(0)
        l = step(model, train_loader, optimizer, pieces_of_loss_train, True)
        loss_train.append(l)
        l = step(model, val_loader, optimizer,pieces_of_loss_val)
        loss_val.append(l)
        if (epoch+1)%5==0:
            print(f"loss training at the {epoch+1}-th = {loss_train[-1]}")
            print(f"loss validation at the {epoch+1}-th = {loss_val[-1]}")
            
        if loss_val[-1]<be:
            be = loss_val[-1]
            bm = model
    pieces_of_loss_train['epoch'] = list(range(1, num_epochs+1))
    pieces_of_loss_val['epoch'] = list(range(1, num_epochs+1))
    return bm, loss_train, loss_val, pieces_of_loss_train, pieces_of_loss_val
