In [None]:
# ============== Cell 1: Import All Required Packages ==============
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import sys
import os
import gc
import time
import pickle
import subprocess
import random
import threading
import psutil
import cv2
from concurrent.futures import ThreadPoolExecutor

# Add project directory to sys.path
project_dir = "/workspace/kswgd"
if project_dir not in sys.path:
    sys.path.append(project_dir)

# Scientific computing
from sklearn.metrics import pairwise_distances
from sklearn.decomposition import MiniBatchDictionaryLearning
from scipy.sparse.linalg import eigsh
from scipy.linalg import eig
from scipy import linalg

# PyTorch related
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm, trange

# Diffusers
from diffusers import DiffusionPipeline

# Datasets
from datasets import load_dataset

# Real-ESRGAN and GFPGAN
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from gfpgan import GFPGANer

# FID evaluation
try:
    from pytorch_fid import fid_score
    from pytorch_fid.inception import InceptionV3
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "pytorch-fid", "-q"])
    from pytorch_fid import fid_score
    from pytorch_fid.inception import InceptionV3

# LPIPS
try:
    import lpips
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "lpips", "-q"])
    import lpips

# Custom kernel functions
from grad_ker1 import grad_ker1
from K_tar_eval import K_tar_eval

# Try GPU version
try:
    import cupy as cp
    from grad_ker1_gpu import grad_ker1 as grad_ker1_gpu
    from K_tar_eval_gpu import K_tar_eval as K_tar_eval_gpu
    GPU_KSWGD = True
    print("‚úì GPU KSWGD backend available (CuPy)")
except Exception as e:
    cp = None
    grad_ker1_gpu = None
    K_tar_eval_gpu = None
    GPU_KSWGD = False
    print(f"‚úó GPU KSWGD backend not available: {e}")

# ==================== GPU Selection: Use the least utilized GPU ====================
def get_gpu_utilization():
    """Get GPU utilization for all available GPUs using nvidia-smi"""
    try:
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=index,utilization.gpu,memory.used,memory.total', '--format=csv,noheader,nounits'],
            capture_output=True, text=True, timeout=10
        )
        if result.returncode != 0:
            return None
        
        gpu_info = []
        for line in result.stdout.strip().split('\n'):
            parts = [p.strip() for p in line.split(',')]
            if len(parts) >= 4:
                gpu_id = int(parts[0])
                util_percent = float(parts[1])
                mem_used = float(parts[2])
                mem_total = float(parts[3])
                mem_percent = (mem_used / mem_total) * 100 if mem_total > 0 else 0
                gpu_info.append({
                    'id': gpu_id,
                    'util_percent': util_percent,
                    'mem_used_mb': mem_used,
                    'mem_total_mb': mem_total,
                    'mem_percent': mem_percent
                })
        return gpu_info
    except Exception as e:
        print(f"  ‚ö†Ô∏è Could not get GPU utilization: {e}")
        return None

def select_best_gpu():
    """Select the GPU with lowest utilization"""
    gpu_info = get_gpu_utilization()
    
    if gpu_info is None or len(gpu_info) == 0:
        print("  ‚ö†Ô∏è Cannot detect GPU utilization, defaulting to GPU 0")
        return 0
    
    print("\nüîç GPU Utilization Check:")
    print(f"  {'GPU':<6} {'Util %':<10} {'Mem Used':<12} {'Mem Total':<12} {'Mem %':<10}")
    print("  " + "-" * 50)
    
    for info in gpu_info:
        print(f"  GPU {info['id']:<3} {info['util_percent']:<10.1f} {info['mem_used_mb']:<12.0f} {info['mem_total_mb']:<12.0f} {info['mem_percent']:<10.1f}")
    
    # Select GPU with lowest combined score (weighted: 60% utilization, 40% memory)
    best_gpu = min(gpu_info, key=lambda x: 0.6 * x['util_percent'] + 0.4 * x['mem_percent'])
    
    print(f"\n  ‚úì Selected GPU {best_gpu['id']} (Util: {best_gpu['util_percent']:.1f}%, Mem: {best_gpu['mem_percent']:.1f}%)")
    return best_gpu['id']

