In [1]:
import torch
from torch.utils.data import DataLoader
from auto_encoder_classes import VAE
from auto_encoder_functions import vae_loss

# --------------------------
# Minimal Dataset Definition
# --------------------------
import os
import numpy as np
from glob import glob
from PIL import Image

class MRISliceDataset(torch.utils.data.Dataset):
    """Simple dataset for MRI slices (no masks)."""
    def __init__(self, image_dir):
        self.image_paths = sorted(glob(os.path.join(image_dir, "*.png")))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("L")
        img = np.array(img, dtype=np.float32) / 255.0
        img = torch.tensor(img).unsqueeze(0)  # shape (1,H,W)
        return img

# --------------------------
# Config
# --------------------------
TRAIN_DIR = "../4.2/keras_png_slices_train"
VAL_DIR = "../4.2/keras_png_slices_validate"
BATCH_SIZE = 32
LATENT_DIM = 32
EPOCHS = 10
LR = 1e-3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = MRISliceDataset(TRAIN_DIR)
val_dataset = MRISliceDataset(VAL_DIR)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

vae = VAE(latent_dim=LATENT_DIM, input_shape=(1, 256, 256)).to(device)
optimizer = torch.optim.Adam(vae.parameters(), lr=LR)

best_val_loss = float("inf")
patience, max_patience = 0, 3  # early stopping

for epoch in range(EPOCHS):
    # --------------------------
    # Training
    # --------------------------
    vae.train()
    total_train_loss = 0
    for imgs in train_loader:
        imgs = imgs.to(device)
        optimizer.zero_grad()
        recon, mu, logvar = vae(imgs)
        loss = vae_loss(recon, imgs, mu, logvar)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)

    # --------------------------
    # Validation
    # --------------------------
    vae.eval()
    total_val_loss = 0
    with torch.no_grad():
        for imgs in val_loader:
            imgs = imgs.to(device)
            recon, mu, logvar = vae(imgs)
            total_val_loss += vae_loss(recon, imgs, mu, logvar).item()

    avg_val_loss = total_val_loss / len(val_loader)
    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

    # --------------------------
    # Checkpoint Saving
    # --------------------------
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(vae.state_dict(), f"vae_epoch{epoch+1}_loss{best_val_loss:.4f}.pth")
        print(f"✅ Saved new best model (val loss {best_val_loss:.4f})")
        patience = 0
    else:
        patience += 1
        if patience >= max_patience:
            print("⏹ Early stopping triggered.")
            break


Epoch 1/10 | Train Loss: 19674.9568 | Val Loss: 17796.1710
✅ Saved new best model (val loss 17796.1710)
Epoch 2/10 | Train Loss: 17513.4383 | Val Loss: 17485.0215
✅ Saved new best model (val loss 17485.0215)
Epoch 3/10 | Train Loss: 19312.5051 | Val Loss: 18487.5304
Epoch 4/10 | Train Loss: 2024285.7002 | Val Loss: 18373.2919
Epoch 5/10 | Train Loss: 18622.7894 | Val Loss: 18250.3958
⏹ Early stopping triggered.


In [None]:
import os
import numpy as np
from glob import glob
from PIL import Image
import torch
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from torch.utils.data import DataLoader

from auto_encoder_classes import VAE
from auto_encoder_functions import vae_loss

# --------------------------
# Define Dataset Here
# --------------------------
class MRISliceDataset(torch.utils.data.Dataset):
    """
    Simple dataset for MRI slices (no masks).
    Loads all .png files in the given directory.
    """
    def __init__(self, image_dir):
        self.image_paths = sorted(glob(os.path.join(image_dir, "*.png")))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("L")
        img = np.array(img, dtype=np.float32) / 255.0
        img = torch.tensor(img).unsqueeze(0)  # shape (1,H,W)
        return img

# --------------------------
# Config
# --------------------------
TEST_DIR = "../4.2/keras_png_slices_test"
BATCH_SIZE = 16
LATENT_DIM = 32  # must match training

# --------------------------
# Find Latest Model Checkpoint
# --------------------------
checkpoints = sorted(glob.glob("vae_epoch*.pth"), key=os.path.getmtime, reverse=True)
if not checkpoints:
    raise FileNotFoundError("❌ No VAE checkpoint found in this folder. Train the model first!")
MODEL_PATH = checkpoints[0]
print(f"🔎 Loading latest checkpoint: {MODEL_PATH}")

# --------------------------
# Load Dataset + Model
# --------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test_dataset = MRISliceDataset(TEST_DIR)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

vae = VAE(latent_dim=LATENT_DIM, input_shape=(1, 256, 256)).to(device)
vae.load_state_dict(torch.load(MODEL_PATH, map_location=device))
vae.eval()

# --------------------------
# 1. Compute Average Test Loss
# --------------------------
total_loss = 0.0
with torch.no_grad():
    for imgs in test_loader:
        imgs = imgs.to(device)
        recon, mu, logvar = vae(imgs)
        loss = vae_loss(recon, imgs, mu, logvar)
        total_loss += loss.item()

avg_loss = total_loss / len(test_loader)
print(f"Average Test Loss: {avg_loss:.4f}")

# --------------------------
# 2. Visualize Reconstructions
# --------------------------
imgs = next(iter(test_loader)).to(device)
with torch.no_grad():
    recon, mu, logvar = vae(imgs)

