In [None]:
"""
Comparing VAE, GAN, and Diffusion Models for Medical Image Synthesis
"""

import sys
import os
import random
import warnings
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.nn.utils import spectral_norm
from torchvision import transforms, datasets, models
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from skimage.metrics import structural_similarity as ssim
from scipy.linalg import sqrtm
from einops import rearrange

# ==============================================================================
# IMPORTANT: Add PathLDM to path BEFORE any ldm imports
# The user needs to set this path correctly.
# ==============================================================================
PATH_LDM_PROJECT_ROOT = '/home/zihend1/Disentanglement/PathLDM'
if PATH_LDM_PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PATH_LDM_PROJECT_ROOT)

# Diffusion model imports (must be after sys.path update)
from omegaconf import OmegaConf
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config

warnings.filterwarnings("ignore")

# ==============================================================================
#  centrally managed configuration
# ==============================================================================
class Config:
    # --- System ---
    DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    SEED = 42
    
    # --- Paths ---
    # Diffusion Model
    DIFFUSION_CONFIG_PATH = Path(PATH_LDM_PROJECT_ROOT) / "plip_imagenet_finetune/configs/08-03T09-35-project.yaml"
    DIFFUSION_CKPT_PATH = Path(PATH_LDM_PROJECT_ROOT) / "plip_imagenet_finetune/checkpoints/epoch_3.ckpt"
    
    # VAE/GAN Training Data
    UNI_FEATURES_PATH = "uni_features.npy"
    UNI_LABELS_PATH = "uni_labels.npy"
    REAL_IMAGES_PATH = '/extra/zhanglab0/INDV/zihend1/Disentanglement/ICIAR2018_BACH_Challenge/Photos'

    # --- Training Hyperparameters ---
    BATCH_SIZE = 16
    VAE_EPOCHS = 50
    GAN_EPOCHS = 100
    VAE_LR = 1e-4
    GAN_G_LR = 1e-4
    GAN_D_LR = 2e-5
    BETA = 4.0 # For Beta-VAE
    
    # --- Model & Generation Parameters ---
    LATENT_DIM = 32
    CONDITION_DIM = 4
    IMG_CHANNELS = 3
    IMG_SIZE = 224 # VAE/GAN output size
    DIFFUSION_GEN_SIZE = 64 # Native size of diffusion model
    NUM_SAMPLES_PER_CLASS = 4

