# Data Loading and Preprocessing

In [6]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import matplotlib.pyplot as plt

# Step 1: Custom Dataset Class for CelebA
class CelebADataset(Dataset):
    def __init__(self, img_dir, attr_file, transform=None, selected_attrs=None):
        """
        Custom dataset for CelebA that combines images with attributes
        
        Args:
            img_dir: Directory containing images
            attr_file: Path to attributes CSV file (usually list_attr_celeba.csv)
            transform: Image transformations
            selected_attrs: List of attribute names to use (None = use all)
        """
        self.img_dir = img_dir
        self.transform = transform
        
        # Load attributes CSV
        # The CSV typically has image names in first column and attributes as other columns
        self.attr_df = pd.read_csv(attr_file)
        
        # Handle different CSV formats - some have image names as index, some as first column
        if 'image_id' not in self.attr_df.columns and self.attr_df.columns[0] != 'image_id':
            # If first column contains image names but isn't named 'image_id'
            self.attr_df.rename(columns={self.attr_df.columns[0]: 'image_id'}, inplace=True)
        
        # Get list of image files that exist
        available_images = set(os.listdir(img_dir))
        
        # Filter dataframe to only include images that exist
        self.attr_df = self.attr_df[self.attr_df['image_id'].isin(available_images)]
        
        # Select specific attributes if provided
        if selected_attrs:
            cols_to_keep = ['image_id'] + selected_attrs
            self.attr_df = self.attr_df[cols_to_keep]
        
        # Convert attribute values to 0/1 (they might be -1/1 in original)
        attr_cols = [col for col in self.attr_df.columns if col != 'image_id']
        for col in attr_cols:
            self.attr_df[col] = (self.attr_df[col] + 1) / 2  # Convert -1,1 to 0,1
        
        self.attr_names = attr_cols
        
        print(f"Loaded {len(self.attr_df)} images with {len(self.attr_names)} attributes")
        print(f"Attributes: {self.attr_names}")
    
    def __len__(self):
        return len(self.attr_df)
    
    def __getitem__(self, idx):
        # Get image path and attributes
        row = self.attr_df.iloc[idx]
        img_name = row['image_id']
        img_path = os.path.join(self.img_dir, img_name)
        
        # Load image
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a black image as fallback
            image = Image.new('RGB', (178, 218), color='black')
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        # Get attributes (excluding image_id)
        attributes = torch.tensor(row[self.attr_names].astype(int).values, dtype=torch.float32)
        
        return image, attributes, img_name

# Step 2: Data Loading Setup
def setup_celeba_data(img_dir, attr_file, batch_size=32, img_size=64, selected_attrs=None):
    """
    Set up data loaders for CelebA dataset
    
    Args:
        img_dir: Path to images directory
        attr_file: Path to attributes CSV file
        batch_size: Batch size for training
        img_size: Size to resize images to
        selected_attrs: List of specific attributes to use
    """
    
    # Define transforms
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
    ])
    
    # Create dataset
    dataset = CelebADataset(img_dir, attr_file, transform, selected_attrs)
    
    # Create data loader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    
    return dataset, dataloader

# Step 3: Example usage and data exploration
def explore_celeba_data(dataset, dataloader):
    """
    Explore the loaded dataset
    """
    print(f"Dataset size: {len(dataset)}")
    print(f"Number of attributes: {len(dataset.attr_names)}")
    print(f"Attribute names: {dataset.attr_names}")
    
    # Get a sample batch
    sample_batch = next(iter(dataloader))
    images, attributes, img_names = sample_batch
    
    print(f"Batch image shape: {images.shape}")
    print(f"Batch attributes shape: {attributes.shape}")
    print(f"Sample image names: {img_names[:5]}")
    
    # Show attribute statistics
    print("\nAttribute statistics (first batch):")
    for i, attr_name in enumerate(dataset.attr_names):
        mean_val = attributes[:, i].mean().item()
        print(f"{attr_name}: {mean_val:.2f}")
    
    return sample_batch

# Usage example:

# Set paths according to your Kaggle dataset structure
IMG_DIR = "/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba"
ATTR_FILE = "/kaggle/input/celeba-dataset/list_attr_celeba.csv"

# You can select specific attributes for disentanglement experiments
SELECTED_ATTRS = [
    'Male', 'Young', 'Eyeglasses', 'Bald', 'Mustache', 
    'Smiling', 'Attractive', 'Blond_Hair', 'Heavy_Makeup'
]

# Or use all attributes by setting SELECTED_ATTRS = None

# Create dataset and dataloader
dataset, dataloader = setup_celeba_data(
    IMG_DIR, ATTR_FILE, 
    batch_size=32, 
    img_size=64, 
    selected_attrs=SELECTED_ATTRS
)

