In [None]:
!ls

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os

base_dir = "/content/drive/MyDrive/VAE_Project_lab3"
plots_dir = base_dir + "/plots"
samples_dir = base_dir + "/samples"
models_dir = base_dir + "/models"

os.makedirs(plots_dir, exist_ok=True)
os.makedirs(samples_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)

print("Folders created successfully!")


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import numpy as np


In [None]:
batch_size = 128
learning_rate = 1e-3
epochs = 20
latent_dim = 2   # keep 2 for visualization
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

print("Dataset loaded!")


In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 400)
        self.fc_mu = nn.Linear(400, latent_dim)
        self.fc_logvar = nn.Linear(400, latent_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar


In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(latent_dim, 400)
        self.fc2 = nn.Linear(400, 784)

    def forward(self, z):
        z = torch.relu(self.fc1(z))
        x_recon = torch.sigmoid(self.fc2(z))
        return x_recon


In [None]:
class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar


In [None]:
def loss_function(x_recon, x, mu, logvar):
    recon_loss = nn.functional.binary_cross_entropy(x_recon, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_loss


In [None]:
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

losses = []


In [None]:
for epoch in range(epochs):
    total_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.view(-1, 784).to(device)

        optimizer.zero_grad()
        x_recon, mu, logvar = model(data)
        loss = loss_function(x_recon, data, mu, logvar)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader.dataset)
    losses.append(avg_loss)
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")


In [None]:
plt.figure()
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss")

plot_path = plots_dir + "/training_loss.png"
plt.savefig(plot_path)
plt.show()

print("Saved:", plot_path)


In [None]:
model.eval()
latent_vectors = []
labels = []

with torch.no_grad():
    for data, target in train_loader:
        data = data.view(-1, 784).to(device)
        mu, logvar = model.encoder(data)
        latent_vectors.append(mu.cpu())
        labels.append(target)

latent_vectors = torch.cat(latent_vectors)
labels = torch.cat(labels)

plt.figure(figsize=(6,6))
plt.scatter(latent_vectors[:,0], latent_vectors[:,1], c=labels, cmap='tab10', s=5)
plt.colorbar()
plt.title("Latent Space")

latent_plot_path = plots_dir + "/latent_space.png"
plt.savefig(latent_plot_path)
plt.show()

print("Saved:", latent_plot_path)


In [None]:
with torch.no_grad():
    z = torch.randn(16, latent_dim).to(device)
    samples = model.decoder(z).cpu()

samples = samples.view(16, 1, 28, 28)

fig, axes = plt.subplots(4,4, figsize=(5,5))
for i, ax in enumerate(axes.flat):
    ax.imshow(samples[i][0], cmap="gray")
    ax.axis("off")

sample_path = samples_dir + "/generated_samples.png"
plt.savefig(sample_path)
plt.show()

print("Saved:", sample_path)


In [None]:
model_path = models_dir + "/vae_model.pth"
torch.save(model.state_dict(), model_path)
print("Model saved at:", model_path)


In [None]:
model.eval()

data_iter = iter(train_loader)
images, _ = next(data_iter)

images = images.view(-1, 784).to(device)

with torch.no_grad():
    recon_images, mu, logvar = model(images)

images = images.cpu().view(-1,1,28,28)
recon_images = recon_images.cpu().view(-1,1,28,28)

n = 8  # number of samples to display

fig, axes = plt.subplots(2, n, figsize=(15,4))

for i in range(n):
    axes[0,i].imshow(images[i][0], cmap="gray")
    axes[0,i].set_title("Original")
    axes[0,i].axis("off")

    axes[1,i].imshow(recon_images[i][0], cmap="gray")
    axes[1,i].set_title("Reconstructed")
    axes[1,i].axis("off")

save_path = samples_dir + "/original_vs_reconstruction.png"
plt.savefig(save_path)
plt.show()

print("Saved:", save_path)


In [None]:
import imageio

model.eval()

frames = []

grid_x = np.linspace(-3, 3, 10)
grid_y = np.linspace(-3, 3, 10)

for xi in grid_x:
    for yi in grid_y:
        z = torch.tensor([[xi, yi]]).float().to(device)

        with torch.no_grad():
            sample = model.decoder(z).cpu().view(28,28)

        fig, ax = plt.subplots()
        ax.imshow(sample, cmap='gray')
        ax.axis('off')

        # Save frame temporarily
        frame_path = "/content/frame.png"
        plt.savefig(frame_path)
        plt.close()

        frames.append(imageio.imread(frame_path))

gif_path = samples_dir + "/latent_space_animation.gif"
imageio.mimsave(gif_path, frames, fps=5)

print("GIF saved at:", gif_path)
