In [2]:
import torch
from torchvision import datasets, transforms
import torch.nn as nn
from torch.distributions import Categorical
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter

# Metadata e inizializzazione datasets

In [None]:
torch.manual_seed(0)

batch_size = 128
temperature = 1.0
seed = 0
log_interval = 100
log_interval_writer = 100
hard = False
latent_dim = 20
categorical_dim = 2
temp_min = 0.5
ANNEAL_RATE = 0.00003

if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Utilizzo Apple Silicon GPU (MPS)")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Utilizzo NVIDIA GPU (CUDA)")
else:
    device = torch.device("cpu")
    print("Utilizzo la CPU")


torch.manual_seed(seed)
if device.type == "cuda": 
    torch.cuda.manual_seed(seed)
elif device.type == "mps": 
    torch.mps.manual_seed(seed)

kwargs = {'num_workers': 1, 'pin_memory': True} if device.type == "cuda" or device.type == "mps" else {} # pin_memory può essere utile anche per MPS


Utilizzo Apple Silicon GPU (MPS)


## MNIST

In [None]:

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        './data/MNIST',
        train=True,
        download=True,
        transform=transforms.ToTensor()
        ),
    batch_size=batch_size,
    shuffle=True,
    **kwargs
    )

val_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        './data/MNIST',
        train=False,
        download=True,
        transform=transforms.ToTensor()
        ),
    batch_size=batch_size,
    shuffle=False,
    **kwargs
    )


# Gumbel-softmax

In [4]:

def sample_gumbel(shape, eps=1e-20):
    # sample from a uniform distribution
    U = torch.rand(shape)
    return -torch.log(-torch.log(U.to(device) + eps) + eps)

def gumbel_softmax_sample(logits, temperature):
    y = logits + sample_gumbel(logits.size())
    return F.softmax(y / temperature, dim=-1)


def gumbel_softmax(logits, temperature, hard=False):
    y = gumbel_softmax_sample(logits, temperature)
    
    if not hard:
        return y.view(-1, latent_dim * categorical_dim)
    
    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    # skip the gradient of y_hard
    y_hard = (y_hard - y).detach() + y 
    return y_hard.view(-1, latent_dim * categorical_dim)



## Gumbel-softmax alternativo

In [None]:
def Gumbel_softmax(logits, tau, hard=False):
    U = torch.rand_like(logits)
    G = -torch.log(-torch.log(U + 1e-20) + 1e-20)
    y = F.softmax((logits + G) / tau, dim=-1)

    if hard:
        y_hard = torch.zeros_like(y)
        y_hard.scatter_(-1, y.argmax(dim=-1, keepdim=True), 1.0)
        y = (y_hard - y).detach() + y  # straight-through estimator

    return y


# Class VAE

In [None]:
class VAE_model(nn.Module):
    def __init__(self):
        super(VAE_model, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, latent_dim * categorical_dim)
        self.fc4 = nn.Linear(latent_dim * categorical_dim, 256)
        self.fc5 = nn.Linear(256, 512)
        self.fc6 = nn.Linear(512, 784)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def sample_img(self, img, temp, random=True):
        with torch.no_grad():
            logits_z = self.encode(img.view(-1, 784))
            logits_z = logits_z.view(-1, latent_dim, categorical_dim)
            if random:
                latent_z = gumbel_softmax(logits_z, temp, True)
            else:
                latent_z = logits_z.view(-1, latent_dim * categorical_dim)
            logits_x = self.decode(latent_z)
            dist_x = torch.distributions.Bernoulli(probs=logits_x)
            sampled_img = dist_x.sample()
        return sampled_img

    def encode(self, x):
        h1 = self.relu(self.fc1(x))
        h2 = self.relu(self.fc2(h1))
        return self.relu(self.fc3(h2))

    def decode(self, z):
        h4 = self.relu(self.fc4(z))
        h5 = self.relu(self.fc5(h4))
        return self.sigmoid(self.fc6(h5))

    def forward(self, data, temp, hard):
        logits_z = self.encode(data.view(-1, 784))
        logits_z = logits_z.view(-1, latent_dim, categorical_dim)

        probs_z = F.softmax(logits_z, dim=-1)
        posterior_distrib = torch.distributions.Categorical(probs=probs_z)
        probs_prior = torch.ones_like(logits_z)/categorical_dim
        prior_distrib = torch.distributions.Categorical(probs=probs_prior)

        latent_z = gumbel_softmax(logits_z, temp)
        latent_z = latent_z.view(-1, latent_dim * categorical_dim)

        probs_x = self.decode(latent_z)
        dist_x = torch.distributions.Bernoulli(probs=probs_x, validate_args=False)

        rec_loss = dist_x.log_prob(data.view(-1, 784)).sum(dim=-1)
        logits_z_log = F.log_softmax(logits_z, dim=-1)

        KL = (posterior_distrib.probs * (logits_z_log - prior_distrib.probs.log())).view(-1, latent_dim * categorical_dim).sum(dim=-1)
        elbo = rec_loss - KL
        loss = -elbo.mean()
        return loss, KL.mean(), rec_loss.mean()


