In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd /content/drive/MyDrive/master/ivae

/content/drive/MyDrive/master/ivae


In [None]:
# ============================================
# analysis.py — Compare VAE vs iVAE-trained
# Reconstruction error (MSE) over iterative inference
# ============================================

import os
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

# --- Imports from your project ---
# Adjust if you keep everything in one file
from models import VAE, Encoder, Decoder, iVAETrainer
from losses import elbo_loss   # or import from your training script
from trainer import get_dataloaders, device

ModuleNotFoundError: No module named 'models'

# MSE

In [None]:
# ============================================
# Utility: Mean Squared Error (MSE) metric
# ============================================
def recon_mse(recon, target):
    """
    Computes per-sample MSE between reconstructions and targets.
    recon, target: [B, D] in [0,1]
    Returns:
        mse_per_sample: numpy array [B]
        mean_mse: scalar
    """
    mse_per_sample = F.mse_loss(recon, target, reduction='none').mean(dim=1)
    return mse_per_sample.cpu().numpy(), float(mse_per_sample.mean())

#model

In [None]:
# ============================================
# Model loader
# ============================================
def load_models(cfg, device):
    """
    Loads trained VAE and iVAE models from disk.
    Returns:
        vae, ivae_trainer, ivae_encoder, ivae_decoder
    """
    # --- Load baseline VAE ---
    vae = VAE(z_dim=cfg['z_dim']).to(device)
    vae_ckpt = torch.load(os.path.join(cfg['save_dir'], "vae_baseline.pth"), map_location=device)
    vae.load_state_dict(vae_ckpt)

    # --- Load iVAE checkpoint ---
    ivae_ckpt = torch.load(os.path.join(cfg['save_dir'], "ivae.pth"), map_location=device)
    ivae_enc = Encoder(z_dim=cfg['z_dim']).to(device)
    ivae_dec = Decoder(z_dim=cfg['z_dim']).to(device)
    ivae_enc.load_state_dict(ivae_ckpt["encoder"])
    ivae_dec.load_state_dict(ivae_ckpt["decoder"])

    ivae_trainer = iVAETrainer(
        ivae_enc, ivae_dec,
        latent_dim=cfg['z_dim'],
        beta=cfg['beta'],
        lr_model=cfg['lr_model'],
        lr_inf=cfg['lr_inf'],
        device=device,
        use_amp=False
    )

    return vae, ivae_trainer, ivae_enc, ivae_dec

# Iterative inference evaluation

In [None]:
def iterative_recon_mse(encoder, decoder, x, n_steps, lr_eval, device, beta,
                        update_decoder=False, save_path=None, save_latent_traj=False):
    """
    Run iterative SVI refinement starting from amortized encoder outputs.
    Optionally fine-tune decoder, and record latent evolution.
    Returns:
        mses_over_steps: numpy array [n_steps+1]
        recons_list: list of recon tensors (optional)
        z_traj: list of latent means per iteration [n_steps+1, B, z_dim] if save_latent_traj=True
    """
    encoder.eval()
    decoder.train() if update_decoder else decoder.eval()
    x = x.to(device)

    opt_dec = torch.optim.Adam(decoder.parameters(), lr=lr_eval) if update_decoder else None

    with torch.no_grad():
        mu0, logvar0 = encoder(x)

    mu = mu0.clone().detach().to(device).requires_grad_(True)
    logvar = logvar0.clone().detach().to(device).requires_grad_(True)

    mses, recons = [], []
    z_traj = [mu.detach().cpu().numpy()] if save_latent_traj else None

    # Step 0 reconstruction
    with torch.no_grad():
        z = mu
        x_logit = decoder(z)
        x_recon = torch.sigmoid(x_logit)
        recons.append(x_recon.detach().cpu())
        _, mean_mse0 = recon_mse(x_recon.view(x_recon.size(0), -1),
                                 x.view(x.size(0), -1))
        mses.append(mean_mse0)

    for t in range(1, n_steps + 1):
        loss, _ = elbo_loss(x, decoder, mu, logvar, beta=beta)
        grads = torch.autograd.grad(loss, [mu, logvar], retain_graph=False, create_graph=False)

        with torch.no_grad():
            mu -= lr_eval * grads[0]
            logvar -= lr_eval * grads[1]
            mu.requires_grad_(True)
            logvar.requires_grad_(True)

        # optional decoder fine-tuning
        if update_decoder:
            opt_dec.zero_grad()
            loss.backward()
            opt_dec.step()

        with torch.no_grad():
            x_logit = decoder(mu)
            x_recon = torch.sigmoid(x_logit)
            recons.append(x_recon.detach().cpu())
            _, mean_mse = recon_mse(x_recon.view(x_recon.size(0), -1),
                                    x.view(x.size(0), -1))
            mses.append(mean_mse)

        if save_latent_traj:
            z_traj.append(mu.detach().cpu().numpy())

    if update_decoder and save_path is not None:
        os.makedirs(save_path, exist_ok=True)
        torch.save(decoder.state_dict(), os.path.join(save_path, "decoder_test.pth"))
        print(f"[✓] Saved test-adapted decoder to {os.path.join(save_path, 'decoder_test.pth')}")

    if save_latent_traj:
        return np.array(mses), recons, np.stack(z_traj)  # [steps+1, B, z_dim]
    else:
        return np.array(mses), recons


