# PixelCNN arch didn't match with the VQVAE

In [None]:
"""
VQ-VAE Emoji Generation - Phase 2: Prior Training, Generation & Analysis
"""

import os
import json
import csv
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from scipy import linalg
from skimage.metrics import structural_similarity as ssim
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# Phase 2 Configuration (Extends Phase 1)
# ============================================================================

EXPERIMENT_CONFIGS_PHASE2 = {
    # Prior (PixelCNN)
    "num_epochs_prior": 200,
    "learning_rate_prior": 1e-4,
    "pixelcnn_layers": 15,
    "pixelcnn_hidden": 128,
    "grad_clip": 1.0,

    # Generation
    "num_samples": 64,
    "num_interpolation_steps": 10,
    "temperature": 1.0,

    # Experiment metadata for Phase 2
    "phase2_experiment_name": "prior_baseline",
    "phase2_notes": "PixelCNN autoregressive prior training"
}

# ============================================================================
# Configuration Management
# ============================================================================

class Config:
    """Configuration class that loads from JSON dict"""
    def __init__(self, config_dict):
        # Define expected types for automatic conversion
        self.type_mappings = {
            'image_size': int,
            'batch_size': int,
            'num_workers': int,
            'num_hiddens': int,
            'num_residual_hiddens': int,
            'num_residual_layers': int,
            'embedding_dim': int,
            'num_embeddings': int,
            'commitment_cost': float,
            'decay': float,
            'num_epochs_vqvae': int,
            'learning_rate_vqvae': float,
            'min_codebook_usage': float,
            'check_usage_every': int,
            'num_epochs_prior': int,
            'learning_rate_prior': float,
            'pixelcnn_layers': int,
            'pixelcnn_hidden': int,
            'grad_clip': float,
            'num_samples': int,
            'num_interpolation_steps': int,
            'temperature': float,
        }

        for key, value in config_dict.items():
            # Skip None keys and empty strings
            if key is None or key == '':
                continue
                
            # Convert to proper type if needed
            if key in self.type_mappings and isinstance(value, str):
                try:
                    value = self.type_mappings[key](value)
                except (ValueError, TypeError):
                    pass  # Keep as string if conversion fails
            setattr(self, key, value)

        # Add device (not in JSON as it's system-dependent)
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    def to_dict(self):
        """Convert config back to dictionary for CSV export"""
        return {k: v for k, v in self.__dict__.items() if not k.startswith('_') and k != 'type_mappings'}

# ============================================================================
# Dataset (Same as Phase 1)
# ============================================================================

class EmojiDataset(Dataset):
    def __init__(self, data_dir, image_size=64, transform=None):
        self.data_dir = data_dir
        self.image_size = image_size
        self.image_files = [f for f in os.listdir(data_dir)
                           if f.endswith(('.png', '.jpg', '.jpeg'))]

        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((image_size, image_size)),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
            ])
        else:
            self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.data_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')
        return self.transform(image)

# ============================================================================
# VQ-VAE Components (Must match Phase 1 exactly)
# ============================================================================

class VectorQuantizerEMA(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay=0.99, epsilon=1e-5):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost
        self.decay = decay
        self.epsilon = epsilon

        embed = torch.randn(num_embeddings, embedding_dim)
        self.register_buffer("embed", embed)
        self.register_buffer("cluster_size", torch.zeros(num_embeddings))
        self.register_buffer("embed_avg", embed.clone())

    def forward(self, inputs):
        input_shape = inputs.shape
        flat_input = inputs.permute(0, 2, 3, 1).contiguous().view(-1, self.embedding_dim)

        distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
                    + torch.sum(self.embed**2, dim=1)
                    - 2 * torch.matmul(flat_input, self.embed.t()))

        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)

        quantized = torch.matmul(encodings, self.embed)

        if self.training:
            self.cluster_size.data.mul_(self.decay).add_(encodings.sum(0), alpha=1 - self.decay)
            dw = torch.matmul(encodings.t(), flat_input)
            self.embed_avg.data.mul_(self.decay).add_(dw, alpha=1 - self.decay)

            n = self.cluster_size.sum()
            cluster_size = ((self.cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n)
            embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
            self.embed.data.copy_(embed_normalized)

        e_latent_loss = F.mse_loss(quantized.detach(), flat_input)
        loss = self.commitment_cost * e_latent_loss
        quantized = flat_input + (quantized - flat_input).detach()

        quantized = quantized.view(input_shape[0], input_shape[2], input_shape[3], self.embedding_dim)
        quantized = quantized.permute(0, 3, 1, 2).contiguous()

        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        return quantized, loss, perplexity, encoding_indices

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(in_channels, num_residual_hiddens, 3, padding=1, bias=False),
            nn.BatchNorm2d(num_residual_hiddens),
            nn.ReLU(),
            nn.Conv2d(num_residual_hiddens, num_hiddens, 1, bias=False),
            nn.BatchNorm2d(num_hiddens)
        )

    def forward(self, x):
        return x + self.block(x)