# Train function

In [None]:
def train(model, optimizer, epochs):
    global_batch_idx = 0
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        temp = temperature
        train_KL = 0

        for batch_idx, (data, _) in enumerate(train_loader):
            global_batch_idx += 1
            # Sposta i dati sul device corretto
            data = data.to(device)
            optimizer.zero_grad()
            loss, KL, rec_loss = model(data, temp, hard)
            loss.backward()
            train_loss += loss.item() * len(data)
            optimizer.step()

            if batch_idx % 100 == 1:
                temp = np.maximum(temp * np.exp(-ANNEAL_RATE * batch_idx), temp_min)

            if global_batch_idx % log_interval_writer == 0:
                writer.add_scalar('KL/Train', KL, global_step=global_batch_idx)
                writer.add_scalar('rec_loss/Train', rec_loss, global_step=global_batch_idx)


        writer.add_scalar('Loss/Train', train_loss/len(train_loader.dataset), global_step=epoch)


        print('Epoch: {}/{}, Average loss: {:.4f}, Average KL: '.format(
            epoch, epochs, train_loss / len(train_loader.dataset)))
        
        # Validation
        
        model.eval()
        val_loss_sum = 0.0
        with torch.no_grad():
            for batch_idx, (data, _) in enumerate(val_loader):
                data = data.to(device)
                loss, KL, rec_loss = model(data, temp, hard=True)
                val_loss_sum += loss.item() * len(data)

        writer.add_scalar('Loss/Validation', val_loss_sum/len(val_loader.dataset), global_step=epoch)

        # Log histogram of weights and gradients
        for name, param in model.named_parameters():
            writer.add_histogram(f'Weights/{name}', param, global_step=epoch)
            if param.grad is not None:
                writer.add_histogram(f'Grads/{name}', param.grad, global_step=epoch)

    writer.close()
    print("Training completato e dati scritti su tensorboard")


## Train loop

In [None]:
writer = SummaryWriter(log_dir='runs/discrete_VAE_Categorical/_1')

In [None]:

my_model = VAE_model().to(device)
optimizer = optim.Adam(my_model.parameters(), lr=1e-3)
train(my_model, optimizer, epochs=15)



====> Epoch: 0 Average loss: 195.8797
====> Epoch: 1 Average loss: 161.6164
====> Epoch: 2 Average loss: 145.2841
====> Epoch: 3 Average loss: 137.5701
====> Epoch: 4 Average loss: 133.0866
====> Epoch: 5 Average loss: 129.0748
====> Epoch: 6 Average loss: 126.0739
====> Epoch: 7 Average loss: 123.8237
====> Epoch: 8 Average loss: 121.8725
====> Epoch: 9 Average loss: 120.0578
====> Epoch: 10 Average loss: 118.0473
====> Epoch: 11 Average loss: 116.4868
====> Epoch: 12 Average loss: 115.2364
====> Epoch: 13 Average loss: 114.0681
====> Epoch: 14 Average loss: 112.9505
Training completato e dati scritti su tensorboard
