In [2]:
import torch
from q1_train_vae import VAE
from q1_vae import log_likelihood_bernoulli, kl_gaussian_gaussian_analytic
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the trained model
model = torch.load('model.pt', map_location=device)
model.to(device).eval()

# Define the loss function
def loss_function(recon_x, x, mu, logvar):
    x_flat = x.view(x.size(0), -1)
    recon_loss = -log_likelihood_bernoulli(recon_x, x_flat).sum()
    kl = kl_gaussian_gaussian_analytic(
        mu, logvar,
        torch.zeros_like(mu), torch.zeros_like(logvar)
    ).sum()
    return recon_loss + kl

# Prepare validation data loader
batch_size = 128
transform = transforms.ToTensor()
val_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, download=True, transform=transform),
    batch_size=batch_size, shuffle=False
)

# Compute final validation loss
running_val = 0.0
with torch.no_grad():
    for data, _ in val_loader:
        data = data.to(device)
        recon_batch, mu, logvar = model(data)
        running_val += loss_function(recon_batch, data, mu, logvar).item()

final_val_loss = running_val / len(val_loader.dataset)
print(f"Final Validation Loss: {final_val_loss:.4f}")

# --------------------------------------------------
# Plot 1: Samples from prior
# --------------------------------------------------
with torch.no_grad():
    z = torch.randn(64, 20).to(device)
    samples = model.decode(z).cpu().view(64, 1, 28, 28)
grid = make_grid(samples, nrow=8)
plt.figure(figsize=(4,4))
plt.imshow(grid.squeeze(), cmap='gray')
plt.axis('off')
plt.title("Samples from VAE Prior")
plt.savefig("vae_samples.png", bbox_inches='tight')
plt.close()

# --------------------------------------------------
# Plot 2: Latent traversals (20 dims × 5 steps)
# --------------------------------------------------
eps_vals = torch.linspace(-3, 3, 5)
fig, axes = plt.subplots(20, 5, figsize=(10, 40))
with torch.no_grad():
    base_z = torch.randn(1, 20).to(device)
    for i in range(20):
        for j, eps in enumerate(eps_vals):
            z2 = base_z.clone()
            z2[0, i] += eps
            img = model.decode(z2).cpu().view(28, 28)
            axes[i, j].imshow(img, cmap='gray')
            axes[i, j].axis('off')
plt.suptitle("Latent Traversals")
plt.tight_layout()
plt.savefig("latent_traversals.png", bbox_inches='tight')
plt.close()

# --------------------------------------------------
# Plot 3: Interpolation in latent vs data space
# --------------------------------------------------
alphas = torch.linspace(0, 1, 11)
with torch.no_grad():
    z0 = torch.randn(1, 20).to(device)
    z1 = torch.randn(1, 20).to(device)
    latent_imgs = [model.decode(alpha*z0 + (1-alpha)*z1).cpu().view(28, 28) for alpha in alphas]
    x0 = model.decode(z0).cpu().view(28, 28)
    x1 = model.decode(z1).cpu().view(28, 28)
    data_imgs = [(alpha*x0 + (1-alpha)*x1) for alpha in alphas]

fig, axes = plt.subplots(2, len(alphas), figsize=(22, 4))
for idx in range(len(alphas)):
    axes[0, idx].imshow(latent_imgs[idx], cmap='gray')
    axes[0, idx].axis('off')
    axes[1, idx].imshow(data_imgs[idx], cmap='gray')
    axes[1, idx].axis('off')
axes[0, 0].set_ylabel("Latent Interpolation")
axes[1, 0].set_ylabel("Data Interpolation")
plt.tight_layout()
plt.savefig("interpolations.png", bbox_inches='tight')
plt.close()

print("Saved plots: vae_samples.png, latent_traversals.png, interpolations.png")


usage: ipykernel_launcher.py [-h] [--batch-size N] [--epochs N] [--no-cuda]
                             [--no-mps] [--seed S] [--log-interval N]
