In [18]:

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from models import BaseVAE
from torch.nn import functional as F
from models.simple_vae import SimpleVAE  #The model that we simulate

In [19]:
CONFIG = {
    "n_samples": 4000,
    "latent_dim": 2,
    "hidden_dims": [64, 64, 64, 64], 
    "batch_size": 64,
    "num_epochs": 100,
    "learning_rate": 1e-3,
    "beta_values": [0.0, 0.5, 1.0, 5.0],
    "plot_dir": "plots",
    "seed": 42,
}


In [None]:
# Data
def generate_swiss_roll(n_samples):
    """
    Generate Swiss Roll(3D)
    """
    s = np.random.uniform(1.5 * np.pi, 4.5 * np.pi, n_samples).astype(np.float32)
    t = np.random.uniform(0, 15, n_samples).astype(np.float32)
    
    x = s * np.cos(s)
    y = t
    z = s * np.sin(s)
    coords = np.stack([x, y, z], axis=1)
    
    # Normalize to [-1, 1]
    mins, maxs = coords.min(axis=0), coords.max(axis=0)
    data = 2.0 * (coords - mins) / (maxs - mins + 1e-8) - 1.0
    
    s_norm = (s - s.min()) / (s.max() - s.min() + 1e-8)
    t_norm = (t - t.min()) / (t.max() - t.min() + 1e-8)
    
    return data.astype(np.float32), s_norm, t_norm, s, t


def generate_three_circles(n_samples):
    """
    Generate three nested circles (2D).
    """
    n_per_circle = n_samples // 3
    
    radii = [0.3, 0.6, 0.9]
    data_list = []
    labels = []
    
    for i, r in enumerate(radii):
        theta = np.random.uniform(0, 2 * np.pi, n_per_circle)
        noise = np.random.normal(0, 0.02, n_per_circle)
        x = (r + noise) * np.cos(theta)
        y = (r + noise) * np.sin(theta)
        data_list.append(np.stack([x, y], axis=1))
        labels.extend([i] * n_per_circle)
    
    data = np.vstack(data_list).astype(np.float32)
    labels = np.array(labels)
    
    # Normalize theta for coloring (angle around circle)
    theta_all = np.arctan2(data[:, 1], data[:, 0])
    theta_norm = (theta_all + np.pi) / (2 * np.pi)
    
    return data, theta_norm, labels


def generate_smile_face(n_samples=500):
    """
    Generate a simple smile face (2D).
    """
    n_face = n_samples // 2
    n_eye = n_samples // 8
    n_smile = n_samples - n_face - 2 * n_eye
    
    data_list = []
    labels = []
    
    # Face outline (circle)
    theta = np.random.uniform(0, 2 * np.pi, n_face)
    noise = np.random.normal(0, 0.02, n_face)
    x = (0.9 + noise) * np.cos(theta)
    y = (0.9 + noise) * np.sin(theta)
    data_list.append(np.stack([x, y], axis=1))
    labels.extend([0] * n_face)
    
    # Left eye (small filled circle)
    r = np.random.uniform(0, 0.1, n_eye)
    theta = np.random.uniform(0, 2 * np.pi, n_eye)
    x = -0.35 + r * np.cos(theta)
    y = 0.3 + r * np.sin(theta)
    data_list.append(np.stack([x, y], axis=1))
    labels.extend([1] * n_eye)
    
    # Right eye (small filled circle)
    r = np.random.uniform(0, 0.1, n_eye)
    theta = np.random.uniform(0, 2 * np.pi, n_eye)
    x = 0.35 + r * np.cos(theta)
    y = 0.3 + r * np.sin(theta)
    data_list.append(np.stack([x, y], axis=1))
    labels.extend([2] * n_eye)
    
    # Smile (arc)
    theta = np.random.uniform(-0.8 * np.pi, -0.2 * np.pi, n_smile)
    noise = np.random.normal(0, 0.02, n_smile)
    x = (0.5 + noise) * np.cos(theta)
    y = (0.5 + noise) * np.sin(theta) + 0.1
    data_list.append(np.stack([x, y], axis=1))
    labels.extend([3] * n_smile)
    
    data = np.vstack(data_list).astype(np.float32)
    labels = np.array(labels)
    
    # Color by angle from center
    theta_all = np.arctan2(data[:, 1], data[:, 0])
    theta_norm = (theta_all + np.pi) / (2 * np.pi)
    
    return data, theta_norm, labels


