In [None]:
# ============== Cell 1: Import All Required Packages ==============
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import sys
import os
import gc
import time
import pickle
import subprocess
import random
import threading
import psutil
import cv2
from concurrent.futures import ThreadPoolExecutor

# Add project directory to sys.path
project_dir = "/workspace/kswgd"
if project_dir not in sys.path:
    sys.path.append(project_dir)

# Scientific computing
from sklearn.metrics import pairwise_distances
from sklearn.decomposition import MiniBatchDictionaryLearning
from scipy.sparse.linalg import eigsh
from scipy.linalg import eig
from scipy import linalg

# PyTorch related
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm, trange

# Diffusers
from diffusers import DiffusionPipeline

# Datasets
from datasets import load_dataset

# Real-ESRGAN and GFPGAN
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from gfpgan import GFPGANer

# FID evaluation
try:
    from pytorch_fid import fid_score
    from pytorch_fid.inception import InceptionV3
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "pytorch-fid", "-q"])
    from pytorch_fid import fid_score
    from pytorch_fid.inception import InceptionV3

# LPIPS
try:
    import lpips
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "lpips", "-q"])
    import lpips

# Custom kernel functions
from grad_ker1 import grad_ker1
from K_tar_eval import K_tar_eval

# Try GPU version
try:
    import cupy as cp
    from grad_ker1_gpu import grad_ker1 as grad_ker1_gpu
    from K_tar_eval_gpu import K_tar_eval as K_tar_eval_gpu
    GPU_KSWGD = True
    print("‚úì GPU KSWGD backend available (CuPy)")
except Exception as e:
    cp = None
    grad_ker1_gpu = None
    K_tar_eval_gpu = None
    GPU_KSWGD = False
    print(f"‚úó GPU KSWGD backend not available: {e}")

# ==================== GPU Selection: Use the least utilized GPU ====================
def get_gpu_utilization():
    """Get GPU utilization for all available GPUs using nvidia-smi"""
    try:
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=index,utilization.gpu,memory.used,memory.total', '--format=csv,noheader,nounits'],
            capture_output=True, text=True, timeout=10
        )
        if result.returncode != 0:
            return None
        
        gpu_info = []
        for line in result.stdout.strip().split('\n'):
            parts = [p.strip() for p in line.split(',')]
            if len(parts) >= 4:
                gpu_id = int(parts[0])
                util_percent = float(parts[1])
                mem_used = float(parts[2])
                mem_total = float(parts[3])
                mem_percent = (mem_used / mem_total) * 100 if mem_total > 0 else 0
                gpu_info.append({
                    'id': gpu_id,
                    'util_percent': util_percent,
                    'mem_used_mb': mem_used,
                    'mem_total_mb': mem_total,
                    'mem_percent': mem_percent
                })
        return gpu_info
    except Exception as e:
        print(f"  ‚ö†Ô∏è Could not get GPU utilization: {e}")
        return None

def select_best_gpu():
    """Select the GPU with lowest utilization"""
    gpu_info = get_gpu_utilization()
    
    if gpu_info is None or len(gpu_info) == 0:
        print("  ‚ö†Ô∏è Cannot detect GPU utilization, defaulting to GPU 0")
        return 0
    
    print("\nüîç GPU Utilization Check:")
    print(f"  {'GPU':<6} {'Util %':<10} {'Mem Used':<12} {'Mem Total':<12} {'Mem %':<10}")
    print("  " + "-" * 50)
    
    for info in gpu_info:
        print(f"  GPU {info['id']:<3} {info['util_percent']:<10.1f} {info['mem_used_mb']:<12.0f} {info['mem_total_mb']:<12.0f} {info['mem_percent']:<10.1f}")
    
    # Select GPU with lowest combined score (weighted: 60% utilization, 40% memory)
    best_gpu = min(gpu_info, key=lambda x: 0.6 * x['util_percent'] + 0.4 * x['mem_percent'])
    
    print(f"\n  ‚úì Selected GPU {best_gpu['id']} (Util: {best_gpu['util_percent']:.1f}%, Mem: {best_gpu['mem_percent']:.1f}%)")
    return best_gpu['id']

# Check GPU availability
print(f"\nCUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
    print(f"    Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.1f} GB")

# Auto-select best GPU based on utilization
SELECTED_GPU = select_best_gpu()
torch.cuda.set_device(SELECTED_GPU)
device = torch.device(f"cuda:{SELECTED_GPU}")
print(f"\nüéØ Global device set to: {device}")
print("   All subsequent cells will use this GPU.")

In [None]:
# ============== Cell 2: All Function Definitions ==============

from datetime import datetime

# ==================== Data Processing ====================
DATA_DIR = "/workspace/kswgd/data"
CELEBAHQ_CACHE = os.path.join(DATA_DIR, "CelebA-HQ")
CACHE_DIR = "/workspace/kswgd/cache"
os.makedirs(CELEBAHQ_CACHE, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs('/workspace/kswgd/figures', exist_ok=True)

# VAE helper functions
def _to_vae_range(x):
    """[0,1] ‚Üí [-1,1]"""
    return (x * 2.0) - 1.0

def _from_vae_range(x):
    """[-1,1] ‚Üí [0,1]"""
    return torch.clamp((x + 1.0) * 0.5, 0.0, 1.0)

# Image preprocessing transform
transform_celebahq = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
])