def set_seed(seed):
    """Set random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(Config.SEED)

# ==============================================================================
# Model Definitions (VAE, GAN)
# ==============================================================================

class UNIBetaVAE(nn.Module):
    """Conditional Beta-VAE for pathology image generation"""
    def __init__(self, uni_emb_dim=1536, latent_dim=Config.LATENT_DIM, condition_dim=Config.CONDITION_DIM, img_channels=Config.IMG_CHANNELS, img_size=Config.IMG_SIZE):
        super().__init__()
        # Encoder
        self.fc_mu = nn.Linear(uni_emb_dim, latent_dim)
        self.fc_logvar = nn.Linear(uni_emb_dim, latent_dim)
        
        # Decoder
        # Calculate initial size for ConvTranspose2d based on final desired size
        # 224 -> 112 -> 56 -> 28 -> 14. We start from 14x14.
        self.decoder_input = nn.Linear(latent_dim + condition_dim, 1024 * 14 * 14)
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (1024, 14, 14)),
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1), # 14->28
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), # 28->56
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # 56->112
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 112->224
            nn.ReLU(),
            nn.Conv2d(64, img_channels, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid() # Output in [0, 1]
        )

    def encode(self, uni_emb):
        mu = self.fc_mu(uni_emb)
        logvar = self.fc_logvar(uni_emb)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, cond):
        z_cond = torch.cat([z, cond], dim=1)
        return self.decoder(self.decoder_input(z_cond))

    def forward(self, uni_emb, cond):
        mu, logvar = self.encode(uni_emb)
        z = self.reparameterize(mu, logvar)
        recon_img = self.decode(z, cond)
        return recon_img, mu, logvar

class Generator(nn.Module):
    """GAN Generator, similar architecture to VAE Decoder"""
    def __init__(self, latent_dim=Config.LATENT_DIM, condition_dim=Config.CONDITION_DIM, img_channels=Config.IMG_CHANNELS):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim + condition_dim, 1024 * 14 * 14),
            nn.Unflatten(1, (1024, 14, 14)),
            nn.ConvTranspose2d(1024, 512, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(512, 256, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(64, img_channels, 3, 1, 1),
            nn.Tanh() # Output in [-1, 1]
        )

    def forward(self, z, cond):
        input_tensor = torch.cat([z, cond], dim=1)
        return self.model(input_tensor)

class Discriminator(nn.Module):
    """GAN Discriminator with spectral normalization"""
    def __init__(self, img_channels=Config.IMG_CHANNELS, condition_dim=Config.CONDITION_DIM, img_size=Config.IMG_SIZE):
        super().__init__()
        self.img_size = img_size
        self.model = nn.Sequential(
            spectral_norm(nn.Conv2d(img_channels + condition_dim, 64, 4, 2, 1)), nn.LeakyReLU(0.2), # 224->112
            spectral_norm(nn.Conv2d(64, 128, 4, 2, 1)), nn.LeakyReLU(0.2), # 112->56
            spectral_norm(nn.Conv2d(128, 256, 4, 2, 1)), nn.LeakyReLU(0.2), # 56->28
            spectral_norm(nn.Conv2d(256, 512, 4, 2, 1)), nn.LeakyReLU(0.2), # 28->14
            nn.Flatten(),
        )
        # Calculate flattened size dynamically
        final_feature_size = img_size // (2**4)
        flattened_dim = 512 * final_feature_size * final_feature_size
        self.output_layer = nn.Sequential(
            nn.Linear(flattened_dim, 1),
            nn.Sigmoid() # Probability output
        )

    def forward(self, img, cond):
        # Create a condition map and concatenate it as a channel
        cond = cond.view(cond.size(0), cond.size(1), 1, 1).repeat(1, 1, self.img_size, self.img_size)
        input_tensor = torch.cat([img, cond], dim=1)
        features = self.model(input_tensor)
        return self.output_layer(features)
        
# ==============================================================================
# Diffusion Model Wrapper
# ==============================================================================
class DiffusionModelWrapper:
    """Fixed wrapper for the pre-trained Latent Diffusion Model"""
    def __init__(self, config_path, ckpt_path, device):
        self.device = device
        self.available = False
        self.model = None
        self.sampler = None
        self.label_text = {
            0: "Normal tissue",
            1: "Benign tumor tissue",
            2: "In-situ carcinoma",
            3: "Invasive carcinoma"
        }
        
        if not config_path.exists() or not ckpt_path.exists():
            print(f"Warning: Diffusion model paths not found.")
            print(f"  Config path: {config_path}")
            print(f"  Checkpoint path: {ckpt_path}")
            print("Diffusion model will be unavailable.")
            return

        try:
            print(f"Loading diffusion model from {ckpt_path}...")
            config = OmegaConf.load(config_path)
            
            # As you correctly did, remove ckpt_path from config to avoid conflicts
            config.model.params.first_stage_config.params.pop('ckpt_path', None)
            config.model.params.unet_config.params.pop('ckpt_path', None)
            
            pl_sd = torch.load(ckpt_path, map_location="cpu")
            sd = pl_sd.get("state_dict", pl_sd)
            
            self.model = instantiate_from_config(config.model)
            self.model.load_state_dict(sd, strict=False)
            self.model.to(device)
            self.model.eval()
            
            self.sampler = DDIMSampler(self.model)
            self.available = True
            print("Diffusion model loaded successfully!")
            
        except Exception as e:
            print(f"Error loading diffusion model: {e}")
            print("Diffusion model will be unavailable.")
            self.available = False

    def generate_images(self, labels, num_samples_per_class, steps=50, scale=1.5, image_size=Config.DIFFUSION_GEN_SIZE):
        """Generate images using the diffusion model"""
        if not self.available:
            print("Cannot generate images: Diffusion model is not available.")
            return None, None
        
        expanded_labels = [label for label in labels for _ in range(num_samples_per_class)]
        batch_size = len(expanded_labels)
        prompts = [self.label_text[i] for i in expanded_labels]
        
        shape = [Config.IMG_CHANNELS, image_size, image_size]
        
        with torch.no_grad():
            # Get unconditional and conditional embeddings
            uc = self.model.get_learned_conditioning([""] * batch_size)
            cc = self.model.get_learned_conditioning(prompts)
            
            # Sample from the model
            samples_ddim, _ = self.sampler.sample(
                S=steps,
                batch_size=batch_size,
                shape=shape,
                conditioning=cc,
                verbose=False,
                unconditional_guidance_scale=scale,
                unconditional_conditioning=uc,
                eta=0.0
            )
            
            # Decode latents to pixel space and normalize to [0, 1]
            x_samples = self.model.decode_first_stage(samples_ddim)
            x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
            
            # Upsample to match VAE/GAN output size for fair comparison
            x_samples = F.interpolate(x_samples, size=(Config.IMG_SIZE, Config.IMG_SIZE), mode='bilinear', align_corners=False)
        
        return x_samples, expanded_labels

# ==============================================================================
# Training and Data Functions
# ==============================================================================

def get_dataloaders(device):
    """Loads real data or creates synthetic data as a fallback."""
    try:
        # Attempt to load pre-computed features and labels
        uni_emb = torch.tensor(np.load(Config.UNI_FEATURES_PATH)).float()
        labels = torch.tensor(np.load(Config.UNI_LABELS_PATH)).long()
        labels_onehot = F.one_hot(labels, num_classes=Config.CONDITION_DIM).float()
        print(f"Successfully loaded UNI features ({uni_emb.shape}) and labels ({labels.shape}).")
        
        uni_dataset = TensorDataset(uni_emb.to(device), labels_onehot.to(device))
        uni_loader = DataLoader(uni_dataset, batch_size=Config.BATCH_SIZE, shuffle=True)
    except FileNotFoundError:
        print("Warning: UNI feature/label files not found. Creating synthetic data for VAE/GAN training.")
        uni_emb = torch.randn(1000, 1536).to(device)
        labels = torch.randint(0, Config.CONDITION_DIM, (1000,)).to(device)
        labels_onehot = F.one_hot(labels, num_classes=Config.CONDITION_DIM).float().to(device)
        uni_dataset = TensorDataset(uni_emb, labels_onehot)
        uni_loader = DataLoader(uni_dataset, batch_size=Config.BATCH_SIZE, shuffle=True)

    try:
        # Attempt to load real images
        transform = transforms.Compose([
            transforms.Resize((Config.IMG_SIZE, Config.IMG_SIZE)),
            transforms.ToTensor()
        ])
        dataset_img = datasets.ImageFolder(root=Config.REAL_IMAGES_PATH, transform=transform)
        loader_img = DataLoader(dataset_img, batch_size=Config.BATCH_SIZE, shuffle=True)
        print(f"Successfully loaded real images from {Config.REAL_IMAGES_PATH}.")
    except Exception as e:
        print(f"Warning: Could not load real images from path. Reason: {e}")
        print("Creating synthetic images for VAE/GAN training.")
        synthetic_imgs = torch.rand(1000, Config.IMG_CHANNELS, Config.IMG_SIZE, Config.IMG_SIZE).to(device)
        synthetic_labels = torch.zeros(1000) # Dummy labels
        img_dataset = TensorDataset(synthetic_imgs, synthetic_labels)
        loader_img = DataLoader(img_dataset, batch_size=Config.BATCH_SIZE, shuffle=True)
        
    return uni_loader, loader_img

def train_vae(uni_loader, loader_img, device):
    """Train VAE model"""
    print(f"--- Training VAE for {Config.VAE_EPOCHS} epochs ---")
    model = UNIBetaVAE().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=Config.VAE_LR)
    
    for epoch in range(Config.VAE_EPOCHS):
        model.train()
        total_loss, batch_count = 0, 0
        
        for (uni_batch, cond_batch), (img_batch, _) in zip(uni_loader, loader_img):
            recon_img, mu, logvar = model(uni_batch, cond_batch)
            
            recon_loss = F.mse_loss(recon_img, img_batch.to(device))
            kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            kl_div /= uni_batch.size(0) # Per-sample KL
            loss = recon_loss + Config.BETA * kl_div
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            batch_count += 1
            
        avg_loss = total_loss / batch_count
        if (epoch + 1) % 10 == 0:
            print(f"VAE Epoch {epoch+1}/{Config.VAE_EPOCHS}, Average Loss: {avg_loss:.4f}")
            
    return model

def train_gan(uni_loader, loader_img, device):
    """Train GAN model"""
    print(f"--- Training GAN for {Config.GAN_EPOCHS} epochs ---")
    gen = Generator().to(device)
    disc = Discriminator().to(device)
    
    opt_g = optim.Adam(gen.parameters(), lr=Config.GAN_G_LR, betas=(0.5, 0.999))
    opt_d = optim.Adam(disc.parameters(), lr=Config.GAN_D_LR, betas=(0.5, 0.999))
    criterion = nn.BCELoss()
    
    for epoch in range(Config.GAN_EPOCHS):
        epoch_g_loss, epoch_d_loss, batch_count = 0, 0, 0
        
        for (uni_batch, cond_batch), (real_img, _) in zip(uni_loader, loader_img):
            real_img = real_img.to(device) * 2.0 - 1.0 # Normalize to [-1, 1] for Tanh
            batch_size = real_img.size(0)
            
            # --- Train Discriminator ---
            opt_d.zero_grad()
            real_labels = torch.ones(batch_size, 1).to(device) * 0.9 # Label smoothing
            fake_labels = torch.zeros(batch_size, 1).to(device) + 0.1 # Label smoothing
            
            pred_real = disc(real_img, cond_batch)
            loss_d_real = criterion(pred_real, real_labels)
            
            z = torch.randn(batch_size, Config.LATENT_DIM).to(device)
            fake_img = gen(z, cond_batch)
            pred_fake = disc(fake_img.detach(), cond_batch)
            loss_d_fake = criterion(pred_fake, fake_labels)
            
            loss_d = (loss_d_real + loss_d_fake) / 2
            loss_d.backward()
            opt_d.step()
            
            # --- Train Generator ---
            opt_g.zero_grad()
            pred_g = disc(fake_img, cond_batch)
            loss_g = criterion(pred_g, real_labels) # Trick discriminator
            loss_g.backward()
            opt_g.step()
            
            epoch_g_loss += loss_g.item()
            epoch_d_loss += loss_d.item()
            batch_count += 1
            
        avg_g_loss = epoch_g_loss / batch_count
        avg_d_loss = epoch_d_loss / batch_count
        if (epoch + 1) % 20 == 0:
            print(f"GAN Epoch [{epoch+1}/{Config.GAN_EPOCHS}] | D_loss: {avg_d_loss:.4f} | G_loss: {avg_g_loss:.4f}")
            
    return gen, disc


# ==============================================================================
# Image Quality Evaluation
# ==============================================================================
class ImageQualityEvaluator:
    def __init__(self, device='cuda'):
        self.device = device
        self.inception = models.inception_v3(pretrained=True, transform_input=False)
        self.inception.fc = nn.Identity()
        self.inception.to(device).eval()
    
    def calculate_fid(self, real_features, fake_features):
        mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
        mu2, sigma2 = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)
        diff = mu1 - mu2
        covmean = sqrtm(sigma1.dot(sigma2))
        if np.iscomplexobj(covmean): covmean = covmean.real
        fid = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
        return fid

    def calculate_is(self, images, splits=10):
        N = len(images)
        preds = []
        for i in range(0, N, Config.BATCH_SIZE):
            batch = images[i:i+Config.BATCH_SIZE].to(self.device)
            batch = F.interpolate(batch, size=(299, 299), mode='bilinear', align_corners=False)
            with torch.no_grad():
                pred = F.softmax(self.inception(batch), dim=1)
            preds.append(pred.cpu().numpy())
        preds = np.concatenate(preds, 0)
        scores = []
        for i in range(splits):
            part = preds[i * (N // splits): (i + 1) * (N // splits), :]
            py = np.mean(part, axis=0)
            scores_part = []
            for k in range(part.shape[0]):
                pyx = part[k, :]
                scores_part.append(F.kl_div(torch.log(torch.tensor(pyx)), torch.tensor(py), reduction='sum').item())
            scores.append(np.exp(np.mean(scores_part)))
        return np.mean(scores), np.std(scores)

    def calculate_ssim(self, real_images, fake_images):
        ssim_scores = []
        for real, fake in zip(real_images, fake_images):
            real_np = real.cpu().numpy().transpose(1, 2, 0)
            fake_np = fake.cpu().numpy().transpose(1, 2, 0)
            real_gray = cv2.cvtColor((real_np * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
            fake_gray = cv2.cvtColor((fake_np * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
            score = ssim(real_gray, fake_gray, data_range=255)
            ssim_scores.append(score)
        return np.mean(ssim_scores)

    def extract_features(self, images):
        features = []
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        with torch.no_grad():
            for i in range(0, len(images), Config.BATCH_SIZE):
                batch = images[i:i+Config.BATCH_SIZE].to(self.device)
                batch = F.interpolate(batch, size=(299, 299), mode='bilinear', align_corners=False)
                batch = normalize(batch)
                feat = self.inception(batch)
                features.append(feat.cpu().numpy())
        return np.concatenate(features, 0)

    def evaluate(self, generated_images, real_images=None):
        results = {}
        generated_images = torch.clamp(generated_images, 0, 1)
        try:
            is_mean, is_std = self.calculate_is(generated_images)
            results['IS_mean'], results['IS_std'] = is_mean, is_std
        except Exception as e:
            print(f"Could not calculate IS: {e}")
            results['IS_mean'], results['IS_std'] = 'N/A', 'N/A'
        
        if real_images is not None:
            real_images = torch.clamp(real_images, 0, 1)
            try:
                real_features = self.extract_features(real_images)
                fake_features = self.extract_features(generated_images)
                results['FID'] = self.calculate_fid(real_features, fake_features)
            except Exception as e:
                print(f"Could not calculate FID: {e}")
                results['FID'] = 'N/A'
            try:
                min_len = min(len(real_images), len(generated_images))
                results['SSIM'] = self.calculate_ssim(real_images[:min_len], generated_images[:min_len])
            except Exception as e:
                print(f"Could not calculate SSIM: {e}")
                results['SSIM'] = 'N/A'
        return results

# ==============================================================================
# Visualization Functions
# ==============================================================================
def visualize_comparison(image_dict):
    label_text = {0: "Normal", 1: "Benign", 2: "In-situ", 3: "Invasive"}
    num_methods = len(image_dict)
    num_classes = Config.CONDITION_DIM
    
    fig, axes = plt.subplots(num_methods, num_classes, figsize=(12, 3 * num_methods))
    fig.suptitle('Pathology Generation: Model Comparison', fontsize=16)

    for row, (method, data) in enumerate(image_dict.items()):
        images, labels = data['images'], data['labels']
        if images is None:
            continue
            
        axes[row, 0].set_ylabel(method, fontsize=14, rotation=90, labelpad=20)
        
        for col in range(num_classes):
            # Find the first image for the current class
            try:
                idx = labels.index(col)
                img = images[idx].cpu().permute(1, 2, 0).numpy()
                img = np.clip(img, 0, 1)
                
                axes[row, col].imshow(img)
                if row == 0:
                    axes[row, col].set_title(label_text[col])
            except (ValueError, IndexError):
                # Handle case where a class might not have a generated image
                axes[row, col].text(0.5, 0.5, 'N/A', ha='center', va='center')

            axes[row, col].axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

# ==============================================================================
# Main Pipeline
# ==============================================================================
def main():
    print("=" * 70)
    print("PATHOLOGY IMAGE GENERATION COMPARISON: VAE vs GAN vs Diffusion")
    print(f"Running on device: {Config.DEVICE}")
    print("=" * 70)

    # --- Step 1: Load Data ---
    print("\n[1/6] Loading data...")
    uni_loader, loader_img = get_dataloaders(Config.DEVICE)
    if not uni_loader or not loader_img:
        print("Fatal: Could not load or create data. Exiting.")
        return

    # --- Step 2: Train VAE and GAN ---
    print("\n[2/6] Training VAE model...")
    vae_model = train_vae(uni_loader, loader_img, Config.DEVICE)
    
    print("\n[3/6] Training GAN model...")
    gan_generator, _ = train_gan(uni_loader, loader_img, Config.DEVICE)

    # --- Step 4: Load Diffusion Model ---
    print("\n[4/6] Loading pre-trained Diffusion model...")
    diffusion_wrapper = DiffusionModelWrapper(
        Config.DIFFUSION_CONFIG_PATH, Config.DIFFUSION_CKPT_PATH, Config.DEVICE
    )

    # --- Step 5: Generate Images ---
    print("\n[5/6] Generating images from all models...")
    all_generated_images = {}
    
    # VAE Generation
    vae_model.eval()
    with torch.no_grad():
        z = torch.randn(Config.NUM_SAMPLES_PER_CLASS * Config.CONDITION_DIM, Config.LATENT_DIM).to(Config.DEVICE)
        labels = torch.tensor([i for i in range(Config.CONDITION_DIM) for _ in range(Config.NUM_SAMPLES_PER_CLASS)])
        condition = F.one_hot(labels, num_classes=Config.CONDITION_DIM).float().to(Config.DEVICE)
        vae_images = vae_model.decode(z, condition)
        all_generated_images['VAE'] = {'images': vae_images, 'labels': labels.tolist()}
        print(f"Generated {len(vae_images)} images from VAE.")

    # GAN Generation
    gan_generator.eval()
    with torch.no_grad():
        # Using the same z and labels for a more direct comparison
        gan_images_raw = gan_generator(z, condition)
        gan_images = (gan_images_raw + 1.0) / 2.0 # Rescale from [-1, 1] to [0, 1]
        all_generated_images['GAN'] = {'images': gan_images, 'labels': labels.tolist()}
        print(f"Generated {len(gan_images)} images from GAN.")

    # Diffusion Generation
    diffusion_images, diffusion_labels = diffusion_wrapper.generate_images(
        labels=list(range(Config.CONDITION_DIM)),
        num_samples_per_class=Config.NUM_SAMPLES_PER_CLASS
    )
    all_generated_images['Diffusion'] = {'images': diffusion_images, 'labels': diffusion_labels}
    if diffusion_images is not None:
        print(f"Generated {len(diffusion_images)} images from Diffusion model.")

    # --- Step 6: Visualize and Evaluate ---
    print("\n[6/6] Visualizing results and running evaluation...")
    visualize_comparison(all_generated_images)

    # Prepare real images for evaluation
    real_images_for_eval = []
    for i, (img_batch, _) in enumerate(loader_img):
        real_images_for_eval.append(img_batch)
        if len(real_images_for_eval) * Config.BATCH_SIZE >= 100: # Gather ~100 images
            break
    real_images_for_eval = torch.cat(real_images_for_eval, dim=0) if real_images_for_eval else None

    evaluator = ImageQualityEvaluator(Config.DEVICE)
    
    for method, data in all_generated_images.items():
        print(f"\n--- Evaluating {method} ---")
        if data['images'] is None:
            print("No images to evaluate.")
            continue
        
        results = evaluator.evaluate(data['images'], real_images_for_eval)
        print(f"  Inception Score (IS): {results.get('IS_mean', 'N/A')}")
        print(f"  Fréchet Inception Distance (FID): {results.get('FID', 'N/A')}")
        print(f"  Structural Similarity (SSIM): {results.get('SSIM', 'N/A')}")

if __name__ == '__main__':
    main()

PATHOLOGY IMAGE GENERATION COMPARISON: VAE vs GAN vs Diffusion
Running on device: cuda:0

[1/6] Loading data...
Successfully loaded UNI features (torch.Size([400, 1536])) and labels (torch.Size([400])).
Successfully loaded real images from /extra/zhanglab0/INDV/zihend1/Disentanglement/ICIAR2018_BACH_Challenge/Photos.

[2/6] Training VAE model...
--- Training VAE for 50 epochs ---
VAE Epoch 10/50, Average Loss: 0.6011
VAE Epoch 20/50, Average Loss: 0.2787
VAE Epoch 30/50, Average Loss: 0.1596
VAE Epoch 40/50, Average Loss: 0.1029
VAE Epoch 50/50, Average Loss: 0.0720

[3/6] Training GAN model...
--- Training GAN for 100 epochs ---
