In [None]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from scipy import linalg
from skimage.metrics import structural_similarity as ssim
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# Configuration
# ============================================================================

class Config:
    # Paths
    data_dir = './emoji_data'
    checkpoint_dir = './checkpoints'
    results_dir = './results'

    # Data
    image_size = 64
    batch_size = 64
    num_workers = 2

    # VQ-VAE Architecture
    num_hiddens = 128
    num_residual_hiddens = 32
    num_residual_layers = 2
    embedding_dim = 64
    num_embeddings = 512
    commitment_cost = 0.25
    decay = 0.999

    # Training
    num_epochs_vqvae = 100
    learning_rate_vqvae = 3e-4
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Prior (PixelCNN)
    num_epochs_prior = 50
    learning_rate_prior = 1e-4
    pixelcnn_layers = 12
    pixelcnn_hidden = 64
    grad_clip = 1.0

    # Generation
    num_samples = 64
    num_interpolation_steps = 10

config = Config()

# Create directories
os.makedirs(config.checkpoint_dir, exist_ok=True)
os.makedirs(config.results_dir, exist_ok=True)
os.makedirs(config.data_dir, exist_ok=True)

print(f"Using device: {config.device}")

# ============================================================================
# Dataset
# ============================================================================

class EmojiDataset(Dataset):
    def __init__(self, data_dir, image_size=64, transform=None):
        self.data_dir = data_dir
        self.image_size = image_size
        self.image_files = [f for f in os.listdir(data_dir)
                           if f.endswith(('.png', '.jpg', '.jpeg'))]

        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((image_size, image_size)),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
            ])
        else:
            self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.data_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')
        return self.transform(image)

def download_emoji_dataset():
    print("Downloading emoji dataset...")
    try:
        from datasets import load_dataset
        dataset = load_dataset("valhalla/emoji-dataset", split="train")

        print(f"Downloaded {len(dataset)} emojis")
        print("Saving images to disk...")

        for idx, item in enumerate(tqdm(dataset)):
            img = item['image']
            if img.mode != 'RGB':
                img = img.convert('RGB')
            img.save(os.path.join(config.data_dir, f'emoji_{idx:05d}.png'))

        print(f"Saved {len(dataset)} emoji images to {config.data_dir}")
    except Exception as e:
        print(f"Error downloading dataset: {e}")
        print("Please manually download emojis to the data directory")

# ============================================================================
# VQ-VAE Components
# ============================================================================

class VectorQuantizerEMA(nn.Module):
    """Vector Quantization with EMA updates"""

    def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay=0.99, epsilon=1e-5):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost
        self.decay = decay
        self.epsilon = epsilon

        embed = torch.randn(num_embeddings, embedding_dim)
        self.register_buffer("embed", embed)
        self.register_buffer("cluster_size", torch.zeros(num_embeddings))
        self.register_buffer("embed_avg", embed.clone())

    def forward(self, inputs):
        input_shape = inputs.shape
        flat_input = inputs.permute(0, 2, 3, 1).contiguous().view(-1, self.embedding_dim)

        distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
                    + torch.sum(self.embed**2, dim=1)
                    - 2 * torch.matmul(flat_input, self.embed.t()))

        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)

        quantized = torch.matmul(encodings, self.embed)

        if self.training:
            self.cluster_size.data.mul_(self.decay).add_(encodings.sum(0), alpha=1 - self.decay)
            dw = torch.matmul(encodings.t(), flat_input)
            self.embed_avg.data.mul_(self.decay).add_(dw, alpha=1 - self.decay)

            n = self.cluster_size.sum()
            cluster_size = ((self.cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n)
            embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
            self.embed.data.copy_(embed_normalized)

        e_latent_loss = F.mse_loss(quantized.detach(), flat_input)
        loss = self.commitment_cost * e_latent_loss
        quantized = flat_input + (quantized - flat_input).detach()

        quantized = quantized.view(input_shape[0], input_shape[2], input_shape[3], self.embedding_dim)
        quantized = quantized.permute(0, 3, 1, 2).contiguous()

        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        return quantized, loss, perplexity, encoding_indices

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(in_channels, num_residual_hiddens, 3, padding=1, bias=False),
            nn.BatchNorm2d(num_residual_hiddens),
            nn.ReLU(),
            nn.Conv2d(num_residual_hiddens, num_hiddens, 1, bias=False),
            nn.BatchNorm2d(num_hiddens)
        )

    def forward(self, x):
        return x + self.block(x)

