In [151]:
import numpy as np
import pandas as pd
import math

from sklearn.datasets import load_iris
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from torch.utils.data import DataLoader, Dataset

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from tqdm.auto import tqdm

In [152]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
    
    
print(f'Actual device: {device}')

Actual device: mps


In [153]:
X = load_iris()['data']
X
min_max = MinMaxScaler()
#std_scl = StandardScaler()
X = min_max.fit_transform(X)
#X = std_scl.fit_transform(X)
X[:10]

array([[0.22222222, 0.625     , 0.06779661, 0.04166667],
       [0.16666667, 0.41666667, 0.06779661, 0.04166667],
       [0.11111111, 0.5       , 0.05084746, 0.04166667],
       [0.08333333, 0.45833333, 0.08474576, 0.04166667],
       [0.19444444, 0.66666667, 0.06779661, 0.04166667],
       [0.30555556, 0.79166667, 0.11864407, 0.125     ],
       [0.08333333, 0.58333333, 0.06779661, 0.08333333],
       [0.19444444, 0.58333333, 0.08474576, 0.04166667],
       [0.02777778, 0.375     , 0.06779661, 0.04166667],
       [0.16666667, 0.45833333, 0.08474576, 0.        ]])

In [154]:
class IrisDataset(Dataset):
    def __init__(self, data):
        self.data = torch.tensor(data, dtype = torch.float32)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

dataset = IrisDataset(X)
dataloader = DataLoader(dataset, batch_size = 32, shuffle = True)

for batch in dataloader:
    print(f'{batch.shape}\n{batch}')
    break

