<a href="https://colab.research.google.com/github/avrymi-asraf/AML/blob/main/Ex1/AML_ex1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.utils import data

import torchvision
import torchvision.transforms as transforms

import pandas as pd
import time
from tqdm import tqdm
import random

import plotly.express as px
import plotly.graph_objects as go

from IPython.display import clear_output


from typing import Tuple
device = "cuda" if torch.cuda.is_available() else "cpu"

# Tools and Calsses

In [None]:
# @title Tools code

class DataSetWithIndices(data.Dataset):
    def __init__(self,dataset):
        self.dataset = dataset
    def __getitem__(self, index):
        data, target = self.dataset[index]
        return data, target, index
    def __len__(self):
        return len(self.dataset)


def import_MNIST_dataset(with_index=False,batch_size=64,test=True,amount=None):
    """
    Downloads the MNIST dataset and loads it into DataLoader objects for training and testing.

    The MNIST dataset consists of 60,000 training images and 10,000 testing images of handwritten digits.
    The images are normalized to have pixel values between -1 and 1.

    :return: A tuple containing the training DataLoader and the testing DataLoader.
    """
    # Define a transform to normalize the data
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
    )

    # Download and load the training dataset
    trainset = torchvision.datasets.MNIST(
        root="./data", train=True, download=True, transform=transform
    )
    if amount:
        subset_indices = list(range(amount))
        trainset = data.Subset(trainset, subset_indices)
    if with_index:
        trainset = DataSetWithIndices(trainset)

    train_loader = data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    if not test:
        return train_loader

    # Download and load the testing dataset
    testset = torchvision.datasets.MNIST(
        root="./data", train=False, download=True, transform=transform
    )
    test_loader = data.DataLoader(testset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader



def import_MNIST_examples(mnist:data.DataLoader,with_index=False):
    re = torch.empty(10,28,28)
    indices = torch.empty(10,dtype=torch.long)
    for i in range(10):
        run_ind = 0
        while(mnist.dataset[run_ind][1]!=i):
            run_ind+=1
        re[i]=mnist.dataset[run_ind][0]
        indices[i] = run_ind
    if not with_index:
        return re.unsqueeze(1)
    return re.unsqueeze(1), indices



def import_set_examples():
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
    )

    # Download and load the training dataset
    train_set = torchvision.datasets.MNIST(
        root="./data", train=True, download=True, transform=transform
    )
    test_set = torchvision.datasets.MNIST(
        root="./data",train=False,download=True,transform=transform
    )
    re_test = torch.empty(10,5,1,28,28) #num,examples,channels,h,w
    re_train = torch.empty(10,5,1,28,28) #num,examples,channels,h,w
    for num in range(10):
        for i in range(5):
            ind = random.randint(0,len(train_set)-1)
            while(train_set[ind][1]!=num):
                ind = random.randint(0,len(train_set)-1)
            re_test[num][i] = train_set[ind][0]

            ind = random.randint(0,len(test_set)-1)
            while(test_set[ind][1]!=num):
                ind = random.randint(0,len(test_set)-1)
            re_train[num][i] = test_set[ind][0]

    return re_test, re_train


def train_model(model:nn.Module,data_loader,epochs=30,lr=1e-3,device='cpu'):
    optimazer = optim.Adam(model.parameters(),lr=lr)
    loss_func = vae_loss
    model = model.to(device)
    model.train()
    for epoch in tqdm(range(epochs)):
        for x, _ in data_loader:
            x = x.to(device)
            recon_x, mu, logvar = model(x)
            loss = loss_func(recon_x, x, mu, logvar)
            loss.backward()
            optimazer.step()
            optimazer.zero_grad()



def min_max_normailze(x):
    return (x - x.min()) / (x.max() - x.min())

In [None]:
# @title ConvVAEamortized model


