In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.distributions import Beta
%matplotlib inline
import numpy as np
import pandas as pd
import sys
import os
import random
import time
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

import utils
import evaluate
import cs236781.plot as plot

In [2]:
def ELBO_loss(recon_x, x, mu, logvar, beta=1, b1 = 255, a1 = 0, distr='normal'):
    recon_error = nn.functional.mse_loss(recon_x, x, reduction='sum') / x.size(0)
    #recon_error = evaluate.reconstruction_loss(recon_x,x)
    #recon_error = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum') / x.size(0)

    if distr == "normal":
        kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), axis=1) / x.size(0)
    elif distr == "lognormal":
        kl = 0.5 * (mu.pow(2) + logvar.exp() - 1 - logvar) / x.size(0)
    elif distr == "uniform":
        #Assuming mu, logvar = b2,a2 represent the learned bounds
        kl = torch.log((mu - logvar) / (b1 - a1))
    else:
        raise ValueError(f"Distribution {distr} not recognized.")

    return recon_error + beta * kl.sum()

In [3]:
def elbo2_loss(x, x_rec, mu, sigma):
    # Reconstruction loss summed over elements and averaged over batch
    reconstruction_loss = torch.nn.functional.mse_loss(x_rec, x, reduction='sum') / x.size(0)
    
    # KL divergence per sample summed over latent dimensions
    kl_divergence = -0.5 * torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2), dim=1)
    kl_divergence = kl_divergence.mean()

    elbo = reconstruction_loss + kl_divergence
    
    return elbo

In [4]:
def plot_tensors_as_images(images):
    fig, axs = plt.subplots(1, len(images), figsize=(15, 3))  # Create 1x5 grid of subplots
    for i in range(5):
        axs[i].imshow(images[i].cpu().detach().numpy(), cmap='Grays')  # Convert tensor to numpy for plotting
        axs[i].axis('off')  # Hide axes
    
    plt.show()

In [5]:
def get_device():
    if torch.cuda.is_available(): 
     dev = "cuda:0" 
    else: 
     dev = "cpu" 
    return torch.device(dev) 

In [6]:
device = get_device()
print(device)
torch.autograd.set_detect_anomaly(True)

cpu


<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f0b18d03a60>

In [7]:
class VariationalAutoDecoder(nn.Module):
    def __init__(self, x_dim, z_dim, mu, sigma, distr = "normal", device=torch.device("cpu")):
        super().__init__()
        self.device = device
        self.x_dim = x_dim
        self.z_dim = z_dim
        self.distr = distr
        self.mu = nn.parameter.Parameter(mu,True) 
        self.sigma = nn.parameter.Parameter(sigma,True) 
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(z_dim, 128, kernel_size=7, stride=1, padding=0),  
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()  # Output in range [0, 1]
        )
        
        self.mlp =  nn.Linear(28*28,28*28,bias=True)
        
    def sample_vectors(self, mu, sigma):
        Z = torch.randn_like(mu, requires_grad=True).to(self.device)
        X = sigma*Z + mu
        return X
        
    def forward(self, mu, sigma):        
        z = self.sample_vectors(mu, sigma)
        z = z.view(-1, self.z_dim, 1, 1)
        print(z.shape)
        print(self.decoder(z).shape)
        reconstructed_images = 255 * self.decoder(z).view(-1,28*28) # Output in range [0, 255]
        reconstructed_images = self.mlp(reconstructed_images)
        return reconstructed_images

In [8]:
BATCH_SIZE = 64
LEARNING_RATE = 1e-3
NUM_EPOCHS = 500
X_DIM = 28 * 28  
Z_DIM = 100
NUM_CLASSES = 10

config = {
    "BATCH_SIZE": BATCH_SIZE,
    "LEARNING_RATE": LEARNING_RATE,
    "NUM_EPOCHS": NUM_EPOCHS,
    "X_DIM": X_DIM,
    "Z_DIM": Z_DIM,
    "NUM_CLASSES": NUM_CLASSES
}

In [9]:
train_ds, train_dl, test_ds, test_dl = utils.create_dataloaders("dataset", device, BATCH_SIZE)

In [10]:
mu = torch.randn(NUM_CLASSES, Z_DIM, requires_grad=True)
sigma = torch.randn(NUM_CLASSES, Z_DIM, requires_grad=True)

In [11]:
vad = VariationalAutoDecoder(x_dim=X_DIM, 
                             z_dim=Z_DIM, 
                             mu=mu, 
                             sigma=sigma, 
                             device=device).to(device)
print(vad)

VariationalAutoDecoder(
  (decoder): Sequential(
    (0): ConvTranspose2d(100, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): Sigmoid()
  )
  (mlp): Linear(in_features=784, out_features=784, bias=True)
)


In [12]:
optim = torch.optim.Adam(params=vad.parameters(), lr=LEARNING_RATE)

In [13]:
train_losses = []
print("TRAINING with ", config)
for epoch in range(NUM_EPOCHS):
    batch_losses = []
    
    for batch_i, batch in enumerate(train_dl):
        idx, x = batch
        labels = train_ds.y[idx]
        mu_batch = vad.mu[labels]
        sigma_batch = vad.sigma[labels] 
        
        x_reconstruction = vad.forward(mu_batch, sigma_batch)
        x_reconstruction = x_reconstruction.view(x.size(0), 28, 28)
        
        loss = elbo2_loss(x_reconstruction, x.float(), mu_batch, sigma_batch)
        
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        batch_losses.append(loss.data.cpu().item())
    
    train_losses.append(np.mean(batch_losses))

    if(epoch % 50 == 0):
        print("epoch: {} training loss: {:.5f}".format(epoch, train_losses[-1]))

TRAINING with  {'BATCH_SIZE': 64, 'LEARNING_RATE': 0.001, 'NUM_EPOCHS': 500, 'X_DIM': 784, 'Z_DIM': 100, 'NUM_CLASSES': 10}
torch.Size([64, 100, 1, 1])
torch.Size([64, 1, 8, 8])


RuntimeError: shape '[-1, 784]' is invalid for input of size 4096

In [None]:
random_indices = random.sample(range(1000), 5)
distr_params = vad.distr_params[random_indices]
initial = train_ds[random_indices][1]

x = vad(distr_params)
restored = x.view(-1, 28, 28)

plot_tensors_as_images(restored)
plot_tensors_as_images(initial)

In [None]:
latent_vectors = vad.sample_vectors(vad.distr_params)
utils.plot_tsne(train_ds, latent_vectors, "VAD_TSNE_train")

In [None]:
"""
test_latents = torch.randn(len(test_ds), Z_DIM, requires_grad=True, device = device)
vad.latent_vectors = nn.parameter.Parameter(test_latents,True)
opt = torch.optim.Adam([vad.latent_vectors], lr=LEARNING_RATE)

evaluate.evaluate_model(vad, test_dl, opt, vad.latent_vectors, 50 , device)
utils.plot_tsne(train_ds, vad.latent_vectors, "VAD_TSNE_test")
"""