In [11]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, random_split, Dataset
import torchvision
from torchvision.transforms import transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
import os
from typing import Tuple


In [13]:
torch.manual_seed(42)
np.random.seed(42)

device: torch.device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f'using device: {device}')

using device: mps


In [None]:
# Intresting we always set hyperparameters at the top of the file

learning_rate: float = 1e-3
batch_size: int = 128
num_epochs: int = 15
validation_split: float = 0.2
z_dim: int = 128 # for CAFIR let's use a bigger latent space

In [18]:
os.makedirs('CIFAR_model', exist_ok=True)

In [20]:
# data Transoformations

transform: transforms.Compose = transforms.Compose([
    transforms.ToTensor()
])

# Let's do this on the CAFIR Datast



train_val_set: torchvision.datasets.CIFAR10 = torchvision.datasets.CIFAR10(
    root='data', 
    train=True,
    download=True,
    transform=transform
)

test_set: torchvision.datasets.CIFAR10 = torchvision.datasets.CIFAR10(
    root='data',
    train=False,
    download=True,
    transform=transform
)

# split the dataset into train and validation set

num_train: int = int((1-validation_split)*len(train_val_set))
num_val: int = len(train_val_set) - num_train
train_set: Dataset
val_set: Dataset
train_set, val_set = random_split(train_val_set, [num_train, num_val]) # Okay I understand how random split works now, you enter the dataset and than you enter the lenghts you want to split into and it returns the datasets

print(f'train set size: {len(train_set)}, validation set size: {len(val_set)}, test set size: {len(test_set)}')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:06<00:00, 26822591.81it/s]


Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified
train set size: 40000, validation set size: 10000, test set size: 10000


In [26]:
# Dataloader - tehse are the things we iterate over

train_laoder: DataLoader = DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4, 
    pin_memory=True
)

val_loaer: DataLoader = DataLoader(
    dataset=val_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

train_loader: DataLoader = DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)

def one_hot(labels: torch.Tensor, num_classes: int = 10) -> torch.Tensor:
    return torch.nn.functional.one_hot(labels, num_classes=num_classes).float()

In [None]:
# VAE Implementation


class ConvVAE(nn.Module):
    
    def __init__(self, z_dim: int = 128) -> None:
        super(ConvVAE, self).__init__()
        self.z_dim=z_dim
        
        # Encoder
        self.encoder:nn.Sequential = nn.Sequential(
            nn.Conv2d(in_channels=3,
                      out_channels=32,
                      kernel_size=4,
                      stride=2,
                      padding=1), # 32 x 16 x 16 (because of stride 2)
            nn.ReLU(),
            nn.Conv2d(
                in_channels=32,
                out_channels=64,
                kernel_size=4,
                stride=2,
                padding=1
            ), # 64 x 8 x 8
            nn.ReLU(),
            nn.Conv2d(
                in_channels=64,
                out_channels=128,
                kernel_size=4,
                stride=2,
                padding=1
            ), # 128 x 4 x 4
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128*4*4, 512), 
            nn.ReLU()
        )
        self.fc_mu: nn.Linear = nn.Linear(512, z_dim)
        self.fc_logvar: nn.Linear = nn.Linear(512, z_dim)
        
        # Decoder
        
        self.decoder_fc: nn.Linear = nn.Linear(z_dim, 512) # decoder takes in the the input of the encoder and outputs the image means
        self.decoder: nn.Sequential = nn.Sequential(
            nn.ReLU(),
            nn.Linear(512, 128*4*4),
            nn.ReLU(),
            nn.Unflatten(dim=1, unflattened_size=(128, 4, 4)),
            nn.ConvTranspose2d(
                in_channels=128,
                out_channels=64,
                kernel_size=4,
                stride=2,
                padding=1
            ), # 64 x 8 x 8
            nn.ReLU(),
            nn.ConvTranspose2d(
                in_channels=64,
                out_channels=32,
                kernel_size=4,
                stride=2,
                padding=1
            ), # 32 x 16 x 16
            nn.ReLU(),
            nn.ConvTranspose2d(
                in_channels=32,
                out_channels=3,
                kernel_size=4,
                stride=2,
                padding=1
            ), # 3 x 32 x 32
            nn.Tanh()
        )
        
    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        h: torch.Tensor = self.encoder(x) # h is x x 512
        mu: torch.Tensor = self.fc_mu(h) # mu is x x z_dim
        logvar: torch.Tensor = self.fc_logvar(h) # logvar is x x z_dim
        return mu, logvar # we return the mean and the logvariance - not their samples
    
    def reparameterize(self, mu:torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std: torch.Tensor = torch.exp(0.5*logvar)
        eps: torch.Tensor = torch.randn_like(std)
        return mu + eps*std
    
    def decode(self, z: torch.Tensor) -> torch.Tensor:
        h: torch.Tensor = self.decoder_fc(z)
        x_recon: torch.Tensor = self.decoder(h)
        return x_recon