<a href="https://www.kaggle.com/code/fellahabdelnour13/autoencoders-with-mnist-dataset?scriptVersionId=192371761" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

<h1 style="font-family:verdana;"> Overview </h1>

1. [What are auto-encoders ?](#what_are_auto_encoders)
2. [Necessary Packages](#necessary_packages)
3. [Constants](#constants)
4. [Reproducibility](#reproducibility)
5. [Data Loading](#data_loading)
6. [Visualization](#visualization)
7. [Utils](#utils)
8. [Simple Auto-encoders](#simple_auto_encoders)
     1. [The Architecture](#1_the_architecture)
     2. [Training](#1_training)
     3. [Evaluation](#1_evaluation)
     4. [Auto-encoders Vs PCA (Visualization)](#1_autoencoders_vs_pca_visualization)
     5. [Auto-encoders Vs PCA (Features Quality)](#1_autoencoders_vs_pca_features_quality)
     6. [Retraining With a bigger bottleneck size](#1_retraining_with_bigger_bottleneck_size)
9. [Convolutiona Auto-Encoders](#convolutional_auto_encoders)
    1. [The architecture](#2_the_architecture)
    2. [Training](#2_training)
    3. [Evaluation](#2_evaluation)
10. [Denoising Auto-Encoders](#denoising_auto_encoders)
    1. [Adding noise to the input](#3_adding_noise_to_the_input)
    2. [Training](#3_training)
    3. [Evaluation](#3_evaluation)
11. [Variational Autoencoders](#variational_auto_encoders)
    1. [The architecture](#4_the_architecture)
    2. [The loss function](#4_the_loss_function)
    3. [Training](#4_training)
    4. [Evaluation](#4_evaluation)
    5. [Generating Synthetic Images](#4_generating_synthetic_data)
12. [Conditional Variational Autoencoders](#conditional_variational_auto_encoders)
    1. [The architecture](#5_the_architecture)
    2. [Training](#5_training)
    3. [Evaluation](#5_evaluation)
    4. [Generating Synthetic Images](#5_generating_synthetic_data)
13. [Conclusion](#conclusion)
14. [Thank you :)](#thank_you)

<div id="what_are_auto_encoders" >
    <h1 style="font-family:verdana;"> What are auto-encoders ? </h1>
</div>

<div style="font-size: 1rem;font-family:verdana;" >
    Autoencoders are a type of artificial neural network used to learn efficient codings of unlabeled data. They are primarily used for dimensionality reduction and feature learning. The structure of an autoencoder consists of two main parts: the encoder, which compresses the input into a latent-space representation, and the decoder, which reconstructs the input from this representation.
</div> <br/>


<div style="font-size:1.125rem;font-family:verdana;font-weight:bold;" >
Types of Autoencoders:
</div><br/>
    
1. **Simple Autoencoders** : <br/>
    * *Architecture* : Consist of fully connected layers.   
    * *Use cases* : Dimensionality reduction, noise reduction, anomaly detection.
2. **Convolutional Autoencoders** : <br/>
    * *Architecture* : Use convolutional layers for encoding and decoding. 
    * *Use cases* : Image processing tasks such as denoising, inpainting, and super-resolution.
3. **Denoising Autoencoders**
    * *Architecture* : Can be either fully connected or convolutional, but trained to remove noise from input data.   
    * *Use cases* : Image denoising, improving robustness of representations.
4. **Variational Autoencoders (VAEs)**
    * *Architecture* : Probabilistic approach to autoencoders that imposes a distribution on the latent space.  
    * *Use cases* : Generative tasks, semi-supervised learning, anomaly detection.
5. **Conditional Variational Autoencoders (VAEs)**
    * *Architecture* : Extend VAEs by conditioning the generation process on additional information, such as labels.
    * *Use cases* : Conditional image generation, data augmentation.
    
<div style="font-size: 1rem;font-family:verdana;" >
    For more details you can refer to :
</div>

* [Introduction To Autoencoders.](https://towardsdatascience.com/introduction-to-autoencoders-7a47cf4ef14b)
* [Understanding Variational Autoencoders (VAEs).](https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73)
* [Conditional Variational Autoencoders with Learnable Conditional Embeddings.](https://towardsdatascience.com/conditional-variational-autoencoders-with-learnable-conditional-embeddings-e22ee5359a2a)

<div id="necessary_packages" >
    <h1 style="font-family:verdana;"> Necessary Packages </h1>
</div>

In [None]:
import torch
import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as T
import seaborn as sns
import random
import warnings
from torch import nn,Tensor,optim
from torch.utils.data import DataLoader,Dataset,default_collate
from typing import Callable,Optional,Any,Literal
from tqdm.notebook import tqdm
from sklearn.neighbors import KNeighborsClassifier
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline 
from functools import reduce
from torchinfo import summary
from torch.nn import functional as F

In [None]:
warnings.simplefilter(action="ignore",category=FutureWarning)
sns.set_style("darkgrid")

<div id="constants" >
    <h1 style="font-family:verdana;"> Constants </h1>
</div>

In [None]:
DATA_DIR = '/kaggle/input/mnist-in-csv'

In [None]:
class CONFIG:
    SEED = 42
    LEARNING_RATE : float = 1e-3
    BATCH_SIZE : int = 64
    WEIGHT_DECAY : float = 0.0
    EPOCHS : int = 10
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    NUM_CLASSES = 10
    NOISE_RATE : float = 0.2

<div id="reproducibility" >
    <h1 style="font-family:verdana;"> Reproducibility </h1>
</div>

In [None]:
def seed_everything(seed : int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
seed_everything(CONFIG.SEED)

<div id="data_loading" >
    <h1 style="font-family:verdana;"> Data Loading </h1>
</div>

In [None]:
class MnistDataset(Dataset):

    def __init__(self, 
        df_path : str,
        target_col : str,
        features_transform : Optional[Callable] = None,
        copy_transform : Optional[Callable] = None,
        target_transform : Optional[Callable] = None,
    ) -> None:

        super().__init__()

        self.df_path = df_path
        self.target_col = target_col
        self.features_transform = features_transform
        self.copy_transform = copy_transform
        self.target_transform = target_transform

        self.data = pd.read_csv(df_path)

    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx : int) -> Any:

        row = self.data.iloc[idx]
        features = np.array(row.drop(self.target_col).to_list()).reshape(28,28).astype(np.uint8)
        features_copy = features.copy()
        target = row[self.target_col]

        if self.features_transform is not None:
            features = self.features_transform(features)
            
        if self.copy_transform is not None:
            features_copy = self.copy_transform(features_copy)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return features,features_copy,target

In [None]:
train_dataset = MnistDataset(df_path=os.path.join(DATA_DIR,'mnist_train.csv'),target_col='label')
test_dataset = MnistDataset(df_path=os.path.join(DATA_DIR,'mnist_test.csv'),target_col='label')

<div id="utils" >
    <h1 style="font-family:verdana;"> Utils </h1>
</div>

In [None]:
def create_dataloader(
    train_features_transfrom : Callable,
    train_copy_transfrom : Callable,
    test_features_transfrom : Callable,
    test_copy_transfrom : Callable,
) -> tuple[DataLoader,DataLoader]:
    
    train_dataset = MnistDataset(
        df_path=os.path.join(DATA_DIR,'mnist_train.csv'),
        target_col='label',
        features_transform=train_features_transfrom,
        copy_transform=train_copy_transfrom
    )
    
    test_dataset = MnistDataset(
        df_path=os.path.join(DATA_DIR,'mnist_test.csv'),
        target_col='label',
        features_transform=test_features_transfrom,
        copy_transform=test_copy_transfrom
    )
    
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=CONFIG.BATCH_SIZE,
        shuffle=True,
        num_workers=4,
        prefetch_factor=2
    )

    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=CONFIG.BATCH_SIZE,
        shuffle=False,
        num_workers=4,
        prefetch_factor=2
    )
    
    return train_loader,test_loader

In [None]:
def train(
    model : nn.Module,
    loss_fn : nn.Module,
    train_loader : DataLoader,
    val_loader : DataLoader,
    learning_rate : float,
    weight_decay : float,
    epochs : int,
    device : torch.device,
    use_label : bool = False,
    show_progress : bool = True,
    verbose : bool = True,
) -> pd.DataFrame:
    
    ### Initialize history
    history = {
        "epoch" : [],
        "split" : [],
        "loss" : []
    }
        
    ### Move the model to the right device
    model = model.to(device)
        
    ### Loss function
    loss_fn = loss_fn.to(device)
    
    ### Optimizer
    optimizer = optim.AdamW(params=model.parameters(),lr=learning_rate,weight_decay=weight_decay)
    
    ### Training    
    for epoch in range(epochs):
        
        ### Training Loop
        iterator = tqdm(enumerate(train_loader),total=len(train_loader)) if show_progress else enumerate(train_loader)
        
        if show_progress:
            iterator.set_description(f"Epoch : {(epoch+1)}/{epochs}")
        
        model = model.train()
        running_loss = 0.0
        
        for i,(x,y,z) in iterator:
            
            ### Move the data to the right device
            x,y,z = x.to(device),y.to(device),z.to(device)
            
            ### Forward pass
            y_hat = model(x) if not use_label else model((x,z))
            
            ### Loss
            loss = loss_fn(y_hat,y)
            running_loss += loss.item()
            
            ### Zero gradient
            optimizer.zero_grad()
            
            ### Backward pass
            loss.backward()
            
            ### Update the network's weights
            optimizer.step()
            
            ### Update the progress bar
            if show_progress:
                iterator.set_postfix(running_mean_loss=(running_loss / (i+1)))
            
        ### Update the history
        history["epoch"].append(epoch)
        history["split"].append("train")
        history["loss"].append(running_loss/len(train_loader))
            
        if verbose:
             message = f"Epoch : {(epoch+1)}/{epochs},Split : train,loss={history['loss'][-1]},"
             print(message)
                                
        ### Validation Loop
        iterator = tqdm(enumerate(val_loader),total=len(val_loader)) if show_progress else enumerate(val_loader)
        
        if show_progress:
            iterator.set_description(f"Epoch : {(epoch+1)}/{epochs}")
        
        model = model.eval()
        running_loss = 0.0
        
        for i,(x,y,z) in iterator:
            
            ### Move the data to the right device
            x,y,z = x.to(device),y.to(device),z.to(device)
            
            ### Forward pass
            y_hat = model(x) if not use_label else model((x,z))
            
            ### Loss
            loss = loss_fn(y_hat,y)
            running_loss += loss.item()
        
        ### Update the history
        history["epoch"].append(epoch)
        history["split"].append("val")
        history["loss"].append(running_loss/len(val_loader))
            
        if verbose:
            message = f"Epoch : {(epoch+1)}/{epochs},Split : val,loss={history['loss'][-1]},"
            print(message)
                
    history = pd.DataFrame(history)
    
    return history

In [None]:
class Reshape(nn.Module):
    
    def __init__(self, new_shape : tuple):
        super().__init__()
        
        self.new_shape = new_shape
        
    def forward(self, x : Tensor) -> Tensor:
        B = x.size(0)
        x = torch.reshape(x, (B,)+self.new_shape)
        return x

In [None]:
def knn_score(
    train_data : tuple[np.ndarray,np.ndarray],
    test_data : tuple[np.ndarray,np.ndarray],
) -> float:
    
    x_train, y_train = train_data
    x_test, y_test = test_data

    knn = Pipeline([
        ('scaler',StandardScaler()),
        ('knn',KNeighborsClassifier(n_neighbors=3))
    ])
    knn.fit(x_train,y_train)

    return knn.score(x_test,y_test)

In [None]:
def extract_features(model : nn.Module, dataloader : DataLoader) -> tuple[Tensor,Tensor]:

    features = []
    targets = []
    
    model = model.eval()

    for x,_,y in tqdm(dataloader):

        x = x.to(CONFIG.DEVICE)

        with torch.no_grad():
            x = model.encoder(x)

        features.append(x.cpu())
        targets.append(y)

    features = torch.cat(features).numpy()
    targets = torch.cat(targets).numpy()

    return features, targets

In [None]:
def test_autoencoder(
    img : np.array,
    transform : Callable,
    model : nn.Module
):
    fig,(ax1,ax2) = plt.subplots(nrows=1,ncols=2,figsize=(10,5))
    img = transform(img)
    img = img.unsqueeze(0).to(CONFIG.DEVICE)
    
    with torch.inference_mode():
        decoded_img = model(img).squeeze().cpu().numpy()
        
    img = img[0].detach().cpu()
    
    if len(img.shape) == 3:
        img = img.permute((1,2,0))
    
    ax1.imshow(img, cmap='gray')
    ax2.imshow(decoded_img, cmap='gray')

In [None]:
def plot_images(images : np.ndarray,targets : np.ndarray | None = None,nrows : int = 4,ncols : int = 4) -> None:
    
    fig, axes = plt.subplots(nrows=nrows,ncols=ncols,figsize=(10,10))
    
    k = 0

    for i in range(nrows):
        for j in range(ncols):

            image = images[k]
            
            axes[i,j].imshow(image,cmap='gray')
            axes[i,j].axis('off')
            
            if targets is not None:
                axes[i,j].set_title(f"Target {targets[k]}")
            
            k += 1

<div id="visualization" >
    <h1 style="font-family:verdana;"> Visualization </h1>
</div>

In [None]:
sample = train_dataset.data.sample(36)
images = sample.drop(columns="label").values.reshape(-1,28,28)
labels = sample["label"].values
plot_images(images,labels,6,6)

<div id="simple_auto_encoders" >
    <h1 style="font-family:verdana;"> Simple Auto-Encoders </h1>
</div>

<div id="1_the_architecture" >
    <h2 style="font-family:verdana;"> The architecture </h2>
</div>

In [None]:
class SimpleAutoEncoder(nn.Module):
    
    def __init__(self, 
        input_dim : tuple,
        hidden_dim : int,
        bottleneck_dim : int
    ) -> None:
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.bottleneck_dim = bottleneck_dim
        
        input_dim = reduce(lambda a,b : a * b,input_dim)
        
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, bottleneck_dim),
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(bottleneck_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid(),
            Reshape(self.input_dim)
        )
        
    def forward(self, x : Tensor) -> Tensor:
        
        x = self.encoder(x)
        x = self.decoder(x)
        
        return x

In [None]:
model = SimpleAutoEncoder(
    input_dim=(28,28),
    hidden_dim=128,
    bottleneck_dim=2
)

In [None]:
summary(model=model,input_size=(32,28,28),device="cpu")

<div id="1_training" >
    <h2 style="font-family:verdana;"> Training </h2>
</div>

In [None]:
transform = T.Compose([
    T.ToTensor(),
    T.Lambda(torch.squeeze)
])

In [None]:
train_loader,test_loader = create_dataloader(transform,transform,transform,transform)

In [None]:
history = train(
    model,
    nn.MSELoss(),
    train_loader,
    test_loader,
    learning_rate=CONFIG.LEARNING_RATE,
    weight_decay=CONFIG.WEIGHT_DECAY,
    epochs=5,
    device=CONFIG.DEVICE,
)

<div id="1_evaluation" >
    <h2 style="font-family:verdana;"> Evaluation </h2>
</div>

<div style="font-family:verdana;font-size:18px;"> Learning curve </div>

In [None]:
sns.lineplot(data=history,x="epoch",y="loss",hue="split",palette="Set2")

<div style="font-family:verdana;font-size:18px;"> Original Image Vs Decoded Image </div>

In [None]:
test_autoencoder(
    img=test_dataset[0][0],
    transform=transform,
    model=model
)

<div class="alert alert-block alert-info" style="font-size:16px; font-family:verdana; line-height: 1.7em;">
    📌 The reconstructed image is blurry and not the most accurate one this is due to the small size of the bottleneck layer.
</div>

<div id="1_autoencoders_vs_pca_visualization" >
    <h2 style="font-family:verdana;"> Auto-Encoders Vs PCA (Visualization) </h2>
</div>

In [None]:
autoencoder_train_features,autoencoder_train_targets = extract_features(model, train_loader)
autoencoder_test_features,autoencoder_test_targets = extract_features(model, test_loader)
autoencoder_train_features.shape

In [None]:
pca_model = Pipeline([
    ("scaler",StandardScaler()),
    ("pca",PCA(n_components=2))
]).fit(train_dataset.data.drop(columns="label"))

pca_train_features = pca_model.transform(train_dataset.data.drop(columns="label"))
pca_test_features = pca_model.transform(test_dataset.data.drop(columns="label"))

pca_train_targets = train_dataset.data["label"]
pca_test_targets = test_dataset.data["label"]

pca_train_features.shape

In [None]:
sns.scatterplot(x=autoencoder_train_features[:,0],y=autoencoder_train_features[:,1],hue=autoencoder_train_targets,palette="Set2")

In [None]:
sns.scatterplot(x=pca_train_features[:,0],y=pca_train_features[:,1],hue=pca_train_targets,palette="Set2")

<div class="alert alert-block alert-info" style="font-size:16px; font-family:verdana; line-height: 1.7em;">
    📌 Autoencoder based-visualization is better as the classes are better seperated.
</div>

<div id="1_autoencoders_vs_pca_features_quality" >
    <h2 style="font-family:verdana;"> Auto-Encoders Vs PCA (Features Quality) </h2>
</div>

In [None]:
knn_score(
    train_data=(autoencoder_train_features,autoencoder_train_targets),
    test_data=(autoencoder_test_features,autoencoder_test_targets)
)

In [None]:
knn_score(
    train_data=(pca_train_features,pca_train_targets),
    test_data=(pca_test_features,pca_test_targets)
)

<div class="alert alert-block alert-info" style="font-size:16px; font-family:verdana; line-height: 1.7em;">
    📌 The Autoencoder was able capture more information than PCA.
</div>

<div id="1_retraining_with_bigger_bottleneck_size" >
    <h2 style="font-family:verdana;"> Training with a bigger bottleneck size </h2>
</div>

<div style="font-family:verdana;font-size:18px;"> Training </div>

In [None]:
model = SimpleAutoEncoder(
    input_dim=(28,28),
    hidden_dim=128,
    bottleneck_dim=64
)

In [None]:
summary(model=model,input_size=(32,28,28),device="cpu")

In [None]:
history = train(
    model,
    nn.MSELoss(),
    train_loader,
    test_loader,
    learning_rate=CONFIG.LEARNING_RATE,
    weight_decay=CONFIG.WEIGHT_DECAY,
    epochs=5,
    device=CONFIG.DEVICE,
)

<div style="font-family:verdana;font-size:18px;"> Learning curve </div>

In [None]:
sns.lineplot(data=history,x="epoch",y="loss",hue="split",palette="Set2")

<div style="font-family:verdana;font-size:18px;"> Original Image Vs Decoded Image </div>

In [None]:
test_autoencoder(
    img=test_dataset[0][0],
    transform=transform,
    model=model
)

<div class="alert alert-block alert-info" style="font-size:16px; font-family:verdana; line-height: 1.7em;">
    📌 This time the reconstructed image has a much better quality.
</div>

<div style="font-family:verdana;font-size:18px;"> The quality of the features </div>

In [None]:
autoencoder_train_features,autoencoder_train_targets = extract_features(model, train_loader)
autoencoder_test_features,autoencoder_test_targets = extract_features(model, test_loader)
autoencoder_train_features.shape

In [None]:
knn_score(
    train_data=(autoencoder_train_features,autoencoder_train_targets),
    test_data=(autoencoder_test_features,autoencoder_test_targets)
)

<div class="alert alert-block alert-info" style="font-size:16px; font-family:verdana; line-height: 1.7em;">
    📌 With a bottleneck size equals to 64,KNN reached +97% accuracy on the the test set.
</div>

<div id="convolutional_auto_encoders" >
    <h1 style="font-family:verdana;"> Convolutional Auto-Encoders Auto-Encoders </h1>
</div>

<div id="2_the_architecture" >
    <h2 style="font-family:verdana;"> The architecture </h2>
</div>

In [None]:
class ConvAutoEncoder(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3),
        )
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 64, kernel_size=3),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=3),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 8, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(8, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
model = ConvAutoEncoder()

In [None]:
summary(model=model,input_size=(32,1,28,28),device="cpu")

<div id="2_training" >
    <h2 style="font-family:verdana;"> Training </h2>
</div>

In [None]:
transform = T.ToTensor()

In [None]:
train_loader,test_loader = create_dataloader(transform,transform,transform,transform)

In [None]:
history = train(
    model,
    nn.MSELoss(),
    train_loader,
    test_loader,
    learning_rate=CONFIG.LEARNING_RATE,
    weight_decay=CONFIG.WEIGHT_DECAY,
    epochs=5,
    device=CONFIG.DEVICE,
)

<div id="2_evaluation" >
    <h2 style="font-family:verdana;"> Evaluation </h2>
</div>

<div style="font-family:verdana;font-size:18px;"> Learning curve </div>

In [None]:
sns.lineplot(data=history,x="epoch",y="loss",hue="split",palette="Set2")

<div style="font-family:verdana;font-size:18px;"> Original Image Vs Decoded Image </div>

In [None]:
test_autoencoder(
    img=test_dataset[0][0],
    transform=transform,
    model=model
)

<div style="font-family:verdana;font-size:18px;"> Features Quality </div>

In [None]:
autoencoder_train_features,autoencoder_train_targets = extract_features(model, train_loader)
autoencoder_test_features,autoencoder_test_targets = extract_features(model, test_loader)
autoencoder_train_features.shape

In [None]:
autoencoder_train_features = autoencoder_train_features.squeeze()
autoencoder_test_features = autoencoder_test_features.squeeze()

In [None]:
knn_score(
    train_data=(autoencoder_train_features,autoencoder_train_targets),
    test_data=(autoencoder_test_features,autoencoder_test_targets)
)

<div class="alert alert-block alert-info" style="font-size:16px; font-family:verdana; line-height: 1.7em;">
    📌 With Less parameters we were able to get another auto-encoder with a similar performance.
</div>

<div id="denoising_auto_encoders" >
    <h1 style="font-family:verdana;"> Denoising Auto-Encoders </h1>
</div>

<div id="3_adding_noise_to_the_input" >
    <h2 style="font-family:verdana;"> Adding noise to the data </h2>
</div>

In [None]:
class RandomNoise(nn.Module):

    def __init__(self, noise_rate : float = 0.3, copy : bool = False) -> None:

        super().__init__()

        self.noise_rate = noise_rate
        self.copy = copy

    def forward(self, x : Tensor) -> Tensor:

        if self.copy:
            x = x.clone()
            
        ### Flatten the input
        og_input = x.shape
        x = x.flatten()
        
        ### Get the pixles that will be noised
        mask = torch.rand(x.shape[0]) < self.noise_rate
        pixles = x[mask]

        ### Randomally shuffle the pixles
        pixles = pixles[torch.randperm(pixles.shape[0])]

        ### Replace the pixles in the original tensor
        x[mask] = pixles
        
        ### Reshape to the original shape
        x = x.reshape(og_input)

        return x

In [None]:
def test_random_noise(image : Tensor, noise_rate : float = 0.3) -> None:

    noise = RandomNoise(noise_rate=noise_rate,copy=True)
    x_noised = noise(image)

    fig, axes = plt.subplots(1,2,figsize=(10,5))

    axes[0].imshow(image,cmap='gray')
    axes[0].set_title('Original Image')

    axes[1].imshow(x_noised,cmap='gray')
    axes[1].set_title('Noised Image')

In [None]:
test_random_noise(
    image=torch.tensor(train_dataset[0][0]),
    noise_rate=0.3
)

<div id="3_training" >
    <h2 style="font-family:verdana;"> Training </h2>
</div>

In [None]:
transform = T.Compose([
    T.ToTensor(),
    T.Lambda(torch.squeeze)
])

train_features_transform = T.Compose([
    T.ToTensor(),
    RandomNoise(noise_rate=CONFIG.NOISE_RATE,copy=True),
    T.Lambda(torch.squeeze)
])

In [None]:
train_loader,test_loader = create_dataloader(train_features_transform,transform,transform,transform)

In [None]:
model = SimpleAutoEncoder(
    input_dim=(28,28),
    hidden_dim=128,
    bottleneck_dim=64
)

In [None]:
history = train(
    model,
    nn.MSELoss(),
    train_loader,
    test_loader,
    learning_rate=CONFIG.LEARNING_RATE,
    weight_decay=CONFIG.WEIGHT_DECAY,
    epochs=10,
    device=CONFIG.DEVICE,
)

<div id="3_evaluation" >
    <h2 style="font-family:verdana;"> Evaluation </h2>
</div>

<div style="font-family:verdana;font-size:18px;"> Learning curve </div>

In [None]:
sns.lineplot(data=history,x="epoch",y="loss",hue="split",palette="Set2")

<div style="font-family:verdana;font-size:18px;"> Noised Vs Reconstructed Image </div>

In [None]:
test_autoencoder(
    img=test_dataset[0][0],
    transform=train_features_transform,
    model=model
)

<div class="alert alert-block alert-info" style="font-size:16px; font-family:verdana; line-height: 1.7em;">
    📌 The autoencoder was able to remove the noise introduced on the input image.
</div>

<div id="variational_auto_encoders" >
    <h1 style="font-family:verdana;"> Variational Autoencoder </h1>
</div>

<div id="4_the_architecture" >
    <h2 style="font-family:verdana;"> The Architecture </h2>
</div>

In [None]:
class VariationalAutoEncoder(nn.Module):
    
    def __init__(self, 
        input_dim : tuple,
        hidden_dim : int,
        bottleneck_dim : int
    ) -> None:
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.bottleneck_dim = bottleneck_dim
        
        input_dim = reduce(lambda a,b : a * b,input_dim)
        
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2 * bottleneck_dim),
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(bottleneck_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid(),
            Reshape(self.input_dim)
        )
        
    def forward(self, x : Tensor) -> Tensor:
        
        x = self.encoder(x)
        mean,log_var = x[:,:self.bottleneck_dim],x[:,self.bottleneck_dim:]
        x = self.reparameterization(mean, torch.exp(0.5 * log_var))
        x = self.decoder(x)
        
        return x,mean,log_var
    
    def generate(self,
        batch_size : int,
    ) -> Tensor:
        
        x = torch.randn(batch_size, self.bottleneck_dim).to(CONFIG.DEVICE)

        with torch.inference_mode():
            generated_images = self.decoder(x).detach().cpu()
        
        return generated_images
    
    def reparameterization(self, mean : Tensor, var : Tensor) -> Tensor:
        epsilon = torch.randn_like(var).to(var.device)    
        z = mean + var * epsilon                    
        return z

In [None]:
model = VariationalAutoEncoder(
    input_dim=(28,28),
    hidden_dim=384,
    bottleneck_dim=64
).to(CONFIG.DEVICE)

In [None]:
summary(model, input_size=(32,28,28),device="cpu")

<div id="4_the_loss_function" >
    <h2 style="font-family:verdana;"> The loss function </h2>
</div>

In [None]:
class VAELoss(nn.Module):
    
    def __init__(self):
        super().__init__()
        
    def forward(self,inputs : tuple[Tensor,Tensor,Tensor],targets : Tensor) -> Tensor:
        
        x,mean,log_var = inputs
        
        reconstruction_loss = F.binary_cross_entropy(x,targets,reduction='sum')
        kld_loss = - 0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
        
        return reconstruction_loss + kld_loss

<div id="4_training" >
    <h2 style="font-family:verdana;"> Training </h2>
</div>

In [None]:
transform = T.Compose([
    T.ToTensor(),
    T.Lambda(torch.squeeze)
])

In [None]:
train_loader,test_loader = create_dataloader(transform,transform,transform,transform)

In [None]:
history = train(
    model,
    VAELoss(),
    train_loader,
    test_loader,
    learning_rate=CONFIG.LEARNING_RATE,
    weight_decay=CONFIG.WEIGHT_DECAY,
    epochs=30,
    device=CONFIG.DEVICE,
)

<div id="4_evaluation" >
    <h2 style="font-family:verdana;"> Evaluation </h2>
</div>

In [None]:
sns.lineplot(data=history,x="epoch",y="loss",hue="split",palette="Set2")

<div id="4_generating_synthetic_data" >
    <h2 style="font-family:verdana;"> Generating Synthetic Images </h2>
</div>

In [None]:
generated_images = model.generate(batch_size=16)
plot_images(generated_images)

<div class="alert alert-block alert-info" style="font-size:16px; font-family:verdana; line-height: 1.7em;">
    📌 We were able to generate good quality synthetic images but we have no control over the number that the generated images represent.
</div>

<div id="conditional_variational_auto_encoders" >
    <h1 style="font-family:verdana;"> Conditional Variational Autoencoder </h1>
</div>

<div id="5_the_architecture" >
    <h2 style="font-family:verdana;"> The architecture </h2>
</div>

In [None]:
class ConditionalVariationalAutoEncoder(nn.Module):
    
    def __init__(self, 
        input_dim : tuple,
        hidden_dim : int,
        bottleneck_dim : int,
        num_classes : int
    ) -> None:
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.bottleneck_dim = bottleneck_dim
        self.num_classes = num_classes
        self.input_dim_flat = reduce(lambda a,b : a * b,input_dim)
        
        self.one_hot = nn.Parameter(data=torch.eye(num_classes), requires_grad=False)
        
        self.encoder = nn.Sequential(
            nn.Linear(self.input_dim_flat + num_classes, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2 * bottleneck_dim),
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(bottleneck_dim + num_classes, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, self.input_dim_flat),
            nn.Sigmoid(),
            Reshape(self.input_dim)
        )
        
    def forward(self, x : tuple[Tensor,Tensor]) -> Tensor:
        
        img,label = x
        
        ### Encoder
        label = self.one_hot[label]
        img = torch.flatten(img,start_dim=1)
        x = torch.cat([img,label],dim=-1)
        x = self.encoder(x)
        mean,log_var = x[:,:self.bottleneck_dim],x[:,self.bottleneck_dim:]
        
        ### Rreparameterization
        x = self.reparameterization(mean, torch.exp(0.5 * log_var))
        
        ### Decoding
        x = torch.cat((x, label), dim=-1)
        x = self.decoder(x)
        
        return x,mean,log_var
    
    def generate(self,
        batch_size : int,
        labels : int | Tensor
    ) -> Tensor:
        
        if isinstance(labels,int):
            labels = torch.ones(batch_size) * labels
        else:
            if labels.size(0) != batch_size:
                raise Exception(f"labels.size(0) = {labels.size(0)} is different than batch_size={batch_size}")
                
        noise = torch.randn(batch_size, self.bottleneck_dim).to(CONFIG.DEVICE)
        labels = self.one_hot[labels.long().to(CONFIG.DEVICE)]
        x = torch.cat([noise,labels],dim=-1)
        
        with torch.inference_mode():
            generated_images = self.decoder(x).detach().cpu()
        
        return generated_images
    
    def reparameterization(self, mean : Tensor, var : Tensor) -> Tensor:
        epsilon = torch.randn_like(var).to(var.device)    
        z = mean + var * epsilon                    
        return z

In [None]:
model = ConditionalVariationalAutoEncoder(
    input_dim=(28,28),
    hidden_dim=384,
    bottleneck_dim=64,
    num_classes=CONFIG.NUM_CLASSES
).to(CONFIG.DEVICE)

In [None]:
images = torch.zeros(32,28,28)
labels = torch.zeros(32).long()
summary(model,input_data={ "x" : (images,labels) },device="cpu")

<div id="5_training" >
    <h2 style="font-family:verdana;"> Training </h2>
</div>

In [None]:
history = train(
    model,
    VAELoss(),
    train_loader,
    test_loader,
    learning_rate=CONFIG.LEARNING_RATE,
    weight_decay=CONFIG.WEIGHT_DECAY,
    epochs=20,
    device=CONFIG.DEVICE,
    use_label=True
)

<div id="5_evaluation" >
    <h2 style="font-family:verdana;"> Evaluation </h2>
</div>

In [None]:
sns.lineplot(data=history,x="epoch",y="loss",hue="split",palette="Set2")

<div id="5_generating_synthetic_data" >
    <h2 style="font-family:verdana;"> Generating Synthetic Images </h2>
</div>

In [None]:
generated_images = model.generate(16,9).numpy()
plot_images(generated_images)

In [None]:
labels = torch.randint(low=0,high=CONFIG.NUM_CLASSES,size=(16,))
generated_images = model.generate(16,labels).numpy()
plot_images(generated_images)

<div class="alert alert-block alert-info" style="font-size:16px; font-family:verdana; line-height: 1.7em;">
    📌 Now we can control the number that our generated images represent.
</div>

<div id="conclusion" >
    <h1 style="font-family:verdana;"> Conclusion </h1>
</div>

<div style="font-size: 1rem;font-family:verdana;" >
In this notebook, we explored various types of autoencoders and applied them to the MNIST dataset. We began with simple autoencoders and gradually progressed to more complex architectures like convolutional, denoising, variational, and conditional variational autoencoders. Each type of autoencoder has unique strengths and applications:
</div><br/>

1. **Simple Autoencoders** : Effective for basic feature extraction and dimensionality reduction.
2. **Convolutional Autoencoders** : Excellent for image-related tasks due to their ability to capture spatial hierarchies.
3. **Denoising Autoencoders** : Improve robustness and can clean noisy data.
4. **Variational Autoencoders** : Enable generation of new data samples and facilitate probabilistic interpretation.
5. **Conditional Variational Autoencoders** : Allow for controlled generation based on additional context or labels.


<div id="thank_you" >
    <h1 style="font-family:verdana;"> Thank you :) </h1>
</div>