# Check GPU availability
print(f"\nCUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
    print(f"    Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.1f} GB")

# Auto-select best GPU based on utilization
SELECTED_GPU = select_best_gpu()
torch.cuda.set_device(SELECTED_GPU)
device = torch.device(f"cuda:{SELECTED_GPU}")
print(f"\nüéØ Global device set to: {device}")
print("   All subsequent cells will use this GPU.")

In [None]:
# ============== Cell 2: All Function Definitions ==============

from datetime import datetime

# ==================== Data Processing ====================
DATA_DIR = "/workspace/kswgd/data"
CELEBAHQ_CACHE = os.path.join(DATA_DIR, "CelebA-HQ")
CACHE_DIR = "/workspace/kswgd/cache"
os.makedirs(CELEBAHQ_CACHE, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs('/workspace/kswgd/figures_test', exist_ok=True)

# VAE helper functions
def _to_vae_range(x):
    """[0,1] ‚Üí [-1,1]"""
    return (x * 2.0) - 1.0

def _from_vae_range(x):
    """[-1,1] ‚Üí [0,1]"""
    return torch.clamp((x + 1.0) * 0.5, 0.0, 1.0)

# Image preprocessing transform
transform_celebahq = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
])

# ==================== MLP Latent AutoEncoder ====================
MLP_ARCHITECTURE_VERSION = "v3_configurable_depth"

class LatentAutoEncoder(torch.nn.Module):
    def __init__(self, input_dim=1024, latent_dim=64, hidden_dim=512, num_layers=5, dropout_rate=0.1):
        super().__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout_rate = dropout_rate
        self.architecture_version = MLP_ARCHITECTURE_VERSION
        
        encoder_layers = []
        encoder_layers.extend([
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.LayerNorm(hidden_dim),
            torch.nn.GELU(),
            torch.nn.Dropout(dropout_rate),
        ])
        for i in range(num_layers - 2):
            encoder_layers.extend([
                torch.nn.Linear(hidden_dim, hidden_dim),
                torch.nn.LayerNorm(hidden_dim),
                torch.nn.GELU(),
            ])
            if i % 2 == 0:
                encoder_layers.append(torch.nn.Dropout(dropout_rate))
        encoder_layers.append(torch.nn.Linear(hidden_dim, latent_dim))
        self.encoder = torch.nn.Sequential(*encoder_layers)
        
        decoder_layers = []
        decoder_layers.extend([
            torch.nn.Linear(latent_dim, hidden_dim),
            torch.nn.LayerNorm(hidden_dim),
            torch.nn.GELU(),
        ])
        for i in range(num_layers - 2):
            decoder_layers.extend([
                torch.nn.Linear(hidden_dim, hidden_dim),
                torch.nn.LayerNorm(hidden_dim),
                torch.nn.GELU(),
            ])
            if i % 2 == 0:
                decoder_layers.append(torch.nn.Dropout(dropout_rate))
        decoder_layers.append(torch.nn.Linear(hidden_dim, input_dim))
        self.decoder = torch.nn.Sequential(*decoder_layers)
    
    def encode(self, x):
        return self.encoder(x)
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        z = self.encode(x)
        return self.decode(z), z

def get_mlp_cache_path(config):
    num_layers = config.get('mlp_num_layers', 5)
    lpips_weight = config.get('lpips_weight', 0.7)
    return os.path.join(
        CACHE_DIR, 
        f"vae_n{config['max_samples']}_mlp{config['reduced_dim']}_h{config['mlp_hidden_dim']}_L{num_layers}_lpips{lpips_weight}_e{config['mlp_epochs']}_{MLP_ARCHITECTURE_VERSION}.pkl"
    )

def validate_cached_model(cache_data, config):
    latent_ae = cache_data.get('latent_ae', None)
    if latent_ae is None: return True
    cached_version = cache_data.get('architecture_version', 'unknown')
    if cached_version != MLP_ARCHITECTURE_VERSION: return False
    if latent_ae.latent_dim != config['reduced_dim']: return False
    if latent_ae.hidden_dim != config['mlp_hidden_dim']: return False
    cached_layers = getattr(latent_ae, 'num_layers', 5)
    if cached_layers != config.get('mlp_num_layers', 5): return False
    return True