torch.Size([32, 4])
tensor([[0.4167, 0.8333, 0.0339, 0.0417],
        [0.3056, 0.7083, 0.0847, 0.0417],
        [0.3611, 0.3750, 0.4407, 0.5000],
        [0.3056, 0.5833, 0.0847, 0.1250],
        [0.5000, 0.3333, 0.6271, 0.4583],
        [0.1667, 0.2083, 0.5932, 0.6667],
        [0.2222, 0.2083, 0.3390, 0.4167],
        [0.0000, 0.4167, 0.0169, 0.0000],
        [0.4722, 0.0833, 0.6780, 0.5833],
        [0.6111, 0.4167, 0.7627, 0.7083],
        [0.1944, 0.1250, 0.3898, 0.3750],
        [0.4167, 0.2917, 0.4915, 0.4583],
        [0.7778, 0.4167, 0.8305, 0.8333],
        [0.1944, 0.5833, 0.0847, 0.0417],
        [0.5278, 0.3333, 0.6441, 0.7083],
        [0.5833, 0.3750, 0.5593, 0.5000],
        [0.6667, 0.4167, 0.6780, 0.6667],
        [0.9167, 0.4167, 0.9492, 0.8333],
        [0.4167, 0.3333, 0.6949, 0.9583],
        [0.2500, 0.2917, 0.4915, 0.5417],
        [0.5556, 0.2083, 0.6610, 0.5833],
        [0.3056, 0.7917, 0.1186, 0.1250],
        [0.5833, 0.5000, 0.7288, 0.9167],
        [0.027

In [155]:
for i in dataloader:
    batch = i
    break

In [165]:
class VariationalAutoEncoder(nn.Module):
    def __init__(self, 
                 input_dim = 4, 
                 layers = 5,
                 dropout_rate = 0.5,
                 latent_space_dim = 2):
        super(VariationalAutoEncoder, self).__init__()
        
        max_neurons = 2 ** layers

        encoder_layers_list = [
            nn.Linear(input_dim, max_neurons),
            nn.LeakyReLU(),
            nn.Dropout(dropout_rate)
        ]
        
        current_dim = max_neurons
        while current_dim > latent_space_dim:
            next_dim = current_dim // 2
            encoder_layers_list.extend(nn.Sequential(
                nn.Linear(current_dim, next_dim),
                nn.LeakyReLU(),
                nn.Dropout(dropout_rate)
            ))
            current_dim = next_dim
        
        encoder_layers_list.append(nn.Linear(current_dim, latent_space_dim))

        self.Encoder = nn.Sequential(*encoder_layers_list)
        
        self.fc_mu = nn.Linear(next_dim, latent_space_dim)
        self.fc_logvar = nn.Linear(next_dim, latent_space_dim)
        
        decoder_layers_list = []
        current_dim = latent_space_dim
        while current_dim < max_neurons:
            next_dim = current_dim * 2
            decoder_layers_list.extend([
                nn.Linear(current_dim, next_dim),
                nn.LeakyReLU(),
                nn.Dropout(dropout_rate)
            ])
            current_dim = next_dim
        
        # L'ultimo strato del decoder per ricostruire l'input
        decoder_layers_list.append(nn.Linear(current_dim, input_dim))
        decoder_layers_list.append(nn.Tanh())
        
        self.Decoder = nn.Sequential(*decoder_layers_list)
        
        
    def reparameterize(self, mu, logvar):
        """
        Implementa il trick della rielaborazione per campionare dallo spazio latente.
        Args:
            mu (torch.Tensor): Media della distribuzione latente (dimensione: batch_size x latent_dim).
            logvar (torch.Tensor): Log-varianza della distribuzione latente (dimensione: batch_size x latent_dim).
        Returns:
            z (torch.Tensor): Campione dallo spazio latente (dimensione: batch_size x latent_dim).
        """
        std = torch.exp(0.5 * logvar)  # Calcola lo scarto quadratico medio
        epsilon = torch.randn_like(std)  # Campiona da una distribuzione normale standard
        z = mu + epsilon * std  # Applica il trick della rielaborazione
        return z
        
    def forward(self, x):
        x = self.Encoder(x)
        mu, log_var = self.fc_mu(x), self.fc_logvar(x)
        z = self.reparameterize(mu, log_var)
        reconstruction = self.Decoder(z)
        
        return mu, log_var, reconstruction
        
        
vae = VariationalAutoEncoder(latent_space_dim = 4)
vae
vae(batch)

(tensor([[ 0.3996, -0.0539, -0.0260,  0.0943],
         [ 0.1689, -0.1291,  0.0132, -0.0819],
         [ 0.2273, -0.1118, -0.0586, -0.0346],
         [ 0.2408, -0.1092, -0.0404, -0.0197],
         [ 0.1696, -0.1284,  0.0143, -0.0814],
         [ 0.3756, -0.0614, -0.1363,  0.0705],
         [ 0.1691, -0.1289,  0.0136, -0.0817],
         [ 0.2269, -0.1095, -0.1506, -0.0449],
         [ 0.2839, -0.1001,  0.1769,  0.0354],
         [ 0.2219, -0.1116, -0.0557, -0.0443],
         [ 0.1691, -0.1289,  0.0135, -0.0818],
         [ 0.2342, -0.1076, -0.1754, -0.0403],
         [ 0.2572, -0.1029, -0.1068, -0.0115],
         [ 0.2083, -0.1160,  0.0398, -0.0503],
         [ 0.2621, -0.0988, -0.2171, -0.0188],
         [ 0.2627, -0.1018, -0.0945, -0.0049],
         [ 0.2483, -0.1121,  0.2666,  0.0118],
         [ 0.1689, -0.1291,  0.0132, -0.0819],
         [ 0.3482, -0.0709,  0.1320,  0.0620],
         [ 0.2410, -0.1136,  0.2531,  0.0046],
         [ 0.2341, -0.1101, -0.0737, -0.0295],
         [ 0.

In [166]:
lr = 1e-5

def loss_function(recon_x, x, mu, logvar):
    # Perdita di ricostruzione (Errore quadratico medio)
    reconstruction_loss = F.mse_loss(recon_x, x, reduction='sum')
    
    # Perdita di Divergenza KL
    # Divari di distribuzioni (appena appreso vs. normale standard)
    # Formula di KL: -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    # dove mu e sigma sono la media e la deviazione standard della distribuzione latente.
    kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    # Funzione di perdita totale
    return reconstruction_loss + kl_divergence

optimizer = optim.AdamW(vae.parameters(), lr = lr)

In [168]:
epochs = 10000

for epoch in tqdm(range(epochs), desc = 'Training'):
    vae.train()
    train_loss = 0
    for batch_idx, data in enumerate(dataloader):
        
        data = data.float()
        optimizer.zero_grad()
        
        mu, log_var, rec = vae(data)
        loss = loss_function(rec, data, mu, log_var)
        
        loss.backward()
        
        optimizer.step()

        train_loss += loss.item()
        
    avg_train_loss = train_loss / len(dataloader.dataset)
    if epoch % 1000 == 0 and epoch != 0:
        print(f'Epoch [{epoch}/{epochs}], Loss: {avg_train_loss:.4f}')

Training:   0%|          | 0/10000 [00:00<?, ?it/s]

Epoch [1000/10000], Loss: 0.2902
Epoch [2000/10000], Loss: 0.2875
Epoch [3000/10000], Loss: 0.2788
Epoch [4000/10000], Loss: 0.2904
Epoch [5000/10000], Loss: 0.2785
Epoch [6000/10000], Loss: 0.2739
Epoch [7000/10000], Loss: 0.2843
Epoch [8000/10000], Loss: 0.2857
Epoch [9000/10000], Loss: 0.2739


In [129]:
load_iris()['data'][:32]

array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [4.6, 3.4, 1.4, 0.3],
       [5. , 3.4, 1.5, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5.4, 3.7, 1.5, 0.2],
       [4.8, 3.4, 1.6, 0.2],
       [4.8, 3. , 1.4, 0.1],
       [4.3, 3. , 1.1, 0.1],
       [5.8, 4. , 1.2, 0.2],
       [5.7, 4.4, 1.5, 0.4],
       [5.4, 3.9, 1.3, 0.4],
       [5.1, 3.5, 1.4, 0.3],
       [5.7, 3.8, 1.7, 0.3],
       [5.1, 3.8, 1.5, 0.3],
       [5.4, 3.4, 1.7, 0.2],
       [5.1, 3.7, 1.5, 0.4],
       [4.6, 3.6, 1. , 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [4.8, 3.4, 1.9, 0.2],
       [5. , 3. , 1.6, 0.2],
       [5. , 3.4, 1.6, 0.4],
       [5.2, 3.5, 1.5, 0.2],
       [5.2, 3.4, 1.4, 0.2],
       [4.7, 3.2, 1.6, 0.2],
       [4.8, 3.1, 1.6, 0.2],
       [5.4, 3.4, 1.5, 0.4]])

In [131]:
std_scl.inverse_transform(vae(torch.randn(32,4))[2].detach().numpy())

array([[5.849788 , 3.056685 , 3.7667842, 1.2053581],
       [5.836423 , 3.0582294, 3.7513149, 1.1970277],
       [5.8478074, 3.0545444, 3.7679594, 1.209895 ],
       [5.8513784, 3.055823 , 3.7738783, 1.2105346],
       [5.8376656, 3.060751 , 3.742765 , 1.1895307],
       [5.836784 , 3.058817 , 3.747321 , 1.1895347],
       [5.8463125, 3.0536792, 3.7714396, 1.212753 ],
       [5.8386292, 3.0603087, 3.7459495, 1.1889246],
       [5.854311 , 3.0576048, 3.77677  , 1.2157733],
       [5.8467636, 3.0550447, 3.7678192, 1.2060039],
       [5.846452 , 3.0554574, 3.7659245, 1.2081906],
       [5.835979 , 3.0589201, 3.743428 , 1.1902221],
       [5.838807 , 3.058964 , 3.7507088, 1.1942899],
       [5.8481655, 3.0545468, 3.769882 , 1.2113799],
       [5.834735 , 3.058921 , 3.7400725, 1.1881965],
       [5.8339596, 3.0579345, 3.7447019, 1.1913834],
       [5.8468447, 3.0552444, 3.7678506, 1.2075266],
       [5.839301 , 3.0588505, 3.7528656, 1.1946388],
       [5.836112 , 3.0581193, 3.7465496, 1.193