# Explore the data
sample_batch = explore_celeba_data(dataset, dataloader)


Loaded 202599 images with 9 attributes
Attributes: ['Male', 'Young', 'Eyeglasses', 'Bald', 'Mustache', 'Smiling', 'Attractive', 'Blond_Hair', 'Heavy_Makeup']
Dataset size: 202599
Number of attributes: 9
Attribute names: ['Male', 'Young', 'Eyeglasses', 'Bald', 'Mustache', 'Smiling', 'Attractive', 'Blond_Hair', 'Heavy_Makeup']
Batch image shape: torch.Size([32, 3, 64, 64])
Batch attributes shape: torch.Size([32, 9])
Sample image names: ('085276.jpg', '075760.jpg', '186410.jpg', '189671.jpg', '169024.jpg')

Attribute statistics (first batch):
Male: 0.53
Young: 0.72
Eyeglasses: 0.09
Bald: 0.03
Mustache: 0.06
Smiling: 0.44
Attractive: 0.44
Blond_Hair: 0.12
Heavy_Makeup: 0.25


# VAE architecture with CNNs

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Step 2: VAE Architecture with CNNs

class Encoder(nn.Module):
    def __init__(self, latent_dim=64, img_channels=3):
        super(Encoder, self).__init__()
        self.latent_dim = latent_dim
        
        # CNN layers for feature extraction
        self.conv_layers = nn.Sequential(
            # Input: 3 x 64 x 64
            nn.Conv2d(img_channels, 32, kernel_size=4, stride=2, padding=1),  # 32 x 32 x 32
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # 64 x 16 x 16
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 128 x 8 x 8
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # 256 x 4 x 4
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        # Calculate flattened size
        self.flattened_size = 256 * 4 * 4
        
        # Fully connected layers for latent space
        self.fc_mu = nn.Linear(self.flattened_size, latent_dim)
        self.fc_logvar = nn.Linear(self.flattened_size, latent_dim)
        
    def forward(self, x):
        # CNN feature extraction
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)  # Flatten
        
        # Get latent parameters
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim=64, img_channels=3):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim
        self.img_channels = img_channels
        
        # Project latent to feature map
        self.fc = nn.Linear(latent_dim, 256 * 4 * 4)
        
        # Transpose CNN layers for image generation
        self.deconv_layers = nn.Sequential(
            # 256 x 4 x 4
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 128 x 8 x 8
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),   # 64 x 16 x 16
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),    # 32 x 32 x 32
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(32, img_channels, kernel_size=4, stride=2, padding=1),  # 3 x 64 x 64
            nn.Tanh()  # Output in [-1, 1] range
        )
        
    def forward(self, z):
        # Project and reshape
        x = self.fc(z)
        x = x.view(x.size(0), 256, 4, 4)
        
        # Generate image
        x = self.deconv_layers(x)
        
        return x

