In [2]:
from __future__ import print_function
import os
import time
import random
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from skimage.metrics import structural_similarity as ssim

In [None]:
# -------------------------
# 1. Reproducibility
# -------------------------
manualSeed = 999
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

In [None]:
# -------------------------
# 2. Hyperparameters
# -------------------------
num_epochs = 20
batch_size = 32
log_interval = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# -------------------------
# 3. Data (CIFAR10, 64x64)
#       this makes BCE / MSE / SSIM cleaner.
# -------------------------
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),  # -> [0,1]
])

train_dataset = datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)
val_dataset = datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform
)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

print(f"Train size: {len(train_loader.dataset)}")
print(f"Validation size: {len(val_loader.dataset)}")

os.makedirs("results", exist_ok=True)


Random Seed:  999
Files already downloaded and verified
Files already downloaded and verified
Train size: 50000
Validation size: 10000


In [None]:
# -------------------------
# 4. VAE model
# -------------------------
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)


class UnFlatten(nn.Module):
    def forward(self, input, size=1024):
        return input.view(input.size(0), size, 1, 1)


class VAE(nn.Module):
    def __init__(self, image_channels=3, h_dim=1024, z_dim=32):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 32, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2),
            nn.ReLU(),
            Flatten()
        )

        self.fc1 = nn.Linear(h_dim, z_dim)  # mu
        self.fc2 = nn.Linear(h_dim, z_dim)  # logvar
        self.fc3 = nn.Linear(z_dim, h_dim)

        self.decoder = nn.Sequential(
            UnFlatten(),
            nn.ConvTranspose2d(h_dim, 128, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(32, image_channels, kernel_size=6, stride=2),
            nn.Sigmoid(),  # output in [0,1]
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def bottleneck(self, h):
        mu, logvar = self.fc1(h), self.fc2(h)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar

    def representation(self, x):
        return self.bottleneck(self.encoder(x))[0]

    def forward(self, x):
        h = self.encoder(x)
        z, mu, logvar = self.bottleneck(h)
        z = self.fc3(z)
        return self.decoder(z), mu, logvar


model = VAE().to(device)
print(model)

optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
# -------------------------
# 5. Loss and metrics
# -------------------------

def vae_loss(recon_x, x, mu, logvar):
    """
    VAE loss = BCE (reconstruction) + KLD (regularization).
    Both x and recon_x are in [0,1].
    """
    # BCE over all pixels
    bce = F.binary_cross_entropy(recon_x, x, reduction='sum')
    # KLD term
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return bce + kld


def compute_mse(x, recon_x):
    """
    Mean Squared Error per image (averaged over pixels).
    Inputs are [0,1].
    """
    return F.mse_loss(recon_x, x, reduction='mean')


def compute_ssim(x, recon_x):
    """
    SSIM score in [0,1], higher is better.
    Computed per image then averaged over the batch.
    x, recon_x: (B, C, H, W), values in [0,1].
    """
    # Move channels last: (B, H, W, C)
    x_np = x.detach().permute(0, 2, 3, 1).cpu().numpy()
    recon_np = recon_x.detach().permute(0, 2, 3, 1).cpu().numpy()

    ssim_vals = []
    for i in range(x_np.shape[0]):
        s = ssim(
            x_np[i],
            recon_np[i],
            data_range=1.0,
            channel_axis=-1,  # last axis is channels
        )
        ssim_vals.append(s)

    return float(np.mean(ssim_vals))

In [None]:
# -------------------------
# 6. Train / Test loops
# -------------------------
def train(epoch):
    model.train()
    total_loss = 0.0
    total_mse = 0.0
    total_ssim = 0.0
    n_samples = 0

    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)  # [0,1]
        optimizer.zero_grad()

        recon_batch, mu, logvar = model(data)
        loss = vae_loss(recon_batch, data, mu, logvar)

        # metrics
        batch_mse = compute_mse(data, recon_batch).item()
        batch_ssim = compute_ssim(data, recon_batch)

        loss.backward()
        optimizer.step()

        batch_size_actual = data.size(0)
        n_samples += batch_size_actual

        total_loss += loss.item()
        total_mse += batch_mse * batch_size_actual
        total_ssim += batch_ssim * batch_size_actual

        if batch_idx % log_interval == 0:
            print(
                f"Train Epoch: {epoch} "
                f"[{batch_idx * len(data)}/{len(train_loader.dataset)} "
                f"({100. * batch_idx / len(train_loader):.0f}%)]\t"
                f"Loss: {loss.item() / batch_size_actual:.6f}"
            )

    avg_loss = total_loss / n_samples
    avg_mse = total_mse / n_samples
    avg_ssim = total_ssim / n_samples

    print(
        f"====> Epoch: {epoch} "
        f"Train loss: {avg_loss:.4f} "
        f"Train MSE: {avg_mse:.6f} "
        f"Train SSIM: {avg_ssim:.6f}"
    )

    return avg_loss, avg_mse, avg_ssim


def test(epoch):
    model.eval()
    total_loss = 0.0
    total_mse = 0.0
    total_ssim = 0.0
    n_samples = 0

    with torch.no_grad():
        for i, (data, _) in enumerate(val_loader):
            data = data.to(device)  # [0,1]
            recon_batch, mu, logvar = model(data)

            loss = vae_loss(recon_batch, data, mu, logvar)

            batch_mse = compute_mse(data, recon_batch).item()
            batch_ssim = compute_ssim(data, recon_batch)

            batch_size_actual = data.size(0)
            n_samples += batch_size_actual

            total_loss += loss.item()
            total_mse += batch_mse * batch_size_actual
            total_ssim += batch_ssim * batch_size_actual

            # save a grid of original vs recon for the first batch
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n], recon_batch[:n]])
                save_image(
                    comparison.cpu(),
                    f"results/reconstruction_{epoch}.png",
                    nrow=n
                )

    avg_loss = total_loss / n_samples
    avg_mse = total_mse / n_samples
    avg_ssim = total_ssim / n_samples

    print(
        f"====> Test Epoch: {epoch} "
        f"Test loss: {avg_loss:.4f} "
        f"Test MSE: {avg_mse:.6f} "
        f"Test SSIM: {avg_ssim:.6f}"
    )

    return avg_loss, avg_mse, avg_ssim