#plots and analytics

In [None]:
# ============================================
# Main evaluation function
# ============================================
def run_analysis(cfg, test_loader, device, n_examples=256, n_steps=50,
                 lr_eval_factor=0.1, show_plot=True):
    """
    Compare reconstruction MSE between:
        1. VAE baseline (amortized)
        2. iVAE-trained (iterative inference)
    Returns:
        mse_curves = [vae_mse_constant, ivae_mse_curve]
    """
    # Load models
    vae, ivae_trainer, ivae_enc, ivae_dec = load_models(cfg, device)

    # Collect test samples
    imgs = []
    for x, _ in test_loader:
        imgs.append(x)
        if len(torch.cat(imgs)) >= n_examples:
            break
    imgs = torch.cat(imgs, dim=0)[:n_examples].to(device)
    x_flat = imgs.view(imgs.size(0), -1)

    # --- (1) VAE baseline amortized ---
    vae.eval()
    with torch.no_grad():
        mu_vae, logvar_vae = vae.encoder(x_flat)
        x_logit = vae.decoder(mu_vae)
        x_recon = torch.sigmoid(x_logit)
        _, vae_mse = recon_mse(x_recon.view(x_recon.size(0), -1),
                               x_flat)
    print(f"VAE amortized reconstruction MSE: {vae_mse:.4f}")

    # --- (2) iVAE-trained iterative inference ---
    lr_eval = ivae_trainer.lr_inf * lr_eval_factor
    ivae_mses, _ = iterative_recon_mse(
        ivae_enc, ivae_dec, x_flat, n_steps, lr_eval, device, beta=cfg['beta']
    )
    print(f"iVAE MSE after {n_steps} steps: {ivae_mses[-1]:.4f}")

    # --- Aggregate results ---
    vae_constant = np.ones_like(ivae_mses) * vae_mse
    mse_curves = [vae_constant, ivae_mses]

    # --- Optional Plot ---
    if show_plot:
        iters = np.arange(0, n_steps + 1)
        plt.figure(figsize=(7, 5))
        plt.plot(iters, vae_constant, '--', label='VAE (amortized baseline)')
        plt.plot(iters, ivae_mses, label='iVAE (iterative refinement)')
        plt.xlabel("SVI Iteration")
        plt.ylabel("Reconstruction MSE (↓)")
        plt.title("Reconstruction Error vs SVI Iterations")
        plt.legend()
        plt.grid(True)
        plt.gca().invert_yaxis()  # lower is better
        plt.show()

    return mse_curves



In [None]:
def plot_latent_evolution(decoder, z_traj, x_true=None, num_samples=5, num_steps_to_show=5):
    """
    Visualize decoded reconstructions along the latent refinement trajectory.
    z_traj: numpy array [steps+1, B, z_dim]
    Shows grid (num_samples x num_steps_to_show) of decoded images.
    """
    steps, B, zdim = z_traj.shape
    step_idxs = np.linspace(0, steps-1, num_steps_to_show, dtype=int)
    num_samples = min(num_samples, B)

    fig, axes = plt.subplots(num_samples, num_steps_to_show+1, figsize=(2.2*(num_steps_to_show+1), 2*num_samples))

    decoder.eval()
    for i in range(num_samples):
        for j, step in enumerate(step_idxs):
            z = torch.tensor(z_traj[step, i], dtype=torch.float32).unsqueeze(0).to(next(decoder.parameters()).device)
            with torch.no_grad():
                x_recon = torch.sigmoid(decoder(z)).cpu().view(28, 28)
            axes[i, j].imshow(x_recon, cmap='gray', vmin=0, vmax=1)
            axes[i, j].axis('off')
            if i == 0:
                axes[i, j].set_title(f"iter {step}")
        # optionally show ground truth
        if x_true is not None:
            axes[i, -1].imshow(x_true[i].view(28, 28), cmap='gray', vmin=0, vmax=1)
            axes[i, -1].axis('off')
            if i == 0:
                axes[i, -1].set_title("target")

    plt.tight_layout()
    plt.show()