class ResidualStack(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super().__init__()
        self.layers = nn.ModuleList([
            ResidualBlock(in_channels, num_hiddens, num_residual_hiddens)
            for _ in range(num_residual_layers)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return F.relu(x)

class Encoder(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, num_hiddens // 2, 4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(num_hiddens // 2, num_hiddens, 4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(num_hiddens, num_hiddens, 3, padding=1)
        self.residual_stack = ResidualStack(num_hiddens, num_hiddens, num_residual_layers, num_residual_hiddens)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return self.residual_stack(x)

class Decoder(nn.Module):
    def __init__(self, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
        super().__init__()
        self.conv1 = nn.Conv2d(embedding_dim, num_hiddens, 3, padding=1)
        self.residual_stack = ResidualStack(num_hiddens, num_hiddens, num_residual_layers, num_residual_hiddens)
        self.conv_trans1 = nn.ConvTranspose2d(num_hiddens, num_hiddens // 2, 4, stride=2, padding=1)
        self.conv_trans2 = nn.ConvTranspose2d(num_hiddens // 2, 3, 4, stride=2, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.residual_stack(x)
        x = F.relu(self.conv_trans1(x))
        return torch.tanh(self.conv_trans2(x))

class VQVAE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = Encoder(3, config.num_hiddens, config.num_residual_layers, config.num_residual_hiddens)
        self.pre_vq_conv = nn.Conv2d(config.num_hiddens, config.embedding_dim, 1)
        self.vq = VectorQuantizerEMA(config.num_embeddings, config.embedding_dim, config.commitment_cost, decay=config.decay)
        self.decoder = Decoder(config.embedding_dim, config.num_hiddens, config.num_residual_layers, config.num_residual_hiddens)

    def forward(self, x):
        z = self.encoder(x)
        z = self.pre_vq_conv(z)
        quantized, vq_loss, perplexity, encoding_indices = self.vq(z)
        x_recon = self.decoder(quantized)
        return x_recon, vq_loss, perplexity, encoding_indices

    def encode(self, x):
        z = self.encoder(x)
        z = self.pre_vq_conv(z)
        _, _, _, encoding_indices = self.vq(z)
        B = x.shape[0]
        return encoding_indices.view(B, -1)

    def decode_codes(self, encoding_indices, spatial_size):
        """FIXED: Proper decoding from codes"""
        codes = encoding_indices.view(-1, spatial_size, spatial_size)
        quantized = F.embedding(codes, self.vq.embed)
        quantized = quantized.permute(0, 3, 1, 2).contiguous()
        return self.decoder(quantized)

# ============================================================================
# PixelCNN Prior
# ============================================================================

class MaskedConv2d(nn.Conv2d):
    def __init__(self, mask_type, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.register_buffer('mask', torch.zeros_like(self.weight))
        self.create_mask(mask_type)

    def create_mask(self, mask_type):
        k = self.kernel_size[0]
        self.mask[:, :, :k//2, :] = 1
        self.mask[:, :, k//2, :k//2] = 1
        if mask_type == 'B':
            self.mask[:, :, k//2, k//2] = 1

    def forward(self, x):
        self.weight.data *= self.mask
        return super().forward(x)

class PixelCNNResidualBlock(nn.Module):
    def __init__(self, h):
        super().__init__()
        self.conv = nn.Sequential(
            nn.ReLU(),
            MaskedConv2d('B', h, h, 1),
            nn.BatchNorm2d(h),
            nn.ReLU(),
            MaskedConv2d('B', h, h, 1),
            nn.BatchNorm2d(h)
        )

    def forward(self, x):
        return x + self.conv(x)

class PixelCNN(nn.Module):
    def __init__(self, num_embeddings, num_layers=12, hidden_dim=64):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.input_conv = MaskedConv2d('A', num_embeddings, hidden_dim, 7, padding=3)
        self.residual_blocks = nn.ModuleList([PixelCNNResidualBlock(hidden_dim) for _ in range(num_layers)])
        self.output = nn.Sequential(
            nn.ReLU(),
            MaskedConv2d('B', hidden_dim, hidden_dim, 1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, num_embeddings, 1)
        )
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, MaskedConv2d)):
                nn.init.xavier_uniform_(m.weight, gain=0.1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x_onehot = F.one_hot(x, self.num_embeddings).float()
        x_onehot = x_onehot.permute(0, 3, 1, 2).contiguous()
        x = self.input_conv(x_onehot)
        for block in self.residual_blocks:
            x = block(x)
        return self.output(x)

    @torch.no_grad()
    def sample(self, batch_size, spatial_size, device, temperature=1.0):
        samples = torch.zeros(batch_size, spatial_size, spatial_size, dtype=torch.long, device=device)
        for i in range(spatial_size):
            for j in range(spatial_size):
                logits = self(samples)
                probs = F.softmax(logits[:, :, i, j] / temperature, dim=1)
                samples[:, i, j] = torch.multinomial(probs, 1).squeeze(-1)
        return samples

# ============================================================================
# Metrics
# ============================================================================

def calculate_mse(original, reconstructed):
    return F.mse_loss(reconstructed, original).item()

def calculate_psnr(original, reconstructed, max_val=2.0):
    mse = F.mse_loss(reconstructed, original)
    return (10 * torch.log10(max_val**2 / mse)).item()

def calculate_ssim(original, reconstructed):
    orig_np = (original.detach().cpu().numpy() + 1) / 2
    recon_np = (reconstructed.detach().cpu().numpy() + 1) / 2
    ssim_scores = [ssim(orig_np[i].transpose(1, 2, 0), recon_np[i].transpose(1, 2, 0),
                       channel_axis=2, data_range=1.0) for i in range(orig_np.shape[0])]
    return np.mean(ssim_scores)

def calculate_fid(real_features, fake_features):
    mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)
    ssdiff = np.sum((mu1 - mu2)**2)
    covmean = linalg.sqrtm(sigma1.dot(sigma2))
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    return ssdiff + np.trace(sigma1 + sigma2 - 2*covmean)

def get_inception_features(images, model, device):
    from torchvision.models import inception_v3
    if model is None:
        model = inception_v3(pretrained=True, transform_input=False)
        model.fc = nn.Identity()
        model = model.to(device)
        model.eval()
    with torch.no_grad():
        images_resized = F.interpolate(images, size=(299, 299), mode='bilinear')
        features = model(images_resized)
    return features.cpu().numpy(), model

# ============================================================================
# Trainers - FIXED CHECKPOINT LOADING
# ============================================================================

class Trainer:
    def __init__(self, model, config):
        self.model = model
        self.config = config
        self.optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate_vqvae)
        self.history = {'recon_loss': [], 'vq_loss': [], 'total_loss': [], 'perplexity': [], 'mse': [], 'psnr': [], 'ssim': []}
        self.start_epoch = 0

    def save_checkpoint(self, epoch, filepath):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'history': self.history
        }
        torch.save(checkpoint, filepath)
        print(f"Checkpoint saved: {filepath}")

    def load_checkpoint(self, filepath):
        """FIXED: Proper checkpoint loading"""
        if os.path.exists(filepath):
            try:
                checkpoint = torch.load(filepath, map_location=self.config.device)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                self.history = checkpoint['history']
                self.start_epoch = checkpoint['epoch'] + 1
                print(f"Checkpoint loaded: {filepath}, resuming from epoch {self.start_epoch}")
                return True
            except Exception as e:
                print(f"Error loading checkpoint: {e}")
                return False
        return False

    def train_epoch(self, dataloader):
        self.model.train()
        epoch_recon_loss = epoch_vq_loss = epoch_perplexity = 0
        pbar = tqdm(dataloader, desc="Training")
        for batch in pbar:
            batch = batch.to(self.config.device)
            recon, vq_loss, perplexity, _ = self.model(batch)
            recon_loss = F.mse_loss(recon, batch)
            loss = recon_loss + vq_loss
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            epoch_recon_loss += recon_loss.item()
            epoch_vq_loss += vq_loss.item()
            epoch_perplexity += perplexity.item()
            pbar.set_postfix({'recon_loss': recon_loss.item(), 'vq_loss': vq_loss.item(), 'perplexity': perplexity.item()})
        n = len(dataloader)
        return {'recon_loss': epoch_recon_loss/n, 'vq_loss': epoch_vq_loss/n, 'perplexity': epoch_perplexity/n}

    @torch.no_grad()
    def evaluate(self, dataloader):
        self.model.eval()
        all_original, all_recon = [], []
        for batch in dataloader:
            batch = batch.to(self.config.device)
            recon, _, _, _ = self.model(batch)
            all_original.append(batch)
            all_recon.append(recon)
        original = torch.cat(all_original, dim=0)
        reconstructed = torch.cat(all_recon, dim=0)
        return {'mse': calculate_mse(original, reconstructed),
                'psnr': calculate_psnr(original, reconstructed),
                'ssim': calculate_ssim(original[:64], reconstructed[:64])}

    def train(self, train_loader, val_loader):
        print(f"Starting training from epoch {self.start_epoch}")
        for epoch in range(self.start_epoch, self.config.num_epochs_vqvae):
            print(f"\nEpoch {epoch+1}/{self.config.num_epochs_vqvae}")
            train_metrics = self.train_epoch(train_loader)
            val_metrics = self.evaluate(val_loader)
            metrics = {**train_metrics, **val_metrics}
            for key, value in metrics.items():
                self.history[key].append(value)
            print(f"Recon: {metrics['recon_loss']:.4f}, VQ: {metrics['vq_loss']:.4f}, Perplexity: {metrics['perplexity']:.2f}")
            print(f"MSE: {metrics['mse']:.4f}, PSNR: {metrics['psnr']:.2f}, SSIM: {metrics['ssim']:.4f}")
            if (epoch + 1) % 5 == 0:
                self.save_checkpoint(epoch, os.path.join(self.config.checkpoint_dir, f'vqvae_epoch_{epoch+1}.pt'))
        self.save_checkpoint(self.config.num_epochs_vqvae - 1, os.path.join(self.config.checkpoint_dir, 'vqvae_final.pt'))
        return self.history

class PriorTrainer:
    def __init__(self, prior, vqvae, config):
        self.prior = prior
        self.vqvae = vqvae
        self.config = config
        self.optimizer = torch.optim.Adam(prior.parameters(), lr=config.learning_rate_prior)
        self.history = {'loss': []}
        self.start_epoch = 0

    def save_checkpoint(self, epoch, filepath):
        torch.save({'epoch': epoch, 'model_state_dict': self.prior.state_dict(),
                   'optimizer_state_dict': self.optimizer.state_dict(), 'history': self.history}, filepath)

    def load_checkpoint(self, filepath):
        if os.path.exists(filepath):
            try:
                checkpoint = torch.load(filepath, map_location=self.config.device)
                self.prior.load_state_dict(checkpoint['model_state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                self.history = checkpoint['history']
                self.start_epoch = checkpoint['epoch'] + 1
                return True
            except Exception as e:
                print(f"Error loading prior checkpoint: {e}")
                return False
        return False

    def train_epoch(self, dataloader):
        self.prior.train()
        self.vqvae.eval()
        epoch_loss = num_valid = 0
        pbar = tqdm(dataloader, desc="Training Prior")
        for batch in pbar:
            batch = batch.to(self.config.device)
            with torch.no_grad():
                codes = self.vqvae.encode(batch)
                spatial_size = int(np.sqrt(codes.shape[1]))
                codes = codes.view(-1, spatial_size, spatial_size)
            logits = self.prior(codes)
            loss = F.cross_entropy(logits, codes)
            if torch.isnan(loss):
                continue
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.prior.parameters(), self.config.grad_clip)
            self.optimizer.step()
            epoch_loss += loss.item()
            num_valid += 1
            pbar.set_postfix({'loss': loss.item()})
        return epoch_loss / num_valid if num_valid > 0 else float('inf')

    def train(self, train_loader):
        for epoch in range(self.start_epoch, self.config.num_epochs_prior):
            loss = self.train_epoch(train_loader)
            if loss == float('inf'):
                break
            self.history['loss'].append(loss)
            print(f"Epoch {epoch+1}, Loss: {loss:.4f}")
            if (epoch + 1) % 5 == 0:
                self.save_checkpoint(epoch, os.path.join(self.config.checkpoint_dir, f'prior_epoch_{epoch+1}.pt'))
        self.save_checkpoint(self.config.num_epochs_prior - 1, os.path.join(self.config.checkpoint_dir, 'prior_final.pt'))
        return self.history

# ============================================================================
# Visualization
# ============================================================================

def plot_training_history(history, save_path):
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes[0, 0].plot(history['recon_loss'])
    axes[0, 0].set_title('Reconstruction Loss')
    axes[0, 1].plot(history['vq_loss'])
    axes[0, 1].set_title('VQ Loss')
    axes[0, 2].plot(history['perplexity'])
    axes[0, 2].set_title('Perplexity')
    axes[1, 0].plot(history['mse'])
    axes[1, 0].set_title('MSE')
    axes[1, 1].plot(history['psnr'])
    axes[1, 1].set_title('PSNR')
    axes[1, 2].plot(history['ssim'])
    axes[1, 2].set_title('SSIM')
    for ax in axes.flat:
        ax.grid(True)
        ax.set_xlabel('Epoch')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_reconstructions(model, dataloader, device, save_path, num_images=8):
    model.eval()
    images = next(iter(dataloader))[:num_images].to(device)
    with torch.no_grad():
        reconstructions, _, _, _ = model(images)
    images, reconstructions = (images + 1) / 2, (reconstructions + 1) / 2
    comparison = torch.cat([images, reconstructions])
    grid = make_grid(comparison, nrow=num_images, padding=2)
    plt.figure(figsize=(16, 4))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.title('Top: Original, Bottom: Reconstructed')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_generated_samples(samples, save_path, title="Generated Samples"):
    samples = (samples + 1) / 2
    grid = make_grid(samples, nrow=8, padding=2)
    plt.figure(figsize=(12, 12))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.title(title)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

# ============================================================================
# Main
# ============================================================================

def main():
    print("="*80)
    print("VQ-VAE Emoji Generation Project")
    print("="*80)

    # Dataset preparation
    if not os.listdir(config.data_dir):
        download_emoji_dataset()

    dataset = EmojiDataset(config.data_dir, config.image_size)
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=True)

    print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

    # Train VQ-VAE
    vqvae = VQVAE(config).to(config.device)
    trainer = Trainer(vqvae, config)

    if not trainer.load_checkpoint(os.path.join(config.checkpoint_dir, 'vqvae_final.pt')):
        history = trainer.train(train_loader, val_loader)
        plot_training_history(history, os.path.join(config.results_dir, 'training_history.png'))

    # Evaluate
    plot_reconstructions(vqvae, val_loader, config.device, os.path.join(config.results_dir, 'reconstructions.png'))

    # Train Prior
    sample_batch = next(iter(train_loader) )
    spatial_size = vqvae.encode(sample_batch.to(config.device)).shape[1]
    spatial_size = int(np.sqrt(spatial_size))

    prior = PixelCNN(config.num_embeddings, config.pixelcnn_layers, config.pixelcnn_hidden).to(config.device)
    prior_trainer = PriorTrainer(prior, vqvae, config)

    if not prior_trainer.load_checkpoint(os.path.join(config.checkpoint_dir, 'prior_final.pt')):
        prior_trainer.train(train_loader)

    # Generate new samples
    print("Generating new samples...")
    with torch.no_grad():
        generated_codes = prior.sample(config.num_samples, spatial_size, config.device)
        generated_images = vqvae.decode_codes(generated_codes, spatial_size)

    plot_generated_samples(generated_images, os.path.join(config.results_dir, 'generated_samples.png'), "Generated Emojis")

    # Interpolation
    print("Performing interpolation...")
    with torch.no_grad():
        real_batch = next(iter(val_loader))[:2].to(config.device)
        codes = vqvae.encode(real_batch)
        codes = codes.view(2, spatial_size, spatial_size)

        interpolations = []
        for alpha in np.linspace(0, 1, config.num_interpolation_steps):
            interp_code = (1 - alpha) * codes[0:1] + alpha * codes[1:2]
            interp_code = interp_code.long()
            interp_img = vqvae.decode_codes(interp_code, spatial_size)
            interpolations.append(interp_img)

        interpolation_grid = torch.cat(interpolations)
        plot_generated_samples(interpolation_grid, os.path.join(config.results_dir, 'interpolation.png'),
                              "Latent Space Interpolation")

    # Codebook analysis
    print("Analyzing codebook...")
    with torch.no_grad():
        all_codes = []
        for batch in val_loader:
            codes = vqvae.encode(batch.to(config.device))
            all_codes.append(codes.cpu())
        all_codes = torch.cat(all_codes, dim=0).numpy()

    # t-SNE visualization
    if all_codes.shape[1] > 2:
        tsne = TSNE(n_components=2, random_state=42)
        codes_2d = tsne.fit_transform(all_codes[:1000])
    else:
        codes_2d = all_codes[:1000]

    plt.figure(figsize=(10, 8))
    plt.scatter(codes_2d[:, 0], codes_2d[:, 1], alpha=0.6)
    plt.title('t-SNE Visualization of Latent Codes')
    plt.savefig(os.path.join(config.results_dir, 'tsne_codes.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # Clustering analysis
    kmeans = KMeans(n_clusters=10, random_state=42)
    cluster_labels = kmeans.fit_predict(all_codes[:1000])

    plt.figure(figsize=(10, 8))
    plt.scatter(codes_2d[:, 0], codes_2d[:, 1], c=cluster_labels, alpha=0.6, cmap='tab10')
    plt.title('K-means Clustering of Latent Codes')
    plt.savefig(os.path.join(config.results_dir, 'clustering.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # Quantitative evaluation
    print("Performing quantitative evaluation...")
    with torch.no_grad():
        real_images = []
        gen_images = []

        # Get real images
        for i, batch in enumerate(val_loader):
            if i >= 4:  # Use 256 real images
                break
            real_images.append(batch.cpu())
        real_images = torch.cat(real_images, dim=0)

        # Generate matching number of fake images
        for i in range(0, len(real_images), config.num_samples):
            batch_size = min(config.num_samples, len(real_images) - i)
            codes = prior.sample(batch_size, spatial_size, config.device)
            gen_batch = vqvae.decode_codes(codes, spatial_size)
            gen_images.append(gen_batch.cpu())
        gen_images = torch.cat(gen_images, dim=0)

    # Calculate metrics
    real_images = real_images.to(config.device)
    gen_images = gen_images.to(config.device)

    # FID calculation
    inception_model = None
    real_features, inception_model = get_inception_features(real_images, inception_model, config.device)
    gen_features, _ = get_inception_features(gen_images, inception_model, config.device)
    fid_score = calculate_fid(real_features, gen_features)

    print("\n" + "="*50)
    print("FINAL RESULTS")
    print("="*50)
    print(f"FID Score: {fid_score:.2f}")

    # Save final report
    report = {
        'fid_score': fid_score,
        'num_embeddings': config.num_embeddings,
        'embedding_dim': config.embedding_dim,
        'num_training_samples': len(train_dataset),
        'final_vq_loss': trainer.history['vq_loss'][-1] if trainer.history['vq_loss'] else None,
        'final_perplexity': trainer.history['perplexity'][-1] if trainer.history['perplexity'] else None
    }

    with open(os.path.join(config.results_dir, 'final_report.json'), 'w') as f:
        json.dump(report, f, indent=2)

    print("\nGeneration completed successfully!")
    print(f"Results saved to: {config.results_dir}")
    print(f"Checkpoints saved to: {config.checkpoint_dir}")

if __name__ == "__main__":
    main()