<a href="https://colab.research.google.com/github/Tejnu/CSET-419---Generative-artificial-intelligence/blob/main/GENAI_LAB3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os
import zipfile

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [2]:
transform = transforms.ToTensor()

train_data = datasets.FashionMNIST(
    root="./data", train=True, download=True, transform=transform
)

test_data = datasets.FashionMNIST(
    root="./data", train=False, download=True, transform=transform
)

train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False)

100%|██████████| 26.4M/26.4M [00:02<00:00, 12.9MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 204kB/s]
100%|██████████| 4.42M/4.42M [00:01<00:00, 3.81MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 24.8MB/s]


In [3]:
class VAE(nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(28*28, 400)
        self.fc_mu = nn.Linear(400, 20)
        self.fc_logvar = nn.Linear(400, 20)

        self.fc2 = nn.Linear(20, 400)
        self.fc3 = nn.Linear(400, 28*28)

    def encode(self, x):
        h = torch.relu(self.fc1(x))
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = torch.relu(self.fc2(z))
        return torch.sigmoid(self.fc3(h))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [4]:
def vae_loss(recon_x, x, mu, logvar, use_kl=True):
    recon_loss = nn.functional.binary_cross_entropy(
        recon_x, x, reduction="sum"
    )

    if use_kl:
        kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return recon_loss + kl

    return recon_loss

In [5]:
def train_vae(model, loader, optimizer, epochs, save_dir, use_kl=True):
    os.makedirs(save_dir + "/models", exist_ok=True)
    best_loss = float("inf")

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0

        for x, _ in loader:
            x = x.view(-1, 28*28).to(device)

            recon, mu, logvar = model(x)
            loss = vae_loss(recon, x, mu, logvar, use_kl)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(loader.dataset)
        print(f"Epoch {epoch} | Loss: {avg_loss:.4f}")

        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), f"{save_dir}/models/best_model.pth")

    print("Best loss:", best_loss)

In [6]:
no_kl_dir = "vae_results/no_kl"
os.makedirs(no_kl_dir, exist_ok=True)

model_no_kl = VAE().to(device)
opt_no_kl = optim.Adam(model_no_kl.parameters(), lr=1e-3)

train_vae(model_no_kl, train_loader, opt_no_kl, epochs=10, save_dir=no_kl_dir, use_kl=False)

Epoch 1 | Loss: 256.9139
Epoch 2 | Loss: 225.8123
Epoch 3 | Loss: 220.7764
Epoch 4 | Loss: 218.2734
Epoch 5 | Loss: 216.7959
Epoch 6 | Loss: 215.7972
Epoch 7 | Loss: 215.0489
Epoch 8 | Loss: 214.4521
Epoch 9 | Loss: 213.9573
Epoch 10 | Loss: 213.5453
Best loss: 213.5452830078125


In [7]:
kl_dir = "vae_results/with_kl"
os.makedirs(kl_dir, exist_ok=True)

model_kl = VAE().to(device)
opt_kl = optim.Adam(model_kl.parameters(), lr=1e-3)

train_vae(model_kl, train_loader, opt_kl, epochs=10, save_dir=kl_dir, use_kl=True)

Epoch 1 | Loss: 284.3014
Epoch 2 | Loss: 256.5790
Epoch 3 | Loss: 250.8859
Epoch 4 | Loss: 248.1424
Epoch 5 | Loss: 246.4109
Epoch 6 | Loss: 245.2805
Epoch 7 | Loss: 244.4145
Epoch 8 | Loss: 243.7057
Epoch 9 | Loss: 243.2751
Epoch 10 | Loss: 242.7872
Best loss: 242.78718040364583


In [8]:
def save_recon(model, loader, path):
    model.eval()
    os.makedirs(path, exist_ok=True)

    with torch.no_grad():
        x, _ = next(iter(loader))
        x = x.view(-1, 28*28).to(device)
        recon, _, _ = model(x)

        fig, axs = plt.subplots(2, 10, figsize=(10,2))

        for i in range(10):
            axs[0,i].imshow(x[i].view(28,28).cpu(), cmap="gray")
            axs[0,i].axis("off")

            axs[1,i].imshow(recon[i].view(28,28).cpu(), cmap="gray")
            axs[1,i].axis("off")

        plt.savefig(path + "/reconstruction.png")
        plt.close()

In [9]:
save_recon(model_no_kl, test_loader, no_kl_dir)
save_recon(model_kl, test_loader, kl_dir)

In [10]:
def generate_images(model, path):
    model.eval()
    os.makedirs(path, exist_ok=True)

    with torch.no_grad():
        z = torch.randn(10, 20).to(device)
        samples = model.decode(z)

        fig, axs = plt.subplots(1,10, figsize=(10,1))

        for i in range(10):
            axs[i].imshow(samples[i].view(28,28).cpu(), cmap="gray")
            axs[i].axis("off")

        plt.savefig(path + "/generated.png")
        plt.close()

In [11]:
generate_images(model_no_kl, no_kl_dir)
generate_images(model_kl, kl_dir)

In [12]:
zip_path = "vae_results.zip"

with zipfile.ZipFile(zip_path, 'w') as zipf:
    for folder in ["vae_results/no_kl", "vae_results/with_kl"]:
            for root, dirs, files in os.walk(folder):
                        for file in files:
                                        zipf.write(os.path.join(root, file))

In [13]:
from google.colab import files
files.download("vae_results.zip")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>