def generate_single_circle(n_samples=500):
    """
    Generate a single circle (2D).
    """
    theta = np.random.uniform(0, 2 * np.pi, n_samples).astype(np.float32)
    noise = np.random.normal(0, 0.02, n_samples).astype(np.float32)
    
    x = (1.0 + noise) * np.cos(theta)
    y = (1.0 + noise) * np.sin(theta)
    data = np.stack([x, y], axis=1).astype(np.float32)
    
    theta_norm = (theta - theta.min()) / (theta.max() - theta.min() + 1e-8)
    
    return data, theta_norm, theta


In [21]:
# Training
def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def train_vae(data, beta, config):
    device = get_device()

    # Turn numpy data to TensorDataset
    dataset = TensorDataset(torch.from_numpy(data).float())
    loader = DataLoader(dataset,
                        batch_size=config["batch_size"],
                        shuffle=True)

    # Initialize model
    model = SimpleVAE(
        input_dim=data.shape[1],
        latent_dim=config["latent_dim"],
        hidden_dims=config.get("hidden_dims", None), 
        beta=beta,
    ).to(device)

    optimizer = optim.Adam(model.parameters(),
                           lr=config["learning_rate"])
    history = {"loss": [], "recon": [], "kld": [], "beta": []}

    dataset_size = len(dataset)

    for epoch in range(config["num_epochs"]):
        current_beta = beta
        model.beta = current_beta  # keep internal beta

        model.train()
        epoch_loss, epoch_recon, epoch_kld = 0.0, 0.0, 0.0

        for (batch,) in loader:
            batch = batch.to(device)

            # Forward: [recons, input, mu, log_var]
            recons, x_in, mu, log_var = model(batch)

            # Minibatch weight for KL term (M_N ≈ B / N)
            B = batch.size(0)
            kld_weight = B / dataset_size

            loss_dict = model.loss_function(
                recons, x_in, mu, log_var, M_N=kld_weight
            )
            loss = loss_dict["loss"]

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Accumulate epoch stats
            epoch_loss += loss.item() * B
            epoch_recon += loss_dict["Reconstruction_Loss"].item() * B
            epoch_kld += (-loss_dict["KLD"].item()) * B  

        n_total = dataset_size
        history["loss"].append(epoch_loss / n_total)
        history["recon"].append(epoch_recon / n_total)
        history["kld"].append(epoch_kld / n_total)
        history["beta"].append(current_beta)

    return model, history

def get_latent_embeddings(model, data):
    """Get latent means for all data points."""
    device = get_device()
    model.eval()
    with torch.no_grad():
        mu, _ = model.encode(torch.from_numpy(data).to(device))
        return mu.cpu().numpy()

In [None]:
def ensure_plot_dir(plot_dir):
    os.makedirs(plot_dir, exist_ok=True)


def _scatter_2d(ax, data, labels=None, colors=None, cmap='viridis',
                s=10, alpha=0.7, add_legend=True):
    if labels is not None and len(np.unique(labels)) <= 10:
        for lab in np.unique(labels):
            mask = labels == lab
            ax.scatter(data[mask, 0], data[mask, 1], s=s, alpha=alpha, label=f'Class {lab}')
        if add_legend:
            ax.legend()
    else:
        ax.scatter(data[:, 0], data[:, 1], c=colors, cmap=cmap, s=s, alpha=alpha)


def _train_across_betas(data, config):
    models, histories = [], []
    for beta in config["beta_values"]:
        model, history = train_vae(data, beta, config)
        models.append(model)
        histories.append(history)
    return models, histories


