In [None]:
import torch
import torch.nn as nn
import wandb

In [None]:
from torchvision import datasets, transforms

# Define a transform to normalize the data and convert to tensor
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.5,), (0.5,))
])

# Download and load the MNIST dataset
mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_loader = torch.utils.data.DataLoader(mnist_dataset, batch_size=64, shuffle=True)

In [None]:
# !wandb login
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("API_KEY")

import wandb

In [None]:
wandb.login(key=secret_value_0)

In [None]:
# Initialize wandb
wandb.init(
    project="vae-mnist",
    config={
        "learning_rate": 0.0005,
        "epochs": 200,
        "batch_size": 64,
        "input_dim": 1,
        "hidden_dim": 64,
        "latent_dim": 2,
        "dataset": "MNIST",
        "architecture": "Variational Autoencoder",
        "reconstruction_weight": 0.1,
        "kl_weight": 0.5
    }
)

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, leaky = 0.01):
        super(Encoder, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_dim, hidden_dim, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(leaky),
            nn.Conv2d(hidden_dim , hidden_dim , kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(leaky),
            nn.Conv2d(hidden_dim, hidden_dim , kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(leaky),
            nn.Conv2d(hidden_dim , hidden_dim , kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(leaky),
            nn.Flatten(),
        )
        # self.fc2 = nn.Linear(3136, output_dim)

    def forward(self, x):
        x = self.conv(x)
        # x = self.fc2(x)
        # x = nn.functional.sigmoid(x)
        return x
    
# class Decoder(nn.Module):

In [None]:
enc = Encoder(input_dim=1, hidden_dim=64, output_dim=2).to('cuda')
x = torch.randn(1, 1, 28, 28).to('cuda')  # Example input tensor
# output = enc(x)
# print("Output shape:", output.shape)  # Should print the shape of the output tensor

from torchinfo import summary
summary(enc, (1,1,28,28), device='cuda')  # Print the model summary

In [None]:
# x.shape

In [None]:


class Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, leaky = 0.01):
        super(Decoder, self).__init__()
        self.linear = nn.Linear(output_dim, 3136)
        self.conv = nn.Sequential(
            
            # Reshape(-1, hidden_dim * 2, 16, 16),
            nn.ConvTranspose2d(input_dim, hidden_dim , kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(leaky),
            nn.ConvTranspose2d(hidden_dim, hidden_dim , kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(leaky),
            
            nn.ConvTranspose2d(hidden_dim , hidden_dim , kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(leaky),
            nn.ConvTranspose2d(hidden_dim , 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(leaky),
        )

    def forward(self, x):
        x = self.linear(x)
        x = x.view(-1, 64, 7, 7)
        x = self.conv(x)
        x = nn.functional.sigmoid(x)    
        return x

In [None]:
summary(Decoder(input_dim=64, hidden_dim=64, output_dim=2).to('cuda'), (1, 2), device='cuda')  # Print the model summary

In [None]:
class Autoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim, output_dim)
        self.decoder = Decoder(hidden_dim , hidden_dim , output_dim)
        self.z_mean = nn.Linear(3136, 2, bias=False)
        self.z_log_var = nn.Linear(3136, 2, bias=False)
        
    def reparametrize(self, encoded, mean_sampled, log_var_sampled):
        epsilon = torch.randn(log_var_sampled.size(0), log_var_sampled.size(1), device='cuda')
        # print(epsilon.shape)
        # print(mean_sampled.shape)
        res = mean_sampled + torch.exp(log_var_sampled / 2.0) * epsilon
        return res
        
    def forward(self, x):
        encoded = self.encoder(x)
        sampled_z, log_var_sampled_z = self.z_mean(encoded), self.z_log_var(encoded)
        z = self.reparametrize(encoded, sampled_z, log_var_sampled_z)
        # print(x.shape)
        decoded = self.decoder(z)
        return decoded, sampled_z, log_var_sampled_z, z

In [None]:
autoencoder = Autoencoder(input_dim=1, hidden_dim=64, output_dim=2).to('cuda')
summary(autoencoder, (1, 1, 28, 28), device='cuda')  # Print the model summary

In [None]:
from torch.utils.data import random_split, DataLoader

# Define the split sizes
train_size = int(0.8 * len(mnist_dataset))
val_size = len(mnist_dataset) - train_size

# Split the dataset
train_dataset, val_dataset = random_split(mnist_dataset, [train_size, val_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [None]:
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.0005)
# Training the autoencoder
epochs = 50 * 4  # Number of epochs for training

for epoch in range(epochs):
    # Training phase
    autoencoder.train()
    train_loss = 0.0
    train_recon_loss = 0.0
    train_kl_loss = 0.0
    num_batches = 0
    
    for data, _ in train_loader:
        data = data.to('cuda')
        optimizer.zero_grad()
        output, mu, log_var, z = autoencoder(data)
        
        # Reconstruction loss
        recon_loss = nn.functional.mse_loss(output, data, reduction='none')
        recon_loss = recon_loss.view(output.size(0), -1).sum(dim=1)
        recon_loss = recon_loss.mean()
        # KL divergence loss
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1)
        # kl_loss = kl_loss / data.size(0)  # Average over batch
        kl_loss = kl_loss.mean()
        # print(kl_loss)
        # Total loss
        total_loss = kl_loss + 1.0 * recon_loss
        
        total_loss.backward()
        optimizer.step()
        
        train_loss += total_loss.item()
        train_recon_loss += recon_loss.item()
        train_kl_loss += kl_loss.item()
        num_batches += 1
    
    # Calculate average training losses
    avg_train_loss = train_loss / num_batches
    avg_train_recon = train_recon_loss / num_batches
    avg_train_kl = train_kl_loss / num_batches
    # print(avg_train_kl)
    # Validation phase
    autoencoder.eval()
    val_loss = 0.0
    val_recon_loss = 0.0
    val_kl_loss = 0.0
    val_batches = 0
    
    with torch.no_grad():
        for data, _ in val_loader:
            data = data.to('cuda')
            output, mu, log_var, z = autoencoder(data)
            
            # Reconstruction loss
            recon_loss = nn.functional.mse_loss(output, data, reduction='none')
            recon_loss = recon_loss.view(output.size(0), -1).sum(dim=1)
            recon_loss = recon_loss.mean()
            # KL divergence loss
            kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1)
            # kl_loss = kl_loss / data.size(0)
            kl_loss = kl_loss.mean() 
            
            # Total loss
            total_loss = 0.1 * recon_loss +  recon_loss
            
            val_loss += total_loss.item()
            val_recon_loss += recon_loss.item()
            val_kl_loss += kl_loss.item()
            val_batches += 1
    
    avg_val_loss = val_loss / val_batches
    avg_val_recon = val_recon_loss / val_batches
    avg_val_kl = val_kl_loss / val_batches
    
    # Log to wandb
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": avg_train_loss,
        "train_reconstruction_loss": avg_train_recon,
        "train_kl_loss": avg_train_kl,
        "val_loss": avg_val_loss,
        "val_reconstruction_loss": avg_val_recon,
        "val_kl_loss": avg_val_kl
    })
    
    print(f'Epoch {epoch+1}, Train Loss: {avg_train_loss:} (Recon: {avg_train_recon:}, KL: {avg_train_kl:}), Val Loss: {avg_val_loss:} (Recon: {avg_val_recon:}, KL: {avg_val_kl:})')

In [None]:
import matplotlib.pyplot as plt
import numpy as np

with torch.no_grad():
    for data, _ in val_loader:
        data = data.to('cuda')
        res, mu, log_var, z = autoencoder(data)  # Unpack VAE outputs
        break

# Move tensors to CPU and convert to numpy for visualization
original_images = data.cpu().numpy()
reconstructed_images = res.cpu().numpy()

# Plot original vs reconstructed images
fig, axes = plt.subplots(2, 8, figsize=(15, 4))
fig.suptitle('VAE: Original (top) vs Reconstructed (bottom)')

for i in range(8):
    # Original images
    axes[0, i].imshow(original_images[i].squeeze(), cmap='gray')
    axes[0, i].set_title(f'Original {i+1}')
    axes[0, i].axis('off')
    
    # Reconstructed images
    axes[1, i].imshow(reconstructed_images[i].squeeze(), cmap='gray')
    axes[1, i].set_title(f'Reconstructed {i+1}')
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

# Log sample images to wandb
wandb.log({
    "sample_reconstructions": wandb.Image(plt)
})