class ResidualStack(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super().__init__()
        self.layers = nn.ModuleList([
            ResidualBlock(in_channels, num_hiddens, num_residual_hiddens)
            for _ in range(num_residual_layers)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return F.relu(x)

class Encoder(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, num_hiddens // 2, 4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(num_hiddens // 2, num_hiddens, 4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(num_hiddens, num_hiddens, 3, padding=1)
        self.residual_stack = ResidualStack(num_hiddens, num_hiddens, num_residual_layers, num_residual_hiddens)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return self.residual_stack(x)

class Decoder(nn.Module):
    def __init__(self, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
        super().__init__()
        self.conv1 = nn.Conv2d(embedding_dim, num_hiddens, 3, padding=1)
        self.residual_stack = ResidualStack(num_hiddens, num_hiddens, num_residual_layers, num_residual_hiddens)
        self.conv_trans1 = nn.ConvTranspose2d(num_hiddens, num_hiddens // 2, 4, stride=2, padding=1)
        self.conv_trans2 = nn.ConvTranspose2d(num_hiddens // 2, 3, 4, stride=2, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.residual_stack(x)
        x = F.relu(self.conv_trans1(x))
        return torch.tanh(self.conv_trans2(x))

class VQVAE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = Encoder(3, config.num_hiddens, config.num_residual_layers, config.num_residual_hiddens)
        self.pre_vq_conv = nn.Conv2d(config.num_hiddens, config.embedding_dim, 1)
        self.vq = VectorQuantizerEMA(config.num_embeddings, config.embedding_dim, config.commitment_cost, decay=config.decay)
        self.decoder = Decoder(config.embedding_dim, config.num_hiddens, config.num_residual_layers, config.num_residual_hiddens)

    def forward(self, x):
        z = self.encoder(x)
        z = self.pre_vq_conv(z)
        quantized, vq_loss, perplexity, encoding_indices = self.vq(z)
        x_recon = self.decoder(quantized)
        return x_recon, vq_loss, perplexity, encoding_indices

    def encode(self, x):
        z = self.encoder(x)
        z = self.pre_vq_conv(z)
        _, _, _, encoding_indices = self.vq(z)
        B = x.shape[0]
        return encoding_indices.view(B, -1)


# ============================================================================
# FIXED PixelCNN Implementation
# ============================================================================

class MaskedConv2d(nn.Conv2d):
    """Fixed masked convolution that doesn't break gradients"""
    def __init__(self, mask_type, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.register_buffer('mask', torch.zeros_like(self.weight))
        self.create_mask(mask_type)

    def create_mask(self, mask_type):
        k = self.kernel_size[0]
        self.mask[:, :, :k//2, :] = 1
        self.mask[:, :, k//2, :k//2] = 1
        if mask_type == 'B':
            self.mask[:, :, k//2, k//2] = 1

    def forward(self, x):
        # ‚úì FIXED: Don't modify weights in-place, use functional API
        masked_weight = self.weight * self.mask
        return F.conv2d(x, masked_weight, self.bias, self.stride, 
                       self.padding, self.dilation, self.groups)


class GatedResidualBlock(nn.Module):
    """Gated activation improves PixelCNN performance"""
    def __init__(self, hidden_dim):
        super().__init__()
        self.conv1 = MaskedConv2d('B', hidden_dim, hidden_dim * 2, 3, padding=1)
        self.conv2 = MaskedConv2d('B', hidden_dim, hidden_dim, 1)
        
    def forward(self, x):
        h = self.conv1(F.relu(x))
        # Split into two halves for gated activation
        h1, h2 = h.chunk(2, dim=1)
        h = torch.tanh(h1) * torch.sigmoid(h2)
        h = self.conv2(h)
        return x + h


class ImprovedPixelCNN(nn.Module):
    """Enhanced PixelCNN with better architecture"""
    def __init__(self, num_embeddings, spatial_h, spatial_w, 
                 num_layers=15, hidden_dim=128):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.spatial_h = spatial_h
        self.spatial_w = spatial_w
        
        # Input projection with larger kernel
        self.input_conv = nn.Sequential(
            MaskedConv2d('A', num_embeddings, hidden_dim, 7, padding=3),
            nn.ReLU()
        )
        
        # Gated residual blocks
        self.residual_blocks = nn.ModuleList([
            GatedResidualBlock(hidden_dim) for _ in range(num_layers)
        ])
        
        # Output projection with better initialization
        self.output = nn.Sequential(
            nn.ReLU(),
            MaskedConv2d('B', hidden_dim, hidden_dim, 1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, num_embeddings, 1)
        )
        
        self._init_weights()
    
    def _init_weights(self):
        """Better initialization strategy"""
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, MaskedConv2d)):
                # ‚úì FIXED: Use proper initialization gain
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        # x shape: (B, H, W) - discrete codes
        x_onehot = F.one_hot(x, self.num_embeddings).float()
        x_onehot = x_onehot.permute(0, 3, 1, 2).contiguous()
        
        h = self.input_conv(x_onehot)
        
        for block in self.residual_blocks:
            h = block(h)
        
        logits = self.output(h)
        return logits
    
    @torch.no_grad()
    def sample(self, batch_size, device, temperature=1.0):
        """Improved sampling with temperature control"""
        samples = torch.zeros(batch_size, self.spatial_h, self.spatial_w,
                            dtype=torch.long, device=device)
        
        self.eval()
        
        for i in range(self.spatial_h):
            for j in range(self.spatial_w):
                logits = self(samples)
                logits = logits[:, :, i, j] / temperature
                
                # Add top-k sampling for better quality
                if temperature < 1.0:
                    # Greedy sampling for low temperature
                    samples[:, i, j] = logits.argmax(dim=1)
                else:
                    # Probabilistic sampling
                    probs = F.softmax(logits, dim=1)
                    samples[:, i, j] = torch.multinomial(probs, 1).squeeze(-1)
        
        return samples


# ============================================================================
# IMPROVED Prior Trainer
# ============================================================================

class ImprovedPriorTrainer:
    def __init__(self, prior, vqvae, config):
        self.prior = prior
        self.vqvae = vqvae
        self.config = config
        
        # ‚úì FIXED: Add learning rate scheduler
        self.optimizer = torch.optim.AdamW(
            prior.parameters(), 
            lr=config.learning_rate_prior,
            weight_decay=0.01  # Add weight decay
        )
        
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, 
            T_max=config.num_epochs_prior,
            eta_min=config.learning_rate_prior * 0.1
        )
        
        self.history = {'loss': [], 'accuracy': [], 'perplexity': [], 'epoch': []}
        self.start_epoch = 0
        self.best_loss = float('inf')
    
    def save_checkpoint(self, epoch, filepath):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.prior.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'history': self.history,
            'best_loss': self.best_loss
        }
        torch.save(checkpoint, filepath)
        print(f"‚úì Prior checkpoint saved: {filepath}")
    
    def load_checkpoint(self, filepath):
        if os.path.exists(filepath):
            try:
                checkpoint = torch.load(filepath, map_location=self.config.device)
                self.prior.load_state_dict(checkpoint['model_state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                
                if 'scheduler_state_dict' in checkpoint:
                    self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
                
                self.history = checkpoint['history']
                self.start_epoch = checkpoint['epoch'] + 1
                self.best_loss = checkpoint.get('best_loss', float('inf'))
                
                print(f"‚úì Prior checkpoint loaded, resuming from epoch {self.start_epoch}")
                return True
            except Exception as e:
                print(f"‚úó Error loading prior checkpoint: {e}")
                return False
        return False
    
    def train_epoch(self, dataloader, spatial_h, spatial_w):
        self.prior.train()
        self.vqvae.eval()
        
        epoch_loss = 0
        epoch_accuracy = 0
        epoch_perplexity = 0
        num_batches = 0
        
        pbar = tqdm(dataloader, desc="Training Prior")
        
        for batch in pbar:
            batch = batch.to(self.config.device)
            
            # Get codes from VQ-VAE
            with torch.no_grad():
                codes = self.vqvae.encode(batch)
                codes = codes.view(-1, spatial_h, spatial_w)
            
            # Forward pass
            logits = self.prior(codes)
            loss = F.cross_entropy(logits, codes)
            
            # ‚úì FIXED: Add perplexity and accuracy tracking
            with torch.no_grad():
                probs = F.softmax(logits, dim=1)
                avg_probs = probs.mean(dim=[0, 2, 3])
                perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
                
                pred = logits.argmax(dim=1)
                accuracy = (pred == codes).float().mean()
            
            # Check for NaN
            if torch.isnan(loss):
                print("Warning: NaN loss detected, skipping batch")
                continue
            
            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()
            
            # ‚úì FIXED: Adaptive gradient clipping
            torch.nn.utils.clip_grad_norm_(self.prior.parameters(), self.config.grad_clip)
            
            self.optimizer.step()
            
            # Accumulate metrics
            epoch_loss += loss.item()
            epoch_accuracy += accuracy.item()
            epoch_perplexity += perplexity.item()
            num_batches += 1
            
            pbar.set_postfix({
                'loss': loss.item(),
                'acc': f'{accuracy.item():.3f}',
                'ppl': f'{perplexity.item():.1f}'
            })
        
        # Update learning rate
        self.scheduler.step()
        
        return {
            'loss': epoch_loss / num_batches,
            'accuracy': epoch_accuracy / num_batches,
            'perplexity': epoch_perplexity / num_batches
        }
    
    def train(self, train_loader, spatial_h, spatial_w):
        print(f"\n{'='*80}")
        print(f"TRAINING IMPROVED PIXELCNN PRIOR")
        print(f"{'='*80}\n")
        
        for epoch in range(self.start_epoch, self.config.num_epochs_prior):
            print(f"\nEpoch {epoch+1}/{self.config.num_epochs_prior}")
            print(f"Learning Rate: {self.scheduler.get_last_lr()[0]:.6f}")
            
            metrics = self.train_epoch(train_loader, spatial_h, spatial_w)
            
            # Update history
            self.history['loss'].append(metrics['loss'])
            self.history['accuracy'].append(metrics['accuracy'])
            self.history['perplexity'].append(metrics['perplexity'])
            self.history['epoch'].append(epoch + 1)
            
            print(f"Loss: {metrics['loss']:.4f}, "
                  f"Acc: {metrics['accuracy']:.3f}, "
                  f"Perplexity: {metrics['perplexity']:.1f}")
            
            # Save best model
            if metrics['loss'] < self.best_loss:
                self.best_loss = metrics['loss']
                self.save_checkpoint(
                    epoch, 
                    os.path.join(self.config.checkpoint_dir, 'prior_best.pt')
                )
                print(f"‚úì New best model saved!")
            
            # Regular checkpoints
            if (epoch + 1) % 10 == 0:
                self.save_checkpoint(
                    epoch, 
                    os.path.join(self.config.checkpoint_dir, f'prior_epoch_{epoch+1}.pt')
                )
        
        # Save final checkpoint
        self.save_checkpoint(
            self.config.num_epochs_prior - 1,
            os.path.join(self.config.checkpoint_dir, 'prior_final.pt')
        )
        
        return self.history




# ============================================================================
# Decoding Helper
# ============================================================================

def decode_codes(vqvae, encoding_indices, spatial_h, spatial_w):
    """Decode from discrete codes with correct spatial dimensions"""
    # encoding_indices shape: (B, H, W)
    batch_size = encoding_indices.shape[0]

    # Flatten for embedding lookup
    flat_codes = encoding_indices.view(-1)  # (B*H*W)
    quantized = F.embedding(flat_codes, vqvae.vq.embed)  # (B*H*W, embedding_dim)

    # Reshape back to spatial format
    quantized = quantized.view(batch_size, spatial_h, spatial_w, -1)  # (B, H, W, embedding_dim)
    quantized = quantized.permute(0, 3, 1, 2).contiguous()  # (B, embedding_dim, H, W)

    return vqvae.decoder(quantized)




"""
Diagnostic Tools for VQ-VAE Code Analysis
Run this to understand why PixelCNN is struggling
"""

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import entropy

def analyze_code_spatial_structure(vqvae, dataloader, device, spatial_h, spatial_w):
    """
    Comprehensive analysis of VQ-VAE codes to diagnose PixelCNN issues
    """
    print("\n" + "="*80)
    print("DIAGNOSING VQ-VAE CODE STRUCTURE")
    print("="*80 + "\n")
    
    vqvae.eval()
    all_codes = []
    
    # Collect codes from dataset
    print("Collecting codes from dataset...")
    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            if i >= 20:  # Sample 20 batches
                break
            batch = batch.to(device)
            codes = vqvae.encode(batch)
            codes = codes.view(-1, spatial_h, spatial_w)
            all_codes.append(codes.cpu())
    
    all_codes = torch.cat(all_codes, dim=0)  # (N, H, W)
    N, H, W = all_codes.shape
    
    print(f"‚úì Collected {N} code maps of size {H}√ó{W}\n")
    
    # =======================================================================
    # 1. CODE DISTRIBUTION ANALYSIS
    # =======================================================================
    print("1. CODE DISTRIBUTION ANALYSIS")
    print("-" * 80)
    
    flat_codes = all_codes.view(-1).numpy()
    unique_codes, counts = np.unique(flat_codes, return_counts=True)
    
    print(f"Unique codes used: {len(unique_codes)}")
    print(f"Total possible codes: {vqvae.vq.num_embeddings}")
    print(f"Usage: {len(unique_codes)/vqvae.vq.num_embeddings*100:.1f}%")
    
    # Check if distribution is uniform or skewed
    code_probs = counts / counts.sum()
    code_entropy = entropy(code_probs)
    max_entropy = np.log(len(unique_codes))
    
    print(f"\nDistribution entropy: {code_entropy:.2f}")
    print(f"Max possible entropy: {max_entropy:.2f}")
    print(f"Entropy ratio: {code_entropy/max_entropy:.2%}")
    
    if code_entropy/max_entropy < 0.7:
        print("‚ö† WARNING: Code distribution is skewed (some codes dominate)")
    else:
        print("‚úì Code distribution is relatively uniform")
    
    # =======================================================================
    # 2. SPATIAL AUTOCORRELATION ANALYSIS
    # =======================================================================
    print("\n2. SPATIAL AUTOCORRELATION ANALYSIS")
    print("-" * 80)
    
    # Check horizontal and vertical neighbors
    horizontal_same = (all_codes[:, :, :-1] == all_codes[:, :, 1:]).float().mean().item()
    vertical_same = (all_codes[:, :-1, :] == all_codes[:, 1:, :]).float().mean().item()
    diagonal_same = (all_codes[:, :-1, :-1] == all_codes[:, 1:, 1:]).float().mean().item()
    
    print(f"Horizontal neighbor similarity: {horizontal_same:.2%}")
    print(f"Vertical neighbor similarity: {vertical_same:.2%}")
    print(f"Diagonal neighbor similarity: {diagonal_same:.2%}")
    print(f"Average spatial correlation: {(horizontal_same + vertical_same)/2:.2%}")
    
    if (horizontal_same + vertical_same)/2 < 0.3:
        print("\n‚ö† CRITICAL: Very low spatial correlation!")
        print("  This explains why PixelCNN struggles - codes are too 'random'")
        print("  Recommendation: Reduce codebook size or increase commitment cost")
    elif (horizontal_same + vertical_same)/2 < 0.5:
        print("\n‚ö† WARNING: Moderate spatial correlation")
        print("  PixelCNN will struggle. Consider architectural changes.")
    else:
        print("\n‚úì Good spatial correlation - suitable for PixelCNN")
    
    # =======================================================================
    # 3. PREDICTABILITY ANALYSIS
    # =======================================================================
    print("\n3. PREDICTABILITY ANALYSIS (Oracle Test)")
    print("-" * 80)
    
    # Simple oracle: what if we predict each pixel from its left neighbor?
    oracle_correct = 0
    oracle_total = 0
    
    for i in range(N):
        for h in range(H):
            for w in range(1, W):  # Skip first column
                if all_codes[i, h, w-1] == all_codes[i, h, w]:
                    oracle_correct += 1
                oracle_total += 1
    
    oracle_accuracy = oracle_correct / oracle_total
    print(f"Left-neighbor oracle accuracy: {oracle_accuracy:.2%}")
    print(f"Random chance: {1/vqvae.vq.num_embeddings:.2%}")
    print(f"PixelCNN current accuracy: ~42%")
    
    if oracle_accuracy < 0.4:
        print("\n‚ö† CRITICAL: Even simple oracle fails!")
        print("  Codes lack autoregressive structure. PixelCNN won't work well.")
    elif oracle_accuracy < 0.6:
        print("\n‚ö† WARNING: Oracle accuracy is low")
        print("  PixelCNN will need many layers to capture patterns")
    else:
        print("\n‚úì Codes have predictable structure")
    
    # =======================================================================
    # 4. VISUALIZE CODE PATTERNS
    # =======================================================================
    print("\n4. GENERATING VISUALIZATIONS...")
    print("-" * 80)
    
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    
    # Show 8 random code maps
    for idx in range(8):
        ax = axes[idx // 4, idx % 4]
        sample_codes = all_codes[np.random.randint(N)]
        im = ax.imshow(sample_codes.numpy(), cmap='tab20', interpolation='nearest')
        ax.set_title(f'Sample {idx+1}')
        ax.axis('off')
    
    plt.colorbar(im, ax=axes.ravel().tolist(), label='Code Index')
    plt.tight_layout()
    plt.savefig('./results/code_visualization.png', dpi=150, bbox_inches='tight')
    plt.close()
    
    print("‚úì Saved: ./results/code_visualization.png")
    
    # =======================================================================
    # 5. CODE TRANSITION MATRIX
    # =======================================================================
    print("\n5. CODE TRANSITION ANALYSIS")
    print("-" * 80)
    
    # Build transition matrix for horizontal neighbors
    transition_counts = np.zeros((vqvae.vq.num_embeddings, vqvae.vq.num_embeddings))
    
    for i in range(N):
        for h in range(H):
            for w in range(W-1):
                curr = all_codes[i, h, w].item()
                next_code = all_codes[i, h, w+1].item()
                transition_counts[curr, next_code] += 1
    
    # Normalize to probabilities
    row_sums = transition_counts.sum(axis=1, keepdims=True)
    transition_probs = np.divide(transition_counts, row_sums, 
                                 where=row_sums>0, out=np.zeros_like(transition_counts))
    
    # Measure transition entropy (how predictable are transitions?)
    transition_entropy = []
    for i in range(vqvae.vq.num_embeddings):
        if transition_probs[i].sum() > 0:
            trans_ent = entropy(transition_probs[i] + 1e-10)
            transition_entropy.append(trans_ent)
    
    avg_transition_entropy = np.mean(transition_entropy) if transition_entropy else 0
    max_transition_entropy = np.log(vqvae.vq.num_embeddings)
    
    print(f"Average transition entropy: {avg_transition_entropy:.2f}")
    print(f"Max transition entropy: {max_transition_entropy:.2f}")
    print(f"Predictability: {1 - avg_transition_entropy/max_transition_entropy:.2%}")
    
    if avg_transition_entropy/max_transition_entropy > 0.8:
        print("\n‚ö† CRITICAL: Transitions are nearly random!")
        print("  Each code can be followed by almost any other code equally")
    
    # =======================================================================
    # 6. RECOMMENDATIONS
    # =======================================================================
    print("\n" + "="*80)
    print("RECOMMENDATIONS")
    print("="*80 + "\n")
    
    # Compute overall "learnability score"
    spatial_corr = (horizontal_same + vertical_same) / 2
    predictability = 1 - avg_transition_entropy/max_transition_entropy
    
    score = (spatial_corr * 0.5 + oracle_accuracy * 0.3 + predictability * 0.2) * 100
    
    print(f"Code Learnability Score: {score:.1f}/100")
    
    if score < 30:
        print("\nüî¥ POOR - PixelCNN will struggle significantly")
        print("\nSuggested fixes (in order of impact):")
        print("1. REDUCE codebook size: 256 ‚Üí 128 or 64")
        print("2. INCREASE commitment cost: Try 0.25, 0.5, or even 1.0")
        print("3. DECREASE spatial compression: Change encoder strides")
        print("4. Consider Transformer prior instead of PixelCNN")
        
    elif score < 50:
        print("\nüü° MODERATE - PixelCNN needs help")
        print("\nSuggested improvements:")
        print("1. Use multi-scale PixelCNN (condition on lower resolution)")
        print("2. Increase model capacity: 256-512 hidden dims, 20+ layers")
        print("3. Train much longer: 200-300 epochs")
        print("4. Try hierarchical VQ-VAE (2-level)")
        
    else:
        print("\nüü¢ GOOD - PixelCNN should work")
        print("\nFine-tuning suggestions:")
        print("1. Increase training epochs to 150-200")
        print("2. Use larger model: 256 hidden, 20 layers")
        print("3. Try different learning rate schedules")
    
    return {
        'spatial_correlation': spatial_corr,
        'oracle_accuracy': oracle_accuracy,
        'transition_predictability': predictability,
        'learnability_score': score,
        'unique_codes': len(unique_codes),
        'code_entropy': code_entropy / max_entropy
    }





# # ============================================================================
# # PixelCNN Prior
# # ============================================================================

# class MaskedConv2d(nn.Conv2d):
#     def __init__(self, mask_type, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.register_buffer('mask', torch.zeros_like(self.weight))
#         self.create_mask(mask_type)

#     def create_mask(self, mask_type):
#         k = self.kernel_size[0]
#         self.mask[:, :, :k//2, :] = 1
#         self.mask[:, :, k//2, :k//2] = 1
#         if mask_type == 'B':
#             self.mask[:, :, k//2, k//2] = 1

#     def forward(self, x):
#         self.weight.data *= self.mask
#         return super().forward(x)

# class PixelCNNResidualBlock(nn.Module):
#     def __init__(self, h):
#         super().__init__()
#         self.conv = nn.Sequential(
#             nn.ReLU(),
#             MaskedConv2d('B', h, h, 1),
#             nn.BatchNorm2d(h),
#             nn.ReLU(),
#             MaskedConv2d('B', h, h, 1),
#             nn.BatchNorm2d(h)
#         )

#     def forward(self, x):
#         return x + self.conv(x)

# class PixelCNN(nn.Module):
#     def __init__(self, num_embeddings, spatial_h, spatial_w, num_layers=12, hidden_dim=64):
#         super().__init__()
#         self.num_embeddings = num_embeddings
#         self.spatial_h = spatial_h
#         self.spatial_w = spatial_w

#         # Input projection
#         self.input_conv = MaskedConv2d('A', num_embeddings, hidden_dim, 7, padding=3)

#         # Residual blocks
#         self.residual_blocks = nn.ModuleList([
#             PixelCNNResidualBlock(hidden_dim) for _ in range(num_layers)
#         ])

#         # Output projection
#         self.output = nn.Sequential(
#             nn.ReLU(),
#             MaskedConv2d('B', hidden_dim, hidden_dim, 1),
#             nn.BatchNorm2d(hidden_dim),
#             nn.ReLU(),
#             nn.Conv2d(hidden_dim, num_embeddings, 1)
#         )

#         self._init_weights()

#     def _init_weights(self):
#         for m in self.modules():
#             if isinstance(m, (nn.Conv2d, MaskedConv2d)):
#                 nn.init.xavier_uniform_(m.weight, gain=0.1)
#                 if m.bias is not None:
#                     nn.init.constant_(m.bias, 0)
#             elif isinstance(m, nn.BatchNorm2d):
#                 nn.init.constant_(m.weight, 1)
#                 nn.init.constant_(m.bias, 0)

#     def forward(self, x):
#         # x shape: (B, H, W)
#         x_onehot = F.one_hot(x, self.num_embeddings).float()  # (B, H, W, num_embeddings)
#         x_onehot = x_onehot.permute(0, 3, 1, 2).contiguous()  # (B, num_embeddings, H, W)

#         x = self.input_conv(x_onehot)
#         for block in self.residual_blocks:
#             x = block(x)
#         logits = self.output(x)  # (B, num_embeddings, H, W)

#         return logits

#     @torch.no_grad()
#     def sample(self, batch_size, device, temperature=1.0):
#         samples = torch.zeros(batch_size, self.spatial_h, self.spatial_w,
#                             dtype=torch.long, device=device)

#         # Sample pixel by pixel
#         for i in range(self.spatial_h):
#             for j in range(self.spatial_w):
#                 logits = self(samples)  # (B, num_embeddings, H, W)
#                 probs = F.softmax(logits[:, :, i, j] / temperature, dim=1)
#                 samples[:, i, j] = torch.multinomial(probs, 1).squeeze(-1)

#         return samples

# # ============================================================================
# # Prior Trainer
# # ============================================================================

# class PriorTrainer:
#     def __init__(self, prior, vqvae, config):
#         self.prior = prior
#         self.vqvae = vqvae
#         self.config = config
#         self.optimizer = torch.optim.Adam(prior.parameters(), lr=config.learning_rate_prior)
#         self.history = {'loss': [], 'epoch': []}
#         self.start_epoch = 0

#     def save_checkpoint(self, epoch, filepath):
#         checkpoint = {
#             'epoch': epoch,
#             'model_state_dict': self.prior.state_dict(),
#             'optimizer_state_dict': self.optimizer.state_dict(),
#             'history': self.history
#         }
#         torch.save(checkpoint, filepath)
#         print(f"‚úì Prior checkpoint saved: {filepath}")

#     def load_checkpoint(self, filepath):
#         if os.path.exists(filepath):
#             try:
#                 checkpoint = torch.load(filepath, map_location=self.config.device, weights_only=False)
#                 self.prior.load_state_dict(checkpoint['model_state_dict'])
#                 self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#                 self.history = checkpoint['history']
#                 self.start_epoch = checkpoint['epoch'] + 1
#                 print(f"‚úì Prior checkpoint loaded, resuming from epoch {self.start_epoch}")
#                 return True
#             except Exception as e:
#                 print(f"‚úó Error loading prior checkpoint: {e}")
#                 return False
#         return False

#     def train_epoch(self, dataloader, spatial_h, spatial_w):
#         self.prior.train()
#         self.vqvae.eval()
#         epoch_loss = 0
#         num_valid = 0

#         pbar = tqdm(dataloader, desc="Training Prior")
#         for batch in pbar:
#             batch = batch.to(self.config.device)

#             # Get codes from VQ-VAE with correct reshaping
#             with torch.no_grad():
#                 codes = self.vqvae.encode(batch)  # (B, H*W)
#                 codes = codes.view(-1, spatial_h, spatial_w)  # (B, H, W)

#             # Train PixelCNN
#             logits = self.prior(codes)  # (B, num_embeddings, H, W)
#             loss = F.cross_entropy(logits, codes)

#             if torch.isnan(loss):
#                 print("Warning: NaN loss detected, skipping batch")
#                 continue

#             self.optimizer.zero_grad()
#             loss.backward()
#             torch.nn.utils.clip_grad_norm_(self.prior.parameters(), self.config.grad_clip)
#             self.optimizer.step()

#             epoch_loss += loss.item()
#             num_valid += 1
#             pbar.set_postfix({'loss': loss.item()})

#         return epoch_loss / num_valid if num_valid > 0 else float('inf')

#     def train(self, train_loader, spatial_h, spatial_w):  # FIXED: Added spatial_h, spatial_w parameters
#         print(f"\n{'='*80}")
#         print(f"TRAINING PIXELCNN PRIOR")
#         print(f"{'='*80}\n")

#         for epoch in range(self.start_epoch, self.config.num_epochs_prior):
#             print(f"\nEpoch {epoch+1}/{self.config.num_epochs_prior}")

#             loss = self.train_epoch(train_loader, spatial_h, spatial_w)

#             if loss == float('inf'):
#                 print("Training stopped due to invalid loss")
#                 break

#             self.history['loss'].append(loss)
#             self.history['epoch'].append(epoch + 1)

#             print(f"Loss: {loss:.4f}")

#             # Save checkpoints
#             if (epoch + 1) % 10 == 0:
#                 self.save_checkpoint(epoch, os.path.join(self.config.checkpoint_dir, f'prior_epoch_{epoch+1}.pt'))

#         # Save final checkpoint
#         self.save_checkpoint(self.config.num_epochs_prior - 1, os.path.join(self.config.checkpoint_dir, 'prior_final.pt'))

#         return self.history




# ============================================================================
# FID Calculation
# ============================================================================

def calculate_fid(real_features, fake_features):
    mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)

    ssdiff = np.sum((mu1 - mu2)**2)
    covmean = linalg.sqrtm(sigma1.dot(sigma2))

    if np.iscomplexobj(covmean):
        covmean = covmean.real

    return ssdiff + np.trace(sigma1 + sigma2 - 2*covmean)

def get_inception_features(images, model, device):
    from torchvision.models import inception_v3
    if model is None:
        model = inception_v3(pretrained=True, transform_input=False)
        model.fc = nn.Identity()
        model = model.to(device)
        model.eval()

    with torch.no_grad():
        # Normalize from [-1, 1] to [0, 1] for Inception
        images_norm = (images + 1) / 2
        images_resized = F.interpolate(images_norm, size=(299, 299), mode='bilinear', align_corners=False)
        features = model(images_resized)

    return features.cpu().numpy(), model

# ============================================================================
# Visualization Functions
# ============================================================================

def plot_generated_samples(samples, save_path, title="Generated Samples"):
    samples = (samples + 1) / 2
    samples = torch.clamp(samples, 0, 1)
    grid = make_grid(samples, nrow=8, padding=2)

    plt.figure(figsize=(12, 12))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.title(title, fontsize=16)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"‚úì Saved: {save_path}")

def visualize_latent_interpolation(vqvae, dataloader, device, spatial_h, spatial_w, save_path, num_steps=10):  # FIXED: Added spatial parameters
    vqvae.eval()
    real_batch = next(iter(dataloader))[:2].to(device)

    with torch.no_grad():
        codes = vqvae.encode(real_batch)
        codes = codes.view(2, spatial_h, spatial_w)  # FIXED: Use actual spatial dimensions

        interpolations = []
        for alpha in np.linspace(0, 1, num_steps):
            interp_code = (1 - alpha) * codes[0:1] + alpha * codes[1:2]
            interp_code = interp_code.long()
            interp_img = decode_codes(vqvae, interp_code, spatial_h, spatial_w)  # FIXED
            interpolations.append(interp_img)

        interpolation_grid = torch.cat(interpolations)

    plot_generated_samples(interpolation_grid, save_path, "Latent Space Interpolation")

def visualize_tsne(vqvae, dataloader, device, save_path):
    print("Creating t-SNE visualization...")
    vqvae.eval()
    all_codes = []

    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            if i >= 10:  # Limit samples for speed
                break
            codes = vqvae.encode(batch.to(device))
            all_codes.append(codes.cpu())

    all_codes = torch.cat(all_codes, dim=0).numpy()

    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    codes_2d = tsne.fit_transform(all_codes[:1000])

    plt.figure(figsize=(10, 8))
    plt.scatter(codes_2d[:, 0], codes_2d[:, 1], alpha=0.5, s=10)
    plt.title('t-SNE Visualization of Latent Codes', fontsize=14)
    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.grid(True, alpha=0.3)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"‚úì Saved: {save_path}")

def visualize_clustering(vqvae, dataloader, device, save_path, n_clusters=10):
    print("Performing clustering analysis...")
    vqvae.eval()
    all_codes = []

    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            if i >= 10:
                break
            codes = vqvae.encode(batch.to(device))
            all_codes.append(codes.cpu())

    all_codes = torch.cat(all_codes, dim=0).numpy()

    # K-means clustering
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    cluster_labels = kmeans.fit_predict(all_codes[:1000])

    # t-SNE for visualization
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    codes_2d = tsne.fit_transform(all_codes[:1000])

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(codes_2d[:, 0], codes_2d[:, 1], c=cluster_labels,
                         alpha=0.6, s=10, cmap='tab10')
    plt.colorbar(scatter, label='Cluster')
    plt.title(f'K-means Clustering (k={n_clusters}) of Latent Codes', fontsize=14)
    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.grid(True, alpha=0.3)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"‚úì Saved: {save_path}")

def plot_prior_training_history(history, save_path):
    plt.figure(figsize=(10, 6))
    plt.plot(history['epoch'], history['loss'], linewidth=2)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Cross-Entropy Loss', fontsize=12)
    plt.title('PixelCNN Prior Training Loss', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"‚úì Saved: {save_path}")

# ============================================================================
# CSV Export
# ============================================================================

def save_results_to_csv(config, results, save_path):
    """Save experiment config and results to CSV"""
    row_data = {
        'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
        **config.to_dict(),
        **results
    }

    file_exists = os.path.exists(save_path)

    with open(save_path, 'a', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=row_data.keys())
        if not file_exists:
            writer.writeheader()
        writer.writerow(row_data)

    print(f"‚úì Results appended to: {save_path}")

def verify_pipeline(vqvae, prior, dataloader, device, spatial_h, spatial_w):
    """
    Verify that the VQ-VAE and Prior are compatible
    Diagnostic function to verify everything is working
    """
    print("\nüîç VERIFYING PIPELINE COMPATIBILITY...")

    vqvae.eval()
    prior.eval()

    with torch.no_grad():
        # Test with real data
        real_batch = next(iter(dataloader))[:4].to(device)

        # Encode with VQ-VAE
        real_codes = vqvae.encode(real_batch)
        real_codes = real_codes.view(-1, spatial_h, spatial_w)
        print(f"Real codes shape: {real_codes.shape}")

        # Test prior forward pass
        logits = prior(real_codes)
        print(f"Prior logits shape: {logits.shape}")

        # Test sampling
        sample_codes = prior.sample(batch_size=4, device=device)
        print(f"Sampled codes shape: {sample_codes.shape}")

        # Test decoding
        reconstructed = decode_codes(vqvae, sample_codes, spatial_h, spatial_w)
        print(f"Decoded images shape: {reconstructed.shape}")

        # Check value ranges
        print(f"Decoded range: [{reconstructed.min():.3f}, {reconstructed.max():.3f}]")

    print("‚úì Pipeline verification complete!")

# ============================================================================
# Main Pipeline
# ============================================================================

def main():
    print("="*80)
    print("PHASE 2: PRIOR TRAINING & COMPLETE ANALYSIS")
    print("="*80)

    # ========================================================================
    # Load Phase 1 Results and Config
    # ========================================================================

    # Check if Phase 1 completed
    phase1_summary_path = './results/phase1_summary.json'
    if not os.path.exists(phase1_summary_path):
        print("\n‚úó ERROR: Phase 1 not completed. Run phase 1 first!")
        print("  Expected file: ./results/phase1_summary.json")
        return

    with open(phase1_summary_path, 'r') as f:
        phase1_summary = json.load(f)

    print(f"\n{'='*80}")
    print("PHASE 1 SUMMARY")
    print(f"{'='*80}")
    print(f"Codebook usage: {phase1_summary['codebook_usage_percent']:.2f}%")
    print(f"Active codes: {phase1_summary['active_codes']}/{phase1_summary['total_codes']}")
    print(f"Target achieved: {phase1_summary['target_achieved']}")

    # Load Phase 1 config from CSV
    csv_path = './results/experiment_results.csv'
    phase1_config_dict = {}

    if os.path.exists(csv_path):
        with open(csv_path, 'r') as f:
            reader = csv.DictReader(f)
            rows = list(reader)
            if rows:
                phase1_config_dict = rows[-1]  # Get last experiment
                # FIX: Clean up the dictionary by removing None/empty keys
                phase1_config_dict = {k: v for k, v in phase1_config_dict.items() 
                                    if k is not None and k != '' and v is not None and v != ''}
                print(f"\n‚úì Loaded Phase 1 config: {phase1_config_dict.get('experiment_name', 'unknown')}")
    else:
        print("\n‚ö† Warning: No CSV found, using default Phase 1 config")

    # Warn if codebook usage is low
    if not phase1_summary['target_achieved']:
        print(f"\n‚ö† WARNING: Phase 1 codebook usage was {phase1_summary['codebook_usage_percent']:.2f}%")
        min_usage = float(phase1_config_dict.get('min_codebook_usage', 50.0))
        print(f"  Target is {min_usage}%. Generation quality may be affected.")
        response = input("\nContinue anyway? (y/n): ")
        if response.lower() != 'y':
            return

    # Merge configs (Phase 1 + Phase 2)
    merged_config = {**phase1_config_dict, **EXPERIMENT_CONFIGS_PHASE2}
    config = Config(merged_config)

    print(f"\n‚úì Using device: {config.device}")

    # ========================================================================
    # Load Dataset
    # ========================================================================

    dataset = EmojiDataset(config.data_dir, config.image_size)
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                             shuffle=True, num_workers=config.num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size,
                           shuffle=False, num_workers=config.num_workers, pin_memory=True)

    print(f"\n‚úì Dataset: {len(train_dataset)} train, {len(val_dataset)} val")

    # ========================================================================
    # Load Pre-trained VQ-VAE and Calculate ACTUAL Spatial Size
    # ========================================================================

    vqvae = VQVAE(config).to(config.device)
    vqvae_checkpoint_path = os.path.join(config.checkpoint_dir, 'vqvae_final.pt')

    if not os.path.exists(vqvae_checkpoint_path):
        print(f"\n‚úó ERROR: VQ-VAE checkpoint not found!")
        print(f"  Expected: {vqvae_checkpoint_path}")
        return

    vqvae_checkpoint = torch.load(vqvae_checkpoint_path, map_location=config.device, weights_only=False)
    vqvae.load_state_dict(vqvae_checkpoint['model_state_dict'])
    vqvae.eval()

    print(f"‚úì VQ-VAE loaded successfully from epoch {vqvae_checkpoint['epoch'] + 1}")


    # Calculate ACTUAL spatial size
    def calculate_spatial_size(vqvae, config, device):
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, config.image_size, config.image_size).to(device)
            encoded = vqvae.encoder(dummy_input)
            spatial_h, spatial_w = encoded.shape[2], encoded.shape[3]
            print(f"‚úì Actual latent spatial size: {spatial_h}x{spatial_w}")
            return spatial_h, spatial_w

    spatial_h, spatial_w = calculate_spatial_size(vqvae, config, config.device)





    # Diagnose why PixelCNN is struggling
    diagnostics = analyze_code_spatial_structure(
        vqvae, train_loader, config.device, spatial_h, spatial_w
    )

    # Save diagnostics
    with open(os.path.join(config.results_dir, 'code_diagnostics.json'), 'w') as f:
        json.dump(diagnostics, f, indent=2)

    # Decide whether to continue
    if diagnostics['learnability_score'] < 30:
        response = input("\n‚ö† Codes have poor learnability. Continue anyway? (y/n): ")
        if response.lower() != 'y':
            return

    # ========================================================================
    # Train PixelCNN Prior with CORRECT Spatial Dimensions
    # ========================================================================

    print(f"\n{'='*80}")
    print("PIXELCNN PRIOR TRAINING")
    print(f"{'='*80}")

    # prior = PixelCNN(
    #     num_embeddings=config.num_embeddings,
    #     spatial_h=spatial_h,
    #     spatial_w=spatial_w,
    #     num_layers=config.pixelcnn_layers,
    #     hidden_dim=config.pixelcnn_hidden
    # ).to(config.device)

    # prior_trainer = PriorTrainer(prior, vqvae, config)


    prior = ImprovedPixelCNN(
        num_embeddings=config.num_embeddings,
        spatial_h=spatial_h,
        spatial_w=spatial_w,
        num_layers=config.pixelcnn_layers,  
        hidden_dim=config.pixelcnn_hidden  
    ).to(config.device)

    prior_trainer = ImprovedPriorTrainer(prior, vqvae, config)

    # Verify pipeline compatibility with correct spatial dimensions
    verify_pipeline(vqvae, prior, train_loader, config.device, spatial_h, spatial_w)

    # Update the trainer call to pass both spatial dimensions
    if not prior_trainer.load_checkpoint(os.path.join(config.checkpoint_dir, 'prior_final.pt')):
        print("\n‚úì Training prior from scratch...")
        prior_history = prior_trainer.train(train_loader, spatial_h, spatial_w)  # FIXED: Pass both dimensions

        # Plot training history
        plot_prior_training_history(prior_history, os.path.join(config.results_dir, 'prior_training.png'))
    else:
        print("\n‚úì Using pre-trained prior")
        prior_history = prior_trainer.history

    # ========================================================================
    # Generate Novel Emojis with CORRECT Decoding
    # ========================================================================

    print(f"\n{'='*80}")
    print("GENERATING NOVEL EMOJIS")
    print(f"{'='*80}")

    prior.eval()
    vqvae.eval()

    print(f"\nGenerating {config.num_samples} samples...")
    with torch.no_grad():
        generated_codes = prior.sample(
            batch_size=config.num_samples,
            device=config.device,
            temperature=config.temperature
        )
        generated_images = decode_codes(vqvae, generated_codes, spatial_h, spatial_w)

    plot_generated_samples(
        generated_images,
        os.path.join(config.results_dir, 'generated_samples.png'),
        f"Generated Emojis (Temperature={config.temperature})"
    )

    # ========================================================================
    # Latent Space Interpolation
    # ========================================================================

    print(f"\n{'='*80}")
    print("LATENT SPACE ANALYSIS")
    print(f"{'='*80}\n")

    print("Performing interpolation...")
    visualize_latent_interpolation(
        vqvae, val_loader, config.device, spatial_h, spatial_w,  # FIXED: Added spatial dimensions
        os.path.join(config.results_dir, 'interpolation.png'),
        num_steps=config.num_interpolation_steps
    )

    # ========================================================================
    # t-SNE Visualization
    # ========================================================================

    visualize_tsne(
        vqvae, val_loader, config.device,
        os.path.join(config.results_dir, 'tsne_codes.png')
    )

    # ========================================================================
    # Clustering Analysis
    # ========================================================================

    visualize_clustering(
        vqvae, val_loader, config.device,
        os.path.join(config.results_dir, 'clustering.png'),
        n_clusters=10
    )

    # ========================================================================
    # FID Score Calculation - FIXED SECTION
    # ========================================================================

    print(f"\n{'='*80}")
    print("CALCULATING FID SCORE")
    print(f"{'='*80}\n")

    # Collect real images
    print("Collecting real images...")
    real_images = []
    for i, batch in enumerate(val_loader):
        if i >= 4:  # Use ~256 images
            break
        real_images.append(batch)
    real_images = torch.cat(real_images, dim=0)[:256].to(config.device)

    print(f"‚úì Real images: {real_images.shape[0]}")

    # Generate matching number of fake images
    print("Generating images for FID calculation...")
    gen_images = []
    num_batches = (len(real_images) + 63) // 64

    with torch.no_grad():
        for i in range(num_batches):
            batch_size = min(64, len(real_images) - i * 64)
            if batch_size <= 0:
                break
            # FIXED: Use correct sampling call
            codes = prior.sample(batch_size, config.device, temperature=config.temperature)
            gen_batch = decode_codes(vqvae, codes, spatial_h, spatial_w)  # FIXED: Use spatial_h, spatial_w
            gen_images.append(gen_batch)

    gen_images = torch.cat(gen_images, dim=0)[:len(real_images)].to(config.device)
    print(f"‚úì Generated images: {gen_images.shape[0]}")

    # Extract Inception features
    print("\nExtracting Inception features...")
    inception_model = None
    real_features, inception_model = get_inception_features(real_images, inception_model, config.device)
    gen_features, _ = get_inception_features(gen_images, inception_model, config.device)

    print(f"‚úì Real features: {real_features.shape}")
    print(f"‚úì Generated features: {gen_features.shape}")

    # Calculate FID
    fid_score = calculate_fid(real_features, gen_features)

    print(f"\n{'='*80}")
    print(f"FID SCORE: {fid_score:.2f}")
    print(f"{'='*80}")

    # ========================================================================
    # Save Comprehensive Results - FIXED SECTION
    # ========================================================================

    print(f"\n{'='*80}")
    print("SAVING RESULTS")
    print(f"{'='*80}\n")

    # Prepare comprehensive results for CSV
    results = {
        'phase': 'phase2',
        'fid_score': float(fid_score),
        'final_prior_loss': prior_history['loss'][-1] if prior_history['loss'] else None,
        'num_generated_samples': config.num_samples,
        'spatial_h': spatial_h,  # FIXED: Save both dimensions
        'spatial_w': spatial_w,
        'temperature': config.temperature,
        'phase1_codebook_usage': phase1_summary['codebook_usage_percent'],
        'phase1_active_codes': phase1_summary['active_codes'],
        'phase1_target_achieved': phase1_summary['target_achieved'],
        'generation_completed': True,
        'num_prior_epochs': config.num_epochs_prior,
        'pixelcnn_layers': config.pixelcnn_layers,
        'pixelcnn_hidden': config.pixelcnn_hidden
    }

    # Save to CSV (appends to same file as Phase 1)
    csv_path = os.path.join(config.results_dir, 'experiment_results.csv')
    save_results_to_csv(config, results, csv_path)

    # Save JSON report for backward compatibility
    final_report = {
        'fid_score': float(fid_score),
        'num_embeddings': config.num_embeddings,
        'embedding_dim': config.embedding_dim,
        'num_training_samples': len(train_dataset),
        'phase1_codebook_usage': phase1_summary['codebook_usage_percent'],
        'phase1_active_codes': phase1_summary['active_codes'],
        'spatial_h': spatial_h,  # FIXED
        'spatial_w': spatial_w,
        'temperature': config.temperature,
        'prior_layers': config.pixelcnn_layers,
        'prior_hidden': config.pixelcnn_hidden,
        'final_prior_loss': prior_history['loss'][-1] if prior_history['loss'] else None
    }

    with open(os.path.join(config.results_dir, 'final_report.json'), 'w') as f:
        json.dump(final_report, f, indent=2)

    print("‚úì Final report saved: final_report.json")


    # Create comprehensive summary document
    summary_text = f"""
{'='*80}
VQ-VAE EMOJI GENERATION - COMPLETE PIPELINE RESULTS
{'='*80}

EXPERIMENT: {config.phase2_experiment_name}
TIMESTAMP: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

{'='*80}
PHASE 1: VQ-VAE TRAINING
{'='*80}
Codebook Size: {config.num_embeddings}
Embedding Dim: {config.embedding_dim}
Commitment Cost: {config.commitment_cost}
Decay: {config.decay}

Results:
- Codebook Usage: {phase1_summary['codebook_usage_percent']:.2f}%
- Active Codes: {phase1_summary['active_codes']}/{phase1_summary['total_codes']}
- Target Achieved: {'‚úì' if phase1_summary['target_achieved'] else '‚úó'}

{'='*80}
PHASE 2: PRIOR TRAINING & GENERATION
{'='*80}
Prior Architecture: PixelCNN
- Layers: {config.pixelcnn_layers}
- Hidden Dim: {config.pixelcnn_hidden}
- Learning Rate: {config.learning_rate_prior}
- Epochs: {config.num_epochs_prior}

Generation:
- Spatial Size: {spatial_h}x{spatial_w}  # FIXED
- Temperature: {config.temperature}
- Samples Generated: {config.num_samples}

Results:
- FID Score: {fid_score:.2f}

{'='*80}
OUTPUTS GENERATED
{'='*80}
‚úì generated_samples.png - Novel emoji generations
‚úì interpolation.png - Latent space interpolations
‚úì tsne_codes.png - t-SNE visualization
‚úì clustering.png - K-means clustering analysis
‚úì prior_training.png - Prior training curves
‚úì final_report.json - Detailed results
‚úì experiment_results.csv - Complete experiment log

{'='*80}
NOTES
{'='*80}
Phase 1: {phase1_config_dict.get('notes', 'N/A')}
Phase 2: {config.phase2_notes}

{'='*80}
END OF REPORT
{'='*80}
"""

    with open(os.path.join(config.results_dir, 'comprehensive_summary.txt'), 'w') as f:
        f.write(summary_text)

    print("‚úì Comprehensive summary saved: comprehensive_summary.txt")

    # ========================================================================
    # Final Output
    # ========================================================================

    print(summary_text)

    print(f"\n{'='*80}")
    print("PHASE 2 COMPLETE!")
    print(f"{'='*80}")
    print(f"‚úì All results saved to: {config.results_dir}")
    print(f"‚úì FID Score: {fid_score:.2f}")
    print(f"‚úì Codebook Usage: {phase1_summary['codebook_usage_percent']:.2f}%")
    print(f"‚úì Latent Spatial Size: {spatial_h}x{spatial_w}")
    print(f"\n{'='*80}")
    print("EXPERIMENT TRACKING")
    print(f"{'='*80}")
    print(f"CSV Log: {csv_path}")
    print(f"Phase 1 Experiment: {phase1_config_dict.get('experiment_name', 'N/A')}")
    print(f"Phase 2 Experiment: {config.phase2_experiment_name}")
    print(f"\nCopy the CSV row to your Excel tracking sheet for comparison!")
    print(f"{'='*80}\n")

if __name__ == "__main__":
    main()