# Run

In [None]:
# ============================================
# Example usage (Colab cell)
# ============================================
if __name__ == "__main__":
    from config import CONFIG  # or define inline
    _, test_loader = get_dataloaders(CONFIG)
    curves = run_analysis(CONFIG, test_loader, device, n_examples=256, n_steps=50)

In [None]:
# Run iterative inference and record latents
ivae_mses, _, z_traj = iterative_recon_mse(
    ivae_enc, ivae_dec, x_flat[:5], n_steps=cfg['svi_steps_eval']//10,
    lr_eval=ivae_trainer.lr_inf*0.1, device=device, beta=cfg['beta'],
    save_latent_traj=True
)

# Plot decoded trajectory evolution
plot_latent_evolution(ivae_dec, z_traj, x_true=x_flat[:5].cpu(), num_samples=5, num_steps_to_show=5)


# Out-of-distribution

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as T
import torch.nn.functional as F
from torchvision.transforms import functional as TF
import random
import cv2

# ================================================
# 1. CORRUPTION FUNCTIONS
# ================================================
def add_white_noise(x, sigma=0.6):
    """Add Gaussian white noise to input [B,1,28,28]."""
    noise = torch.randn_like(x) * sigma
    x_noisy = x + noise
    return torch.clamp(x_noisy, 0.0, 1.0)

def add_gaussian_blur(x, sigma=2.0):
    """Apply Gaussian blur (approximation using OpenCV per sample)."""
    x_np = x.cpu().numpy()
    x_blur = []
    for i in range(x_np.shape[0]):
        img = (x_np[i,0]*255).astype(np.uint8)
        img_blur = cv2.GaussianBlur(img, (5,5), sigma)
        x_blur.append(img_blur / 255.0)
    x_blur = torch.tensor(np.stack(x_blur), dtype=torch.float32).unsqueeze(1)
    return x_blur.to(x.device)

def add_salt_pepper_noise(x, p=0.4):
    """Apply salt & pepper noise with probability p."""
    x_np = x.cpu().numpy()
    noisy_imgs = []
    for i in range(x_np.shape[0]):
        img = x_np[i,0].copy()
        mask = np.random.rand(*img.shape)
        img[mask < (p/2)] = 0.0
        img[mask > (1 - p/2)] = 1.0
        noisy_imgs.append(img)
    noisy_imgs = torch.tensor(np.stack(noisy_imgs), dtype=torch.float32).unsqueeze(1)
    return noisy_imgs.to(x.device)

# ================================================
# 2. EVALUATION LOOP
# ================================================
def eval_corrupted_images(model_enc, model_dec, test_imgs, corruption_fn, corr_name,
                          n_steps, lr_eval, beta, device, return_recons=False):
    """Apply corruption, run iterative inference, compute MSE."""
    x_corr = corruption_fn(test_imgs)
    x_flat = x_corr.view(x_corr.size(0), -1)
    accs, recons = iterative_recon_mse(
        model_enc, model_dec, x_flat, n_steps=n_steps, lr_eval=lr_eval,
        device=device, beta=beta, save_latent_traj=False
    )
    final_mse = accs[-1]
    if return_recons:
        return final_mse, x_corr, recons
    return final_mse, x_corr

