In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F 

import torchvision
import torchvision.datasets as datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image

from tqdm.autonotebook import tqdm 

  from tqdm.autonotebook import tqdm


In [7]:
from dataclasses import dataclass

@dataclass
class Params:
    # Basic 
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # MODEL
    INPUT_DIM = 28 * 28 
    Z_DIM = 20
    H_DIM  = 200
    
    # TRAINING 
    NUM_EPOCHS = 10
    BATCH_SIZE = 32
    LR_RATE = 3e-4
    
params = Params()

In [3]:
class Encoder(nn.Module):
    def __init__(self, params):
        super().__init__()
        
        self.img2hidden = nn.Linear(params.INPUT_DIM, params.H_DIM)
        self.hidden2mean = nn.Linear(params.H_DIM, params.Z_DIM)
        self.hidden2std = nn.Linear(params.H_DIM, params.Z_DIM)
        
    def forward(self, x):
        hidden = F.relu(self.img2hidden(x))
        mu = self.hidden2mean(hidden)
        std = self.hidden2std(hidden)
        
        return mu, std 

In [4]:
class Decoder(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.z2hidden = nn.Linear(params.Z_DIM, params.H_DIM)
        self.hidden2img = nn.Linear(params.H_DIM, params.INPUT_DIM)
    
    def forward(self, z):
        new_h = F.relu(self.z2hidden(z))
        img = torch.sigmoid(self.hidden2img(new_h))
        return img 
    

In [5]:
class VAE(nn.Module):
    def __init__(self, params, encoder, decoder):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder 

    def forward(self, x):
        mu, std = self.encoder(x)

        # Sample from latent distribution from encoder
        epsilon = torch.randn_like(std)
        z_reparametrized = mu + std * epsilon
        
        img = self.decode(z_reparametrized)
        
        return img, mu, std 

In [6]:
def make_model(params):
    encoder = Encoder(params)
    decoder = Decoder(params)
    model = VAE(params, encoder, decoder)
    return model

# Build Dataset

In [None]:

dataset = datasets.MNIST(
    root="dataset/", train=True, transform=transforms.ToTensor(), download=True
)
train_loader = DataLoader(dataset=dataset, batch_size=params.BATCH_SIZE, shuffle=True)

# Train Model

In [None]:
model = make_model(params)
optimizer = torch.optim.Adam(model.parameters(), lr=params.LR_RATE)
loss_fn = nn.BCELoss(reduction='sum')

# Inference 

In [None]:
def inference(digit,dataset, num_examples=1, ):
    images= []
    idx = 0
    for x, y in dataset:
        if y == idx:
            images.append(x)
            idx += 1
        if idx == 10:
            break  
    
    
    encodings_digit = []
    for d in range(10):
        with torch.no_grad():
            mu, sigma = model.encode(images)