In [None]:
# -------------------------
# 7. Main training loop
# -------------------------
if __name__ == "__main__":
    train_losses = []
    test_losses = []

    train_mse_list = []
    train_ssim_list = []
    test_mse_list = []
    test_ssim_list = []

    epochs_list = []

    epoch_times = []
    cumulative_times = []

    start_time = time.time()

    for epoch in range(1, num_epochs + 1):
        epoch_start = time.time()

        train_loss, train_mse, train_ssim = train(epoch)
        test_loss, test_mse, test_ssim = test(epoch)

        # per-epoch summary line (what your prof likely wants)
        print(
            f"[Epoch {epoch:02d}] "
            f"Train MSE: {train_mse:.6f}, Train SSIM: {train_ssim:.6f} | "
            f"Test MSE: {test_mse:.6f}, Test SSIM: {test_ssim:.6f}"
        )

        train_losses.append(train_loss)
        test_losses.append(test_loss)

        train_mse_list.append(train_mse)
        train_ssim_list.append(train_ssim)
        test_mse_list.append(test_mse)
        test_ssim_list.append(test_ssim)
        epochs_list.append(epoch)

        epoch_duration = time.time() - epoch_start
        epoch_times.append(epoch_duration)
        cumulative_times.append(time.time() - start_time)

        # sample from prior and save
        with torch.no_grad():
            # sample in h_dim space directly (as in your original code)
            sample = torch.randn(64, 1024, 1, 1).to(device)
            sample = model.decoder(sample).cpu()
            save_image(
                sample.view(64, 3, 64, 64),
                f"results/sample_{epoch}.png"
            )

    total_training_time = time.time() - start_time

    print("\nFinal metrics:")
    print(
        f"Train -> MSE: {train_mse_list[-1]:.6f}, "
        f"SSIM: {train_ssim_list[-1]:.6f}"
    )
    print(
        f"Test  -> MSE: {test_mse_list[-1]:.6f}, "
        f"SSIM: {test_ssim_list[-1]:.6f}"
    )
    print(f"Total training time: {total_training_time:.2f}s")


VAE(
  (encoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ReLU()
    (4): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2))
    (5): ReLU()
    (6): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2))
    (7): ReLU()
    (8): Flatten()
  )
  (fc1): Linear(in_features=1024, out_features=32, bias=True)
  (fc2): Linear(in_features=1024, out_features=32, bias=True)
  (fc3): Linear(in_features=32, out_features=1024, bias=True)
  (decoder): Sequential(
    (0): UnFlatten()
    (1): ConvTranspose2d(1024, 128, kernel_size=(5, 5), stride=(2, 2))
    (2): ReLU()
    (3): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2))
    (4): ReLU()
    (5): ConvTranspose2d(64, 32, kernel_size=(6, 6), stride=(2, 2))
    (6): ReLU()
    (7): ConvTranspose2d(32, 3, kernel_size=(6, 6), stride=(2, 2))
    (8): Sigmoid()
  )
)
====> Epoch: 1 Train loss: 7452.7634 Train MSE: 0.019650 Train SS

In [None]:
# ==========================
# 8. PLOTS
# ==========================

# 1) Training error (MSE) vs training time
plt.figure()
plt.plot(cumulative_times, train_mse_list, marker='o')
plt.xlabel("Training time (s)")
plt.ylabel("Training error (MSE)")
plt.title("Training Error (MSE) vs Training Time")
plt.grid(True)
plt.tight_layout()
plt.savefig("result_VAE/train_error_vs_time.png")
plt.close()

# 2) Training & testing errors vs training time
plt.figure()
plt.plot(cumulative_times, train_mse_list, marker='o', label="Train MSE")
plt.plot(cumulative_times, test_mse_list, marker='s', label="Test MSE")
plt.xlabel("Training time (s)")
plt.ylabel("Error (MSE)")
plt.title("Training and Testing Error vs Training Time")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("result_VAE/train_test_error_vs_time.png")
plt.close()

# 3) Training & testing loss vs epochs
plt.figure()
plt.plot(epochs_list, train_losses, marker='o', label="Train loss")
plt.plot(epochs_list, test_losses, marker='s', label="Test loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Testing Loss vs Epochs")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("result_VAE/loss_vs_epochs.png")
plt.close()
plt.show()
