<center><img src='https://drive.google.com/uc?id=1_utx_ZGclmCwNttSe40kYA6VHzNocdET' height="60"></center>

AI TECH - Akademia Innowacyjnych Zastosowań Technologii Cyfrowych. Program Operacyjny Polska Cyfrowa na lata 2014-2020
<hr>

<center><img src='https://drive.google.com/uc?id=1BXZ0u3562N_MqCLcekI-Ens77Kk4LpPm'></center>

<center>
Projekt współfinansowany ze środków Unii Europejskiej w ramach Europejskiego Funduszu Rozwoju Regionalnego
Program Operacyjny Polska Cyfrowa na lata 2014-2020,
Oś Priorytetowa nr 3 "Cyfrowe kompetencje społeczeństwa" Działanie  nr 3.2 "Innowacyjne rozwiązania na rzecz aktywizacji cyfrowej"
Tytuł projektu:  „Akademia Innowacyjnych Zastosowań Technologii Cyfrowych (AI Tech)”
    </center>

# Latent Space Classifier


In this task, you will:
* train a Variational AutoEncoder on MNIST (the code is already prepared and ready)
* train a digit classifier on the latent space of the Variational AutoEncoder

## Variational AutoEncoder - review
Below is a quick reminder on the Variational AutoEncoder:

* Let $P^*$ be the true data distribution. We have some samples from this.
* Let $p(z)$ be a *prior* distribution over the latent space. In our model, it is a multivariate Gaussian distribution $N(0,\mathbb{I})$.
* Let $E(x)$ be the encoder that accepts data points as input and outputs distributions over the latent space $Z$. The produced distribution is denoted $q_\phi(z|x)$ and is the (approximate) *posterior* distribution. In our model, this is multivariate Gaussian distribution $q_\phi(z|x) \sim N(\mu, diag(\sigma^2))$, where:
    1. $\phi$ are weights of the encoder network.
    2. The Encoder network accepts data points as input and outputs $\mu$ and $\sigma$, which are vectors of the same length as latent space. They are used to construct the approximate posterior distribution $q_\phi(z|x)$.
* Let $D(z)$ be the decoder that accepts samples from the latent distribution and output parameters of the likelihood distribution $p_\theta(x|z)$. In our model, this is the Bernoulli trial per each pixel $p_\theta(x|z_0) \sim Bern(p)$, where:
    1. $\theta$ are weights of the decoder network.
    2. The decoder network accepts a sample from the posterior distribution $q_\phi(z|x)$ and outputs p, which is a matrix of the shape of the input image. Each value of the matrix is the parameter $\pi$ of the Bernoulli trial $Bern(\pi)$ for the corresponding pixel.
    3. Data points are clipped to only contain values 0 and 1 so that the model can be trained in the given setup.

The Variational AutoEncoder works by maximizing the Evidence Lower Bound (ELBO):

$$ELBO = \mathbb{E}_{z \sim q(z|x)} \big[\log p_\theta(x|z)\big] - \mathbb{KL}\big(q_\phi(z | x) || p(z)\big).$$

Where the first term of the loss is trained via stochastic gradient descent. Whereas, the second term can be calculated analytically in our setup and is equal to the following:

$$ \mathbb{KL}\big( \mathcal{N}(\mu, \sigma^2) || \mathcal{N}(0, 1) \big) = \frac12 \big(\sigma^2 - \log(\sigma^2) + \mu^2 - 1 \big).$$

You do not need to use the formulas above, as the Variational AutoEncoder is already implemented below.

## Variational AutoEncoder - code
The code for VAE is already completed and attached below. Run the code to train the VAE.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms # type: ignore
from torch.utils.data import DataLoader

from collections import namedtuple
import numpy as np
import matplotlib.pyplot as plt

In [2]:
batch_size = 1024
test_batch_size = 1000
epochs = 5
lr = 5e-3
seed = 1
log_interval = 5
latent_size = 10

In [3]:
use_cuda = torch.cuda.is_available()
torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")