n_show = min(8, imgs.size(0))
plt.figure(figsize=(12, 4))
for i in range(n_show):
    # Input
    plt.subplot(2, n_show, i + 1)
    plt.imshow(imgs[i, 0].cpu(), cmap="gray")
    plt.axis("off")
    if i == 0:
        plt.ylabel("Input", fontsize=12)

    # Reconstruction
    plt.subplot(2, n_show, n_show + i + 1)
    plt.imshow(recon[i, 0].cpu(), cmap="gray")
    plt.axis("off")
    if i == 0:
        plt.ylabel("Recon", fontsize=12)

plt.suptitle("VAE Reconstructions", fontsize=14)
plt.tight_layout()
plt.show()

# --------------------------
# 3. Latent Space Visualization
# --------------------------
if LATENT_DIM > 2:
    latents = []
    with torch.no_grad():
        for imgs in test_loader:
            imgs = imgs.to(device)
            mu, _ = vae.encoder(imgs)
            latents.append(mu.cpu().numpy())
    latents = np.concatenate(latents, axis=0)

    z_2d = PCA(n_components=2).fit_transform(latents)
    plt.figure(figsize=(6, 6))
    plt.scatter(z_2d[:, 0], z_2d[:, 1], s=8, alpha=0.7)
    plt.title("Latent Space (PCA Projection)")
    plt.xlabel("PC1")
    plt.ylabel("PC2")
    plt.show()
else:
    all_mu = []
    with torch.no_grad():
        for imgs in test_loader:
            imgs = imgs.to(device)
            mu, _ = vae.encoder(imgs)
            all_mu.append(mu.cpu().numpy())
    all_mu = np.concatenate(all_mu, axis=0)
    plt.figure(figsize=(6, 6))
    plt.scatter(all_mu[:, 0], all_mu[:, 1], s=8, alpha=0.7)
    plt.title("Latent Space")
    plt.xlabel("z1")
    plt.ylabel("z2")
    plt.show()


AttributeError: 'function' object has no attribute 'glob'

In [None]:
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
from torch.utils.data import DataLoader

from auto_encoder_classes import VAE, MRISliceDataset
from auto_encoder_functions import vae_loss

# --------------------------
# Config
# --------------------------
MODEL_PATH = "best_vae.pth"
TEST_DIR = "../4.2/keras_png_slices_test"  # points to data in part_4_scirpts/4.2
BATCH_SIZE = 16
LATENT_DIM = 32  # must match what you trained with

# --------------------------
# Load Dataset + Model
# --------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test_dataset = MRISliceDataset(TEST_DIR)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

vae = VAE(latent_dim=LATENT_DIM).to(device)
vae.load_state_dict(torch.load(MODEL_PATH, map_location=device))
vae.eval()

# --------------------------
# 1. Compute Average Test Loss
# --------------------------
total_loss = 0.0
with torch.no_grad():
    for imgs in test_loader:
        imgs = imgs.to(device)
        recon, mu, logvar = vae(imgs)
        loss = vae_loss(recon, imgs, mu, logvar)
        total_loss += loss.item()

avg_loss = total_loss / len(test_loader)
print(f"Average Test Loss: {avg_loss:.4f}")

# --------------------------
# 2. Visualize Reconstructions
# --------------------------
imgs = next(iter(test_loader)).to(device)
with torch.no_grad():
    recon, mu, logvar = vae(imgs)

n_show = min(8, imgs.size(0))
plt.figure(figsize=(12, 4))
for i in range(n_show):
    # Input
    plt.subplot(2, n_show, i + 1)
    plt.imshow(imgs[i, 0].cpu(), cmap="gray")
    plt.axis("off")
    if i == 0:
        plt.ylabel("Input", fontsize=12)
    # Reconstruction
    plt.subplot(2, n_show, n_show + i + 1)
    plt.imshow(recon[i, 0].cpu(), cmap="gray")
    plt.axis("off")
    if i == 0:
        plt.ylabel("Recon", fontsize=12)

plt.suptitle("VAE Reconstructions", fontsize=14)
plt.tight_layout()
plt.show()

# --------------------------
# 3. Latent Space Visualization
# --------------------------
if LATENT_DIM > 2:
    latents = []
    with torch.no_grad():
        for imgs in test_loader:
            imgs = imgs.to(device)
            mu, _ = vae.encoder(imgs)
            latents.append(mu.cpu().numpy())
    latents = np.concatenate(latents, axis=0)

    z_2d = PCA(n_components=2).fit_transform(latents)
    plt.figure(figsize=(6, 6))
    plt.scatter(z_2d[:, 0], z_2d[:, 1], s=8, alpha=0.7)
    plt.title("Latent Space (PCA Projection)")
    plt.xlabel("PC1")
    plt.ylabel("PC2")
    plt.show()
else:
    # Direct plot if latent_dim == 2
    all_mu = []
    with torch.no_grad():
        for imgs in test_loader:
            imgs = imgs.to(device)
            mu, _ = vae.encoder(imgs)
            all_mu.append(mu.cpu().numpy())
    all_mu = np.concatenate(all_mu, axis=0)
    plt.figure(figsize=(6, 6))
    plt.scatter(all_mu[:, 0], all_mu[:, 1], s=8, alpha=0.7)
    plt.title("Latent Space")
    plt.xlabel("z1")
    plt.ylabel("z2")
    plt.show()


  vae.load_state_dict(torch.load(MODEL_PATH, map_location=device))


FileNotFoundError: [Errno 2] No such file or directory: 'best_vae.pth'