<a href="https://colab.research.google.com/github/MihaiDogariu/CV3/blob/main/colocviu/Colocviu_S4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Subiectul 4 - generarea cu VAE
Modificati codul de mai jos astfel incat la final sa se afiseze pe ecran trecerea de la cifra 2 la cifra 4 in 20 de pasi. Imaginile vor fi afisate pe 2 randuri a cate 10 coloane.

In [None]:
import sys

%matplotlib inline
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

latent_dims = 10
num_epochs = 10
batch_size = 128
capacity = 64
learning_rate = 1e-3
variational_beta = 1

In [None]:
# Descarcarea bazei de date MNIST Digits
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

img_transform = transforms.Compose([transforms.ToTensor()])

train_dataset = MNIST(root='./data/MNIST', download=True, train=True, transform=img_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = MNIST(root='./data/MNIST', download=True, train=False, transform=img_transform)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# Definirea modelului VAE

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        c = capacity
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=c, kernel_size=4, stride=2, padding=1) # out: c x 14 x 14
        self.conv2 = nn.Conv2d(in_channels=c, out_channels=c*2, kernel_size=4, stride=2, padding=1) # out: c x 7 x 7
        self.fc_mu = nn.Linear(in_features=c*2*7*7, out_features=latent_dims) # mu reprezinta vectorul mediilor
        self.fc_logvar = nn.Linear(in_features=c*2*7*7, out_features=latent_dims) # logvar reprezinta vectorul 
            
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1) # flatten
        x_mu = self.fc_mu(x)
        x_logvar = self.fc_logvar(x)
        return x_mu, x_logvar

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        c = capacity
        self.fc = nn.Linear(in_features=latent_dims, out_features=c*2*7*7)
        self.conv2 = nn.ConvTranspose2d(in_channels=c*2, out_channels=c, kernel_size=4, stride=2, padding=1)
        self.conv1 = nn.ConvTranspose2d(in_channels=c, out_channels=1, kernel_size=4, stride=2, padding=1)
            
    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), capacity*2, 7, 7) # unflatten
        x = F.relu(self.conv2(x))
        x = torch.sigmoid(self.conv1(x)) # se foloseste sigmoid datorita functiei de reconstructie BCE
        return x
    
class VariationalAutoencoder(nn.Module):
    def __init__(self):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
    
    def forward(self, x):
        latent_mu, latent_logvar = self.encoder(x)
        latent = self.latent_sample(latent_mu, latent_logvar)
        x_reconstruit = self.decoder(latent)
        return x_reconstruit, latent_mu, latent_logvar
    
    def latent_sample(self, mu, logvar):
        if self.training:
            # reparametrizarea
            std = logvar.mul(0.5).exp_()
            eps = torch.empty_like(std).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu
    
def vae_loss(x_reconstruit, x, mu, logvar):
    loss_reconstructie = F.binary_cross_entropy(x_reconstruit.view(-1, 784), x.view(-1, 784), reduction='sum')
    loss_similaritate = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # divergenta Kullback-Leibler
    return loss_reconstructie + variational_beta * loss_similaritate # variational_beta este un parametru care controleaza aportul celor 2 componente la functia de pierdere finala

# Antrenarea VAE

In [None]:
vae = VariationalAutoencoder()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
vae = vae.to(device)

optimizer = torch.optim.Adam(params=vae.parameters(), lr=learning_rate, weight_decay=1e-5)

# set to training mode
vae.train()

train_loss_avg = []

print('Training ...')
for epoch in range(num_epochs):
    train_loss_avg.append(0)
    num_batches = 0
    
    for image_batch, _ in train_dataloader:
        
        image_batch = image_batch.to(device)

        # reconstructia x -> codor -> h -> decodor -> x_reconstruit
        image_batch_reconstruit, latent_mu, latent_logvar = vae(image_batch)
        
        # eroarea variationala (Evidence Lower Bound - ELBO)
        loss = vae_loss(image_batch_reconstruit, image_batch, latent_mu, latent_logvar)
        
        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        
        # actualizarea ponderilor retelei
        optimizer.step()
        
        # acumularea erorii la fiecare iteratie
        train_loss_avg[-1] += loss.item()
        num_batches += 1
        
    # medierea erorii pe intreaga epoca
    train_loss_avg[-1] /= num_batches
    print('Epoch [%d / %d] average reconstruction error: %f' % (epoch+1, num_epochs, train_loss_avg[-1]))

# Interpolare in spațiul latent

In [None]:
vae.eval()

def to_img(x):
    x = x.clamp(0, 1)
    return x

def interpolation(lambda1, model, img1, img2):
    
    with torch.no_grad():
    
        img1 = img1.to(device)
        latent_1, _ = model.encoder(img1)

        img2 = img2.to(device)
        latent_2, _ = model.encoder(img2)

        inter_latent = lambda1* latent_1 + (1- lambda1) * latent_2

        inter_image = model.decoder(inter_latent)
        inter_image = inter_image.cpu()

        return inter_image
    
digits = [[] for _ in range(10)]
for img_batch, label_batch in test_dataloader:
    for i in range(img_batch.size(0)):
        digits[label_batch[i]].append(img_batch[i:i+1])
    if sum(len(d) for d in digits) >= 1000:
        break;

lambda_range=np.linspace(0,1,16)

fig, axs = plt.subplots(4,4, figsize=(20, 6))
fig.subplots_adjust(hspace = .5, wspace=.001)
axs = axs.ravel()

for ind,l in enumerate(lambda_range):
    inter_image=interpolation(float(l), vae, digits[0][0], digits[0][0])
   
    inter_image = to_img(inter_image)
    
    image = inter_image.numpy()
   
    axs[ind].imshow(image[0,0,:,:], cmap='gray')
plt.show() 