class ConvVAEamortized(nn.Module):
    def __init__(self, latent_dim: int = 200):
        """
        Initialize the ConvVAEamortized model.

        Args:
            latent_dim (int): Dimension of the latent space. Default is 200.
        """
        super(ConvVAEamortized, self).__init__()

        self.latent_dim = latent_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(
                1, 32, kernel_size=3, stride=2, padding=1
            ),  # (batch_size, 32, 14, 14)
            nn.ReLU(),
            nn.Conv2d(
                32, 64, kernel_size=3, stride=2, padding=1
            ),  # (batch_size, 64, 7, 7)
            nn.ReLU(),
            nn.Conv2d(
                64, 128, kernel_size=3, stride=2, padding=1
            ),  # (batch_size, 128, 4, 4)
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=2),  # (batch_size, 512, 1, 1)
        )

        # Latent space
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_logvar = nn.Linear(128, latent_dim)
        self.fc_decode = nn.Linear(latent_dim, 128)

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 128, kernel_size=2),  # (batch_size, 128, 2, 2)
            nn.ReLU(),
            nn.ConvTranspose2d(
                128, 128, kernel_size=3, stride=2, padding=1, output_padding=1
            ),  # (batch_size, 128, 4, 4)
            nn.ReLU(),
            nn.ConvTranspose2d(
                128, 64, kernel_size=3, stride=2, padding=1
            ),  # (batch_size, 64, 7, 7)
            nn.ReLU(),
            nn.ConvTranspose2d(
                64, 32, kernel_size=3, stride=2, padding=1, output_padding=1
            ),  # (batch_size, 32, 14, 14)
            nn.ReLU(),
            nn.ConvTranspose2d(
                32, 1, kernel_size=3, stride=2, padding=1, output_padding=1
            ),  # (batch_size, 1, 28, 28)
        )

    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        """
        Perform the reparameterization trick.

        Args:
            mu (torch.Tensor): Mean of the latent Gaussian. Shape: (batch_size, latent_dim)
            logvar (torch.Tensor): Log variance of the latent Gaussian. Shape: (batch_size, latent_dim)

        Returns:
            torch.Tensor: Sampled latent vector. Shape: (batch_size, latent_dim)
        """
        device = next(self.parameters()).device
        var = torch.exp(logvar * 0.5)
        return torch.randn_like(mu).to(device) * var + mu

    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Encode the input image into the latent space.

        Args:
            x (torch.Tensor): Input image. Shape: (batch_size, 1, 28, 28)

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Mean and log variance of the latent Gaussian.
                                               Each shape: (batch_size, latent_dim)
        """
        x = self.encoder(x)
        # add average pooling
        x = F.adaptive_avg_pool2d(x, 1)
        x = x.view(x.size(0), -1)  # Flatten
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """
        Decode the latent vector into an image.

        Args:
            z (torch.Tensor): Latent vector. Shape: (batch_size, latent_dim)

        Returns:
            torch.Tensor: Reconstructed image. Shape: (batch_size, 1, 28, 28)
        """
        z = self.fc_decode(z)
        z = z.view(z.size(0), 128, 1, 1)
        z = self.decoder(z)
        return z

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Forward pass through the VAE.

        Args:
            x (torch.Tensor): Input image. Shape: (batch_size, 1, 28, 28)

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Reconstructed image, mean, and log variance.
                                                             Shapes: (batch_size, 1, 28, 28), (batch_size, latent_dim), (batch_size, latent_dim)
        """

        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar

In [None]:
# @title ConvVAElo model