# ==================== MLP Latent AutoEncoder ====================
class LatentAutoEncoder(torch.nn.Module):
    def __init__(self, input_dim=1024, latent_dim=64, hidden_dim=512):
        super().__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.LayerNorm(hidden_dim),
            torch.nn.GELU(),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.LayerNorm(hidden_dim),
            torch.nn.GELU(),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.LayerNorm(hidden_dim),
            torch.nn.GELU(),
            torch.nn.Linear(hidden_dim, hidden_dim // 2),
            torch.nn.LayerNorm(hidden_dim // 2),
            torch.nn.GELU(),
            torch.nn.Linear(hidden_dim // 2, latent_dim),
        )
        
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(latent_dim, hidden_dim // 2),
            torch.nn.LayerNorm(hidden_dim // 2),
            torch.nn.GELU(),
            torch.nn.Linear(hidden_dim // 2, hidden_dim),
            torch.nn.LayerNorm(hidden_dim),
            torch.nn.GELU(),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.LayerNorm(hidden_dim),
            torch.nn.GELU(),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.LayerNorm(hidden_dim),
            torch.nn.GELU(),
            torch.nn.Linear(hidden_dim, input_dim),
        )
    
    def encode(self, x):
        return self.encoder(x)
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        z = self.encode(x)
        return self.decode(z), z


def train_latent_autoencoder(Z_flat, latent_dim=64, hidden_dim=512, epochs=100, batch_size=512, lr=1e-3, use_perceptual_loss=True):
    """Train the MLP AutoEncoder on VAE latent codes"""
    print(f"\n=== Training Latent AutoEncoder ({Z_flat.shape[1]} -> {latent_dim}) ===")
    print(f"  Hidden dim: {hidden_dim}, Epochs: {epochs}")
    
    model = LatentAutoEncoder(input_dim=Z_flat.shape[1], latent_dim=latent_dim, hidden_dim=hidden_dim).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    lpips_loss_fn = None
    if use_perceptual_loss:
        try:
            lpips_loss_fn = lpips.LPIPS(net='vgg').to(device)
            lpips_loss_fn.eval()
            for param in lpips_loss_fn.parameters():
                param.requires_grad = False
        except:
            use_perceptual_loss = False
    
    Z_tensor = torch.from_numpy(Z_flat).float().to(device)
    dataset = torch.utils.data.TensorDataset(Z_tensor)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    model.train()
    best_loss = float('inf')
    
    for epoch in range(epochs):
        total_loss = 0.0
        for (batch,) in dataloader:
            optimizer.zero_grad()
            recon, z = model(batch)
            mse_loss = torch.nn.functional.mse_loss(recon, batch)
            
            if use_perceptual_loss and lpips_loss_fn is not None:
                batch_spatial = batch.view(-1, 4, 16, 16)
                recon_spatial = recon.view(-1, 4, 16, 16)
                batch_3ch = batch_spatial[:, :3, :, :]
                recon_3ch = recon_spatial[:, :3, :, :]
                batch_norm = 2.0 * (batch_3ch - batch_3ch.min()) / (batch_3ch.max() - batch_3ch.min() + 1e-8) - 1.0
                recon_norm = 2.0 * (recon_3ch - recon_3ch.min()) / (recon_3ch.max() - recon_3ch.min() + 1e-8) - 1.0
                lpips_loss = lpips_loss_fn(batch_norm, recon_norm).mean()
                loss = mse_loss + 0.7 * lpips_loss
            else:
                loss = mse_loss
            
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * batch.size(0)
        
        avg_loss = total_loss / len(Z_tensor)
        scheduler.step()
        if avg_loss < best_loss:
            best_loss = avg_loss
        if (epoch + 1) % 20 == 0 or epoch == 0:
            print(f"  Epoch {epoch+1:3d}/{epochs}: Loss = {avg_loss:.6f}")
    
    print(f"‚úì Training complete! Best loss: {best_loss:.6f}")
    model.eval()
    with torch.no_grad():
        Z_reduced = model.encode(Z_tensor).cpu().numpy()
    return model, Z_reduced


# ==================== Real-ESRGAN + GFPGAN ====================
model_path = '/workspace/kswgd/weights/RealESRGAN_x4plus.pth'
gfpgan_path = '/workspace/kswgd/weights/GFPGANv1.3.pth'

def create_upscaler(gpu_id):
    model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
    upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=True, gpu_id=gpu_id)
    face_enhancer = GFPGANer(model_path=gfpgan_path, upscale=4, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
    return upsampler, face_enhancer

def preprocess_image(img_bgr):
    processed = img_bgr.copy().astype(np.float32)
    processed = cv2.GaussianBlur(processed, (3, 3), sigmaX=0.5)
    processed = cv2.bilateralFilter(processed.astype(np.uint8), d=5, sigmaColor=30, sigmaSpace=30).astype(np.float32)
    mean_b, mean_g, mean_r = np.mean(processed[:, :, 0]), np.mean(processed[:, :, 1]), np.mean(processed[:, :, 2])
    target_mean, alpha = 127.5, 0.3
    processed[:, :, 0] += alpha * (target_mean - mean_b)
    processed[:, :, 1] += alpha * (target_mean - mean_g)
    processed[:, :, 2] += alpha * (target_mean - mean_r)
    return np.clip(processed, 0, 255).astype(np.uint8)

def process_single_image(args):
    img, gpu_id, use_face_enhance, use_preprocess, face_enhancers, upsamplers = args
    face_enhancer = face_enhancers[gpu_id]
    upsampler = upsamplers[gpu_id]
    
    if isinstance(img, Image.Image):
        img_np = np.array(img)
    else:
        if img.ndim == 3 and img.shape[0] == 3:
            img_np = np.transpose(img, (1, 2, 0))
        else:
            img_np = img
        if img_np.max() <= 1.0:
            img_np = (img_np * 255).astype(np.uint8)
    
    img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
    if use_preprocess:
        img_bgr = preprocess_image(img_bgr)
    
    if use_face_enhance:
        _, _, output_bgr = face_enhancer.enhance(img_bgr, has_aligned=False, only_center_face=False, paste_back=True)
    else:
        output_bgr, _ = upsampler.enhance(img_bgr, outscale=4)
    return cv2.cvtColor(output_bgr, cv2.COLOR_BGR2RGB)

def upscale_images(images_list, face_enhancers, upsamplers, use_face_enhance=True, use_preprocess=True, desc="Upscaling"):
    return [process_single_image((img, 0, use_face_enhance, use_preprocess, face_enhancers, upsamplers)) for img in tqdm(images_list, desc=desc)]


# ==================== Eigendecomposition ====================
def compute_eigendecomposition(rw_kernel, k_eig, USE_GPU_EIGSH=True):
    start_time = time.time()
    n_samples = rw_kernel.shape[0]
    use_truncated = k_eig < n_samples
    _eig_result = {}
    
    if use_truncated:
        print(f"  Using TRUNCATED eigendecomposition (top {k_eig} of {n_samples})...")
        if USE_GPU_EIGSH and torch.cuda.is_available():
            try:
                rw_kernel_torch = torch.from_numpy(rw_kernel).float().to(device)
                X0 = torch.randn(n_samples, k_eig, device=device, dtype=torch.float32)
                eigenvalues, eigenvectors = torch.lobpcg(rw_kernel_torch, k=k_eig, X=X0, largest=True, niter=100)
                _eig_result['lambda'] = eigenvalues.cpu().numpy()
                _eig_result['phi'] = eigenvectors.cpu().numpy()
                del rw_kernel_torch, X0
                torch.cuda.empty_cache()
            except:
                lambda_ns_partial, phi_partial = eigsh(rw_kernel, k=k_eig, which='LM')
                _eig_result['lambda'] = lambda_ns_partial
                _eig_result['phi'] = phi_partial
        else:
            lambda_ns_partial, phi_partial = eigsh(rw_kernel, k=k_eig, which='LM')
            _eig_result['lambda'] = lambda_ns_partial
            _eig_result['phi'] = phi_partial
        
        sort_idx = np.argsort(_eig_result['lambda'])[::-1]
        lambda_ns = _eig_result['lambda'][sort_idx]
        phi = _eig_result['phi'][:, sort_idx]
    else:
        print(f"  Using FULL eigendecomposition ({n_samples} x {n_samples})...")
        if torch.cuda.is_available():
            rw_kernel_torch = torch.from_numpy(rw_kernel).to(device)
            lambda_ns_torch, phi_torch = torch.linalg.eigh(rw_kernel_torch)
            lambda_ns = lambda_ns_torch.cpu().numpy()[::-1].copy()
            phi = phi_torch.cpu().numpy()[:, ::-1].copy()
            del rw_kernel_torch
            torch.cuda.empty_cache()
        else:
            lambda_ns_raw, phi_raw = np.linalg.eigh(rw_kernel)
            lambda_ns = lambda_ns_raw[::-1]
            phi = phi_raw[:, ::-1]
    
    print(f"‚úì Eigendecomposition complete! Time: {time.time() - start_time:.1f}s")
    return lambda_ns, phi


# ==================== KSWGD Sampler ====================
def run_particle_sampler(X_tar, p_tar, sq_tar, D_vec, eps_kswgd, phi_use, lambda_use, 
                        num_particles=16, num_iters=200, step_size=0.05, rng_seed=42, method="kswgd"):
    latent_dim = X_tar.shape[1]
    rng = np.random.default_rng(rng_seed)
    use_gpu = GPU_KSWGD and torch.cuda.is_available()
    xp = cp if use_gpu else np
    grad_fn = grad_ker1_gpu if use_gpu else grad_ker1
    K_eval_fn = K_tar_eval_gpu if use_gpu else K_tar_eval
    
    print(f"Method: {method.upper()}, Backend: {'GPU' if use_gpu else 'CPU'}")
    
    x_hist = xp.zeros((num_particles, latent_dim, num_iters), dtype=xp.float64)
    x_hist[:, :, 0] = xp.asarray(rng.normal(0.0, 1.0, size=(num_particles, latent_dim)))
    
    if use_gpu:
        X_tar_dev, p_tar_dev, sq_tar_dev, D_vec_dev = cp.asarray(X_tar), cp.asarray(p_tar), cp.asarray(sq_tar), cp.asarray(D_vec)
        phi_dev, lambda_dev = cp.asarray(phi_use), cp.asarray(lambda_use)
    else:
        X_tar_dev, p_tar_dev, sq_tar_dev, D_vec_dev = X_tar, p_tar, sq_tar, D_vec
        phi_dev, lambda_dev = phi_use, lambda_use
    
    for t in trange(num_iters - 1, desc=f"{method.upper()} Transport"):
        current = x_hist[:, :, t]
        grad_matrix = grad_fn(current, X_tar_dev, p_tar_dev, sq_tar_dev, D_vec_dev, eps_kswgd)
        cross_matrix = K_eval_fn(X_tar_dev, current, p_tar_dev, sq_tar_dev, D_vec_dev, eps_kswgd)
        tmp = (phi_dev.T @ cross_matrix) * lambda_dev[:, None]
        push = phi_dev @ tmp
        for dim in range(latent_dim):
            sum_term = grad_matrix[:, :, dim] @ push
            x_hist[:, dim, t + 1] = x_hist[:, dim, t] - (step_size / num_particles) * xp.sum(sum_term, axis=1)
    
    return np.asarray(xp.asnumpy(x_hist[:, :, -1]) if use_gpu else x_hist[:, :, -1], dtype=np.float64)


def decode_latents_to_images(flat_latents_std, Z_std, Z_mean, latent_ae, vae, full_latent_shape, vae_scaling, target_device):
    flat_latents = flat_latents_std * Z_std + Z_mean
    
    # Â¶ÇÊûú latent_ae ÊòØ NoneÔºåËØ¥Êòé‰∏ç‰ΩøÁî®MLPÔºåÁõ¥Êé•Áî®ÂéüÂßãlatent
    if latent_ae is None:
        latents_recovered = flat_latents
    else:
        latent_ae.eval()
        with torch.no_grad():
            flat_tensor = torch.from_numpy(flat_latents).float().to(target_device)
            latents_recovered = latent_ae.decode(flat_tensor).cpu().numpy()
    
    latents_tensor = torch.from_numpy(latents_recovered).float().view(-1, *full_latent_shape).to(target_device)
    vae.eval()
    with torch.no_grad():
        decoded = vae.decode(latents_tensor / vae_scaling).sample
        decoded_rgb = _from_vae_range(decoded)
    return decoded_rgb.cpu()


# ==================== FID Evaluation ====================
def get_inception_features_single_gpu(images, gpu_id, batch_size=256, desc="Features"):
    device_local = torch.device(f"cuda:{gpu_id}")
    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
    inception = InceptionV3([block_idx]).to(device_local)
    inception.eval()
    preprocess = T.Compose([T.Resize((299, 299)), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    
    features_list = []
    with torch.no_grad():
        for i in tqdm(range(0, len(images), batch_size), desc=f"{desc} (GPU {gpu_id})"):
            batch_samples = images[i:i+batch_size]
            tensors = []
            for img in batch_samples:
                if isinstance(img, Image.Image):
                    t = T.ToTensor()(img)
                elif isinstance(img, np.ndarray):
                    t = torch.from_numpy(img).float()
                    if t.ndim == 3 and t.shape[2] == 3 and t.shape[0] != 3:
                        t = t.permute(2, 0, 1)
                    if t.max() > 1.0:
                        t /= 255.0
                else:
                    t = torch.as_tensor(img).float()
                tensors.append(t)
            batch = torch.stack(tensors).to(device_local)
            batch = preprocess(batch)
            feat = inception(batch)[0].squeeze(-1).squeeze(-1)
            features_list.append(feat.cpu().numpy())
    del inception
    torch.cuda.empty_cache()
    return np.concatenate(features_list, axis=0)

def get_inception_features(images, batch_size=256, desc="Features"):
    selected_gpu_id = device.index if device.index is not None else 0
    return get_inception_features_single_gpu(images, selected_gpu_id, batch_size, desc)

def calculate_fid(real_features, gen_features):
    mu_real, mu_gen = np.mean(real_features, axis=0), np.mean(gen_features, axis=0)
    sigma_real, sigma_gen = np.cov(real_features, rowvar=False), np.cov(gen_features, rowvar=False)
    diff = mu_real - mu_gen
    offset = np.eye(sigma_real.shape[0]) * 1e-6
    covmean, _ = linalg.sqrtm((sigma_real + offset) @ (sigma_gen + offset), disp=False)
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    return float(diff @ diff + np.trace(sigma_real + sigma_gen - 2 * covmean))

def load_real_images_for_fid(dataset, n_samples, size=256, desc="Loading"):
    real_images = []
    for i in tqdm(range(n_samples), desc=desc):
        img = dataset[i]["image"]
        if size != 1024:
            img = img.resize((size, size))
        real_images.append(T.ToTensor()(img).numpy() if size == 1024 else np.array(img))
    return real_images


# ==================== Main Experiment Function ====================
def run_full_experiment(config, experiment_id=1):
    """Run the full KSWGD vs LDM comparison experiment."""
    print("=" * 80)
    print(f"RUNNING EXPERIMENT {experiment_id} WITH CONFIG:")
    for k, v in config.items():
        print(f"  {k}: {v}")
    print("=" * 80)
    
    exp_dir = f"/workspace/kswgd/figures/test{experiment_id}"
    os.makedirs(exp_dir, exist_ok=True)
    print(f"Output directory: {exp_dir}")
    
    results = {'config': config.copy(), 'experiment_id': experiment_id, 'output_dir': exp_dir}
    
    # Step 1: Load Dataset
    print("\n[1/12] Loading CelebA-HQ dataset...")
    celebahq_dataset = None
    for source in ["mattymchen/celeba-hq", "datasets-community/CelebA-HQ", "xinrongzhang2022/celeba-hq"]:
        try:
            celebahq_dataset = load_dataset(source, split="train", cache_dir=CELEBAHQ_CACHE, trust_remote_code=True)
            print(f"‚úì Loaded: {source}")
            break
        except:
            continue
    if celebahq_dataset is None:
        raise RuntimeError("Unable to load CelebA-HQ dataset")
    
    # Step 2: Load LDM
    print("\n[2/12] Loading LDM model...")
    print(f"  Using GPU: {device}")
    ldm_pipe = DiffusionPipeline.from_pretrained("CompVis/ldm-celebahq-256")
    ldm_pipe = ldm_pipe.to(device)
    ldm_pipe.vqvae.config.scaling_factor = 1.0
    vae = ldm_pipe.vqvae
    vae_scaling = 1.0
    
    with torch.no_grad():
        dummy = torch.zeros(1, 3, 256, 256, device=device)
        dummy_latent = vae.encode(_to_vae_range(dummy))
        full_latent_shape = (dummy_latent.latents if hasattr(dummy_latent, 'latents') else dummy_latent[0]).shape[1:]
    print(f"‚úì LDM loaded! Latent shape: {full_latent_shape}")
    
    # Step 3: Initialize Upscalers
    print("\n[3/12] Initializing GFPGAN upscalers...")
    gpu_id = device.index if device.index is not None else 0
    u, f = create_upscaler(gpu_id)
    upsamplers, face_enhancers = [u], [f]
    print(f"‚úì Upscalers initialized on GPU {gpu_id}")
    
    # Step 4: Encode Images
    print(f"\n[4/12] Encoding {config['max_samples']} images...")
    use_mlp = config['reduced_dim'] < 1024
    
    if use_mlp:
        cache_path = os.path.join(CACHE_DIR, f"vae_n{config['max_samples']}_mlp{config['reduced_dim']}_h{config['mlp_hidden_dim']}_e{config['mlp_epochs']}.pkl")
    else:
        cache_path = os.path.join(CACHE_DIR, f"vae_n{config['max_samples']}_nomlp.pkl")
        print("  ‚ö†Ô∏è reduced_dim >= 1024, ‰∏ç‰ΩøÁî®MLPÂéãÁº©ÔºåÁõ¥Êé•Áî®VAE latent")
    
    if os.path.exists(cache_path):
        print(f"  Loading from cache: {cache_path}")
        with open(cache_path, 'rb') as f_cache:
            cache_data = pickle.load(f_cache)
        Z_all = cache_data['Z_all']
        latent_ae = cache_data.get('latent_ae', None)
        full_latent_shape = cache_data.get('full_latent_shape', full_latent_shape)
    else:
        all_latents = []
        vae.eval()
        with torch.no_grad():
            for i in tqdm(range(min(config['max_samples'], len(celebahq_dataset))), desc="Encoding"):
                img = celebahq_dataset[i]["image"]
                img_tensor = transform_celebahq(img).unsqueeze(0).to(device)
                latent = vae.encode(_to_vae_range(img_tensor))
                latent_code = latent.latents if hasattr(latent, 'latents') else latent[0]
                all_latents.append((latent_code * vae_scaling).view(1, -1).cpu().numpy())
        Z_flat = np.concatenate(all_latents, axis=0)
        
        if use_mlp:
            latent_ae, Z_all = train_latent_autoencoder(Z_flat, latent_dim=config['reduced_dim'], hidden_dim=config['mlp_hidden_dim'], epochs=config['mlp_epochs'])
        else:
            latent_ae = None
            Z_all = Z_flat
            print(f"  ‚úì ‰ΩøÁî®ÂéüÂßãVAE latentÔºåÁª¥Â∫¶: {Z_all.shape[1]}")
        
        with open(cache_path, 'wb') as f_cache:
            pickle.dump({'Z_all': Z_all, 'latent_ae': latent_ae, 'full_latent_shape': full_latent_shape}, f_cache)
    
    print(f"  ‚úì Z_all shape: {Z_all.shape}")
    
    # Step 5: Build Kernel Matrix
    print("\n[5/12] Building kernel matrix...")
    Z_mean = np.mean(Z_all, axis=0, keepdims=True).astype(np.float64)
    Z_std = (np.std(Z_all, axis=0, keepdims=True) + 1e-8).astype(np.float64)
    X_tar = ((Z_all - Z_mean) / Z_std).astype(np.float64)
    sq_tar = np.sum(X_tar ** 2, axis=1)
    dists = pairwise_distances(X_tar, metric="euclidean")
    eps_kswgd = float(max(np.median(dists**2) / (2.0 * np.log(X_tar.shape[0] + 1)), 1e-6))
    data_kernel = np.exp(-dists**2 / (2.0 * eps_kswgd))
    p_x = np.sqrt(np.sum(data_kernel, axis=1))
    data_kernel_norm = data_kernel / (p_x[:, None] * p_x[None, :] + 1e-12)
    D_y = np.sum(data_kernel_norm, axis=0)
    rw_kernel = 0.5 * (data_kernel_norm / (D_y + 1e-12) + data_kernel_norm / (D_y[:, None] + 1e-12))
    rw_kernel = np.nan_to_num(rw_kernel)
    p_tar = np.sum(data_kernel, axis=0)
    sqrt_p = np.sqrt(p_tar + 1e-12)
    D_vec = np.sum(data_kernel / sqrt_p[:, None] / sqrt_p[None, :], axis=1)
    print(f"‚úì Kernel matrix built: {rw_kernel.shape}")
    
    # Step 6: Eigendecomposition
    print(f"\n[6/12] Computing eigendecomposition (k={config['k_eig']})...")
    lambda_ns, phi = compute_eigendecomposition(rw_kernel, config['k_eig'])
    tol, reg = 1e-6, 1e-3
    above_tol = int(np.sum(lambda_ns >= tol))
    lambda_ = lambda_ns - 1.0
    inv_lambda = np.zeros_like(lambda_)
    inv_lambda[1:][lambda_[1:] > tol] = 1.0 / (np.abs(lambda_[1:][lambda_[1:] > tol]) + reg)
    inv_lambda *= eps_kswgd
    lambda_ns_inv = np.zeros_like(lambda_ns)
    lambda_ns_inv[lambda_ns >= tol] = eps_kswgd / (lambda_ns[lambda_ns >= tol] + reg)
    phi_trunc = phi[:, :above_tol]
    lambda_ns_s_ns = np.nan_to_num((lambda_ns_inv * inv_lambda * lambda_ns_inv)[:above_tol])
    print(f"‚úì {above_tol} modes retained")
    
    # Step 7: EDMD
    print(f"\n[7/12] Computing EDMD with dt={config['dt_edmd']}...")
    dist2_edmd = pairwise_distances(X_tar, metric="sqeuclidean")
    h_edmd = np.sqrt(np.median(dist2_edmd) + 1e-12)
    W_edmd = np.exp(-dist2_edmd / (2.0 * h_edmd ** 2))
    score_edmd = (W_edmd @ X_tar / (np.sum(W_edmd, axis=1, keepdims=True) + 1e-12) - X_tar) / (h_edmd ** 2)
    X_tar_next = X_tar + config['dt_edmd'] * score_edmd + np.sqrt(2.0 * config['dt_edmd']) * np.random.normal(0, 1, X_tar.shape)
    
    dict_model = MiniBatchDictionaryLearning(n_components=config['n_dict_components'], alpha=1e-3, batch_size=256, max_iter=500, random_state=42, fit_algorithm="lars")
    dict_model.fit(X_tar)
    Phi_X = np.hstack([np.ones((X_tar.shape[0], 1)), dict_model.transform(X_tar)])
    Phi_Y = np.hstack([np.ones((X_tar_next.shape[0], 1)), dict_model.transform(X_tar_next)])
    N_edmd, m_edmd = Phi_X.shape
    G_edmd = (Phi_X.T @ Phi_X) / N_edmd + 1e-3 * np.eye(m_edmd)
    A_edmd = (Phi_X.T @ Phi_Y) / N_edmd
    eigvals_edmd, eigvecs_edmd = eig(A_edmd, G_edmd)
    idx_edmd = np.argsort(-eigvals_edmd.real)
    eigvals_edmd, eigvecs_edmd = eigvals_edmd[idx_edmd], eigvecs_edmd[:, idx_edmd]
    efuns_edmd = Phi_X @ eigvecs_edmd
    
    lambda_ns_edmd = eigvals_edmd.real
    lambda_gen_edmd = (lambda_ns_edmd - 1.0) / config['dt_edmd']
    valid_idx = np.arange(1, lambda_ns_edmd.shape[0])[lambda_ns_edmd[1:] > 1e-6]
    phi_trunc_edmd = np.real(efuns_edmd[:, valid_idx])
    lambda_gen_inv = np.zeros_like(lambda_gen_edmd)
    lambda_gen_inv[np.abs(lambda_gen_edmd) > 1e-6] = 1.0 / lambda_gen_edmd[np.abs(lambda_gen_edmd) > 1e-6]
    lambda_ns_s_ns_edmd = lambda_gen_inv[valid_idx].real
    print(f"‚úì EDMD complete: {valid_idx.size} Koopman modes")
    
    # Step 8: Run KSWGD
    print(f"\n[8/12] Running KSWGD...")
    start_time = time.time()
    Z_kswgd_std = run_particle_sampler(X_tar, p_tar, sq_tar, D_vec, eps_kswgd, phi_trunc_edmd, lambda_ns_s_ns_edmd,
                                       num_particles=config['kswgd_num_particles'], num_iters=config['kswgd_num_iters'],
                                       step_size=config['kswgd_step_size'], rng_seed=42, method="kswgd")
    kswgd_time = time.time() - start_time
    results['kswgd_time'] = kswgd_time
    print(f"‚úì KSWGD complete! Time: {kswgd_time:.1f}s")
    
    # Step 9: Decode KSWGD Images
    print("\n[9/12] Decoding KSWGD images...")
    vae = vae.to(device)
    if latent_ae is not None:
        latent_ae = latent_ae.to(device)
    all_kswgd_images = []
    for i in tqdm(range(0, Z_kswgd_std.shape[0], 128), desc="Decoding"):
        batch = decode_latents_to_images(Z_kswgd_std[i:i+128], Z_std, Z_mean, latent_ae, vae, full_latent_shape, vae_scaling, device)
        all_kswgd_images.append(batch.numpy())
    kswgd_images_np = np.concatenate(all_kswgd_images, axis=0)
    
    kswgd_for_upscale = [np.clip(np.transpose(kswgd_images_np[i], (1, 2, 0)), 0, 1) for i in range(kswgd_images_np.shape[0])]
    kswgd_upscaled = upscale_images(kswgd_for_upscale, face_enhancers, upsamplers, desc="KSWGD GFPGAN")
    
    n_grid = min(16, len(kswgd_upscaled))
    print(f"  Saving KSWGD grid ({n_grid} of {len(kswgd_upscaled)} images)...")
    fig, axes = plt.subplots(4, 4, figsize=(16, 16))
    for idx, ax in enumerate(axes.flat):
        if idx < n_grid:
            ax.imshow(kswgd_upscaled[idx])
            ax.set_title(f"KSWGD #{idx+1}")
        ax.axis('off')
    plt.suptitle(f"KSWGD Enhanced - Test {experiment_id}", fontsize=16)
    plt.tight_layout()
    plt.savefig(os.path.join(exp_dir, "kswgd_grid.png"), dpi=150, bbox_inches='tight')
    plt.close()
    
    # Step 10: KSWGD FID (DISABLED FOR DEBUG)
    # print("\n[10/12] Computing KSWGD FID...")
    # n_real = min(10000, len(celebahq_dataset))
    # real_256 = load_real_images_for_fid(celebahq_dataset, n_real, 256, "Real (256)")
    # real_feat_256 = get_inception_features(real_256, desc="Real (256)")
    # kswgd_feat_256 = get_inception_features(kswgd_images_np, desc="KSWGD (256)")
    # fid_kswgd_raw = calculate_fid(real_feat_256, kswgd_feat_256)
    # del real_256; gc.collect()
    # 
    # real_1024 = load_real_images_for_fid(celebahq_dataset, n_real, 1024, "Real (1024)")
    # real_feat_1024 = get_inception_features(real_1024, desc="Real (1024)")
    # kswgd_feat_1024 = get_inception_features(kswgd_upscaled, desc="KSWGD (1024)")
    # fid_kswgd_enhanced = calculate_fid(real_feat_1024, kswgd_feat_1024)
    # results['fid_kswgd_raw'], results['fid_kswgd_enhanced'] = fid_kswgd_raw, fid_kswgd_enhanced
    # print(f"‚úì KSWGD FID: Raw={fid_kswgd_raw:.2f}, Enhanced={fid_kswgd_enhanced:.2f}")
    fid_kswgd_raw, fid_kswgd_enhanced = -1.0, -1.0  # Placeholder for debug mode
    results['fid_kswgd_raw'], results['fid_kswgd_enhanced'] = fid_kswgd_raw, fid_kswgd_enhanced
    print("[10/12] KSWGD FID calculation SKIPPED (debug mode)")
    
    del kswgd_images_np, kswgd_upscaled; gc.collect(); torch.cuda.empty_cache()
    
    # Step 11: Run LDM
    print(f"\n[11/12] Running LDM...")
    ldm_pipe = DiffusionPipeline.from_pretrained("CompVis/ldm-celebahq-256").to(device)
    ldm_pipe.vqvae.config.scaling_factor = 1.0
    ldm_pipe.set_progress_bar_config(disable=True)
    
    start_time = time.time()
    ldm_images = []
    for batch_idx in tqdm(range((config['kswgd_num_particles'] + 63) // 64), desc="LDM"):
        bs = min(64, config['kswgd_num_particles'] - batch_idx * 64)
        ldm_images.extend(ldm_pipe(batch_size=bs, num_inference_steps=200).images)
    ldm_time = time.time() - start_time
    results['ldm_time'] = ldm_time
    print(f"‚úì LDM complete! Time: {ldm_time:.1f}s")
    
    ldm_upscaled = upscale_images(ldm_images, face_enhancers, upsamplers, desc="LDM GFPGAN")
    
    n_grid_ldm = min(16, len(ldm_upscaled))
    print(f"  Saving LDM grid ({n_grid_ldm} of {len(ldm_upscaled)} images)...")
    fig, axes = plt.subplots(4, 4, figsize=(16, 16))
    for idx, ax in enumerate(axes.flat):
        if idx < n_grid_ldm:
            ax.imshow(ldm_upscaled[idx])
            ax.set_title(f"LDM #{idx+1}")
        ax.axis('off')
    plt.suptitle(f"LDM Enhanced - Test {experiment_id}", fontsize=16)
    plt.tight_layout()
    plt.savefig(os.path.join(exp_dir, "ldm_grid.png"), dpi=150, bbox_inches='tight')
    plt.close()
    
    # Step 12: LDM FID (DISABLED FOR DEBUG)
    # print("\n[12/12] Computing LDM FID...")
    # real_256 = load_real_images_for_fid(celebahq_dataset, n_real, 256, "Real (256)")
    # real_feat_256 = get_inception_features(real_256, desc="Real (256)")
    # ldm_feat_256 = get_inception_features(ldm_images, desc="LDM (256)")
    # fid_ldm_raw = calculate_fid(real_feat_256, ldm_feat_256)
    # del real_256; gc.collect()
    # 
    # real_1024 = load_real_images_for_fid(celebahq_dataset, n_real, 1024, "Real (1024)")
    # real_feat_1024 = get_inception_features(real_1024, desc="Real (1024)")
    # ldm_feat_1024 = get_inception_features(ldm_upscaled, desc="LDM (1024)")
    # fid_ldm_enhanced = calculate_fid(real_feat_1024, ldm_feat_1024)
    # results['fid_ldm_raw'], results['fid_ldm_enhanced'] = fid_ldm_raw, fid_ldm_enhanced
    # print(f"‚úì LDM FID: Raw={fid_ldm_raw:.2f}, Enhanced={fid_ldm_enhanced:.2f}")
    fid_ldm_raw, fid_ldm_enhanced = -1.0, -1.0  # Placeholder for debug mode
    results['fid_ldm_raw'], results['fid_ldm_enhanced'] = fid_ldm_raw, fid_ldm_enhanced
    print("[12/12] LDM FID calculation SKIPPED (debug mode)")
    
    del ldm_images, ldm_upscaled; gc.collect(); torch.cuda.empty_cache()
    
    # Print & Save Results
    print("\n" + "=" * 80)
    print("EXPERIMENT RESULTS")
    print("=" * 80)
    print(f"{'Method':<25} {'Resolution':<15} {'FID':<10}")
    print("-" * 60)
    print(f"{'KSWGD':<25} {'256x256':<15} {fid_kswgd_raw:.2f}")
    print(f"{'KSWGD + GFPGAN':<25} {'1024x1024':<15} {fid_kswgd_enhanced:.2f}")
    print(f"{'LDM':<25} {'256x256':<15} {fid_ldm_raw:.2f}")
    print(f"{'LDM + GFPGAN':<25} {'1024x1024':<15} {fid_ldm_enhanced:.2f}")
    print("=" * 80)
    
    config_path = os.path.join(exp_dir, "experiment_config.txt")
    with open(config_path, 'w') as f_out:
        f_out.write(f"{'='*60}\nEXPERIMENT {experiment_id} CONFIG & RESULTS\n{'='*60}\n")
        f_out.write(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
        f_out.write("PARAMETERS:\n")
        for k, v in config.items():
            f_out.write(f"  {k}: {v}\n")
        f_out.write(f"\nRESULTS:\n")
        f_out.write(f"  KSWGD Raw FID: {fid_kswgd_raw:.4f}\n")
        f_out.write(f"  KSWGD Enhanced FID: {fid_kswgd_enhanced:.4f}\n")
        f_out.write(f"  LDM Raw FID: {fid_ldm_raw:.4f}\n")
        f_out.write(f"  LDM Enhanced FID: {fid_ldm_enhanced:.4f}\n")
        f_out.write(f"\nTIMING:\n")
        f_out.write(f"  KSWGD: {kswgd_time:.1f}s\n")
        f_out.write(f"  LDM: {ldm_time:.1f}s\n")
    print(f"‚úì Config saved to: {config_path}")
    

    return results
    return resultsprint("‚úì All functions defined!")


print("‚úì All functions defined!")

In [None]:
# ============== Cell 3: Parameter Grid and Iteration Loop ==============
import itertools
from datetime import datetime
import torch.multiprocessing as mp

# Verify we're using the GPU selected in Cell 1
print(f"üéØ Using device from Cell 1: {device}")
print(f"   SELECTED_GPU = {SELECTED_GPU}")

# ==================== Define Parameter Grid ====================
# DEBUG MODE - Significantly reduced parameters for quick testing
DEBUG_MODE = False
print("‚úì PRODUCTION MODE - Using full parameters")

# Production parameters
DEBUG_MAX_SAMPLES = 28000
DEBUG_NUM_PARTICLES = 16
DEBUG_MLP_EPOCHS = 300
DEBUG_K_EIG = 1000
DEBUG_N_DICT = 300
DEBUG_KSWGD_ITERS = 300

# ==================== Simplified Experiment Design ====================
# - dt_edmd: [0.05, 0.1, 0.5] - 3 choices
# - kswgd_step_size: [0.05, 0.1, 0.5] - 3 choices
# - reduced_dim: fixed at 8
# - mlp_hidden_dim: fixed at 512

dt_edmd_values = [0.05, 0.1, 0.5]
kswgd_step_size_values = [0.05, 0.1, 0.5]

configs_to_run = []
for dt in dt_edmd_values:
    for step_size in kswgd_step_size_values:
        configs_to_run.append({
            'dt_edmd': dt,
            'mlp_epochs': DEBUG_MLP_EPOCHS,
            'mlp_hidden_dim': 512,
            'reduced_dim': 8,
            'k_eig': DEBUG_K_EIG,
            'n_dict_components': DEBUG_N_DICT,
            'kswgd_num_particles': DEBUG_NUM_PARTICLES,
            'kswgd_num_iters': DEBUG_KSWGD_ITERS,
            'kswgd_step_size': step_size,
            'max_samples': DEBUG_MAX_SAMPLES,
        })

print(f"üìã SELECTED CONFIGS: {len(configs_to_run)} experiments (3 x 3 = 9)")
print(f"\nüîß Parameter settings:")
print(f"  max_samples: {DEBUG_MAX_SAMPLES}")
print(f"  num_particles: {DEBUG_NUM_PARTICLES}")
print(f"  mlp_epochs: {DEBUG_MLP_EPOCHS}")
print(f"  k_eig: {DEBUG_K_EIG}")
print(f"  n_dict_components: {DEBUG_N_DICT}")
print(f"  kswgd_num_iters: {DEBUG_KSWGD_ITERS}")
print(f"\nüìä Experiment design:")
print(f"  Fixed: reduced_dim=8, mlp_hidden_dim=512")
print(f"  Variable: dt_edmd={dt_edmd_values}")
print(f"  Variable: kswgd_step_size={kswgd_step_size_values}")

# ==================== Run Experiments ====================
all_results = []
results_save_path = f"/workspace/kswgd/cache/iteration_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pkl"

# Sequential execution on the selected GPU
for i, config in enumerate(configs_to_run):
    print(f"\n{'#'*80}")
    print(f"# EXPERIMENT {i+1}/{len(configs_to_run)} on GPU {SELECTED_GPU}")
    print(f"#   dt_edmd={config['dt_edmd']}, reduced_dim={config['reduced_dim']}, mlp_hidden={config['mlp_hidden_dim']}")
    print(f"{'#'*80}")
    
    try:
        result = run_full_experiment(config, experiment_id=i+1)
        all_results.append(result)
        
        # Save intermediate results
        with open(results_save_path, 'wb') as f:
            pickle.dump(all_results, f)
        print(f"\n‚úì Results saved to {results_save_path}")
        
    except Exception as e:
        print(f"\n‚úó Experiment failed with error: {e}")
        import traceback
        traceback.print_exc()
        all_results.append({'config': config, 'error': str(e)})
        
        # Still save failed results
        with open(results_save_path, 'wb') as f:
            pickle.dump(all_results, f)
    
    # Memory cleanup between experiments
    gc.collect()
    torch.cuda.empty_cache()

# ==================== Summary ====================
all_results_sorted = sorted(all_results, key=lambda x: x.get('experiment_id', 0))

print("\n" + "=" * 110)
print("FINAL SUMMARY - ALL EXPERIMENTS")
print("=" * 110)
print(f"{'#':<3} {'dt_edmd':<8} {'reduced':<8} {'hidden':<8} {'epochs':<8} {'KSWGD Raw':<12} {'KSWGD Enh':<12} {'LDM Raw':<12} {'LDM Enh':<12}")
print("-" * 110)

for result in all_results_sorted:
    exp_id = result.get('experiment_id', '?')
    if 'error' in result:
        print(f"{exp_id:<3} ERROR: {result['error'][:60]}...")
    else:
        cfg = result['config']
        print(f"{exp_id:<3} {cfg['dt_edmd']:<8.2f} {cfg['reduced_dim']:<8} {cfg['mlp_hidden_dim']:<8} {cfg['mlp_epochs']:<8} {result['fid_kswgd_raw']:<12.2f} {result['fid_kswgd_enhanced']:<12.2f} {result['fid_ldm_raw']:<12.2f} {result['fid_ldm_enhanced']:<12.2f}")

print("=" * 110)
print(f"\nResults saved to: {results_save_path}")
print("\n‚ö†Ô∏è DEBUG MODE parameters are significantly reduced. Restore normal values after testing passes.")