ipykernel_launcher.py: error: unrecognized arguments: --f=c:\Users\akobe\AppData\Roaming\jupyter\runtime\kernel-v34884347e61239f845f94a96b1c6449840b122f3d.json


SystemExit: 2

In [None]:


# --- model & optimizer ---
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# --- loss function using your q1_vae code ---
def loss_function(recon_x, x, mu, logvar):
    x_flat = x.view(x.size(0), -1)
    recon_loss = -log_likelihood_bernoulli(recon_x, x_flat).sum()
    kl = kl_gaussian_gaussian_analytic(mu, logvar,
                                       torch.zeros_like(mu),
                                       torch.zeros_like(logvar)).sum()
    return recon_loss + kl

# --- training & validation loops ---
num_epochs = 20
train_losses, val_losses = [], []

for epoch in range(1, num_epochs+1):
    model.train()
    running_train = 0.0
    for data, _ in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        recon, mu, logvar = model(data)
        loss = loss_function(recon, data, mu, logvar)
        loss.backward()
        running_train += loss.item()
        optimizer.step()
    train_losses.append(running_train / len(train_loader.dataset))

    model.eval()
    running_val = 0.0
    with torch.no_grad():
        for data, _ in val_loader:
            data = data.to(device)
            recon, mu, logvar = model(data)
            running_val += loss_function(recon, data, mu, logvar).item()
    val_losses.append(running_val / len(val_loader.dataset))
    
    print(f"Epoch {epoch:2d}  Train loss: {train_losses[-1]:.4f}  Val loss: {val_losses[-1]:.4f}")

# --- plot losses ---
plt.figure()
plt.plot(range(1, num_epochs+1), train_losses, label="Train")
plt.plot(range(1, num_epochs+1), val_losses, label="Val")
plt.xlabel("Epoch"), plt.ylabel("Avg Loss")
plt.legend(), plt.title("Training vs Validation Loss")
plt.show()

print(f"\nFinal Validation Loss: {val_losses[-1]:.4f}")

# --- samples from prior ---
model.eval()
with torch.no_grad():
    z = torch.randn(64, 20).to(device)
    samples = model.decode(z).cpu().view(64,1,28,28)
grid = make_grid(samples, nrow=8)
plt.figure(figsize=(6,6))
plt.imshow(grid.squeeze(), cmap='gray')
plt.axis('off'), plt.title("64 Samples from VAE Prior")
plt.show()

# --- latent traversals ---
eps_vals = torch.linspace(-3,3,5)
fig, axes = plt.subplots(20, 5, figsize=(10,40))
with torch.no_grad():
    base_z = torch.randn(1,20).to(device)
    for i in range(20):
        for j, eps in enumerate(eps_vals):
            z2 = base_z.clone()
            z2[0,i] += eps
            img = model.decode(z2).cpu().view(28,28)
            axes[i,j].imshow(img, cmap='gray')
            axes[i,j].axis('off')
plt.suptitle("Latent Traversals (rows=latent dim, cols=eps values)")
plt.tight_layout(), plt.show()

# --- interpolation latent vs data ---
alphas = torch.linspace(0,1,11)
with torch.no_grad():
    z0, z1 = torch.randn(1,20).to(device), torch.randn(1,20).to(device)
    lat_imgs = [model.decode(alpha*z0 + (1-alpha)*z1).cpu().view(28,28) for alpha in alphas]
    x0 = model.decode(z0).cpu().view(28,28)
    x1 = model.decode(z1).cpu().view(28,28)
    data_imgs = [(alpha*x0 + (1-alpha)*x1) for alpha in alphas]

fig, axes = plt.subplots(2, len(alphas), figsize=(22,4))
for idx in range(len(alphas)):
    axes[0,idx].imshow(lat_imgs[idx], cmap='gray'); axes[0,idx].axis('off')
    axes[1,idx].imshow(data_imgs[idx], cmap='gray'); axes[1,idx].axis('off')
axes[0,0].set_ylabel("Latent interp")
axes[1,0].set_ylabel("Data interp")
plt.tight_layout(), plt.show()