# 2D Plots #
def plot_2d_data(data, colors, labels, title, plot_dir, filename):
    """Plot 2D dataset."""
    ensure_plot_dir(plot_dir)
    fig, ax = plt.subplots(figsize=(6, 6))

    _scatter_2d(ax, data, labels=labels, colors=colors)
    if labels is None or len(np.unique(labels)) > 10:
        sc = ax.collections[0]
        plt.colorbar(sc, ax=ax)

    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_title(title)
    ax.set_aspect('equal')

    plt.tight_layout()
    fig.savefig(os.path.join(plot_dir, filename), dpi=150, bbox_inches="tight")
    plt.close(fig)


def plot_2d_latent_comparison(models, beta_values, data, colors, labels, plot_dir, prefix):
    """Compare latent spaces for 2D input data."""
    ensure_plot_dir(plot_dir)
    n_betas = len(beta_values)
    fig, axes = plt.subplots(1, n_betas + 1, figsize=(4 * (n_betas + 1), 4))

    # Original data
    _scatter_2d(axes[0], data, labels=labels, colors=colors, add_legend=False)
    axes[0].set_title('Original Data')
    axes[0].set_aspect('equal')

    # Latent embeddings for each beta
    for i, (model, beta) in enumerate(zip(models, beta_values), start=1):
        mu = get_latent_embeddings(model, data)
        _scatter_2d(axes[i], mu, labels=labels, colors=colors, add_legend=False)
        axes[i].set_title(f'β={beta}')
        axes[i].set_xlabel('$z_1$')
        axes[i].set_ylabel('$z_2$')
        axes[i].set_aspect('equal')

    plt.tight_layout()
    fig.savefig(os.path.join(plot_dir, f"{prefix}_latent_comparison.png"),
                dpi=150, bbox_inches="tight")
    plt.close(fig)

def plot_2d_reconstruction(models, beta_values, data, colors, labels, plot_dir, prefix):
    """Show reconstruction quality for 2D data."""
    ensure_plot_dir(plot_dir)
    device = get_device()
    n_betas = len(beta_values)

    fig, axes = plt.subplots(1, n_betas + 1, figsize=(4 * (n_betas + 1), 4))

    # Original data
    _scatter_2d(axes[0], data, labels=labels, colors=colors, add_legend=False)
    axes[0].set_title('Original')
    axes[0].set_aspect('equal')

    for i, (model, beta) in enumerate(zip(models, beta_values), start=1):
        model.eval()
        with torch.no_grad():
            x = torch.from_numpy(data).float().to(device)
            x_recon, _, _, _ = model(x)   # unpack 4 outputs
            x_recon = x_recon.cpu().numpy()

        # Reconstruction
        _scatter_2d(axes[i], x_recon, labels=labels, colors=colors, add_legend=False)
        axes[i].set_title(f'Recon β={beta}')
        axes[i].set_aspect('equal')

    plt.tight_layout()
    fig.savefig(os.path.join(plot_dir, f"{prefix}_reconstruction.png"),
                dpi=150, bbox_inches="tight")
    plt.close(fig)


# 3D Plot
def plot_swiss_roll_3d(data, s_norm, t_norm, plot_dir):
    """Plot the 3D Swiss roll."""
    ensure_plot_dir(plot_dir)
    fig = plt.figure(figsize=(12, 5))

    ax1 = fig.add_subplot(121, projection='3d')
    ax1.scatter(data[:, 0], data[:, 1], data[:, 2],
                c=s_norm, cmap='viridis', s=5, alpha=0.7)
    ax1.set_xlabel('X')
    ax1.set_ylabel('Y')
    ax1.set_zlabel('Z')
    ax1.set_title('Colored by s (angle)')

    ax2 = fig.add_subplot(122, projection='3d')
    ax2.scatter(data[:, 0], data[:, 1], data[:, 2],
                c=t_norm, cmap='plasma', s=5, alpha=0.7)
    ax2.set_xlabel('X')
    ax2.set_ylabel('Y')
    ax2.set_zlabel('Z')
    ax2.set_title('Colored by t (height)')

    plt.tight_layout()
    fig.savefig(os.path.join(plot_dir, "swiss_roll_3d.png"),
                dpi=150, bbox_inches="tight")
    plt.close(fig)


