# Main

In this notebook the whole preprocessing, training, and evaluation will take place.

## Loading Libraries

Library | Version | Channel
--- | --- | ---
NumPy | 1.26.4 | default
PyTorch | 2.2.2 | pytorch
Torchvision | 0.17.2 | pytorch
Tensorboard | / | conda-forge

In [None]:
# Built-in libraries
from dataclasses import dataclass
from datetime import datetime

# Third-party libraries
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision.transforms import v2

## Loading Data

The [MNIST](http://yann.lecun.com/exdb/mnist/) dataset is a widely-used benchmark in machine learning, consisting of 70,000 images of handwritten digits from 0 to 9. Each image is a 28x28 grayscale pixel grid. Due to its simplicity and well-structured format, MNIST serves as an excellent starting point for developing and testing machine learning models, particularly in the field of image recognition and classification.

In [None]:
# file_path = "../data/adata_normalized_sample.h5ad"
file_path = "../data/adata_30kx10k_normalized_sample.h5ad"

adata = ad.read_h5ad(filename=file_path)

In [None]:
adata

AnnData object with n_obs × n_vars = 30000 × 10000
    obs: 'NRP', 'age_group', 'cell_source', 'cell_type', 'donor', 'gender', 'n_counts', 'n_genes', 'percent_mito', 'percent_ribo', 'region', 'sample', 'scrublet_score', 'source', 'type', 'version', 'cell_states', 'Used'
    var: 'gene_ids-Harvard-Nuclei', 'feature_types-Harvard-Nuclei', 'gene_ids-Sanger-Nuclei', 'feature_types-Sanger-Nuclei', 'gene_ids-Sanger-Cells', 'feature_types-Sanger-Cells', 'gene_ids-Sanger-CD45', 'feature_types-Sanger-CD45'
    uns: 'cell_type_colors'
    obsm: 'X_pca', 'X_umap'
    layers: 'cpm_normalized', 'min_max_normalized'

In [None]:
# Does not work yet --> Data type error
full_data = adata.layers["min_max_normalized"]

In [None]:
full_data

<30000x10000 sparse matrix of type '<class 'numpy.float64'>'
	with 12490183 stored elements in Compressed Sparse Row format>

## Data Split

Split data in training and testing data.

In [None]:
from torch.utils.data import Dataset

class SparseDataset(Dataset):
    """
    Custom dataset class for sparse data.
    """

    def __init__(self, sparse_data):
        self.sparse_data = sparse_data

    def __len__(self):
        return self.sparse_data.shape[0]

    def __getitem__(self, index: int):
        if index >= len(self):
            raise IndexError("Index out of range")
        
        # Extract the row as a dense numpy array
        row = self.sparse_data.getrow(index).toarray().squeeze()
        
        # Convert the row to a PyTorch tensor
        return torch.tensor(row, dtype=torch.float32)

In [None]:
train_size = int(0.8 * full_data.shape[0])
test_size = full_data.shape[0] - train_size

torch.manual_seed(2406)
perm = torch.randperm(full_data.shape[0])
train_split, test_split = perm[:train_size], perm[train_size:]

In [None]:
train_data = SparseDataset(full_data[train_split, :])
test_data = SparseDataset(full_data[test_split, :])

In [None]:
# Dataloaders
batch_size = 128

# Create data loaders
train_loader = torch.utils.data.DataLoader(
    train_data, 
    batch_size=batch_size, 
    shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
    test_data, 
    batch_size=batch_size, 
    shuffle=False,
)

## Model Structure

The **autoencoder** is comprised of two primary components: the **encoder** and the **decoder**. The encoder is responsible for reducing the dimensionality of the input tensor. The decoder, in turn, attempts to reconstruct the original input data from the reduced representation generated by the encoder.

#### Device Specification

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
cuda = True if device == "cuda" else False 

### Hyperparameters

Hyperparameters used for training the model.

In [None]:
size_layers = [
    (10000, nn.ReLU()), 
    (6000, nn.ReLU()), 
    (3000, nn.ReLU()), 
    (1000, nn.ReLU()), 
    (200, nn.ReLU())
]
# Optimizer
learning_rate = 1e-1
weight_decay = 1e-8
# Training
folds = 5
epochs = 15

The dataset is divided into two sections: one for training and the other for validation following training. This division is referred to as a **"fold"**. The fold is created by extracting all cells from a single donor to ensure that the results are not influenced by any batch effects specific to that donor.

In [None]:
@dataclass
class AEOutput:
    """
    Dataclass for AE output.

    Attributes:
        z_sample (torch.Tensor): The sampled value of the latent variable z.
        x_recon (torch.Tensor): The reconstructed output from the VAE.
        loss (torch.Tensor): The overall loss of the VAE.
    """

    z_sample: torch.Tensor
    x_recon: torch.Tensor
    loss: torch.Tensor

In [None]:
class Autoencoder(nn.Module):
    """
    Autoencoder (AE) class.

    Args:
        size_layers (list): Dimensionality of the layers.
        loss_function (nn.Module): Loss function used for evaluation.
        optimizer (nn.Module): Optimizer used
    """

    def __init__(
        self,
        size_layers: list[tuple[int, nn.Module]],
        criterion: nn.modules.loss._Loss,
        learning_rate: float = 1e-1,
        weight_decay: float = 1e-8,
        optimizer: torch.optim.Optimizer = torch.optim.Adam,
    ):
        super(Autoencoder, self).__init__()

        ## Encoder architecture
        self.encoder_layers = []
        # Only iterate until second to last element
        # --> Idx of last element called
        for idx, (size, activation) in enumerate(size_layers[:-1]):
            # While second to last element no reached
            # --> Activation function in decoder
            if idx < len(size_layers[:-1]) - 1:
                self.encoder_layers.append(nn.Linear(size, size_layers[idx + 1][0]))

                # Checks if activation is viable
                if activation is not None:
                    assert isinstance(
                        activation, nn.Module
                    ), f"Activation should be of type {nn.Module}"
                    self.encoder_layers.append(activation)
            else:
                self.encoder_layers.append(nn.Linear(size, size_layers[idx + 1][0]))

        self.encoder = nn.Sequential(*self.encoder_layers)

        print("Constructed encoder...")

        ## Decoder archtitecture
        # Reverse to build decoder (hourglass)
        reversed_layers = list(reversed(size_layers))
        self.decoder_layers = []
        for idx, (size, activation) in enumerate(reversed_layers[:-1]):
            # While second to last element no reached
            # --> Activation function in encoder
            if idx < len(reversed_layers[:-1]) - 1:
                self.decoder_layers.append(nn.Linear(size, reversed_layers[idx + 1][0]))

                # Checks if activation is viable
                if activation is not None:
                    assert isinstance(
                        activation, nn.Module
                    ), f"Activation should be of type {nn.Module}"
                    self.decoder_layers.append(activation)
            else:
                self.decoder_layers.append(nn.Linear(size, reversed_layers[idx + 1][0]))

        self.decoder = nn.Sequential(*self.decoder_layers)

        print("Constructed decoder...")

        self.criterion = criterion
        self.learning_rate = learning_rate
        self.optimizer = optimizer(
            params=self.parameters(), lr=learning_rate, weight_decay=weight_decay
        )

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(self.device)

    def encode(self, x):
        """
        Encodes the input data into the latent space.

        Args:
            x (torch.Tensor): Input data.

        Returns:
            torch.Tensor: Input data compressed to latent space.
        """
        return self.encoder(x)

    def decode(self, z):
        """
        Decodes the data from the latent space to the original input space.

        Args:
            z (torch.Tensor): Data in the latent space.

        Returns:
            torch.Tensor: Reconstructed data in the original input space.
        """
        return self.decoder(z)

    def forward(self, x, compute_loss: bool = True):
        """
        Performs a forward pass of the AE.

        Args:
            x (torch.Tensor): Input data.
            compute_loss (bool): Whether to compute the loss or not.

        Returns:
            VAEOutput: VAE output dataclass.
        """
        z = self.encode(x)
        recon_x = self.decode(z)

        if not compute_loss:
            return AEOutput(z_sample=z, x_recon=recon_x, loss=None)

        # compute loss terms
        loss_recon = self.criterion(recon_x, x)

        return AEOutput(z_sample=z, x_recon=recon_x, loss=loss_recon)

In [None]:
model = Autoencoder(size_layers, criterion=nn.MSELoss())

Constructed encoder...
Constructed decoder...


In [None]:
model

Autoencoder(
  (encoder): Sequential(
    (0): Linear(in_features=10000, out_features=6000, bias=True)
    (1): ReLU()
    (2): Linear(in_features=6000, out_features=3000, bias=True)
    (3): ReLU()
    (4): Linear(in_features=3000, out_features=1000, bias=True)
    (5): ReLU()
    (6): Linear(in_features=1000, out_features=200, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=200, out_features=1000, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1000, out_features=3000, bias=True)
    (3): ReLU()
    (4): Linear(in_features=3000, out_features=6000, bias=True)
    (5): ReLU()
    (6): Linear(in_features=6000, out_features=10000, bias=True)
  )
  (criterion): MSELoss()
)

## Training

In [None]:
from tqdm import tqdm

def train(model, dataloader, optimizer, prev_updates, writer=None, device="cpu"):
    """
    Trains the model on the given data.

    Args:
        model (nn.Module): The model to train.
        dataloader (torch.utils.data.DataLoader): The data loader.
        loss_fn: The loss function.
        optimizer: The optimizer.
    """
    model.train()  # Set the model to training mode

    for batch_idx, data in enumerate(tqdm(dataloader)):
        n_upd = prev_updates + batch_idx

        data = data.to(device)

        optimizer.zero_grad()  # Zero the gradients

        output = model(data)  # Forward pass
        loss = output.loss

        # loss.backward()

        if n_upd % 100 == 0:
            # Calculate and log gradient norms
            total_norm = 0.0
            for p in model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** (1.0 / 2)

            print(
                f"Step {n_upd:,} (N samples: {n_upd * dataloader.batch_size:,}), Loss: {loss.item():.4f} Grad: {total_norm:.4f}"
            )

            if writer is not None:
                global_step = n_upd
                writer.add_scalar("Loss/Train", loss.item(), global_step)
                writer.add_scalar("GradNorm/Train", total_norm, global_step)

        # gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()  # Update the model parameters

    return prev_updates + len(dataloader)

In [None]:
def test(model, dataloader, cur_step, writer=None, device="cpu"):
    """
    Tests the model on the given data.

    Args:
        model (nn.Module): The model to test.
        dataloader (torch.utils.data.DataLoader): The data loader.
        cur_step (int): The current step.
        writer: The TensorBoard writer.
    """
    model.eval()  # Set the model to evaluation mode
    test_loss = 0

    with torch.no_grad():
        for data in tqdm(dataloader, desc="Testing"):
            data = data.to(device)

            output = model(data, compute_loss=True)  # Forward pass

            test_loss += output.loss.item()

    test_loss /= len(dataloader)
    print(f"====> Test set loss: {test_loss:.4f}")

    if writer is not None:
        writer.add_scalar("Loss/Test", test_loss, global_step=cur_step)

In [None]:
from datetime import datetime

writer = SummaryWriter(f'runs/mnist/vae_{datetime.now().strftime("%Y%m%d-%H%M%S")}')

prev_updates = 0
for epoch in range(epochs):
    print(f'Epoch {epoch+1}/{epochs}')
    prev_updates = train(model, train_loader, model.optimizer, prev_updates, writer=writer)
    test(model, test_loader, prev_updates, writer=writer)

In [None]:
def create_folds(train_data):
    donors = adata.obs["donor"].unique()

    for donor in donors:
        # Create training data
        # Remove cells from chosen donor
        train_data = adata[adata.obs.donor != donor]
        # Create validation data
        val_data = adata[adata.obs.donor == donor]

    # Load data
    train_loader = AnnLoader(train_data, batch_size=batch_size, shuffle=True, use_cuda=cuda)
    val_loader = AnnLoader(val_data, batch_size=batch_size, shuffle=True, use_cuda=cuda)

    return train_loader, val_loader