def train_latent_autoencoder(Z_flat, latent_dim=64, hidden_dim=512, num_layers=5, epochs=100, batch_size=512, lr=1e-3, use_perceptual_loss=True, lpips_weight=0.7):
    print(f"\n=== Training Latent AutoEncoder ({Z_flat.shape[1]} -> {latent_dim}) ===")
    model = LatentAutoEncoder(input_dim=Z_flat.shape[1], latent_dim=latent_dim, hidden_dim=hidden_dim, num_layers=num_layers).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    lpips_loss_fn = None
    if use_perceptual_loss and lpips_weight > 0:
        try:
            lpips_loss_fn = lpips.LPIPS(net='vgg').to(device)
            lpips_loss_fn.eval()
        except: use_perceptual_loss = False
    
    Z_tensor = torch.from_numpy(Z_flat).float().to(device)
    dataloader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(Z_tensor), batch_size=batch_size, shuffle=True)
    
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for (batch,) in dataloader:
            optimizer.zero_grad()
            recon, z = model(batch)
            mse_loss = torch.nn.functional.mse_loss(recon, batch)
            if use_perceptual_loss and lpips_loss_fn is not None:
                batch_3ch = batch.view(-1, 4, 16, 16)[:, :3, :, :]
                recon_3ch = recon.view(-1, 4, 16, 16)[:, :3, :, :]
                batch_norm = 2.0 * (batch_3ch - batch_3ch.min()) / (batch_3ch.max() - batch_3ch.min() + 1e-8) - 1.0
                recon_norm = 2.0 * (recon_3ch - recon_3ch.min()) / (recon_3ch.max() - recon_3ch.min() + 1e-8) - 1.0
                loss = mse_loss + lpips_weight * lpips_loss_fn(batch_norm, recon_norm).mean()
            else: loss = mse_loss
            loss.backward(); optimizer.step(); total_loss += loss.item() * batch.size(0)
        scheduler.step()
        if (epoch + 1) % 50 == 0: print(f"  Epoch {epoch+1:3d}/{epochs}: Loss = {total_loss/len(Z_tensor):.6f}")
    
    model.eval()
    with torch.no_grad(): Z_reduced = model.encode(Z_tensor).cpu().numpy()
    return model, Z_reduced, MLP_ARCHITECTURE_VERSION

# ==================== Real-ESRGAN + GFPGAN ====================
model_path = '/workspace/kswgd/weights/RealESRGAN_x4plus.pth'
gfpgan_path = '/workspace/kswgd/weights/GFPGANv1.3.pth'