train_kwargs = {'batch_size': batch_size}
test_kwargs = {'batch_size': test_batch_size}
if use_cuda:
    cuda_kwargs = {'num_workers': 1,
                    'pin_memory': True,
                    'shuffle': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

In [5]:
class Binarize:
    def __call__(self, sample: torch.Tensor) -> torch.Tensor:
        return torch.bernoulli(sample)


transform = transforms.Compose([transforms.ToTensor(), Binarize()])
dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
dataset2 = datasets.MNIST("../data", train=False, transform=transform)
train_loader = DataLoader(dataset1, **train_kwargs)
test_loader = DataLoader(dataset2, **test_kwargs)

EncoderOutput = namedtuple("EncoderOutput", ["mu", "sigma"])


class Encoder(nn.Module):
    def __init__(self, linear_sizes: list[int], latent_size: int):
        super().__init__()
        self.layers = nn.ModuleList()
        for in_layer_size, out_layer_size in zip(linear_sizes, linear_sizes[1:]):
            self.layers.append(nn.Linear(in_layer_size, out_layer_size))
            self.layers.append(nn.BatchNorm1d(out_layer_size))
            self.layers.append(nn.ReLU())

        self.last_layer_mu = nn.Linear(linear_sizes[-1], latent_size)
        self.last_layer_sigma = nn.Linear(linear_sizes[-1], latent_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = nn.Flatten()(x)
        for layer in self.layers:
            x = layer(x)

        mu = self.last_layer_mu(x)
        logsigma = self.last_layer_sigma(x)
        return EncoderOutput(mu, torch.log(1 + torch.exp(logsigma)))


class Decoder(nn.Module):
    def __init__(self, linear_sizes: list[int], output_size: tuple[int]):
        super().__init__()
        self.layers = nn.ModuleList()
        for in_layer_size, out_layer_size in zip(linear_sizes, linear_sizes[1:]):
            self.layers.append(nn.Linear(in_layer_size, out_layer_size))
            self.layers.append(nn.BatchNorm1d(out_layer_size))
            self.layers.append(nn.ReLU())

        self.last_layer = nn.Sequential(
            nn.Linear(linear_sizes[-1], output_size[0] * output_size[1]), nn.Sigmoid()
        )
        self.output_size = output_size

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            z = layer(z)

        x = self.last_layer(z)

        x = x.view(-1, 1, *self.output_size)
        return x


VariationalAutoEncoderOutput = namedtuple(
    "VariationalAutoEncoderOutput", ["mu", "sigma", "p"]
)


class VariationalAutoEncoder(nn.Module):
    def __init__(
        self,
        encoder_linear_sizes: list[int],
        latent_size: int,
        decoder_linear_sizes: list[int],
        output_size: tuple[int],
    ):
        super().__init__()
        self.encoder = Encoder(encoder_linear_sizes, latent_size)
        self.decoder = Decoder(decoder_linear_sizes, output_size)
        self.latent_size = latent_size
        self.output_size = output_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        encoded = self.encoder(x)

        z = torch.normal(0.0, 1.0, size=list(encoded.mu.size())).to(device)
        z = (z * encoded.sigma) + encoded.mu

        decoded = self.decoder(z)
        return VariationalAutoEncoderOutput(encoded.mu, encoded.sigma, decoded)

    def sample_latent(self, x: torch.Tensor) -> torch.Tensor:
        encoded = self.encoder(x)
        z = torch.normal(0.0, 1.0, size=list(encoded.mu.size())).to(device)
        z = (z * encoded.sigma) + encoded.mu

        return z

    def sample(self, sample_size: int, samples=None) -> torch.Tensor:
        if samples is None:
            samples = torch.normal(0.0, 1.0, size=(sample_size, self.latent_size)).to(
                device
            )

        decoded = self.decoder(samples)
        return decoded


def KL_gaussian_loss(mu, sigma):
    return torch.sum(((sigma * sigma) - (2 * torch.log(sigma)) + (mu * mu) - 1) / 2)


def ELBO(x, p, mu, sigma):
    BCE = F.binary_cross_entropy(p, x, reduction="sum")
    KL = KL_gaussian_loss(mu, sigma)
    return BCE + KL


def train(
    model: nn.Module,
    device: torch.device,
    train_loader: DataLoader,
    optimizer: optim.Optimizer,
    epoch: int,
    log_interval: int,
):
    model.train()
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = ELBO(data, output.p, output.mu, output.sigma)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item(),
                )
            )


def test(model: nn.Module, device: torch.device, test_loader: DataLoader):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch_idx, (data, _) in enumerate(test_loader):
            data = data.to(device)
            output = model(data)
            loss = ELBO(data, output.p, output.mu, output.sigma)
            test_loss = test_loss + (loss * data.size(0))

    test_loss /= len(test_loader.dataset)

    print("\nTest set: Average loss: {:.4f}\n".format(test_loss))