class VAE(nn.Module):
    def __init__(self, latent_dim=64, img_channels=3):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        
        self.encoder = Encoder(latent_dim, img_channels)
        self.decoder = Decoder(latent_dim, img_channels)
        
    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick for VAE
        """
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
        else:
            return mu
    
    def forward(self, x):
        # Encode
        mu, logvar = self.encoder(x)
        
        # Reparameterize
        z = self.reparameterize(mu, logvar)
        
        # Decode
        recon_x = self.decoder(z)
        
        return recon_x, mu, logvar, z
    
    def generate(self, num_samples, device):
        """
        Generate new samples from the latent space
        """
        self.eval()
        with torch.no_grad():
            z = torch.randn(num_samples, self.latent_dim).to(device)
            generated = self.decoder(z)
        return generated
    
    def encode(self, x):
        """
        Encode input to latent space
        """
        self.eval()
        with torch.no_grad():
            mu, logvar = self.encoder(x)
            z = self.reparameterize(mu, logvar)
        return z, mu, logvar

# Test the architecture
def test_vae_architecture():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Create VAE
    vae = VAE(latent_dim=64, img_channels=3).to(device)
    
    # Test with dummy input
    batch_size = 8
    dummy_input = torch.randn(batch_size, 3, 64, 64).to(device)
    
    print("Testing VAE architecture...")
    print(f"Input shape: {dummy_input.shape}")
    
    # Forward pass
    recon_x, mu, logvar, z = vae(dummy_input)
    
    print(f"Reconstructed shape: {recon_x.shape}")
    print(f"Latent z shape: {z.shape}")
    print(f"Mu shape: {mu.shape}")
    print(f"Logvar shape: {logvar.shape}")
    
    # Test generation
    generated = vae.generate(num_samples=4, device=device)
    print(f"Generated samples shape: {generated.shape}")
    
    # Count parameters
    total_params = sum(p.numel() for p in vae.parameters())
    trainable_params = sum(p.numel() for p in vae.parameters() if p.requires_grad)
    
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    return vae

# Usage:
vae_model = test_vae_architecture()

Testing VAE architecture...
Input shape: torch.Size([8, 3, 64, 64])
Reconstructed shape: torch.Size([8, 3, 64, 64])
Latent z shape: torch.Size([8, 64])
Mu shape: torch.Size([8, 64])
Logvar shape: torch.Size([8, 64])
Generated samples shape: torch.Size([4, 3, 64, 64])
Total parameters: 2,172,099
Trainable parameters: 2,172,099


# Disentangled VAE Loss Function

In [8]:
import torch
import torch.nn.functional as F
import numpy as np

# Step 3: Loss Functions for Disentangled VAE

def vae_loss(recon_x, x, mu, logvar, beta=1.0):
    """
    Standard VAE loss with β-VAE modification
    
    Args:
        recon_x: Reconstructed images
        x: Original images
        mu: Mean of latent distribution
        logvar: Log variance of latent distribution
        beta: Weight for KL divergence (β-VAE)
    """
    # Reconstruction loss (MSE or BCE)
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    
    # KL divergence loss
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    # Total loss
    total_loss = recon_loss + beta * kl_loss
    
    return total_loss, recon_loss, kl_loss

def factor_vae_loss(recon_x, x, mu, logvar, z, discriminator=None, gamma=10.0, beta=1.0):
    """
    Factor-VAE loss for disentanglement
    
    Args:
        recon_x: Reconstructed images
        x: Original images
        mu: Mean of latent distribution
        logvar: Log variance of latent distribution
        z: Sampled latent codes
        discriminator: Discriminator network (if available)
        gamma: Weight for total correlation penalty
        beta: Weight for KL divergence
    """
    # Standard VAE loss
    vae_loss_val, recon_loss, kl_loss = vae_loss(recon_x, x, mu, logvar, beta)
    
    # Total Correlation penalty (if discriminator is available)
    tc_loss = 0
    if discriminator is not None:
        # Permute latent codes to break dependencies
        z_perm = permute_latent_codes(z)
        
        # Discriminator scores
        d_z = discriminator(z)
        d_z_perm = discriminator(z_perm)
        
        # Total correlation loss
        tc_loss = torch.mean(d_z) - torch.mean(d_z_perm)
    
    total_loss = vae_loss_val + gamma * tc_loss
    
    return total_loss, recon_loss, kl_loss, tc_loss

def beta_tcvae_loss(recon_x, x, mu, logvar, z, alpha=1.0, beta=6.0, gamma=1.0):
    batch_size, latent_dim = z.shape
    
    # Reconstruction loss
    recon_loss = F.mse_loss(recon_x, x, reduction='sum') / batch_size
    
    # log q(z|x) = sum over latent dims
    log_qz_cond = log_density_gaussian(z, mu, logvar).sum(1)  # [batch]
    log_pz = log_density_standard_gaussian(z).sum(1)          # [batch]
    
    # Pairwise comparisons for marginal q(z)
    z_expand = z.unsqueeze(1)          # [B, 1, L]
    mu_expand = mu.unsqueeze(0)        # [1, B, L]
    logvar_expand = logvar.unsqueeze(0)# [1, B, L]

    # log q(z_j | x_i) for all pairs
    log_qz_matrix = log_density_gaussian(z_expand, mu_expand, logvar_expand)  # [B, B, L]
    log_qz = torch.logsumexp(log_qz_matrix.sum(2), dim=1) - np.log(batch_size)  # [B]

    # MI, TC, DW (see Chen et al., β-TCVAE)
    mi_loss = (log_qz_cond - log_qz).mean()
    tc_loss = (log_qz - log_qz_matrix.sum(2).mean(0)).mean()
    dw_kl_loss = (log_qz_matrix.sum(2).mean(0) - log_pz.mean()).mean()

    total_loss = recon_loss + alpha * mi_loss + beta * tc_loss + gamma * dw_kl_loss
    return total_loss, recon_loss, mi_loss, tc_loss, dw_kl_loss
    

def log_density_gaussian(z, mu, logvar):
    """
    Log density of Gaussian N(mu, sigma^2) for each dimension.
    Returns shape [batch, latent_dim].
    """
    norm_const = -0.5 * np.log(2 * np.pi)
    log_density = norm_const - 0.5 * logvar - 0.5 * ((z - mu) ** 2) / torch.exp(logvar)
    return log_density  # [batch, latent_dim]

def log_density_standard_gaussian(z):
    """
    Log density of standard Gaussian N(0, I).
    Returns shape [batch, latent_dim].
    """
    norm_const = -0.5 * np.log(2 * np.pi)
    log_density = norm_const - 0.5 * (z ** 2)
    return log_density  # [batch, latent_dim]

def permute_latent_codes(z):
    """
    Randomly permute latent codes across batch dimension
    """
    z_perm = z.clone()
    for i in range(z.size(1)):  # For each latent dimension
        perm_idx = torch.randperm(z.size(0))
        z_perm[:, i] = z_perm[perm_idx, i]
    return z_perm

def supervised_disentanglement_loss(z, attributes, attribute_weights=None):
    """
    Supervised loss to encourage specific latent dimensions to correspond to attributes
    
    Args:
        z: Latent codes [batch_size, latent_dim]
        attributes: Ground truth attributes [batch_size, num_attributes]
        attribute_weights: Weights for each attribute
    """
    batch_size, latent_dim = z.shape
    batch_size, num_attributes = attributes.shape
    
    if attribute_weights is None:
        attribute_weights = torch.ones(num_attributes)
    
    # Simple approach: use first few latent dimensions for attributes
    num_supervised = min(latent_dim, num_attributes)
    
    # Supervised loss: encourage z[:, :num_supervised] to predict attributes
    z_attr = z[:, :num_supervised]
    attr_target = attributes[:, :num_supervised]
    
    # Use sigmoid to get probabilities and BCE loss
    attr_pred = torch.sigmoid(z_attr)
    supervised_loss = F.binary_cross_entropy(attr_pred, attr_target, weight=attribute_weights[:num_supervised])
    
    return supervised_loss

# Complete loss function combining different approaches
class DisentangledVAELoss:
    def __init__(self, loss_type='beta_vae', beta=4.0, gamma=10.0, alpha=1.0, 
                 supervised_weight=0.1):
        self.loss_type = loss_type
        self.beta = beta
        self.gamma = gamma
        self.alpha = alpha
        self.supervised_weight = supervised_weight
        
    def __call__(self, recon_x, x, mu, logvar, z, attributes=None):
        """
        Compute disentangled VAE loss
        """
        losses = {}
        
        if self.loss_type == 'beta_vae':
            total_loss, recon_loss, kl_loss = vae_loss(recon_x, x, mu, logvar, self.beta)
            losses.update({
                'total': total_loss,
                'reconstruction': recon_loss,
                'kl': kl_loss
            })
            
        elif self.loss_type == 'beta_tcvae':
            total_loss, recon_loss, mi_loss, tc_loss, dw_kl_loss = beta_tcvae_loss(
                recon_x, x, mu, logvar, z, self.alpha, self.beta, self.gamma
            )
            losses.update({
                'total': total_loss,
                'reconstruction': recon_loss,
                'mi': mi_loss,
                'tc': tc_loss,
                'dw_kl': dw_kl_loss
            })
        
        # Add supervised disentanglement if attributes provided
        if attributes is not None and self.supervised_weight > 0:
            supervised_loss = supervised_disentanglement_loss(z, attributes)
            losses['supervised'] = supervised_loss
            losses['total'] = losses['total'] + self.supervised_weight * supervised_loss
        
        return losses

# Usage example:

# Create loss function
# loss_fn = DisentangledVAELoss(
#     loss_type='beta_tcvae',
#     beta=6.0,
#     gamma=1.0,
#     alpha=1.0,
#     supervised_weight=0.1
# )

# # During training:
# recon_x, mu, logvar, z = vae_model(images)
# losses = loss_fn(recon_x, images, mu, logvar, z, attributes)

# total_loss = losses['total']
# total_loss.backward()


# Complete Training Loop

In [None]:
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from datetime import datetime

# Step 4: Complete Training Loop

class VAETrainer:
    def __init__(self, model, train_loader, val_loader=None, device='cuda', 
                 loss_type='beta_tcvae', lr=1e-4, save_dir='./vae_checkpoints'):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.save_dir = save_dir
        
        # Create save directory
        os.makedirs(save_dir, exist_ok=True)
        
        # Optimizer
        self.optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))
        
        # Loss function
        self.loss_fn = DisentangledVAELoss(
            loss_type=loss_type,
            beta=6.0 if loss_type == 'beta_tcvae' else 4.0,
            gamma=1.0,
            alpha=1.0,
            supervised_weight=0.1
        )
        
        # Tensorboard writer
        self.writer = SummaryWriter(f'runs/vae_{datetime.now().strftime("%Y%m%d_%H%M%S")}')
        
        # Training history
        self.train_losses = []
        self.val_losses = []
        
    def train_epoch(self, epoch):
        self.model.train()
        total_losses = {}
        num_batches = len(self.train_loader)
        
        with tqdm(self.train_loader, desc=f'Epoch {epoch}') as pbar:
            for batch_idx, (images, attributes, _) in enumerate(pbar):
                images = images.to(self.device)
                attributes = attributes.to(self.device)
                
                # Forward pass
                recon_x, mu, logvar, z = self.model(images)
                
                # Compute loss
                losses = self.loss_fn(recon_x, images, mu, logvar, z, attributes)
                
                # Backward pass
                self.optimizer.zero_grad()
                losses['total'].backward()
                
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                
                self.optimizer.step()
                
                # Update running losses
                for key, value in losses.items():
                    if key not in total_losses:
                        total_losses[key] = 0
                    total_losses[key] += value.item()
                
                # Update progress bar
                pbar.set_postfix({
                    'Total': f"{losses['total'].item():.4f}",
                    'Recon': f"{losses['reconstruction'].item():.4f}",
                    'KL': f"{losses.get('kl', losses.get('tc', 0)):.4f}"
                })
                
                # Log to tensorboard
                global_step = epoch * num_batches + batch_idx
                if batch_idx % 100 == 0:
                    for key, value in losses.items():
                        self.writer.add_scalar(f'Train/{key}', value.item(), global_step)
        
        # Average losses
        avg_losses = {key: value / num_batches for key, value in total_losses.items()}
        self.train_losses.append(avg_losses)
        
        return avg_losses
    
    def validate(self, epoch):
        if self.val_loader is None:
            return None
            
        self.model.eval()
        total_losses = {}
        num_batches = len(self.val_loader)
        
        with torch.no_grad():
            for images, attributes, _ in self.val_loader:
                images = images.to(self.device)
                attributes = attributes.to(self.device)
                
                # Forward pass
                recon_x, mu, logvar, z = self.model(images)
                
                # Compute loss
                losses = self.loss_fn(recon_x, images, mu, logvar, z, attributes)
                
                # Update running losses
                for key, value in losses.items():
                    if key not in total_losses:
                        total_losses[key] = 0
                    total_losses[key] += value.item()
        
        # Average losses
        avg_losses = {key: value / num_batches for key, value in total_losses.items()}
        self.val_losses.append(avg_losses)
        
        # Log to tensorboard
        for key, value in avg_losses.items():
            self.writer.add_scalar(f'Val/{key}', value, epoch)
        
        return avg_losses
    
    def save_samples(self, epoch, num_samples=8):
        """Save reconstructed and generated samples"""
        self.model.eval()
        
        with torch.no_grad():
            # Get a batch for reconstruction
            images, _, _ = next(iter(self.train_loader))
            images = images[:num_samples].to(self.device)
            
            # Reconstruct
            recon_images, _, _, _ = self.model(images)
            
            # Generate new samples
            generated_images = self.model.generate(num_samples, self.device)
            
            # Denormalize images for visualization
            def denormalize(x):
                return (x + 1) / 2  # From [-1, 1] to [0, 1]
            
            # Create comparison plot
            fig, axes = plt.subplots(3, num_samples, figsize=(num_samples * 2, 6))
            
            for i in range(num_samples):
                # Original
                orig_img = denormalize(images[i]).cpu().permute(1, 2, 0).numpy()
                axes[0, i].imshow(orig_img)
                axes[0, i].set_title('Original')
                axes[0, i].axis('off')
                
                # Reconstructed
                recon_img = denormalize(recon_images[i]).cpu().permute(1, 2, 0).numpy()
                axes[1, i].imshow(recon_img)
                axes[1, i].set_title('Reconstructed')
                axes[1, i].axis('off')
                
                # Generated
                gen_img = denormalize(generated_images[i]).cpu().permute(1, 2, 0).numpy()
                axes[2, i].imshow(gen_img)
                axes[2, i].set_title('Generated')
                axes[2, i].axis('off')
            
            plt.tight_layout()
            plt.savefig(f'{self.save_dir}/samples_epoch_{epoch}.png')
            plt.close()
    
    def save_checkpoint(self, epoch, best_loss=None):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'best_loss': best_loss
        }
        
        torch.save(checkpoint, f'{self.save_dir}/checkpoint_epoch_{epoch}.pth')
        torch.save(checkpoint, f'{self.save_dir}/latest_checkpoint.pth')
        
        if best_loss is not None:
            torch.save(checkpoint, f'{self.save_dir}/best_checkpoint.pth')
    
    def load_checkpoint(self, checkpoint_path):
        """Load model checkpoint"""
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.train_losses = checkpoint['train_losses']
        self.val_losses = checkpoint['val_losses']
        
        return checkpoint['epoch'], checkpoint.get('best_loss')
    
    def train(self, num_epochs, save_every=5, sample_every=5):
        """Main training loop"""
        print(f"Starting training for {num_epochs} epochs...")
        print(f"Device: {self.device}")
        print(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
        
        best_val_loss = float('inf')
        
        for epoch in range(1, num_epochs + 1):
            # Train
            train_losses = self.train_epoch(epoch)
            
            # Validate
            val_losses = self.validate(epoch)
            
            # Print losses
            print(f"\nEpoch {epoch}/{num_epochs}")
            print(f"Train - Total: {train_losses['total']:.4f}, "
                  f"Recon: {train_losses['reconstruction']:.4f}")
            
            if val_losses:
                print(f"Val   - Total: {val_losses['total']:.4f}, "
                      f"Recon: {val_losses['reconstruction']:.4f}")
                
                # Check for best model
                if val_losses['total'] < best_val_loss:
                    best_val_loss = val_losses['total']
                    self.save_checkpoint(epoch, best_val_loss)
                    print("New best model saved!")
            
            # Save samples
            if epoch % sample_every == 0:
                self.save_samples(epoch)
            
            # Save checkpoint
            if epoch % save_every == 0:
                self.save_checkpoint(epoch)
        
        print(f"Training completed! Best validation loss: {best_val_loss:.4f}")
        self.writer.close()

def train_celeba_vae():
    """
    Complete training pipeline for CelebA VAE
    """
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Data paths (adjust for your Kaggle setup)
    IMG_DIR = "/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba"
    ATTR_FILE = "/kaggle/input/celeba-dataset/list_attr_celeba.csv"
    
    # Selected attributes for disentanglement
    SELECTED_ATTRS = [
        'Male', 'Young', 'Eyeglasses', 'Bald', 'Mustache', 
        'Smiling', 'Attractive', 'Blond_Hair', 'Heavy_Makeup'
    ]
    
    # Hyperparameters
    BATCH_SIZE = 32
    IMG_SIZE = 64
    LATENT_DIM = 64
    LEARNING_RATE = 1e-4
    NUM_EPOCHS = 1
    
    # Create dataset and dataloader
    print("Loading dataset...")
    dataset, train_loader = setup_celeba_data(
        IMG_DIR, ATTR_FILE,
        batch_size=BATCH_SIZE,
        img_size=IMG_SIZE,
        selected_attrs=SELECTED_ATTRS
    )
    
    # Create validation split (optional)
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    # Create model
    print("Creating VAE model...")
    vae = VAE(latent_dim=LATENT_DIM, img_channels=3).to(device)
    
    # Create trainer
    trainer = VAETrainer(
        model=vae,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        loss_type='beta_tcvae',  # or 'beta_vae'
        lr=LEARNING_RATE,
        save_dir='/kaggle/working/vae_checkpoints'
    )
    
    # Start training
    trainer.train(
        num_epochs=NUM_EPOCHS,
        save_every=10,
        sample_every=5
    )
    
    return trainer, vae

# Usage in Kaggle:

# Run this to train the complete VAE
trainer, trained_vae = train_celeba_vae()


Using device: cpu
Loading dataset...
Loaded 202599 images with 9 attributes
Attributes: ['Male', 'Young', 'Eyeglasses', 'Bald', 'Mustache', 'Smiling', 'Attractive', 'Blond_Hair', 'Heavy_Makeup']
Creating VAE model...
Starting training for 1 epochs...
Device: cpu
Model parameters: 2,172,099


Epoch 1: 100%|██████████| 5699/5699 [20:36<00:00,  4.61it/s, Total=1762.1620, Recon=1437.5869, KL=59.3584]


# Attribute Manipulation and Evaluation

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import seaborn as sns
from scipy.stats import spearmanr

# Step 5: Attribute Manipulation and Evaluation

class AttributeManipulator:
    def __init__(self, vae_model, device='cuda'):
        self.vae = vae_model
        self.device = device
        self.vae.eval()
        
        # Store attribute classifiers
        self.attribute_classifiers = {}
    
    def encode_dataset(self, dataloader, max_samples=1000):
        """
        Encode a dataset to get latent representations and attributes
        """
        latent_codes = []
        attributes_list = []
        
        with torch.no_grad():
            sample_count = 0
            for images, attributes, _ in dataloader:
                if sample_count >= max_samples:
                    break
                    
                images = images.to(self.device)
                z, mu, _ = self.vae.encode(images)
                
                latent_codes.append(z.cpu().numpy())
                attributes_list.append(attributes.cpu().numpy())
                
                sample_count += len(images)
        
        latent_codes = np.vstack(latent_codes)
        attributes_array = np.vstack(attributes_list)
        
        return latent_codes[:max_samples], attributes_array[:max_samples]
    
    def train_attribute_classifiers(self, latent_codes, attributes, attribute_names):
        """
        Train linear classifiers to predict attributes from latent codes
        """
        print("Training attribute classifiers...")
        
        for i, attr_name in enumerate(attribute_names):
            # Train logistic regression classifier
            clf = LogisticRegression(random_state=42, max_iter=1000)
            clf.fit(latent_codes, attributes[:, i])
            
            # Evaluate accuracy
            predictions = clf.predict(latent_codes)
            accuracy = accuracy_score(attributes[:, i], predictions)
            
            self.attribute_classifiers[attr_name] = {
                'classifier': clf,
                'accuracy': accuracy
            }
            
            print(f"{attr_name}: Accuracy = {accuracy:.3f}")
    
    def find_attribute_directions(self, latent_codes, attributes, attribute_names):
        """
        Find directions in latent space corresponding to attributes
        """
        directions = {}
        
        for i, attr_name in enumerate(attribute_names):
            # Split samples by attribute value
            pos_samples = latent_codes[attributes[:, i] == 1]
            neg_samples = latent_codes[attributes[:, i] == 0]
            
            if len(pos_samples) > 0 and len(neg_samples) > 0:
                # Compute direction as difference in means
                direction = np.mean(pos_samples, axis=0) - np.mean(neg_samples, axis=0)
                direction = direction / np.linalg.norm(direction)  # Normalize
                
                directions[attr_name] = direction
                
                print(f"{attr_name}: Found direction vector")
            else:
                print(f"{attr_name}: Insufficient samples")
        
        return directions
    
    def manipulate_attributes(self, input_image, directions, attribute_names, 
                            manipulation_strengths=None):
        """
        Manipulate specific attributes in an input image
        
        Args:
            input_image: Input image tensor [1, C, H, W]
            directions: Dictionary of attribute direction vectors
            attribute_names: List of attributes to manipulate
            manipulation_strengths: Strength of manipulation for each attribute
        """
        if manipulation_strengths is None:
            manipulation_strengths = [3.0] * len(attribute_names)
        
        # Encode input image
        with torch.no_grad():
            input_image = input_image.to(self.device)
            z_original, _, _ = self.vae.encode(input_image)
            z_original = z_original.cpu().numpy()
        
        results = {'original': input_image}
        
        # Manipulate each attribute
        for attr_name, strength in zip(attribute_names, manipulation_strengths):
            if attr_name in directions:
                direction = directions[attr_name]
                
                # Add direction to latent code
                z_modified = z_original + strength * direction.reshape(1, -1)
                
                # Decode modified latent code
                with torch.no_grad():
                    z_tensor = torch.FloatTensor(z_modified).to(self.device)
                    modified_image = self.vae.decoder(z_tensor)
                
                results[f'{attr_name}_+{strength}'] = modified_image
                
                # Also try negative direction
                z_modified_neg = z_original - strength * direction.reshape(1, -1)
                
                with torch.no_grad():
                    z_tensor_neg = torch.FloatTensor(z_modified_neg).to(self.device)
                    modified_image_neg = self.vae.decoder(z_tensor_neg)
                
                results[f'{attr_name}_-{strength}'] = modified_image_neg
        
        return results
    
    def interpolate_attributes(self, image1, image2, num_steps=10):
        """
        Interpolate between two images in latent space
        """
        with torch.no_grad():
            image1 = image1.to(self.device)
            image2 = image2.to(self.device)
            
            z1, _, _ = self.vae.encode(image1)
            z2, _, _ = self.vae.encode(image2)
            
            interpolated_images = []
            
            for step in range(num_steps):
                alpha = step / (num_steps - 1)
                z_interp = (1 - alpha) * z1 + alpha * z2
                
                img_interp = self.vae.decoder(z_interp)
                interpolated_images.append(img_interp)
        
        return interpolated_images
    
    def visualize_manipulations(self, manipulation_results, save_path=None):
        """
        Visualize attribute manipulation results
        """
        num_results = len(manipulation_results)
        fig, axes = plt.subplots(1, num_results, figsize=(num_results * 3, 3))
        
        if num_results == 1:
            axes = [axes]
        
        def denormalize(x):
            return torch.clamp((x + 1) / 2, 0, 1)
        
        for i, (title, image) in enumerate(manipulation_results.items()):
            if isinstance(image, torch.Tensor):
                img_np = denormalize(image[0]).cpu().permute(1, 2, 0).numpy()
                axes[i].imshow(img_np)
            axes[i].set_title(title)
            axes[i].axis('off')
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
        
        plt.show()

def evaluate_disentanglement(vae_model, dataloader, attribute_names, device='cuda'):
    """
    Comprehensive evaluation of disentanglement quality
    """
    manipulator = AttributeManipulator(vae_model, device)
    
    # Encode dataset
    print("Encoding dataset...")
    latent_codes, attributes = manipulator.encode_dataset(dataloader, max_samples=2000)
    
    # Train attribute classifiers
    manipulator.train_attribute_classifiers(latent_codes, attributes, attribute_names)
    
    # Find attribute directions
    directions = manipulator.find_attribute_directions(latent_codes, attributes, attribute_names)
    
    # Compute disentanglement metrics
    metrics = compute_disentanglement_metrics(latent_codes, attributes, attribute_names)
    
    return manipulator, directions, metrics

def compute_disentanglement_metrics(latent_codes, attributes, attribute_names):
    """
    Compute disentanglement metrics (SAP, MIG, etc.)
    """
    metrics = {}
    
    # Compute mutual information between latent dimensions and attributes
    mutual_info_matrix = np.zeros((latent_codes.shape[1], len(attribute_names)))
    
    for i in range(latent_codes.shape[1]):  # For each latent dimension
        for j, attr_name in enumerate(attribute_names):  # For each attribute
            # Discretize latent dimension
            latent_discrete = np.digitize(latent_codes[:, i], 
                                        bins=np.percentile(latent_codes[:, i], 
                                                         [20, 40, 60, 80]))
            
            # Compute mutual information (simplified)
            mi = compute_mutual_information(latent_discrete, attributes[:, j].astype(int))
            mutual_info_matrix[i, j] = mi
    
    # SAP Score (Separated Attribute Predictability)
    sap_scores = []
    for j in range(len(attribute_names)):
        mi_j = mutual_info_matrix[:, j]
        if len(mi_j) > 1:
            sorted_mi = np.sort(mi_j)[::-1]  # Sort descending
            if sorted_mi[1] > 0:
                sap = (sorted_mi[0] - sorted_mi[1]) / sorted_mi[0]
            else:
                sap = sorted_mi[0] if sorted_mi[0] > 0 else 0
            sap_scores.append(sap)
    
    metrics['SAP'] = np.mean(sap_scores) if sap_scores else 0
    metrics['MI_Matrix'] = mutual_info_matrix
    
    # Modularity score
    modularity_scores = []
    for i in range(latent_codes.shape[1]):
        mi_i = mutual_info_matrix[i, :]
        if np.sum(mi_i) > 0:
            modularity = np.max(mi_i) / np.sum(mi_i)
            modularity_scores.append(modularity)
    
    metrics['Modularity'] = np.mean(modularity_scores) if modularity_scores else 0
    
    return metrics

def compute_mutual_information(x, y):
    """
    Compute mutual information between discrete variables x and y
    """
    # Create contingency table
    unique_x = np.unique(x)
    unique_y = np.unique(y)
    
    if len(unique_x) <= 1 or len(unique_y) <= 1:
        return 0.0
    
    contingency = np.zeros((len(unique_x), len(unique_y)))
    
    for i, val_x in enumerate(unique_x):
        for j, val_y in enumerate(unique_y):
            contingency[i, j] = np.sum((x == val_x) & (y == val_y))
    
    # Normalize to get probabilities
    p_xy = contingency / np.sum(contingency)
    p_x = np.sum(p_xy, axis=1)
    p_y = np.sum(p_xy, axis=0)
    
    # Compute mutual information
    mi = 0.0
    for i in range(len(unique_x)):
        for j in range(len(unique_y)):
            if p_xy[i, j] > 0 and p_x[i] > 0 and p_y[j] > 0:
                mi += p_xy[i, j] * np.log(p_xy[i, j] / (p_x[i] * p_y[j]))
    
    return mi

# Example usage:

# After training your VAE
manipulator, directions, metrics = evaluate_disentanglement(
    trained_vae, val_loader, SELECTED_ATTRS, device
)

print("Disentanglement Metrics:")
print(f"SAP Score: {metrics['SAP']:.3f}")
print(f"Modularity: {metrics['Modularity']:.3f}")

# Test attribute manipulation
sample_batch = next(iter(val_loader))
test_image = sample_batch[0][:1]  # First image

manipulation_results = manipulator.manipulate_attributes(
    test_image, 
    directions, 
    ['Male', 'Smiling', 'Eyeglasses'], 
    manipulation_strengths=[3.0, 2.5, 3.5]
)

manipulator.visualize_manipulations(manipulation_results)