class ConvVAElo(nn.Module):
    def __init__(self, num_train_examples: int, latent_dim: int = 200):
        """
        Initialize the ConvVAElo model.

        Args:
            num_train_examples (int): Number of training examples.
            latent_dim (int): Dimension of the latent space. Default is 200.
        """
        super(ConvVAElo, self).__init__()

        self.latent_dim = latent_dim
        self.mus = nn.Parameter(torch.randn(num_train_examples,latent_dim,requires_grad=True))
        self.logvars =nn.Parameter(torch.randn(num_train_examples,latent_dim,requires_grad=True))

        # Latent space
        self.fc_decode = nn.Linear(latent_dim, 128)

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 128, kernel_size=2),  # (batch_size, 128, 2, 2)
            nn.ReLU(),
            nn.ConvTranspose2d(
                128, 128, kernel_size=3, stride=2, padding=1, output_padding=1
            ),  # (batch_size, 128, 4, 4)
            nn.ReLU(),
            nn.ConvTranspose2d(
                128, 64, kernel_size=3, stride=2, padding=1
            ),  # (batch_size, 64, 7, 7)
            nn.ReLU(),
            nn.ConvTranspose2d(
                64, 32, kernel_size=3, stride=2, padding=1, output_padding=1
            ),  # (batch_size, 32, 14, 14)
            nn.ReLU(),
            nn.ConvTranspose2d(
                32, 1, kernel_size=3, stride=2, padding=1, output_padding=1
            ),  # (batch_size, 1, 28, 28)
        )

    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        """
        Perform the reparameterization trick.

        Args:
            mu (torch.Tensor): Mean of the latent Gaussian. Shape: (batch_size, latent_dim)
            logvar (torch.Tensor): Log variance of the latent Gaussian. Shape: (batch_size, latent_dim)

        Returns:
            torch.Tensor: Sampled latent vector. Shape: (batch_size, latent_dim)
        """
        device = next(self.parameters()).device
        var = torch.exp(logvar * 0.5)
        return torch.randn_like(mu).to(device) * var + mu


    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """
        Decode the latent vector into an image.

        Args:
            z (torch.Tensor): Latent vector. Shape: (batch_size, latent_dim)

        Returns:
            torch.Tensor: Reconstructed image. Shape: (batch_size, 1, 28, 28)
        """
        z = self.fc_decode(z)
        z = z.view(z.size(0), 128, 1, 1)
        z = self.decoder(z)
        return z

    def forward(self, x: torch.Tensor, indices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Forward pass through the VAE.

        Args:
            x (torch.Tensor): Input image. Shape: (batch_size, 1, 28, 28)
            indices (torch.Tensor): Indices of the input images. Shape: (batch_size,)

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Reconstructed image, mean, and log variance.
                                                             Shapes: (batch_size, 1, 28, 28), (batch_size, latent_dim), (batch_size, latent_dim)
        """
        mu = self.mus[indices]
        logvar = self.logvars[indices]
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar

In [None]:
# @title Loss functions
def vae_loss(input, target, mu, logvar):
    input, target = input.reshape(input.size(0), -1), target.reshape(target.size(0), -1)
    std = torch.sqrt(torch.exp(logvar))
    kld = torch.mean(mu.pow(2) + std.pow(2) - torch.log(std) - 1,dim=1)
    return torch.mean(F.mse_loss(input, target) + kld)


# Q1: Amortized VAE.
Train an amortized VAE on your MNIST subset, for 30 epochs. Plot the loss values after each epoch. Additionally, choose 10 random validation images (one from each class) and plot them and their reconstructions at epochs 1, 5, 10, 20 and 30. Do the same for 10 random images from the training set. Did the
Auto-Encoder overfit the training data? Explain.


In [None]:
data_loader = import_MNIST_dataset(test=False,amount=20000)
examples = import_MNIST_examples(data_loader).to(device)

In [None]:
epochs = 30
lr = 1e-3

In [None]:
model = ConvVAEamortized().to(device)
optimazer = optim.Adam(model.parameters(),lr=lr)

In [None]:
# @title run Q1 & Q2

record_data_ma = pd.DataFrame({"epoch_loss": None}, index=range(epochs))
reconstruct_images_ma = {"pre_train": model(examples)[0].detach().cpu(),"source":examples.detach().cpu()}
prior_dist_examples_ma = {}



for epoch in range(epochs):
    model.train()
    loss_epoch = 0.0
    for x, _ in tqdm(data_loader):

        x = x.to(device)
        recon_x, mu, logvar = model(x)
        loss = vae_loss(recon_x, x, mu, logvar)

        loss.backward()
        optimazer.step()
        optimazer.zero_grad()

        loss_epoch += loss.item()
        record_data_ma.iloc[epoch] = [loss_epoch]

    clear_output(wait=True)
    px.line(record_data_ma).show()
    if (epoch+1) % 5 == 0 or epoch==0:
        model.eval()
        letant_examples = torch.randn(10,200).to(device)
        prior_dist_examples_ma[epoch+1] = model.decode(letant_examples).detach().cpu()
        px.imshow(prior_dist_examples_ma[epoch+1].squeeze(1),facet_col=0).show()

        reconstruct_images_ma[epoch+1] = model(examples)[0].detach().cpu()
        px.imshow(reconstruct_images_ma[epoch+1].squeeze(1),facet_col=0).show()
    print(f"epoch {epoch+1}, loss: {loss_epoch:.5f}")
torch.save(model.state_dict(),"ConvVAEamortized.pth")
clear_output(wait=True)

In [None]:
# @title Show examples { run: "auto", display-mode: "form"}
reconstruct_images_ma[0] = reconstruct_images_ma[1] # only for indexes
epoch = 30 # @param {type:"slider", min:0, max:30, step:5}
px.imshow(reconstruct_images_ma["source"].squeeze(1),facet_col=0).show()
px.imshow(reconstruct_images_ma[epoch].squeeze(1),facet_col=0).show()

<div dir="rtl" lang="he" xml:lang="he">

## האם יש overfit באימון של VAE?
קשה לנסח
overfit
ברשת יוצרת.
המשמעות היחידה יכולה להיות, שהדוגמאות יהיו מתוך סט מצומצם של דוגמאות
(כאלה שהמודל רואה באימון)
ותהליך היצירה לא ישקף את התמונות האמיתיות.

#  Q2: Sampling from a VAE.
Sample 10 latent variables from your prior distribution, pass them in the generators
from epochs 1, 5, 10, 20 and 30. Plot the generations from each epoch, and observe how the generator changed
over-time (No explanation needed).

In [None]:
# @title Show examples { run: "auto", display-mode: "form"}
prior_dist_examples_ma[0] = prior_dist_examples_ma[1] # only for indexes
epoch = 30 # @param {type:"slider", min:0, max:30, step:5}
px.imshow(prior_dist_examples_ma[epoch].squeeze(1),facet_col=0).show()


# Q3: Latent Optimization.
Train a generator by Variational Inference, using Latent Optimization for optimizing the q vectors instead of a shared encoder. Initialize the q vectors by sampling from a gaussian distribution of
q ∼ N (0, I). This will be our prior distribution for this experiment. Use the same dimensions for q as in Q1 and
Q2.

In [None]:
data_set_size = 20000
data_loader_with_indices = import_MNIST_dataset(with_index=True,test=False,amount=data_set_size)
examples = import_MNIST_examples(data_loader_with_indices,with_index=True)

In [None]:
epochs = 30
lr = 1e-3

In [None]:
model = ConvVAElo(data_set_size).to(device)
# chatgpt help me :)
optimazer = optim.Adam([
    {'params': model.mus, 'lr': 0.01},
    {'params': model.logvars, 'lr': 0.01},
    {'params': model.fc_decode.parameters()},
    {'params': model.decoder.parameters()}
], lr=0.0001)

In [None]:
# @title run Q3

record_data_lo = pd.DataFrame({"epoch_loss": None}, index=range(epochs))
reconstruct_images_lo = {"pre_train": model(*examples)[0].detach().cpu(),"source":examples[0].detach().cpu()}
prior_dist_examples_lo = {}


for epoch in range(epochs):
    model.train()
    loss_epoch = 0.0
    for x, _,indices in tqdm(data_loader_with_indices):

        x ,indices= x.to(device),indices.to(device)
        recon_x, mu, logvar = model(x,indices)
        loss = vae_loss(recon_x, x, mu, logvar)

        loss.backward()
        optimazer.step()
        optimazer.zero_grad()

        loss_epoch += loss.item()
        record_data_ma.iloc[epoch] = [loss_epoch]

    clear_output(wait=True)
    px.line(record_data_ma).show()
    if (epoch+1) % 5 == 0 or epoch==0:
        model.eval()
        letant_examples = torch.randn(10,200).to(device)
        prior_dist_examples_lo[epoch+1] = model.decode(letant_examples).detach().cpu()
        px.imshow(prior_dist_examples_lo[epoch+1].squeeze(1),facet_col=0).show()

        reconstruct_images_lo[epoch+1] = model(*examples)[0].detach().cpu()
        px.imshow(reconstruct_images_lo[epoch+1].squeeze(1),facet_col=0).show()
    print(f"epoch {epoch+1}, loss: {loss_epoch:.5f}")


## (a)
Plot the reconstructions of 10 images (one from each class) from the training set, at epochs 1, 5, 10, 20 and 30.
Compare these reconstructions to the ones from Q1. Which method proposed better q vectors? Explain.

In [None]:
# @title Show examples { run: "auto", display-mode: "form"}
epoch = 30 # @param {type:"slider", min:0, max:30, step:5}
reconstruct_images_lo[0] = reconstruct_images_lo[1] # only for indexes
px.imshow(reconstruct_images_lo["source"].squeeze(1),facet_col=0).show()
px.imshow(reconstruct_images_lo[epoch].squeeze(1),facet_col=0).show()

## (b)
Sample from your new model, by inputting it 10 latent vectors sampled from the prior distribution. Compare
these to the samples from Q2. Was our initialization sufficient to establish a good prior distribution for this
problem? Explain.

In [None]:
# @title Show examples { run: "auto", display-mode: "form"}
prior_dist_examples_lo[0] = prior_dist_examples_lo[1] # only for indexes
epoch = 0 # @param {type:"slider", min:0, max:30, step:5}
px.imshow(prior_dist_examples_ma[epoch].squeeze(1),facet_col=0).show()

# Q4: Computing the log-probability of an image.

For each digit (0 − 9) sample 10 images: 5 images from the
training set and 5 from the test set. Compute the log-probability of each image as described in Eq. 9.


In [None]:
# I helpde chatGpt
import torch
from torch.distributions.multivariate_normal import MultivariateNormal

def estimate_log_probability(x, model, M=1000,sigma_p = 0.4):
    """
    Estimates the log probability of an input image under the trained VAE model.

    Args:
    - x (torch.Tensor): Input image tensor of shape (1, 1, 28, 28)
    - model (nn.Module): Trained VAE model
    - M (int): Number of Monte Carlo samples

    Returns:
    - torch.Tensor: Estimated log probability
    """
    device = next(model.parameters()).device
    x = x.to(device)


    mu, logvar = model.encode(x.unsqueeze(0)) #mu and var that that define the q_z
    std = torch.exp(0.5 * logvar)


    x = x.flatten()
    # distributions
    q_z = MultivariateNormal(mu, torch.diag(std.squeeze() ** 2))
    p_z = MultivariateNormal(torch.zeros_like(mu), torch.eye(mu.shape[1]).to(device))

    # sample z
    z = q_z.rsample((M,))

    log_p_z = p_z.log_prob(z).squeeze(1)
    log_q_z = q_z.log_prob(z).squeeze(1)

    # log p(x|z)
    x_hat = model.decode(z).view(M,-1) # the new mu for x|z
    log_d = torch.tensor(2 * torch.pi * sigma_p ** 2).log()
    d = x.flatten().size(0)
    log_p_x_given_z = -0.5 * torch.sum((x - x_hat) ** 2 / (sigma_p ** 2),dim=1) \
                        - 0.5 * d * log_d

    # Compute importance weights
    log_w = log_p_z + log_p_x_given_z - log_q_z

    # Estimate log probability using logsumexp for numerical stability
    log_p_x = torch.logsumexp(log_w,dim=0) - torch.log(torch.tensor(M, dtype=torch.float32, device=device))

    return log_p_x.item()
probabilitis = pd.DataFrame(columns=[str(i) for i in range(10)],index=range(5))

In [None]:
test_examples, train_examples = import_set_examples()
model = ConvVAEamortized().to(device)
model.load_state_dict(torch.load("/content/ConvVAEamortized.pth"))
model.eval()

In [None]:
#helpde by chatGpt
index = pd.MultiIndex.from_product([list(range(10)),["train","test"]],names=["digit","data"])
probabilitis = pd.DataFrame(columns=[i for i in range(5)],index=index)
for num in range(10):
    for i in range(5):
        probabilitis.loc[(num,"test"),i] = estimate_log_probability(test_examples[num][i],model)
        probabilitis.loc[(num,"train"),i] = estimate_log_probability(train_examples[num][i],model)
# probabilitis.style.format("{:.3f}")

## (a) Plot a single image from each digit, with its log-probability.

In [None]:
imgae_list = torch.stack([train_examples[i][0] for i in range(10)])
fig = px.imshow(imgae_list.squeeze(1),facet_col=0,color_continuous_scale='gray')
fig.update_layout(coloraxis_showscale=False)
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)
for i in range(10):
    fig.layout.annotations[i]['text'] = f'{probabilitis.loc[(i,"test"),0]:.3f}'
fig.show()

## (b) Present the average log-probability per digit.
Which digit is the most likely? Why do you think that is the case?

In [None]:
mean = probabilitis.groupby("digit").mean().mean(axis=1)
min_max_normailze(mean).plot(kind="bar")
mean


## (c) Present the average log-probability of the images from the (i) training set (ii) test set.
Are images from the
training set more or less likely? Explain your answer

In [None]:
mean = probabilitis.groupby("data").mean().mean(axis=1)
mean.plot(kind="bar")
mean