def create_upscaler(gpu_id):
    model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
    upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=True, gpu_id=gpu_id)
    face_enhancer = GFPGANer(model_path=gfpgan_path, upscale=4, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
    return upsampler, face_enhancer

def process_single_image(args):
    img, gpu_id, use_face_enhance, face_enhancers, upsamplers = args
    face_enhancer = face_enhancers[gpu_id]
    upsampler = upsamplers[gpu_id]
    if isinstance(img, Image.Image): img_np = np.array(img)
    else:
        img_np = np.transpose(img, (1, 2, 0)) if img.ndim == 3 and img.shape[0] == 3 else img
        if img_np.max() <= 1.0: img_np = (img_np * 255).astype(np.uint8)
    img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
    if use_face_enhance: _, _, output_bgr = face_enhancer.enhance(img_bgr, has_aligned=False, only_center_face=False, paste_back=True)
    else: output_bgr, _ = upsampler.enhance(img_bgr, outscale=4)
    return cv2.cvtColor(output_bgr, cv2.COLOR_BGR2RGB)

def upscale_images(images_list, face_enhancers, upsamplers, use_face_enhance=True, desc="Upscaling"):
    return [process_single_image((img, 0, use_face_enhance, face_enhancers, upsamplers)) for img in tqdm(images_list, desc=desc)]

# ==================== Decode Latents to Images ====================
def decode_latents_to_images(flat_latents_std, Z_std, Z_mean, latent_ae, vae, full_latent_shape, vae_scaling, target_device):
    """Decode standardized latent codes back to images."""
    # Unstandardize
    flat_latents = flat_latents_std * Z_std + Z_mean
    
    # Decode through MLP if available
    if latent_ae is not None:
        latent_ae.eval()
        latent_ae = latent_ae.to(target_device)
        with torch.no_grad():
            flat_tensor = torch.from_numpy(flat_latents).float().to(target_device)
            latents_recovered = latent_ae.decode(flat_tensor).cpu().numpy()
    else:
        latents_recovered = flat_latents
    
    # Reshape and decode through VAE
    latents_tensor = torch.from_numpy(latents_recovered).float().view(-1, *full_latent_shape).to(target_device)
    vae.eval()
    with torch.no_grad():
        decoded = vae.decode(latents_tensor / vae_scaling).sample
        decoded_rgb = _from_vae_range(decoded)
    return decoded_rgb.cpu()

# ==================== Eigendecomposition ====================
def compute_eigendecomposition(rw_kernel, k_eig):
    n_samples = rw_kernel.shape[0]
    if k_eig < n_samples:
        rw_kernel_torch = torch.from_numpy(rw_kernel).float().to(device)
        X0 = torch.randn(n_samples, k_eig, device=device)
        eigenvalues, eigenvectors = torch.lobpcg(rw_kernel_torch, k=k_eig, X=X0, largest=True)
        lambda_ns, phi = eigenvalues.cpu().numpy(), eigenvectors.cpu().numpy()
    else:
        lambda_ns_raw, phi_raw = np.linalg.eigh(rw_kernel)
        lambda_ns, phi = lambda_ns_raw[::-1], phi_raw[:, ::-1]
    return lambda_ns, phi

# ==================== GPU Accelerated Functions ====================
def pairwise_distances_gpu(X, metric="euclidean", batch_size=5000):
    n = X.shape[0]
    X_torch = torch.from_numpy(X).float().to(device)
    if n <= batch_size:
        dists = torch.cdist(X_torch, X_torch, p=2)
        if metric == "sqeuclidean": dists = dists ** 2
        return dists.cpu().numpy()
    dist_matrix = np.zeros((n, n), dtype=np.float32)
    for i in range(0, n, batch_size):
        end_i = min(i + batch_size, n)
        for j in range(0, n, batch_size):
            end_j = min(j + batch_size, n)
            dists_batch = torch.cdist(X_torch[i:end_i], X_torch[j:end_j], p=2)
            if metric == "sqeuclidean": dists_batch = dists_batch ** 2
            dist_matrix[i:end_i, j:end_j] = dists_batch.cpu().numpy()
    return dist_matrix

def compute_kernel_matrix_gpu(X_tar, batch_size=5000):
    print("  üöÄ GPU Kernel Matrix...")
    dists = pairwise_distances_gpu(X_tar, metric="euclidean", batch_size=batch_size)
    eps_kswgd = float(max(np.median(dists**2) / (2.0 * np.log(X_tar.shape[0] + 1)), 1e-6))
    dists_torch = torch.from_numpy(dists).float().to(device)
    data_kernel = torch.exp(-dists_torch**2 / (2.0 * eps_kswgd))
    p_x = torch.sqrt(torch.sum(data_kernel, dim=1))
    data_kernel_norm = data_kernel / (p_x[:, None] * p_x[None, :] + 1e-12)
    D_y = torch.sum(data_kernel_norm, dim=0)
    rw_kernel = 0.5 * (data_kernel_norm / (D_y + 1e-12) + data_kernel_norm / (D_y[:, None] + 1e-12))
    p_tar = torch.sum(data_kernel, dim=0)
    sqrt_p = torch.sqrt(p_tar + 1e-12)
    D_vec = torch.sum(data_kernel / sqrt_p[:, None] / sqrt_p[None, :], dim=1)
    return eps_kswgd, rw_kernel.cpu().numpy(), p_tar.cpu().numpy(), D_vec.cpu().numpy()

def compute_edmd_gpu(X_tar, dt_edmd, n_dict_components, batch_size=5000):
    print("  üöÄ GPU EDMD...")
    dist2_edmd = pairwise_distances_gpu(X_tar, metric="sqeuclidean", batch_size=batch_size)
    h_edmd = np.sqrt(np.median(dist2_edmd) + 1e-12)
    dist2_torch = torch.from_numpy(dist2_edmd).float().to(device)
    W_edmd = torch.exp(-dist2_torch / (2.0 * h_edmd ** 2))
    X_tar_torch = torch.from_numpy(X_tar).float().to(device)
    score_edmd = (W_edmd @ X_tar_torch / (torch.sum(W_edmd, dim=1, keepdim=True) + 1e-12) - X_tar_torch) / (h_edmd ** 2)
    X_tar_next = X_tar_torch + dt_edmd * score_edmd + torch.randn_like(X_tar_torch) * np.sqrt(2.0 * dt_edmd)
    X_tar_next_np = X_tar_next.cpu().numpy()
    
    dict_model = MiniBatchDictionaryLearning(n_components=n_dict_components, alpha=1e-3, batch_size=256, max_iter=500, random_state=42)
    dict_model.fit(X_tar)
    Phi_X = np.hstack([np.ones((X_tar.shape[0], 1)), dict_model.transform(X_tar)])
    Phi_Y = np.hstack([np.ones((X_tar_next_np.shape[0], 1)), dict_model.transform(X_tar_next_np)])
    G_edmd = (Phi_X.T @ Phi_X) / Phi_X.shape[0] + 1e-3 * np.eye(Phi_X.shape[1])
    A_edmd = (Phi_X.T @ Phi_Y) / Phi_X.shape[0]
    eigvals_edmd, eigvecs_edmd = eig(A_edmd, G_edmd)
    idx = np.argsort(-eigvals_edmd.real)
    efuns_edmd = Phi_X @ eigvecs_edmd[:, idx]
    lambda_ns_edmd = eigvals_edmd.real[idx]
    lambda_gen_inv = np.zeros_like(lambda_ns_edmd)
    valid = lambda_ns_edmd[1:] > 1e-6
    lambda_gen_inv[1:][valid] = dt_edmd / (lambda_ns_edmd[1:][valid] - 1.0)
    return np.real(efuns_edmd[:, 1:][:, valid]), lambda_gen_inv[1:][valid], np.sum(valid)

# ==================== KSWGD Sampler ====================
def compute_repulsive_force_gpu(particles, h=None):
    use_gpu = GPU_KSWGD and torch.cuda.is_available()
    xp = cp if use_gpu else np
    n = particles.shape[0]
    if n <= 1: return xp.zeros_like(particles)
    p_sq = xp.sum(particles**2, axis=1, keepdims=True)
    dist_sq = xp.maximum(p_sq + p_sq.T - 2 * xp.dot(particles, particles.T), 0)
    if h is None: h = xp.sqrt(0.5 * xp.median(dist_sq) / xp.log(n + 1) + 1e-8)
    kernel_matrix = xp.exp(-dist_sq / (2 * h**2 + 1e-8))
    repulsion = xp.zeros_like(particles)
    for i in range(n):
        grad_k = -kernel_matrix[i, :, None] * (particles[i:i+1] - particles) / (h**2 + 1e-8)
        repulsion[i] = xp.sum(grad_k, axis=0)
    return repulsion / n

def run_particle_sampler(X_tar, p_tar, sq_tar, D_vec, eps_kswgd, phi_use, lambda_use, num_particles=16, num_iters=200, step_size=0.05, alpha=0.1, rng_seed=42, method="kswgd"):
    use_gpu = GPU_KSWGD and torch.cuda.is_available()
    xp = cp if use_gpu else np
    grad_fn = grad_ker1_gpu if use_gpu else grad_ker1
    K_eval_fn = K_tar_eval_gpu if use_gpu else K_tar_eval
    print(f"Method: {method.upper()}, Alpha: {alpha}")
    x_hist = xp.zeros((num_particles, X_tar.shape[1], num_iters))
    x_hist[:, :, 0] = xp.asarray(np.random.default_rng(rng_seed).normal(0, 1, (num_particles, X_tar.shape[1])))
    X_tar_d, p_tar_d, sq_tar_d, D_vec_d = xp.asarray(X_tar), xp.asarray(p_tar), xp.asarray(sq_tar), xp.asarray(D_vec)
    phi_d, lambda_d = xp.asarray(phi_use), xp.asarray(lambda_use)
    for t in trange(num_iters - 1, desc="Transport"):
        curr = x_hist[:, :, t]
        grad_m = grad_fn(curr, X_tar_d, p_tar_d, sq_tar_d, D_vec_d, eps_kswgd)
        cross_m = K_eval_fn(X_tar_d, curr, p_tar_d, sq_tar_d, D_vec_d, eps_kswgd)
        push = phi_d @ ((phi_d.T @ cross_m) * lambda_d[:, None])
        attr = xp.zeros_like(curr)
        for d in range(X_tar.shape[1]): attr[:, d] = xp.sum(grad_m[:, :, d] @ push, axis=1) / num_particles
        rep = compute_repulsive_force_gpu(curr) if alpha > 0 else 0
        x_hist[:, :, t+1] = curr - step_size * (attr - alpha * rep)
    return np.asarray(xp.asnumpy(x_hist[:, :, -1]) if use_gpu else x_hist[:, :, -1])

# ==================== Main Experiment Function ====================
def run_full_experiment(config, experiment_id=1):
    print(f"\n{'='*60}\nEXPERIMENT {experiment_id}\n{'='*60}")
    for k, v in config.items():
        print(f"  {k}: {v}")
    
    exp_dir = f"/workspace/kswgd/figures_test/test{experiment_id}"
    os.makedirs(exp_dir, exist_ok=True)
    print(f"Output: {exp_dir}")
    
    # Step 1: Load Dataset
    print("\n[1/9] Loading CelebA-HQ...")
    celebahq_dataset = load_dataset("mattymchen/celeba-hq", split="train", cache_dir=CELEBAHQ_CACHE, trust_remote_code=True)
    
    # Step 2: Load LDM
    print("[2/9] Loading LDM...")
    ldm_pipe = DiffusionPipeline.from_pretrained("CompVis/ldm-celebahq-256").to(device)
    ldm_pipe.vqvae.config.scaling_factor = 1.0
    vae = ldm_pipe.vqvae
    vae_scaling = 1.0
    
    # Get latent shape
    with torch.no_grad():
        dummy = torch.zeros(1, 3, 256, 256, device=device)
        dummy_latent = vae.encode(_to_vae_range(dummy))
        full_latent_shape = (dummy_latent.latents if hasattr(dummy_latent, 'latents') else dummy_latent[0]).shape[1:]
    print(f"  Latent shape: {full_latent_shape}")
    
    # Step 3: Encode or Load from Cache
    print(f"[3/9] Encoding {config['max_samples']} images...")
    cache_path = get_mlp_cache_path(config)
    
    if os.path.exists(cache_path):
        print(f"  Loading from cache: {cache_path}")
        with open(cache_path, 'rb') as f: 
            data = pickle.load(f)
        Z_all, latent_ae = data['Z_all'], data.get('latent_ae', None)
        if latent_ae is not None:
            latent_ae = latent_ae.to(device)
    else:
        print(f"  Encoding images...")
        all_l = []
        vae.eval()
        with torch.no_grad():
            for i in tqdm(range(min(config['max_samples'], len(celebahq_dataset))), desc="Encoding"):
                img = transform_celebahq(celebahq_dataset[i]["image"]).unsqueeze(0).to(device)
                latent = vae.encode(_to_vae_range(img))
                latent_code = latent.latents if hasattr(latent, 'latents') else latent[0]
                all_l.append((latent_code * vae_scaling).view(1, -1).cpu().numpy())
        Z_flat = np.concatenate(all_l, axis=0)
        
        latent_ae, Z_all, _ = train_latent_autoencoder(
            Z_flat, config['reduced_dim'], config['mlp_hidden_dim'], 
            config.get('mlp_num_layers', 5), config['mlp_epochs'],
            lpips_weight=config.get('lpips_weight', 0.7)
        )
        with open(cache_path, 'wb') as f: 
            pickle.dump({'Z_all': Z_all, 'latent_ae': latent_ae, 'architecture_version': MLP_ARCHITECTURE_VERSION}, f)
    
    print(f"  Z_all shape: {Z_all.shape}")
    
    # Step 4-5: Kernel & EDMD
    print("[4/9] Building kernel matrix...")
    Z_mean = np.mean(Z_all, axis=0, keepdims=True)
    Z_std = np.std(Z_all, axis=0, keepdims=True) + 1e-8
    X_tar = ((Z_all - Z_mean) / Z_std).astype(np.float64)
    sq_tar = np.sum(X_tar ** 2, axis=1)
    
    eps, rw, p_tar, D_vec = compute_kernel_matrix_gpu(X_tar)
    print(f"  eps_kswgd: {eps:.6f}, kernel shape: {rw.shape}")
    
    print(f"[5/9] Eigendecomposition (k={config['k_eig']})...")
    l_ns, phi = compute_eigendecomposition(rw, config['k_eig'])
    print(f"  Top eigenvalues: {l_ns[:5]}")
    
    print(f"[6/9] EDMD (dt={config['dt_edmd']})...")
    phi_edmd, l_edmd, n_modes = compute_edmd_gpu(X_tar, config['dt_edmd'], config['n_dict_components'])
    print(f"  Koopman modes: {n_modes}")
    
    # Step 6: KSWGD
    print(f"[7/9] KSWGD (particles={config['kswgd_num_particles']}, iters={config['kswgd_num_iters']})...")
    Z_kswgd = run_particle_sampler(
        X_tar, p_tar, sq_tar, D_vec, eps, phi_edmd, l_edmd, 
        config['kswgd_num_particles'], config['kswgd_num_iters'], 
        config['kswgd_step_size'], config.get('repulsive_alpha', 0.1)
    )
    
    # Step 7: Decode
    print("[8/9] Decoding images...")
    imgs = decode_latents_to_images(Z_kswgd, Z_std.flatten(), Z_mean.flatten(), latent_ae, vae, full_latent_shape, vae_scaling, device)
    
    # Step 8: Upscale
    print("[9/9] Upscaling with GFPGAN...")
    gpu_id = device.index if device.index is not None else 0
    u, f = create_upscaler(gpu_id)
    imgs_np = [np.clip(np.transpose(imgs[i].numpy(), (1, 2, 0)), 0, 1) for i in range(imgs.shape[0])]
    upscaled = upscale_images(imgs_np, [f], [u])
    
    # Save
    n_grid = min(16, len(upscaled))
    fig, axes = plt.subplots(4, 4, figsize=(16, 16))
    for i, ax in enumerate(axes.flat):
        if i < n_grid: ax.imshow(upscaled[i])
        ax.axis('off')
    title_text = str(config.get('figure_title', "")).strip()
    if title_text:
        fig.suptitle(title_text, fontsize=16, y=0.96)
        fig.tight_layout(rect=[0.02, 0.02, 0.98, 0.93])
    else:
        fig.tight_layout(rect=[0.02, 0.02, 0.98, 0.98])
    plt.savefig(os.path.join(exp_dir, "kswgd_grid.png"), dpi=150)
    plt.close()
    
    # Save config snapshot
    config_txt_path = os.path.join(exp_dir, "config.txt")
    with open(config_txt_path, "w", encoding="utf-8") as f:
        f.write(f"Experiment {experiment_id} configuration\n")
        for k in sorted(config.keys()):
            f.write(f"{k}: {config[k]}\n")
    print(f"‚úì Saved hyperparameters to {config_txt_path}")
    
    print(f"‚úì Saved to {exp_dir}/kswgd_grid.png")
    
    # Clean up
    del ldm_pipe, vae, imgs, upscaled
    gc.collect()
    torch.cuda.empty_cache()
    
    return {'config': config, 'experiment_id': experiment_id, 'output_dir': exp_dir}

print("‚úì All functions defined!")

In [None]:
# ============== Cell 3: Timing Instrumentation for Steps 4-6 ==============
import time

if 'STEP_TIMINGS' not in globals():
    STEP_TIMINGS = {}
else:
    STEP_TIMINGS = {}

if 'compute_kernel_matrix_gpu_original' not in globals():
    compute_kernel_matrix_gpu_original = compute_kernel_matrix_gpu

def compute_kernel_matrix_gpu(*args, **kwargs):
    start = time.perf_counter()
    result = compute_kernel_matrix_gpu_original(*args, **kwargs)
    duration = time.perf_counter() - start
    STEP_TIMINGS['kernel_matrix_seconds'] = duration
    print(f"  ‚è± Step 4 duration: {duration:.2f}s")
    return result

if 'compute_eigendecomposition_original' not in globals():
    compute_eigendecomposition_original = compute_eigendecomposition

def compute_eigendecomposition(*args, **kwargs):
    start = time.perf_counter()
    result = compute_eigendecomposition_original(*args, **kwargs)
    duration = time.perf_counter() - start
    STEP_TIMINGS['eigendecomposition_seconds'] = duration
    print(f"  ‚è± Step 5 duration: {duration:.2f}s")
    return result

if 'compute_edmd_gpu_original' not in globals():
    compute_edmd_gpu_original = compute_edmd_gpu

def compute_edmd_gpu(*args, **kwargs):
    start = time.perf_counter()
    result = compute_edmd_gpu_original(*args, **kwargs)
    duration = time.perf_counter() - start
    STEP_TIMINGS['edmd_seconds'] = duration
    print(f"  ‚è± Step 6 duration: {duration:.2f}s")
    return result

if 'run_full_experiment_original' not in globals():
    run_full_experiment_original = run_full_experiment

def run_full_experiment(config, experiment_id=1):
    global STEP_TIMINGS
    STEP_TIMINGS = {}
    experiment_start = time.perf_counter()
    result = run_full_experiment_original(config, experiment_id)
    total_seconds = time.perf_counter() - experiment_start
    STEP_TIMINGS['total_seconds'] = total_seconds
    print(f"‚è± Total experiment runtime: {total_seconds/60:.2f} min ({total_seconds:.1f}s)")
    if isinstance(result, dict):
        result = result.copy()
        result['timings'] = STEP_TIMINGS.copy()
    return result

print("‚úì Timing instrumentation active (Steps 4-6 + total runtime)")

In [None]:
# ============== Cell 4: Parameter Grid and Iteration Loop ==============
import itertools
from datetime import datetime
import torch.multiprocessing as mp

# Verify we're using the GPU selected in Cell 1
print(f"üéØ Using device from Cell 1: {device}")
print(f"   SELECTED_GPU = {SELECTED_GPU}")

# ==================== Define Parameter Grid ====================
DEBUG_MODE = False
print("‚úì PRODUCTION MODE - Using full parameters")

DEBUG_MAX_SAMPLES = 28000
DEBUG_NUM_PARTICLES = 16
DEBUG_MLP_EPOCHS = 300

# ==================== Extended Experiment Design ====================
dt_edmd_values = [0.05, 0.1]
kswgd_step_size_values = [0.01, 0.05, 0.1]
k_eig_values = [100, 300, 1000]
n_dict_values = [100, 300]
kswgd_iters_values = [500]
mlp_hidden_dim_values = [512]
mlp_num_layers_values = [6, 7]
lpips_weight_values = [0.0, 1.0]
repulsive_alpha_values = [0.0, 0.5] # Added repulsive force alpha

configs_to_run = []
for dt in dt_edmd_values:
    for step_size in kswgd_step_size_values:
        for k_eig in k_eig_values:
            for n_dict in n_dict_values:
                for kswgd_iters in kswgd_iters_values:
                    for hidden_dim in mlp_hidden_dim_values:
                        for num_layers in mlp_num_layers_values:
                            for lpips_w in lpips_weight_values:
                                for alpha in repulsive_alpha_values:
                                    configs_to_run.append({
                                        'dt_edmd': dt,
                                        'mlp_epochs': DEBUG_MLP_EPOCHS,
                                        'mlp_hidden_dim': hidden_dim,
                                        'mlp_num_layers': num_layers,
                                        'lpips_weight': lpips_w,
                                        'reduced_dim': 8,
                                        'k_eig': k_eig,
                                        'n_dict_components': n_dict,
                                        'kswgd_num_particles': DEBUG_NUM_PARTICLES,
                                        'kswgd_num_iters': kswgd_iters,
                                        'kswgd_step_size': step_size,
                                        'max_samples': DEBUG_MAX_SAMPLES,
                                        'repulsive_alpha': alpha,
                                    })

print(f"üìã SELECTED CONFIGS: {len(configs_to_run)} experiments")
print(f"\nüîß Fixed parameters:")
print(f"  max_samples: {DEBUG_MAX_SAMPLES}, num_particles: {DEBUG_NUM_PARTICLES}, mlp_epochs: {DEBUG_MLP_EPOCHS}")
print(f"\nüìä Variable parameters:")
print(f"  dt_edmd: {dt_edmd_values}, step_size: {kswgd_step_size_values}, k_eig: {k_eig_values}")
print(f"  n_dict: {n_dict_values}, hidden: {mlp_hidden_dim_values}, layers: {mlp_num_layers_values}")
print(f"  lpips: {lpips_weight_values}, alpha: {repulsive_alpha_values}")

# ==================== Run Experiments ====================
all_results = []
results_save_path = f"/workspace/kswgd/cache/iteration_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pkl"

for i, config in enumerate(configs_to_run):
    print(f"\n{'#'*80}\n# EXPERIMENT {i+1}/{len(configs_to_run)} on GPU {SELECTED_GPU}\n# alpha={config['repulsive_alpha']}, dt={config['dt_edmd']}, step={config['kswgd_step_size']}\n{'#'*80}")
    try:
        result = run_full_experiment(config, experiment_id=i+1)
        all_results.append(result)
        with open(results_save_path, 'wb') as f: pickle.dump(all_results, f)
    except Exception as e:
        print(f"\n‚úó Experiment failed: {e}")
        import traceback; traceback.print_exc()
        all_results.append({'config': config, 'error': str(e)})
    gc.collect(); torch.cuda.empty_cache()

# ==================== Summary ====================
print("\n" + "=" * 150)
print("FINAL SUMMARY")
print("=" * 150)
print(f"{'#':<4} {'alpha':<6} {'dt':<6} {'step':<8} {'k_eig':<7} {'n_dict':<7} {'hidden':<7} {'layers':<7} {'lpips':<6}")
print("-" * 150)
for r in all_results:
    if 'error' in r: print(f"{r.get('experiment_id','?'):<4} ERROR: {r['error'][:50]}")
    else:
        c = r['config']
        print(f"{r['experiment_id']:<4} {c['repulsive_alpha']:<6.2f} {c['dt_edmd']:<6.3f} {c['kswgd_step_size']:<8.4f} {c['k_eig']:<7} {c['n_dict_components']:<7} {c['mlp_hidden_dim']:<7} {c['mlp_num_layers']:<7} {c['lpips_weight']:<6.2f}")
print(f"\nResults saved to: {results_save_path}")