def plot_latent_comparison(models, beta_values, data, s_norm, t_norm, plot_dir):
    """Compare latent spaces for different beta values"""
    ensure_plot_dir(plot_dir)
    n_betas = len(beta_values)
    fig, axes = plt.subplots(n_betas, 2, figsize=(10, 4 * n_betas))

    if n_betas == 1:
        axes = axes.reshape(1, -1)

    for i, (model, beta) in enumerate(zip(models, beta_values)):
        mu = get_latent_embeddings(model, data)

        sc1 = axes[i, 0].scatter(mu[:, 0], mu[:, 1],
                                 c=s_norm, cmap='viridis', s=3, alpha=0.7)
        axes[i, 0].set_xlabel('$z_1$')
        axes[i, 0].set_ylabel('$z_2$')
        axes[i, 0].set_title(f'β={beta} (colored by s)')
        axes[i, 0].set_aspect('equal', adjustable='box')
        fig.colorbar(sc1, ax=axes[i, 0])

        sc2 = axes[i, 1].scatter(mu[:, 0], mu[:, 1],
                                 c=t_norm, cmap='plasma', s=3, alpha=0.7)
        axes[i, 1].set_xlabel('$z_1$')
        axes[i, 1].set_ylabel('$z_2$')
        axes[i, 1].set_title(f'β={beta} (colored by t)')
        axes[i, 1].set_aspect('equal', adjustable='box')
        fig.colorbar(sc2, ax=axes[i, 1])

    plt.tight_layout()
    fig.savefig(os.path.join(plot_dir, "latent_comparison.png"),
                dpi=150, bbox_inches="tight")
    plt.close(fig)


def plot_reconstruction_quality(model, data, s_norm, beta, plot_dir):
    """Show reconstructed data for Swiss roll (3D), no original / prior samples."""
    ensure_plot_dir(plot_dir)
    device = get_device()
    model.eval()

    with torch.no_grad():
        x = torch.from_numpy(data).float().to(device)
        # assuming model(x) returns [recons, input, mu, log_var]
        x_recon, _, _, _ = model(x)
        x_recon = x_recon.cpu().numpy()

    fig = plt.figure(figsize=(6, 5))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(x_recon[:, 0], x_recon[:, 1], x_recon[:, 2],
               c=s_norm, cmap='viridis', s=3, alpha=0.7)
    ax.set_title(f'Reconstructed (β={beta})')

    plt.tight_layout()
    fig.savefig(os.path.join(plot_dir, f"reconstruction_beta_{beta}.png"),
                dpi=150, bbox_inches="tight")
    plt.close(fig)

    # still compute and print MSE vs original
    mse = np.mean((data - x_recon) ** 2)

# def plot_reconstruction_quality(model, data, s_norm, beta, plot_dir):
#     """Show original vs reconstructed data for Swiss roll."""
#     ensure_plot_dir(plot_dir)
#     device = get_device()
#     model.eval()

#     with torch.no_grad():
#         x = torch.from_numpy(data).float().to(device)
#         x_recon, _, _, _ = model(x)   # unpack 4 outputs
#         x_recon = x_recon.cpu().numpy()

#         z_prior = torch.randn(2000, model.latent_dim).to(device)
#         x_samples = model.decode(z_prior).cpu().numpy()

#     fig = plt.figure(figsize=(18, 5))

#     ax1 = fig.add_subplot(131, projection='3d')
#     ax1.scatter(data[:, 0], data[:, 1], data[:, 2],
#                 c=s_norm, cmap='viridis', s=3, alpha=0.7)
#     ax1.set_title('Original')