# ================================================
# 3. VISUALIZATION HELPERS
# ================================================
def plot_reconstruction_timeline(decoder, z_traj, x_true, x_corr, corruption_label,
                                 num_steps_to_show=5):
    """Plot one sample's reconstruction evolution given corrupted input."""
    steps = len(z_traj)
    step_idxs = np.linspace(0, steps-1, num_steps_to_show, dtype=int)
    fig, axes = plt.subplots(1, num_steps_to_show+2, figsize=(2.5*(num_steps_to_show+2), 3))

    decoder.eval()
    for j, step in enumerate(step_idxs):
        z = torch.tensor(z_traj[step], dtype=torch.float32).unsqueeze(0).to(next(decoder.parameters()).device)
        with torch.no_grad():
            x_recon = torch.sigmoid(decoder(z)).cpu().view(28, 28)
        axes[j+1].imshow(x_recon, cmap='gray', vmin=0, vmax=1)
        axes[j+1].axis('off')
        axes[j+1].set_title(f"iter {step}")

    axes[0].imshow(x_corr.cpu().view(28,28), cmap='gray', vmin=0, vmax=1)
    axes[0].axis('off')
    axes[0].set_title(f"Input ({corruption_label})")

    axes[-1].imshow(x_true.cpu().view(28,28), cmap='gray', vmin=0, vmax=1)
    axes[-1].axis('off')
    axes[-1].set_title("Target")

    plt.tight_layout()
    plt.show()

def plot_ood_accuracy_bars(vae_scores, ivae_scores, labels):
    """Bar plot comparing VAE vs iVAE mean MSE (lower is better)."""
    vae_means = [np.mean(s) for s in vae_scores]
    vae_stds = [np.std(s) for s in vae_scores]
    ivae_means = [np.mean(s) for s in ivae_scores]
    ivae_stds = [np.std(s) for s in ivae_scores]

    x = np.arange(len(labels))
    width = 0.35

    fig, ax = plt.subplots(figsize=(8,5))
    ax.bar(x - width/2, vae_means, width, yerr=vae_stds, label='VAE', capsize=4)
    ax.bar(x + width/2, ivae_means, width, yerr=ivae_stds, label='iVAE', capsize=4)

    ax.set_ylabel('Reconstruction MSE ↓')
    ax.set_title('Reconstruction error across corruption types')
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.legend()
    plt.grid(True, linestyle='--', alpha=0.4)
    plt.show()

# ================================================
# 4. OOD TEST DRIVER
# ================================================
def run_ood_analysis(cfg, vae, ivae_enc, ivae_dec, test_loader, device):
    vae.eval(); ivae_enc.eval(); ivae_dec.eval()

    # sample a batch
    x, _ = next(iter(test_loader))
    x = x.to(device)
    x_flat = x.view(x.size(0), -1)

    n_steps = cfg['svi_steps_eval']
    lr_eval = cfg['lr_inf'] * 0.1

    # --- Corruptions ---
    corruption_tasks = [
        ("Vanilla", lambda x: x),
        ("Blur σ=2", lambda x: add_gaussian_blur(x, sigma=2.0)),
        ("White noise σ=0.6", lambda x: add_white_noise(x, sigma=0.6)),
        ("Salt&Pepper p=0.4", lambda x: add_salt_pepper_noise(x, p=0.4)),
    ]

    vae_scores = []
    ivae_scores = []

    for name, corr_fn in corruption_tasks:
        print(f"Testing corruption: {name}")

        # --- VAE baseline ---
        with torch.no_grad():
            x_corr = corr_fn(x)
            mu_vae, logvar_vae = vae.encoder(x_corr.view(x_corr.size(0), -1))
            recon = torch.sigmoid(vae.decoder(mu_vae))
            _, mse_vae = recon_mse(recon, x_corr.view(x_corr.size(0), -1))
        vae_scores.append([mse_vae])

        # --- iVAE iterative ---
        mse_ivae, _, _ = iterative_recon_mse(
            ivae_enc, ivae_dec, x_corr.view(x_corr.size(0), -1),
            n_steps=n_steps, lr_eval=lr_eval, device=device, beta=cfg['beta'],
            save_latent_traj=False
        )
        ivae_scores.append([mse_ivae[-1]])

    # --- Visualization: comparison bar plot ---
    plot_ood_accuracy_bars(vae_scores, ivae_scores,
                           ["Vanilla", "Blur", "WhiteNoise", "Salt&Pepper"])

    # --- Timeline reconstruction example ---
    print("\nVisualizing reconstruction timeline for one sample per corruption...")
    for name, corr_fn in corruption_tasks[1:]:
        x_corr = corr_fn(x[:1])
        _, _, z_traj = iterative_recon_mse(
            ivae_enc, ivae_dec, x_corr.view(1, -1),
            n_steps=n_steps//10, lr_eval=lr_eval,
            device=device, beta=cfg['beta'],
            save_latent_traj=True
        )
        plot_reconstruction_timeline(
            ivae_dec, z_traj[:,0,:], x_true=x[0], x_corr=x_corr[0],
            corruption_label=name, num_steps_to_show=5
        )