vae = VariationalAutoEncoder(
    [28 * 28, 500, 350], latent_size, [latent_size, 350, 500], (28, 28)
)
vae.to(device)
optimizer = optim.Adam(vae.parameters(), lr=lr)

for epoch in range(1, epochs + 1):
    train(vae, device, train_loader, optimizer, epoch, log_interval)
    test(vae, device, test_loader)

RuntimeError: DataLoader worker (pid(s) 1620) exited unexpectedly

## Training the Latent Classifier - subtasks:
Below are all graded subtasks associated with this exam task:

1. Complete the implementation of `ClassificationHead`, which, given a latent vector generated by an `Encoder` for an image $i$ (just the $\mu$ part), predicts to which class (digit) the image $i$ belongs.
2. Complete the implementation of `Classifier`, which, given an input image, first encodes it with a frozen pre-trained `Encoder` and then passes it through the `ClassificationHead` to generate logits for classification.
3. Complete the implementation of `train_classifier`, which trains the `Classifier` module. To be more precise, it trains only the `ClassificationHead` of the `Classifier`. So, for an image $i$, given the output of the pre-trained `Encoder` on the image $i$ (just the $\mu$ part), `ClassificationHead` predicts the class to which the image belongs (which digit is present on the image).

Remarks:
* To earn all points, your model should achieve an accuracy greater than 90% (see test at the end).
* Note that not all variables should be trained, and in particular, no gradients should be propagated throughout the `Encoder`.
* Use a proper loss function for training the `Classifier` and select appropriate training parameters to ensure the final accuracy is above 90%.
* Do not change the code outside the following blocks 
```python3
#### TODO ####

##############
````


In [None]:
class ClassifactionHead(nn.Module):  
    def __init__(self, latent_size, num_classes):
        super().__init__()
        #### TODO ####
        self.f1 = nn.Linear(latent_size, 64)
        self.f2 = nn.Linear(64, num_classes)
        ##############

    def forward(self, x):
        #### TODO ####
        x = self.f1(x)
        x = torch.relu(x)
        x = self.f2(x)
        return x
        ##############


class Classifier(nn.Module):  # 1pt
    def __init__(self, vae, head):
        super().__init__()
        #### TODO ####
        self.vae = vae
        self.head = head
        ##############

    def forward(self, x):
        #### TODO ####
        x = self.vae.encoder(x)
        x = self.head(x.mu)

        return x
        ##############


def train_classifier(train_loader, epochs=10, **kwargs):
    #### TODO ####

    epochs = kwargs.get("epochs", epochs)

    num_classes = len(train_loader.dataset.classes)
    head = ClassifactionHead(latent_size, num_classes)
    head.to(device)

    optimizer = optim.Adam(head.parameters(), lr=kwargs["lr"])

    head.train()

    for epoch in range(1, epochs+1):
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.to(device)
            target = target.to(device)
            optimizer.zero_grad()

            with torch.no_grad():
                z = vae.encoder(data).mu.detach()
            
            output = head(z)

            cross_entropy_loss = nn.CrossEntropyLoss()
            loss = cross_entropy_loss(output, target)

            loss.backward()
            optimizer.step()
            if batch_idx % log_interval == 0:
                print(
                    "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                        epoch,
                        batch_idx * len(data),
                        len(train_loader.dataset),
                        100.0 * batch_idx / len(train_loader),
                        loss.item(),
                    )
                )
        
    classifier = Classifier(vae, head)
    classifier.to(device)

    return classifier
    ##############

#### TODO ####
# Adjust kwargs if needed
train_function_kwargs = {"lr": 4e-3, "epochs": 20}
##############

classifier = train_classifier(train_loader, **train_function_kwargs) 




In [None]:
def test_classifier(classifier):
    classifier.eval()
    test_loss = 0
    with torch.no_grad():
        for (data, label) in test_loader:
            data = data.to(device)
            label = label.to(device)
            output = classifier(data)
            loss = torch.mean((torch.argmax(output, dim=-1) == label).to(float))
            test_loss = test_loss + (loss * data.size(0))

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}\n'.format(test_loss))
    return test_loss

assert test_classifier(classifier) > 0.90, 'Classifier not trained well enough'



Test set: Average loss: 0.9516