#     ax2 = fig.add_subplot(132, projection='3d')
#     ax2.scatter(x_recon[:, 0], x_recon[:, 1], x_recon[:, 2],
#                 c=s_norm, cmap='viridis', s=3, alpha=0.7)
#     ax2.set_title(f'Reconstructed (β={beta})')

#     ax3 = fig.add_subplot(133, projection='3d')
#     ax3.scatter(x_samples[:, 0], x_samples[:, 1], x_samples[:, 2],
#                 c='gray', s=3, alpha=0.5)
#     ax3.set_title('Prior Samples')

#     plt.tight_layout()
#     fig.savefig(os.path.join(plot_dir, f"reconstruction_beta_{beta}.png"),
#                 dpi=150, bbox_inches="tight")
#     plt.close(fig)

#     mse = np.mean((data - x_recon) ** 2)
#     print(f"  β={beta}: MSE={mse:.6f}")



def run_2d_experiment(name, data, colors, labels, config):
    """Run VAE experiment on a 2D dataset."""
    plot_dir = os.path.join(config["plot_dir"], name)

    # Plot original data
    plot_2d_data(data, colors, labels, f"{name} - Original", plot_dir, "original.png")

    # Train models across β
    models, histories = _train_across_betas(data, config)

    # Generate plots
    plot_2d_latent_comparison(models, config["beta_values"], data, colors, labels,
                              plot_dir, name)
    plot_2d_reconstruction(models, config["beta_values"], data, colors, labels,
                           plot_dir, name)
    return models, histories


def run_swiss_roll_experiment(config):
    data, s_norm, t_norm, s, t = generate_swiss_roll(config["n_samples"])

    plot_dir = os.path.join(config["plot_dir"], "swiss_roll")
    plot_swiss_roll_3d(data, s_norm, t_norm, plot_dir)

    # Train models across β
    models, histories = _train_across_betas(data, config)

    # Plots
    plot_latent_comparison(models, config["beta_values"], data, s_norm, t_norm, plot_dir)

    for model, beta in zip(models, config["beta_values"]):
        plot_reconstruction_quality(model, data, s_norm, beta, plot_dir)

    return models, histories

In [25]:
np.random.seed(CONFIG["seed"])
torch.manual_seed(CONFIG["seed"])

# 2D datasets
data, colors, theta = generate_single_circle(CONFIG["n_samples"])
run_2d_experiment("single_circle", data, colors, None, CONFIG)

data, colors, labels = generate_three_circles(CONFIG["n_samples"])
run_2d_experiment("three_circles", data, colors, labels, CONFIG)

data, colors, labels = generate_smile_face(CONFIG["n_samples"])
run_2d_experiment("smile_face", data, colors, labels, CONFIG)

# 3D dataset
run_swiss_roll_experiment(CONFIG)

  Shape: (4000, 3)
  β=0.0: MSE=0.007276
  β=0.5: MSE=0.008950
  β=1.0: MSE=0.018183
  β=5.0: MSE=0.060174


([SimpleVAE(
    (encoder): Sequential(
      (0): Sequential(
        (0): Linear(in_features=3, out_features=64, bias=True)
        (1): ReLU()
      )
      (1): Sequential(
        (0): Linear(in_features=64, out_features=64, bias=True)
        (1): ReLU()
      )
      (2): Sequential(
        (0): Linear(in_features=64, out_features=64, bias=True)
        (1): ReLU()
      )
      (3): Sequential(
        (0): Linear(in_features=64, out_features=64, bias=True)
        (1): ReLU()
      )
    )
    (fc_mu): Linear(in_features=64, out_features=2, bias=True)
    (fc_logvar): Linear(in_features=64, out_features=2, bias=True)
    (decoder_input): Linear(in_features=2, out_features=64, bias=True)
    (decoder): Sequential(
      (0): Sequential(
        (0): Linear(in_features=64, out_features=64, bias=True)
        (1): ReLU()
      )
      (1): Sequential(
        (0): Linear(in_features=64, out_features=64, bias=True)
        (1): ReLU()
      )
      (2): Sequential(
        (0): L