In [None]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from scipy import linalg
from skimage.metrics import structural_similarity as ssim
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# Configuration
# ============================================================================

class Config:
    # Paths
    data_dir = './emoji_data'
    checkpoint_dir = './checkpoints'
    results_dir = './results'

    # Data
    image_size = 64
    batch_size = 64
    num_workers = 2

    # VQ-VAE Architecture
    num_hiddens = 128
    num_residual_hiddens = 32
    num_residual_layers = 2
    embedding_dim = 64
    num_embeddings = 512
    commitment_cost = 0.25
    decay = 0.999

    # Training
    num_epochs_vqvae = 100
    learning_rate_vqvae = 3e-4
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Prior (PixelCNN)
    num_epochs_prior = 50
    learning_rate_prior = 1e-4
    pixelcnn_layers = 12
    pixelcnn_hidden = 64
    grad_clip = 1.0

    # Generation
    num_samples = 64
    num_interpolation_steps = 10

config = Config()

# Create directories
os.makedirs(config.checkpoint_dir, exist_ok=True)
os.makedirs(config.results_dir, exist_ok=True)
os.makedirs(config.data_dir, exist_ok=True)

print(f"Using device: {config.device}")

# ============================================================================
# Dataset
# ============================================================================

class EmojiDataset(Dataset):
    def __init__(self, data_dir, image_size=64, transform=None):
        self.data_dir = data_dir
        self.image_size = image_size
        self.image_files = [f for f in os.listdir(data_dir)
                           if f.endswith(('.png', '.jpg', '.jpeg'))]

        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((image_size, image_size)),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
            ])
        else:
            self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.data_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')
        return self.transform(image)

def download_emoji_dataset():
    print("Downloading emoji dataset...")
    try:
        from datasets import load_dataset
        dataset = load_dataset("valhalla/emoji-dataset", split="train")

        print(f"Downloaded {len(dataset)} emojis")
        print("Saving images to disk...")

        for idx, item in enumerate(tqdm(dataset)):
            img = item['image']
            if img.mode != 'RGB':
                img = img.convert('RGB')
            img.save(os.path.join(config.data_dir, f'emoji_{idx:05d}.png'))

        print(f"Saved {len(dataset)} emoji images to {config.data_dir}")
    except Exception as e:
        print(f"Error downloading dataset: {e}")
        print("Please manually download emojis to the data directory")

# ============================================================================
# VQ-VAE Components
# ============================================================================

class VectorQuantizerEMA(nn.Module):
    """Vector Quantization with EMA updates"""

    def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay=0.99, epsilon=1e-5):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost
        self.decay = decay
        self.epsilon = epsilon

        embed = torch.randn(num_embeddings, embedding_dim)
        self.register_buffer("embed", embed)
        self.register_buffer("cluster_size", torch.zeros(num_embeddings))
        self.register_buffer("embed_avg", embed.clone())

    def forward(self, inputs):
        input_shape = inputs.shape
        flat_input = inputs.permute(0, 2, 3, 1).contiguous().view(-1, self.embedding_dim)

        distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
                    + torch.sum(self.embed**2, dim=1)
                    - 2 * torch.matmul(flat_input, self.embed.t()))

        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)

        quantized = torch.matmul(encodings, self.embed)

        if self.training:
            self.cluster_size.data.mul_(self.decay).add_(encodings.sum(0), alpha=1 - self.decay)
            dw = torch.matmul(encodings.t(), flat_input)
            self.embed_avg.data.mul_(self.decay).add_(dw, alpha=1 - self.decay)

            n = self.cluster_size.sum()
            cluster_size = ((self.cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n)
            embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
            self.embed.data.copy_(embed_normalized)

        e_latent_loss = F.mse_loss(quantized.detach(), flat_input)
        loss = self.commitment_cost * e_latent_loss
        quantized = flat_input + (quantized - flat_input).detach()

        quantized = quantized.view(input_shape[0], input_shape[2], input_shape[3], self.embedding_dim)
        quantized = quantized.permute(0, 3, 1, 2).contiguous()

        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        return quantized, loss, perplexity, encoding_indices

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(in_channels, num_residual_hiddens, 3, padding=1, bias=False),
            nn.BatchNorm2d(num_residual_hiddens),
            nn.ReLU(),
            nn.Conv2d(num_residual_hiddens, num_hiddens, 1, bias=False),
            nn.BatchNorm2d(num_hiddens)
        )

    def forward(self, x):
        return x + self.block(x)

class ResidualStack(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super().__init__()
        self.layers = nn.ModuleList([
            ResidualBlock(in_channels, num_hiddens, num_residual_hiddens)
            for _ in range(num_residual_layers)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return F.relu(x)

class Encoder(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, num_hiddens // 2, 4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(num_hiddens // 2, num_hiddens, 4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(num_hiddens, num_hiddens, 3, padding=1)
        self.residual_stack = ResidualStack(num_hiddens, num_hiddens, num_residual_layers, num_residual_hiddens)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return self.residual_stack(x)

class Decoder(nn.Module):
    def __init__(self, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
        super().__init__()
        self.conv1 = nn.Conv2d(embedding_dim, num_hiddens, 3, padding=1)
        self.residual_stack = ResidualStack(num_hiddens, num_hiddens, num_residual_layers, num_residual_hiddens)
        self.conv_trans1 = nn.ConvTranspose2d(num_hiddens, num_hiddens // 2, 4, stride=2, padding=1)
        self.conv_trans2 = nn.ConvTranspose2d(num_hiddens // 2, 3, 4, stride=2, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.residual_stack(x)
        x = F.relu(self.conv_trans1(x))
        return torch.tanh(self.conv_trans2(x))

class VQVAE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = Encoder(3, config.num_hiddens, config.num_residual_layers, config.num_residual_hiddens)
        self.pre_vq_conv = nn.Conv2d(config.num_hiddens, config.embedding_dim, 1)
        self.vq = VectorQuantizerEMA(config.num_embeddings, config.embedding_dim, config.commitment_cost, decay=config.decay)
        self.decoder = Decoder(config.embedding_dim, config.num_hiddens, config.num_residual_layers, config.num_residual_hiddens)

    def forward(self, x):
        z = self.encoder(x)
        z = self.pre_vq_conv(z)
        quantized, vq_loss, perplexity, encoding_indices = self.vq(z)
        x_recon = self.decoder(quantized)
        return x_recon, vq_loss, perplexity, encoding_indices

    def encode(self, x):
        z = self.encoder(x)
        z = self.pre_vq_conv(z)
        _, _, _, encoding_indices = self.vq(z)
        B = x.shape[0]
        return encoding_indices.view(B, -1)

    def decode_codes(self, encoding_indices, spatial_size):
        """FIXED: Proper decoding from codes"""
        codes = encoding_indices.view(-1, spatial_size, spatial_size)
        quantized = F.embedding(codes, self.vq.embed)
        quantized = quantized.permute(0, 3, 1, 2).contiguous()
        return self.decoder(quantized)

# ============================================================================
# PixelCNN Prior
# ============================================================================

class MaskedConv2d(nn.Conv2d):
    def __init__(self, mask_type, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.register_buffer('mask', torch.zeros_like(self.weight))
        self.create_mask(mask_type)

    def create_mask(self, mask_type):
        k = self.kernel_size[0]
        self.mask[:, :, :k//2, :] = 1
        self.mask[:, :, k//2, :k//2] = 1
        if mask_type == 'B':
            self.mask[:, :, k//2, k//2] = 1

    def forward(self, x):
        self.weight.data *= self.mask
        return super().forward(x)

class PixelCNNResidualBlock(nn.Module):
    def __init__(self, h):
        super().__init__()
        self.conv = nn.Sequential(
            nn.ReLU(),
            MaskedConv2d('B', h, h, 1),
            nn.BatchNorm2d(h),
            nn.ReLU(),
            MaskedConv2d('B', h, h, 1),
            nn.BatchNorm2d(h)
        )

    def forward(self, x):
        return x + self.conv(x)

class PixelCNN(nn.Module):
    def __init__(self, num_embeddings, num_layers=12, hidden_dim=64):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.input_conv = MaskedConv2d('A', num_embeddings, hidden_dim, 7, padding=3)
        self.residual_blocks = nn.ModuleList([PixelCNNResidualBlock(hidden_dim) for _ in range(num_layers)])
        self.output = nn.Sequential(
            nn.ReLU(),
            MaskedConv2d('B', hidden_dim, hidden_dim, 1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, num_embeddings, 1)
        )
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, MaskedConv2d)):
                nn.init.xavier_uniform_(m.weight, gain=0.1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x_onehot = F.one_hot(x, self.num_embeddings).float()
        x_onehot = x_onehot.permute(0, 3, 1, 2).contiguous()
        x = self.input_conv(x_onehot)
        for block in self.residual_blocks:
            x = block(x)
        return self.output(x)

    @torch.no_grad()
    def sample(self, batch_size, spatial_size, device, temperature=1.0):
        samples = torch.zeros(batch_size, spatial_size, spatial_size, dtype=torch.long, device=device)
        for i in range(spatial_size):
            for j in range(spatial_size):
                logits = self(samples)
                probs = F.softmax(logits[:, :, i, j] / temperature, dim=1)
                samples[:, i, j] = torch.multinomial(probs, 1).squeeze(-1)
        return samples

# ============================================================================
# Metrics
# ============================================================================

def calculate_mse(original, reconstructed):
    return F.mse_loss(reconstructed, original).item()

def calculate_psnr(original, reconstructed, max_val=2.0):
    mse = F.mse_loss(reconstructed, original)
    return (10 * torch.log10(max_val**2 / mse)).item()

def calculate_ssim(original, reconstructed):
    orig_np = (original.detach().cpu().numpy() + 1) / 2
    recon_np = (reconstructed.detach().cpu().numpy() + 1) / 2
    ssim_scores = [ssim(orig_np[i].transpose(1, 2, 0), recon_np[i].transpose(1, 2, 0),
                       channel_axis=2, data_range=1.0) for i in range(orig_np.shape[0])]
    return np.mean(ssim_scores)

def calculate_fid(real_features, fake_features):
    mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)
    ssdiff = np.sum((mu1 - mu2)**2)
    covmean = linalg.sqrtm(sigma1.dot(sigma2))
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    return ssdiff + np.trace(sigma1 + sigma2 - 2*covmean)

def get_inception_features(images, model, device):
    from torchvision.models import inception_v3
    if model is None:
        model = inception_v3(pretrained=True, transform_input=False)
        model.fc = nn.Identity()
        model = model.to(device)
        model.eval()
    with torch.no_grad():
        images_resized = F.interpolate(images, size=(299, 299), mode='bilinear')
        features = model(images_resized)
    return features.cpu().numpy(), model

# ============================================================================
# Trainers - FIXED CHECKPOINT LOADING
# ============================================================================

class Trainer:
    def __init__(self, model, config):
        self.model = model
        self.config = config
        self.optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate_vqvae)
        self.history = {'recon_loss': [], 'vq_loss': [], 'total_loss': [], 'perplexity': [], 'mse': [], 'psnr': [], 'ssim': []}
        self.start_epoch = 0

    def save_checkpoint(self, epoch, filepath):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'history': self.history
        }
        torch.save(checkpoint, filepath)
        print(f"Checkpoint saved: {filepath}")

    def load_checkpoint(self, filepath):
        """FIXED: Proper checkpoint loading"""
        if os.path.exists(filepath):
            try:
                checkpoint = torch.load(filepath, map_location=self.config.device)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                self.history = checkpoint['history']
                self.start_epoch = checkpoint['epoch'] + 1
                print(f"Checkpoint loaded: {filepath}, resuming from epoch {self.start_epoch}")
                return True
            except Exception as e:
                print(f"Error loading checkpoint: {e}")
                return False
        return False

    def train_epoch(self, dataloader):
        self.model.train()
        epoch_recon_loss = epoch_vq_loss = epoch_perplexity = 0
        pbar = tqdm(dataloader, desc="Training")
        for batch in pbar:
            batch = batch.to(self.config.device)
            recon, vq_loss, perplexity, _ = self.model(batch)
            recon_loss = F.mse_loss(recon, batch)
            loss = recon_loss + vq_loss
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            epoch_recon_loss += recon_loss.item()
            epoch_vq_loss += vq_loss.item()
            epoch_perplexity += perplexity.item()
            pbar.set_postfix({'recon_loss': recon_loss.item(), 'vq_loss': vq_loss.item(), 'perplexity': perplexity.item()})
        n = len(dataloader)
        return {'recon_loss': epoch_recon_loss/n, 'vq_loss': epoch_vq_loss/n, 'perplexity': epoch_perplexity/n}

    @torch.no_grad()
    def evaluate(self, dataloader):
        self.model.eval()
        all_original, all_recon = [], []
        for batch in dataloader:
            batch = batch.to(self.config.device)
            recon, _, _, _ = self.model(batch)
            all_original.append(batch)
            all_recon.append(recon)
        original = torch.cat(all_original, dim=0)
        reconstructed = torch.cat(all_recon, dim=0)
        return {'mse': calculate_mse(original, reconstructed),
                'psnr': calculate_psnr(original, reconstructed),
                'ssim': calculate_ssim(original[:64], reconstructed[:64])}

    def train(self, train_loader, val_loader):
        print(f"Starting training from epoch {self.start_epoch}")
        for epoch in range(self.start_epoch, self.config.num_epochs_vqvae):
            print(f"\nEpoch {epoch+1}/{self.config.num_epochs_vqvae}")
            train_metrics = self.train_epoch(train_loader)
            val_metrics = self.evaluate(val_loader)
            metrics = {**train_metrics, **val_metrics}
            for key, value in metrics.items():
                self.history[key].append(value)
            print(f"Recon: {metrics['recon_loss']:.4f}, VQ: {metrics['vq_loss']:.4f}, Perplexity: {metrics['perplexity']:.2f}")
            print(f"MSE: {metrics['mse']:.4f}, PSNR: {metrics['psnr']:.2f}, SSIM: {metrics['ssim']:.4f}")
            if (epoch + 1) % 5 == 0:
                self.save_checkpoint(epoch, os.path.join(self.config.checkpoint_dir, f'vqvae_epoch_{epoch+1}.pt'))
        self.save_checkpoint(self.config.num_epochs_vqvae - 1, os.path.join(self.config.checkpoint_dir, 'vqvae_final.pt'))
        return self.history

class PriorTrainer:
    def __init__(self, prior, vqvae, config):
        self.prior = prior
        self.vqvae = vqvae
        self.config = config
        self.optimizer = torch.optim.Adam(prior.parameters(), lr=config.learning_rate_prior)
        self.history = {'loss': []}
        self.start_epoch = 0

    def save_checkpoint(self, epoch, filepath):
        torch.save({'epoch': epoch, 'model_state_dict': self.prior.state_dict(),
                   'optimizer_state_dict': self.optimizer.state_dict(), 'history': self.history}, filepath)

    def load_checkpoint(self, filepath):
        if os.path.exists(filepath):
            try:
                checkpoint = torch.load(filepath, map_location=self.config.device)
                self.prior.load_state_dict(checkpoint['model_state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                self.history = checkpoint['history']
                self.start_epoch = checkpoint['epoch'] + 1
                return True
            except Exception as e:
                print(f"Error loading prior checkpoint: {e}")
                return False
        return False

    def train_epoch(self, dataloader):
        self.prior.train()
        self.vqvae.eval()
        epoch_loss = num_valid = 0
        pbar = tqdm(dataloader, desc="Training Prior")
        for batch in pbar:
            batch = batch.to(self.config.device)
            with torch.no_grad():
                codes = self.vqvae.encode(batch)
                spatial_size = int(np.sqrt(codes.shape[1]))
                codes = codes.view(-1, spatial_size, spatial_size)
            logits = self.prior(codes)
            loss = F.cross_entropy(logits, codes)
            if torch.isnan(loss):
                continue
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.prior.parameters(), self.config.grad_clip)
            self.optimizer.step()
            epoch_loss += loss.item()
            num_valid += 1
            pbar.set_postfix({'loss': loss.item()})
        return epoch_loss / num_valid if num_valid > 0 else float('inf')

    def train(self, train_loader):
        for epoch in range(self.start_epoch, self.config.num_epochs_prior):
            loss = self.train_epoch(train_loader)
            if loss == float('inf'):
                break
            self.history['loss'].append(loss)
            print(f"Epoch {epoch+1}, Loss: {loss:.4f}")
            if (epoch + 1) % 5 == 0:
                self.save_checkpoint(epoch, os.path.join(self.config.checkpoint_dir, f'prior_epoch_{epoch+1}.pt'))
        self.save_checkpoint(self.config.num_epochs_prior - 1, os.path.join(self.config.checkpoint_dir, 'prior_final.pt'))
        return self.history

# ============================================================================
# Visualization
# ============================================================================

def plot_training_history(history, save_path):
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes[0, 0].plot(history['recon_loss'])
    axes[0, 0].set_title('Reconstruction Loss')
    axes[0, 1].plot(history['vq_loss'])
    axes[0, 1].set_title('VQ Loss')
    axes[0, 2].plot(history['perplexity'])
    axes[0, 2].set_title('Perplexity')
    axes[1, 0].plot(history['mse'])
    axes[1, 0].set_title('MSE')
    axes[1, 1].plot(history['psnr'])
    axes[1, 1].set_title('PSNR')
    axes[1, 2].plot(history['ssim'])
    axes[1, 2].set_title('SSIM')
    for ax in axes.flat:
        ax.grid(True)
        ax.set_xlabel('Epoch')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_reconstructions(model, dataloader, device, save_path, num_images=8):
    model.eval()
    images = next(iter(dataloader))[:num_images].to(device)
    with torch.no_grad():
        reconstructions, _, _, _ = model(images)
    images, reconstructions = (images + 1) / 2, (reconstructions + 1) / 2
    comparison = torch.cat([images, reconstructions])
    grid = make_grid(comparison, nrow=num_images, padding=2)
    plt.figure(figsize=(16, 4))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.title('Top: Original, Bottom: Reconstructed')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_generated_samples(samples, save_path, title="Generated Samples"):
    samples = (samples + 1) / 2
    grid = make_grid(samples, nrow=8, padding=2)
    plt.figure(figsize=(12, 12))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.title(title)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

# ============================================================================
# Main
# ============================================================================

def main():
    print("="*80)
    print("VQ-VAE Emoji Generation Project")
    print("="*80)

    # Dataset preparation
    if not os.listdir(config.data_dir):
        download_emoji_dataset()

    dataset = EmojiDataset(config.data_dir, config.image_size)
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=True)

    print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

    # Train VQ-VAE
    vqvae = VQVAE(config).to(config.device)
    trainer = Trainer(vqvae, config)

    if not trainer.load_checkpoint(os.path.join(config.checkpoint_dir, 'vqvae_final.pt')):
        history = trainer.train(train_loader, val_loader)
        plot_training_history(history, os.path.join(config.results_dir, 'training_history.png'))

    # Evaluate
    plot_reconstructions(vqvae, val_loader, config.device, os.path.join(config.results_dir, 'reconstructions.png'))

    # Train Prior
    sample_batch = next(iter(train_loader) )
    spatial_size = vqvae.encode(sample_batch.to(config.device)).shape[1]
    spatial_size = int(np.sqrt(spatial_size))

    prior = PixelCNN(config.num_embeddings, config.pixelcnn_layers, config.pixelcnn_hidden).to(config.device)
    prior_trainer = PriorTrainer(prior, vqvae, config)

    if not prior_trainer.load_checkpoint(os.path.join(config.checkpoint_dir, 'prior_final.pt')):
        prior_trainer.train(train_loader)

    # Generate new samples
    print("Generating new samples...")
    with torch.no_grad():
        generated_codes = prior.sample(config.num_samples, spatial_size, config.device)
        generated_images = vqvae.decode_codes(generated_codes, spatial_size)

    plot_generated_samples(generated_images, os.path.join(config.results_dir, 'generated_samples.png'), "Generated Emojis")

    # Interpolation
    print("Performing interpolation...")
    with torch.no_grad():
        real_batch = next(iter(val_loader))[:2].to(config.device)
        codes = vqvae.encode(real_batch)
        codes = codes.view(2, spatial_size, spatial_size)

        interpolations = []
        for alpha in np.linspace(0, 1, config.num_interpolation_steps):
            interp_code = (1 - alpha) * codes[0:1] + alpha * codes[1:2]
            interp_code = interp_code.long()
            interp_img = vqvae.decode_codes(interp_code, spatial_size)
            interpolations.append(interp_img)

        interpolation_grid = torch.cat(interpolations)
        plot_generated_samples(interpolation_grid, os.path.join(config.results_dir, 'interpolation.png'),
                              "Latent Space Interpolation")

    # Codebook analysis
    print("Analyzing codebook...")
    with torch.no_grad():
        all_codes = []
        for batch in val_loader:
            codes = vqvae.encode(batch.to(config.device))
            all_codes.append(codes.cpu())
        all_codes = torch.cat(all_codes, dim=0).numpy()

    # t-SNE visualization
    if all_codes.shape[1] > 2:
        tsne = TSNE(n_components=2, random_state=42)
        codes_2d = tsne.fit_transform(all_codes[:1000])
    else:
        codes_2d = all_codes[:1000]

    plt.figure(figsize=(10, 8))
    plt.scatter(codes_2d[:, 0], codes_2d[:, 1], alpha=0.6)
    plt.title('t-SNE Visualization of Latent Codes')
    plt.savefig(os.path.join(config.results_dir, 'tsne_codes.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # Clustering analysis
    kmeans = KMeans(n_clusters=10, random_state=42)
    cluster_labels = kmeans.fit_predict(all_codes[:1000])

    plt.figure(figsize=(10, 8))
    plt.scatter(codes_2d[:, 0], codes_2d[:, 1], c=cluster_labels, alpha=0.6, cmap='tab10')
    plt.title('K-means Clustering of Latent Codes')
    plt.savefig(os.path.join(config.results_dir, 'clustering.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # Quantitative evaluation
    print("Performing quantitative evaluation...")
    with torch.no_grad():
        real_images = []
        gen_images = []

        # Get real images
        for i, batch in enumerate(val_loader):
            if i >= 4:  # Use 256 real images
                break
            real_images.append(batch.cpu())
        real_images = torch.cat(real_images, dim=0)

        # Generate matching number of fake images
        for i in range(0, len(real_images), config.num_samples):
            batch_size = min(config.num_samples, len(real_images) - i)
            codes = prior.sample(batch_size, spatial_size, config.device)
            gen_batch = vqvae.decode_codes(codes, spatial_size)
            gen_images.append(gen_batch.cpu())
        gen_images = torch.cat(gen_images, dim=0)

    # Calculate metrics
    real_images = real_images.to(config.device)
    gen_images = gen_images.to(config.device)

    # FID calculation
    inception_model = None
    real_features, inception_model = get_inception_features(real_images, inception_model, config.device)
    gen_features, _ = get_inception_features(gen_images, inception_model, config.device)
    fid_score = calculate_fid(real_features, gen_features)

    print("\n" + "="*50)
    print("FINAL RESULTS")
    print("="*50)
    print(f"FID Score: {fid_score:.2f}")

    # Save final report
    report = {
        'fid_score': fid_score,
        'num_embeddings': config.num_embeddings,
        'embedding_dim': config.embedding_dim,
        'num_training_samples': len(train_dataset),
        'final_vq_loss': trainer.history['vq_loss'][-1] if trainer.history['vq_loss'] else None,
        'final_perplexity': trainer.history['perplexity'][-1] if trainer.history['perplexity'] else None
    }

    with open(os.path.join(config.results_dir, 'final_report.json'), 'w') as f:
        json.dump(report, f, indent=2)

    print("\nGeneration completed successfully!")
    print(f"Results saved to: {config.results_dir}")
    print(f"Checkpoints saved to: {config.checkpoint_dir}")

if __name__ == "__main__":
    main()

Using device: cuda
VQ-VAE Emoji Generation Project
Train: 2474, Val: 275
Starting training from epoch 0

Epoch 1/100


Training: 100%|██████████| 39/39 [00:07<00:00,  5.54it/s, recon_loss=0.149, vq_loss=0.0786, perplexity=1.84]


Recon: 0.2776, VQ: 0.0667, Perplexity: 2.09
MSE: 0.2015, PSNR: 12.98, SSIM: 0.3305

Epoch 2/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.97it/s, recon_loss=0.14, vq_loss=0.126, perplexity=1.92]


Recon: 0.1557, VQ: 0.1102, Perplexity: 1.89
MSE: 0.1542, PSNR: 14.14, SSIM: 0.4654

Epoch 3/100


Training: 100%|██████████| 39/39 [00:06<00:00,  5.70it/s, recon_loss=0.107, vq_loss=0.0676, perplexity=2.59]


Recon: 0.1323, VQ: 0.1107, Perplexity: 2.27
MSE: 0.1382, PSNR: 14.62, SSIM: 0.4718

Epoch 4/100


Training: 100%|██████████| 39/39 [00:05<00:00,  7.03it/s, recon_loss=0.0894, vq_loss=0.052, perplexity=3.42]


Recon: 0.1018, VQ: 0.0659, Perplexity: 2.96
MSE: 0.0984, PSNR: 16.09, SSIM: 0.5652

Epoch 5/100


Training: 100%|██████████| 39/39 [00:07<00:00,  5.46it/s, recon_loss=0.0844, vq_loss=0.0567, perplexity=3.35]


Recon: 0.0866, VQ: 0.0580, Perplexity: 3.28
MSE: 0.0873, PSNR: 16.61, SSIM: 0.5935
Checkpoint saved: ./checkpoints/vqvae_epoch_5.pt

Epoch 6/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.96it/s, recon_loss=0.0811, vq_loss=0.0652, perplexity=3.37]


Recon: 0.0806, VQ: 0.0579, Perplexity: 3.31
MSE: 0.0971, PSNR: 16.15, SSIM: 0.5738

Epoch 7/100


Training: 100%|██████████| 39/39 [00:07<00:00,  5.40it/s, recon_loss=0.0706, vq_loss=0.0698, perplexity=3.23]


Recon: 0.0777, VQ: 0.0595, Perplexity: 3.31
MSE: 0.0847, PSNR: 16.74, SSIM: 0.6264

Epoch 8/100


Training: 100%|██████████| 39/39 [00:05<00:00,  7.01it/s, recon_loss=0.0657, vq_loss=0.0618, perplexity=3.3]


Recon: 0.0740, VQ: 0.0634, Perplexity: 3.33
MSE: 0.0985, PSNR: 16.09, SSIM: 0.6200

Epoch 9/100


Training: 100%|██████████| 39/39 [00:07<00:00,  5.52it/s, recon_loss=0.072, vq_loss=0.0467, perplexity=3.7]


Recon: 0.0720, VQ: 0.0576, Perplexity: 3.48
MSE: 0.0693, PSNR: 17.61, SSIM: 0.6388

Epoch 10/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.93it/s, recon_loss=0.0644, vq_loss=0.0486, perplexity=3.95]


Recon: 0.0681, VQ: 0.0471, Perplexity: 3.70
MSE: 0.0633, PSNR: 18.00, SSIM: 0.6564
Checkpoint saved: ./checkpoints/vqvae_epoch_10.pt

Epoch 11/100


Training: 100%|██████████| 39/39 [00:08<00:00,  4.74it/s, recon_loss=0.0658, vq_loss=0.046, perplexity=3.77]


Recon: 0.0651, VQ: 0.0487, Perplexity: 3.78
MSE: 0.0645, PSNR: 17.92, SSIM: 0.6561

Epoch 12/100


Training: 100%|██████████| 39/39 [00:06<00:00,  5.97it/s, recon_loss=0.0598, vq_loss=0.0476, perplexity=3.85]


Recon: 0.0627, VQ: 0.0471, Perplexity: 3.84
MSE: 0.0657, PSNR: 17.85, SSIM: 0.6619

Epoch 13/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.32it/s, recon_loss=0.0818, vq_loss=0.0483, perplexity=3.87]


Recon: 0.0618, VQ: 0.0467, Perplexity: 3.85
MSE: 0.0606, PSNR: 18.20, SSIM: 0.6736

Epoch 14/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.19it/s, recon_loss=0.0628, vq_loss=0.0461, perplexity=4.06]


Recon: 0.0601, VQ: 0.0470, Perplexity: 3.87
MSE: 0.0568, PSNR: 18.48, SSIM: 0.6847

Epoch 15/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.25it/s, recon_loss=0.0553, vq_loss=0.0447, perplexity=3.82]


Recon: 0.0596, VQ: 0.0463, Perplexity: 3.88
MSE: 0.0571, PSNR: 18.45, SSIM: 0.6800
Checkpoint saved: ./checkpoints/vqvae_epoch_15.pt

Epoch 16/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.41it/s, recon_loss=0.0547, vq_loss=0.043, perplexity=3.84]


Recon: 0.0585, VQ: 0.0456, Perplexity: 3.89
MSE: 0.0564, PSNR: 18.51, SSIM: 0.6844

Epoch 17/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.47it/s, recon_loss=0.0543, vq_loss=0.0423, perplexity=4]


Recon: 0.0568, VQ: 0.0443, Perplexity: 3.89
MSE: 0.0552, PSNR: 18.60, SSIM: 0.6835

Epoch 18/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.14it/s, recon_loss=0.0549, vq_loss=0.0499, perplexity=3.94]


Recon: 0.0561, VQ: 0.0446, Perplexity: 3.90
MSE: 0.0543, PSNR: 18.67, SSIM: 0.6883

Epoch 19/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.54it/s, recon_loss=0.0593, vq_loss=0.0479, perplexity=4.15]


Recon: 0.0566, VQ: 0.0459, Perplexity: 3.91
MSE: 0.0615, PSNR: 18.13, SSIM: 0.6886

Epoch 20/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.29it/s, recon_loss=0.0544, vq_loss=0.0478, perplexity=3.95]


Recon: 0.0544, VQ: 0.0455, Perplexity: 3.89
MSE: 0.0540, PSNR: 18.70, SSIM: 0.6905
Checkpoint saved: ./checkpoints/vqvae_epoch_20.pt

Epoch 21/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.39it/s, recon_loss=0.0586, vq_loss=0.0458, perplexity=3.97]


Recon: 0.0533, VQ: 0.0460, Perplexity: 3.93
MSE: 0.0512, PSNR: 18.93, SSIM: 0.6976

Epoch 22/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.27it/s, recon_loss=0.0522, vq_loss=0.0484, perplexity=3.92]


Recon: 0.0529, VQ: 0.0458, Perplexity: 3.94
MSE: 0.0519, PSNR: 18.87, SSIM: 0.7017

Epoch 23/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.30it/s, recon_loss=0.0621, vq_loss=0.0421, perplexity=3.99]


Recon: 0.0529, VQ: 0.0457, Perplexity: 3.95
MSE: 0.0544, PSNR: 18.67, SSIM: 0.6986

Epoch 24/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.43it/s, recon_loss=0.0502, vq_loss=0.0426, perplexity=3.95]


Recon: 0.0510, VQ: 0.0417, Perplexity: 3.95
MSE: 0.0482, PSNR: 19.19, SSIM: 0.7058

Epoch 25/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.29it/s, recon_loss=0.0504, vq_loss=0.038, perplexity=3.91]


Recon: 0.0515, VQ: 0.0412, Perplexity: 3.96
MSE: 0.0502, PSNR: 19.01, SSIM: 0.7017
Checkpoint saved: ./checkpoints/vqvae_epoch_25.pt

Epoch 26/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.42it/s, recon_loss=0.052, vq_loss=0.0407, perplexity=3.93]


Recon: 0.0499, VQ: 0.0406, Perplexity: 3.97
MSE: 0.0510, PSNR: 18.95, SSIM: 0.7053

Epoch 27/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.16it/s, recon_loss=0.0497, vq_loss=0.0411, perplexity=3.96]


Recon: 0.0488, VQ: 0.0410, Perplexity: 3.99
MSE: 0.0485, PSNR: 19.16, SSIM: 0.7124

Epoch 28/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.47it/s, recon_loss=0.0414, vq_loss=0.0419, perplexity=3.96]


Recon: 0.0488, VQ: 0.0410, Perplexity: 4.01
MSE: 0.0477, PSNR: 19.24, SSIM: 0.7129

Epoch 29/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.24it/s, recon_loss=0.0487, vq_loss=0.0464, perplexity=3.91]


Recon: 0.0493, VQ: 0.0408, Perplexity: 4.02
MSE: 0.0550, PSNR: 18.61, SSIM: 0.7044

Epoch 30/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.76it/s, recon_loss=0.0447, vq_loss=0.0385, perplexity=4.07]


Recon: 0.0478, VQ: 0.0415, Perplexity: 4.02
MSE: 0.0481, PSNR: 19.20, SSIM: 0.7131
Checkpoint saved: ./checkpoints/vqvae_epoch_30.pt

Epoch 31/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.01it/s, recon_loss=0.0521, vq_loss=0.0426, perplexity=4.01]


Recon: 0.0482, VQ: 0.0421, Perplexity: 4.04
MSE: 0.0503, PSNR: 19.00, SSIM: 0.7084

Epoch 32/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.45it/s, recon_loss=0.0497, vq_loss=0.0422, perplexity=4.04]


Recon: 0.0478, VQ: 0.0399, Perplexity: 4.04
MSE: 0.0478, PSNR: 19.23, SSIM: 0.7143

Epoch 33/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.07it/s, recon_loss=0.0471, vq_loss=0.0404, perplexity=4.08]


Recon: 0.0470, VQ: 0.0403, Perplexity: 4.05
MSE: 0.0468, PSNR: 19.32, SSIM: 0.7206

Epoch 34/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.57it/s, recon_loss=0.0486, vq_loss=0.0385, perplexity=4.01]


Recon: 0.0456, VQ: 0.0404, Perplexity: 4.08
MSE: 0.0455, PSNR: 19.44, SSIM: 0.7203

Epoch 35/100


Training: 100%|██████████| 39/39 [00:06<00:00,  5.96it/s, recon_loss=0.0413, vq_loss=0.0429, perplexity=4.23]


Recon: 0.0458, VQ: 0.0397, Perplexity: 4.09
MSE: 0.0455, PSNR: 19.44, SSIM: 0.7172
Checkpoint saved: ./checkpoints/vqvae_epoch_35.pt

Epoch 36/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.52it/s, recon_loss=0.0481, vq_loss=0.0392, perplexity=4.23]


Recon: 0.0452, VQ: 0.0404, Perplexity: 4.09
MSE: 0.0467, PSNR: 19.33, SSIM: 0.7177

Epoch 37/100


Training: 100%|██████████| 39/39 [00:06<00:00,  5.99it/s, recon_loss=0.0432, vq_loss=0.04, perplexity=4.18]


Recon: 0.0461, VQ: 0.0414, Perplexity: 4.10
MSE: 0.0467, PSNR: 19.33, SSIM: 0.7172

Epoch 38/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.35it/s, recon_loss=0.0445, vq_loss=0.0405, perplexity=4.18]


Recon: 0.0452, VQ: 0.0430, Perplexity: 4.11
MSE: 0.0501, PSNR: 19.02, SSIM: 0.7116

Epoch 39/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.05it/s, recon_loss=0.0612, vq_loss=0.0447, perplexity=4.21]


Recon: 0.0444, VQ: 0.0404, Perplexity: 4.12
MSE: 0.0448, PSNR: 19.50, SSIM: 0.7273

Epoch 40/100


Training: 100%|██████████| 39/39 [00:06<00:00,  5.98it/s, recon_loss=0.0386, vq_loss=0.0412, perplexity=4.17]


Recon: 0.0439, VQ: 0.0415, Perplexity: 4.13
MSE: 0.0471, PSNR: 19.29, SSIM: 0.7251
Checkpoint saved: ./checkpoints/vqvae_epoch_40.pt

Epoch 41/100


Training: 100%|██████████| 39/39 [00:06<00:00,  6.49it/s, recon_loss=0.0409, vq_loss=0.0403, perplexity=4.09]


Recon: 0.0437, VQ: 0.0416, Perplexity: 4.13
MSE: 0.0430, PSNR: 19.68, SSIM: 0.7278

Epoch 42/100


Training: 100%|██████████| 39/39 [00:06<00:00,  5.80it/s, recon_loss=0.0422, vq_loss=0.0449, perplexity=4.14]


Recon: 0.0431, VQ: 0.0414, Perplexity: 4.14
MSE: 0.0441, PSNR: 19.58, SSIM: 0.7243

Epoch 43/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.62it/s, recon_loss=0.0367, vq_loss=0.0408, perplexity=4.02]


Recon: 0.0436, VQ: 0.0439, Perplexity: 4.13
MSE: 0.0432, PSNR: 19.66, SSIM: 0.7293

Epoch 44/100


Training: 100%|██████████| 39/39 [00:06<00:00,  5.63it/s, recon_loss=0.0423, vq_loss=0.0414, perplexity=4.18]


Recon: 0.0432, VQ: 0.0413, Perplexity: 4.14
MSE: 0.0431, PSNR: 19.68, SSIM: 0.7194

Epoch 45/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.84it/s, recon_loss=0.0409, vq_loss=0.0424, perplexity=4.23]


Recon: 0.0428, VQ: 0.0427, Perplexity: 4.14
MSE: 0.0420, PSNR: 19.79, SSIM: 0.7277
Checkpoint saved: ./checkpoints/vqvae_epoch_45.pt

Epoch 46/100


Training: 100%|██████████| 39/39 [00:06<00:00,  5.71it/s, recon_loss=0.0441, vq_loss=0.0448, perplexity=4.28]


Recon: 0.0417, VQ: 0.0430, Perplexity: 4.16
MSE: 0.0417, PSNR: 19.82, SSIM: 0.7313

Epoch 47/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.95it/s, recon_loss=0.0432, vq_loss=0.0407, perplexity=4.24]


Recon: 0.0423, VQ: 0.0438, Perplexity: 4.14
MSE: 0.0427, PSNR: 19.72, SSIM: 0.7266

Epoch 48/100


Training: 100%|██████████| 39/39 [00:07<00:00,  5.46it/s, recon_loss=0.0394, vq_loss=0.0423, perplexity=3.96]


Recon: 0.0412, VQ: 0.0433, Perplexity: 4.15
MSE: 0.0412, PSNR: 19.87, SSIM: 0.7316

Epoch 49/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.88it/s, recon_loss=0.0503, vq_loss=0.0447, perplexity=4.11]


Recon: 0.0410, VQ: 0.0442, Perplexity: 4.15
MSE: 0.0418, PSNR: 19.81, SSIM: 0.7272

Epoch 50/100


Training: 100%|██████████| 39/39 [00:07<00:00,  5.55it/s, recon_loss=0.0393, vq_loss=0.0487, perplexity=4.24]


Recon: 0.0403, VQ: 0.0453, Perplexity: 4.16
MSE: 0.0411, PSNR: 19.88, SSIM: 0.7316
Checkpoint saved: ./checkpoints/vqvae_epoch_50.pt

Epoch 51/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.82it/s, recon_loss=0.0402, vq_loss=0.0447, perplexity=4.23]


Recon: 0.0414, VQ: 0.0465, Perplexity: 4.15
MSE: 0.0412, PSNR: 19.88, SSIM: 0.7296

Epoch 52/100


Training: 100%|██████████| 39/39 [00:07<00:00,  5.44it/s, recon_loss=0.0439, vq_loss=0.0528, perplexity=4.2]


Recon: 0.0403, VQ: 0.0470, Perplexity: 4.17
MSE: 0.0409, PSNR: 19.90, SSIM: 0.7306

Epoch 53/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.98it/s, recon_loss=0.0466, vq_loss=0.0485, perplexity=4.29]


Recon: 0.0409, VQ: 0.0473, Perplexity: 4.17
MSE: 0.0419, PSNR: 19.79, SSIM: 0.7291

Epoch 54/100


Training: 100%|██████████| 39/39 [00:07<00:00,  5.52it/s, recon_loss=0.0412, vq_loss=0.0466, perplexity=4.12]


Recon: 0.0410, VQ: 0.0487, Perplexity: 4.15
MSE: 0.0425, PSNR: 19.74, SSIM: 0.7256

Epoch 55/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.93it/s, recon_loss=0.041, vq_loss=0.0535, perplexity=4.32]


Recon: 0.0402, VQ: 0.0484, Perplexity: 4.17
MSE: 0.0426, PSNR: 19.72, SSIM: 0.7308
Checkpoint saved: ./checkpoints/vqvae_epoch_55.pt

Epoch 56/100


Training: 100%|██████████| 39/39 [00:07<00:00,  5.49it/s, recon_loss=0.0367, vq_loss=0.0437, perplexity=4.09]


Recon: 0.0396, VQ: 0.0467, Perplexity: 4.18
MSE: 0.0403, PSNR: 19.97, SSIM: 0.7340

Epoch 57/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.84it/s, recon_loss=0.0338, vq_loss=0.0456, perplexity=4.11]


Recon: 0.0402, VQ: 0.0470, Perplexity: 4.18
MSE: 0.0421, PSNR: 19.78, SSIM: 0.7302

Epoch 58/100


Training: 100%|██████████| 39/39 [00:06<00:00,  5.59it/s, recon_loss=0.037, vq_loss=0.0468, perplexity=4.05]


Recon: 0.0399, VQ: 0.0479, Perplexity: 4.17
MSE: 0.0395, PSNR: 20.05, SSIM: 0.7340

Epoch 59/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.89it/s, recon_loss=0.0365, vq_loss=0.0467, perplexity=4.21]


Recon: 0.0396, VQ: 0.0474, Perplexity: 4.18
MSE: 0.0408, PSNR: 19.91, SSIM: 0.7318

Epoch 60/100


Training: 100%|██████████| 39/39 [00:06<00:00,  5.62it/s, recon_loss=0.0379, vq_loss=0.0452, perplexity=4.18]


Recon: 0.0386, VQ: 0.0486, Perplexity: 4.18
MSE: 0.0423, PSNR: 19.75, SSIM: 0.7308
Checkpoint saved: ./checkpoints/vqvae_epoch_60.pt

Epoch 61/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.98it/s, recon_loss=0.0384, vq_loss=0.0465, perplexity=4.12]


Recon: 0.0393, VQ: 0.0491, Perplexity: 4.19
MSE: 0.0380, PSNR: 20.23, SSIM: 0.7403

Epoch 62/100


Training: 100%|██████████| 39/39 [00:06<00:00,  5.58it/s, recon_loss=0.0416, vq_loss=0.0443, perplexity=4.18]


Recon: 0.0382, VQ: 0.0478, Perplexity: 4.19
MSE: 0.0397, PSNR: 20.03, SSIM: 0.7304

Epoch 63/100


Training: 100%|██████████| 39/39 [00:05<00:00,  7.08it/s, recon_loss=0.0379, vq_loss=0.0403, perplexity=4.14]


Recon: 0.0377, VQ: 0.0471, Perplexity: 4.20
MSE: 0.0387, PSNR: 20.14, SSIM: 0.7407

Epoch 64/100


Training: 100%|██████████| 39/39 [00:06<00:00,  5.71it/s, recon_loss=0.037, vq_loss=0.0468, perplexity=4.18]


Recon: 0.0381, VQ: 0.0493, Perplexity: 4.19
MSE: 0.0392, PSNR: 20.09, SSIM: 0.7345

Epoch 65/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.79it/s, recon_loss=0.0359, vq_loss=0.0472, perplexity=4.27]


Recon: 0.0382, VQ: 0.0475, Perplexity: 4.19
MSE: 0.0427, PSNR: 19.71, SSIM: 0.7322
Checkpoint saved: ./checkpoints/vqvae_epoch_65.pt

Epoch 66/100


Training: 100%|██████████| 39/39 [00:07<00:00,  5.56it/s, recon_loss=0.0333, vq_loss=0.0467, perplexity=4.13]


Recon: 0.0376, VQ: 0.0494, Perplexity: 4.20
MSE: 0.0379, PSNR: 20.24, SSIM: 0.7378

Epoch 67/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.90it/s, recon_loss=0.0385, vq_loss=0.0443, perplexity=4.19]


Recon: 0.0378, VQ: 0.0490, Perplexity: 4.20
MSE: 0.0391, PSNR: 20.10, SSIM: 0.7371

Epoch 68/100


Training: 100%|██████████| 39/39 [00:06<00:00,  5.65it/s, recon_loss=0.0329, vq_loss=0.0424, perplexity=4.18]


Recon: 0.0371, VQ: 0.0485, Perplexity: 4.21
MSE: 0.0385, PSNR: 20.16, SSIM: 0.7430

Epoch 69/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.84it/s, recon_loss=0.0415, vq_loss=0.0503, perplexity=4.24]


Recon: 0.0375, VQ: 0.0476, Perplexity: 4.21
MSE: 0.0389, PSNR: 20.12, SSIM: 0.7310

Epoch 70/100


Training: 100%|██████████| 39/39 [00:06<00:00,  5.60it/s, recon_loss=0.0392, vq_loss=0.0648, perplexity=4.25]


Recon: 0.0368, VQ: 0.0484, Perplexity: 4.21
MSE: 0.0384, PSNR: 20.18, SSIM: 0.7391
Checkpoint saved: ./checkpoints/vqvae_epoch_70.pt

Epoch 71/100


Training: 100%|██████████| 39/39 [00:05<00:00,  7.03it/s, recon_loss=0.0362, vq_loss=0.0452, perplexity=4.13]


Recon: 0.0370, VQ: 0.0483, Perplexity: 4.20
MSE: 0.0411, PSNR: 19.88, SSIM: 0.7330

Epoch 72/100


Training: 100%|██████████| 39/39 [00:06<00:00,  5.69it/s, recon_loss=0.0363, vq_loss=0.0434, perplexity=4.09]


Recon: 0.0372, VQ: 0.0494, Perplexity: 4.20
MSE: 0.0371, PSNR: 20.32, SSIM: 0.7433

Epoch 73/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.76it/s, recon_loss=0.0406, vq_loss=0.0525, perplexity=4.23]


Recon: 0.0363, VQ: 0.0480, Perplexity: 4.21
MSE: 0.0381, PSNR: 20.21, SSIM: 0.7461

Epoch 74/100


Training: 100%|██████████| 39/39 [00:07<00:00,  5.57it/s, recon_loss=0.0344, vq_loss=0.0545, perplexity=4.27]


Recon: 0.0358, VQ: 0.0477, Perplexity: 4.21
MSE: 0.0384, PSNR: 20.18, SSIM: 0.7456

Epoch 75/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.96it/s, recon_loss=0.0431, vq_loss=0.0513, perplexity=4.28]


Recon: 0.0362, VQ: 0.0486, Perplexity: 4.22
MSE: 0.0368, PSNR: 20.37, SSIM: 0.7423
Checkpoint saved: ./checkpoints/vqvae_epoch_75.pt

Epoch 76/100


Training: 100%|██████████| 39/39 [00:06<00:00,  5.60it/s, recon_loss=0.0355, vq_loss=0.0519, perplexity=4.14]


Recon: 0.0357, VQ: 0.0487, Perplexity: 4.21
MSE: 0.0361, PSNR: 20.45, SSIM: 0.7480

Epoch 77/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.82it/s, recon_loss=0.0328, vq_loss=0.0475, perplexity=3.96]


Recon: 0.0362, VQ: 0.0491, Perplexity: 4.21
MSE: 0.0375, PSNR: 20.28, SSIM: 0.7469

Epoch 78/100


Training: 100%|██████████| 39/39 [00:07<00:00,  5.38it/s, recon_loss=0.0337, vq_loss=0.0515, perplexity=4.12]


Recon: 0.0357, VQ: 0.0482, Perplexity: 4.22
MSE: 0.0361, PSNR: 20.44, SSIM: 0.7490

Epoch 79/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.95it/s, recon_loss=0.0387, vq_loss=0.052, perplexity=4.15]


Recon: 0.0355, VQ: 0.0471, Perplexity: 4.22
MSE: 0.0359, PSNR: 20.47, SSIM: 0.7462

Epoch 80/100


Training: 100%|██████████| 39/39 [00:07<00:00,  5.50it/s, recon_loss=0.0316, vq_loss=0.0398, perplexity=4.07]


Recon: 0.0351, VQ: 0.0491, Perplexity: 4.22
MSE: 0.0355, PSNR: 20.52, SSIM: 0.7472
Checkpoint saved: ./checkpoints/vqvae_epoch_80.pt

Epoch 81/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.84it/s, recon_loss=0.034, vq_loss=0.0507, perplexity=4.1]


Recon: 0.0353, VQ: 0.0465, Perplexity: 4.23
MSE: 0.0354, PSNR: 20.53, SSIM: 0.7490

Epoch 82/100


Training: 100%|██████████| 39/39 [00:07<00:00,  5.36it/s, recon_loss=0.0345, vq_loss=0.0461, perplexity=4.19]


Recon: 0.0357, VQ: 0.0470, Perplexity: 4.22
MSE: 0.0392, PSNR: 20.09, SSIM: 0.7415

Epoch 83/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.81it/s, recon_loss=0.0397, vq_loss=0.0448, perplexity=4.26]


Recon: 0.0351, VQ: 0.0466, Perplexity: 4.23
MSE: 0.0361, PSNR: 20.44, SSIM: 0.7438

Epoch 84/100


Training: 100%|██████████| 39/39 [00:07<00:00,  5.35it/s, recon_loss=0.0357, vq_loss=0.0447, perplexity=4.22]


Recon: 0.0346, VQ: 0.0477, Perplexity: 4.23
MSE: 0.0348, PSNR: 20.61, SSIM: 0.7498

Epoch 85/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.82it/s, recon_loss=0.0341, vq_loss=0.0446, perplexity=4.2]


Recon: 0.0343, VQ: 0.0457, Perplexity: 4.23
MSE: 0.0354, PSNR: 20.53, SSIM: 0.7523
Checkpoint saved: ./checkpoints/vqvae_epoch_85.pt

Epoch 86/100


Training: 100%|██████████| 39/39 [00:07<00:00,  5.33it/s, recon_loss=0.0461, vq_loss=0.0435, perplexity=4.34]


Recon: 0.0346, VQ: 0.0459, Perplexity: 4.24
MSE: 0.0375, PSNR: 20.28, SSIM: 0.7486

Epoch 87/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.88it/s, recon_loss=0.0355, vq_loss=0.0436, perplexity=4.43]


Recon: 0.0344, VQ: 0.0469, Perplexity: 4.22
MSE: 0.0353, PSNR: 20.55, SSIM: 0.7538

Epoch 88/100


Training: 100%|██████████| 39/39 [00:07<00:00,  5.41it/s, recon_loss=0.0358, vq_loss=0.0521, perplexity=4.24]


Recon: 0.0337, VQ: 0.0455, Perplexity: 4.25
MSE: 0.0342, PSNR: 20.68, SSIM: 0.7514

Epoch 89/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.87it/s, recon_loss=0.0322, vq_loss=0.0409, perplexity=4.22]


Recon: 0.0336, VQ: 0.0456, Perplexity: 4.24
MSE: 0.0347, PSNR: 20.62, SSIM: 0.7559

Epoch 90/100


Training: 100%|██████████| 39/39 [00:07<00:00,  5.39it/s, recon_loss=0.0435, vq_loss=0.0577, perplexity=4.12]


Recon: 0.0346, VQ: 0.0485, Perplexity: 4.24
MSE: 0.0358, PSNR: 20.48, SSIM: 0.7527
Checkpoint saved: ./checkpoints/vqvae_epoch_90.pt

Epoch 91/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.89it/s, recon_loss=0.0354, vq_loss=0.0532, perplexity=4.2]


Recon: 0.0352, VQ: 0.0471, Perplexity: 4.23
MSE: 0.0349, PSNR: 20.59, SSIM: 0.7496

Epoch 92/100


Training: 100%|██████████| 39/39 [00:07<00:00,  5.48it/s, recon_loss=0.0377, vq_loss=0.0578, perplexity=4.43]


Recon: 0.0343, VQ: 0.0464, Perplexity: 4.24
MSE: 0.0347, PSNR: 20.62, SSIM: 0.7523

Epoch 93/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.75it/s, recon_loss=0.0323, vq_loss=0.0466, perplexity=4.1]


Recon: 0.0338, VQ: 0.0468, Perplexity: 4.24
MSE: 0.0339, PSNR: 20.72, SSIM: 0.7553

Epoch 94/100


Training: 100%|██████████| 39/39 [00:07<00:00,  5.41it/s, recon_loss=0.0387, vq_loss=0.0429, perplexity=4.35]


Recon: 0.0345, VQ: 0.0452, Perplexity: 4.24
MSE: 0.0345, PSNR: 20.64, SSIM: 0.7533

Epoch 95/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.87it/s, recon_loss=0.0374, vq_loss=0.052, perplexity=4.37]


Recon: 0.0338, VQ: 0.0454, Perplexity: 4.25
MSE: 0.0353, PSNR: 20.55, SSIM: 0.7548
Checkpoint saved: ./checkpoints/vqvae_epoch_95.pt

Epoch 96/100


Training: 100%|██████████| 39/39 [00:06<00:00,  5.65it/s, recon_loss=0.0271, vq_loss=0.0412, perplexity=4.12]


Recon: 0.0328, VQ: 0.0446, Perplexity: 4.24
MSE: 0.0333, PSNR: 20.79, SSIM: 0.7593

Epoch 97/100


Training: 100%|██████████| 39/39 [00:08<00:00,  4.60it/s, recon_loss=0.0296, vq_loss=0.0542, perplexity=4.18]


Recon: 0.0325, VQ: 0.0447, Perplexity: 4.25
MSE: 0.0345, PSNR: 20.64, SSIM: 0.7575

Epoch 98/100


Training: 100%|██████████| 39/39 [00:07<00:00,  5.16it/s, recon_loss=0.0327, vq_loss=0.0478, perplexity=4.17]


Recon: 0.0329, VQ: 0.0455, Perplexity: 4.24
MSE: 0.0347, PSNR: 20.62, SSIM: 0.7525

Epoch 99/100


Training: 100%|██████████| 39/39 [00:05<00:00,  6.85it/s, recon_loss=0.0341, vq_loss=0.0543, perplexity=4.16]


Recon: 0.0338, VQ: 0.0484, Perplexity: 4.24
MSE: 0.0351, PSNR: 20.57, SSIM: 0.7583

Epoch 100/100


Training: 100%|██████████| 39/39 [00:07<00:00,  4.90it/s, recon_loss=0.0371, vq_loss=0.0493, perplexity=4.27]


Recon: 0.0334, VQ: 0.0466, Perplexity: 4.24
MSE: 0.0352, PSNR: 20.55, SSIM: 0.7532
Checkpoint saved: ./checkpoints/vqvae_epoch_100.pt
Checkpoint saved: ./checkpoints/vqvae_final.pt


Training Prior: 100%|██████████| 39/39 [00:07<00:00,  5.13it/s, loss=5.93]


Epoch 1, Loss: 6.0753


Training Prior: 100%|██████████| 39/39 [00:05<00:00,  6.53it/s, loss=5.65]


Epoch 2, Loss: 5.7831


Training Prior: 100%|██████████| 39/39 [00:07<00:00,  5.44it/s, loss=5.39]


Epoch 3, Loss: 5.5132


Training Prior: 100%|██████████| 39/39 [00:06<00:00,  6.32it/s, loss=5.17]


Epoch 4, Loss: 5.2754


Training Prior: 100%|██████████| 39/39 [00:06<00:00,  6.03it/s, loss=4.98]


Epoch 5, Loss: 5.0732


Training Prior: 100%|██████████| 39/39 [00:07<00:00,  5.47it/s, loss=4.82]


Epoch 6, Loss: 4.8972


Training Prior: 100%|██████████| 39/39 [00:06<00:00,  6.48it/s, loss=4.67]


Epoch 7, Loss: 4.7437


Training Prior: 100%|██████████| 39/39 [00:07<00:00,  5.08it/s, loss=4.52]


Epoch 8, Loss: 4.5901


Training Prior: 100%|██████████| 39/39 [00:05<00:00,  6.57it/s, loss=4.36]


Epoch 9, Loss: 4.4380


Training Prior: 100%|██████████| 39/39 [00:07<00:00,  5.06it/s, loss=4.21]


Epoch 10, Loss: 4.2902


Training Prior: 100%|██████████| 39/39 [00:05<00:00,  6.51it/s, loss=4.05]


Epoch 11, Loss: 4.1358


Training Prior: 100%|██████████| 39/39 [00:08<00:00,  4.86it/s, loss=3.89]


Epoch 12, Loss: 3.9716


Training Prior: 100%|██████████| 39/39 [00:05<00:00,  6.72it/s, loss=3.72]


Epoch 13, Loss: 3.8051


Training Prior: 100%|██████████| 39/39 [00:07<00:00,  5.06it/s, loss=3.57]


Epoch 14, Loss: 3.6432


Training Prior: 100%|██████████| 39/39 [00:06<00:00,  5.80it/s, loss=3.42]


Epoch 15, Loss: 3.4830


Training Prior: 100%|██████████| 39/39 [00:07<00:00,  5.02it/s, loss=3.28]


Epoch 16, Loss: 3.3384


Training Prior: 100%|██████████| 39/39 [00:05<00:00,  6.61it/s, loss=3.17]


Epoch 17, Loss: 3.2189


Training Prior: 100%|██████████| 39/39 [00:07<00:00,  5.21it/s, loss=3.05]


Epoch 18, Loss: 3.1004


Training Prior: 100%|██████████| 39/39 [00:06<00:00,  6.41it/s, loss=2.93]


Epoch 19, Loss: 2.9852


Training Prior: 100%|██████████| 39/39 [00:07<00:00,  5.43it/s, loss=2.8]


Epoch 20, Loss: 2.8647


Training Prior: 100%|██████████| 39/39 [00:06<00:00,  6.08it/s, loss=2.71]


Epoch 21, Loss: 2.7450


Training Prior: 100%|██████████| 39/39 [00:06<00:00,  6.05it/s, loss=2.58]


Epoch 22, Loss: 2.6449


Training Prior: 100%|██████████| 39/39 [00:07<00:00,  5.56it/s, loss=2.51]


Epoch 23, Loss: 2.5343


Training Prior: 100%|██████████| 39/39 [00:05<00:00,  6.63it/s, loss=2.41]


Epoch 24, Loss: 2.4511


Training Prior: 100%|██████████| 39/39 [00:07<00:00,  5.12it/s, loss=2.28]


Epoch 25, Loss: 2.3484


Training Prior: 100%|██████████| 39/39 [00:05<00:00,  6.62it/s, loss=2.17]


Epoch 26, Loss: 2.2456


Training Prior: 100%|██████████| 39/39 [00:08<00:00,  4.75it/s, loss=2.11]


Epoch 27, Loss: 2.1448


Training Prior: 100%|██████████| 39/39 [00:05<00:00,  6.61it/s, loss=1.95]


Epoch 28, Loss: 2.0198


Training Prior: 100%|██████████| 39/39 [00:08<00:00,  4.74it/s, loss=1.9]


Epoch 29, Loss: 1.9324


Training Prior: 100%|██████████| 39/39 [00:05<00:00,  6.74it/s, loss=1.79]


Epoch 30, Loss: 1.8658


Training Prior: 100%|██████████| 39/39 [00:07<00:00,  5.20it/s, loss=1.77]


Epoch 31, Loss: 1.7847


Training Prior: 100%|██████████| 39/39 [00:05<00:00,  6.58it/s, loss=1.68]


Epoch 32, Loss: 1.7347


Training Prior: 100%|██████████| 39/39 [00:06<00:00,  5.72it/s, loss=1.72]


Epoch 33, Loss: 1.6953


Training Prior: 100%|██████████| 39/39 [00:06<00:00,  5.95it/s, loss=1.63]


Epoch 34, Loss: 1.6538


Training Prior: 100%|██████████| 39/39 [00:06<00:00,  6.46it/s, loss=1.54]


Epoch 35, Loss: 1.5859


Training Prior: 100%|██████████| 39/39 [00:07<00:00,  5.30it/s, loss=1.56]


Epoch 36, Loss: 1.5279


Training Prior: 100%|██████████| 39/39 [00:05<00:00,  6.70it/s, loss=1.49]


Epoch 37, Loss: 1.4920


Training Prior: 100%|██████████| 39/39 [00:07<00:00,  5.14it/s, loss=1.43]


Epoch 38, Loss: 1.4571


Training Prior: 100%|██████████| 39/39 [00:05<00:00,  6.74it/s, loss=1.42]


Epoch 39, Loss: 1.4175


Training Prior: 100%|██████████| 39/39 [00:08<00:00,  4.63it/s, loss=1.42]


Epoch 40, Loss: 1.3764


Training Prior: 100%|██████████| 39/39 [00:05<00:00,  6.68it/s, loss=1.34]


Epoch 41, Loss: 1.3493


Training Prior: 100%|██████████| 39/39 [00:07<00:00,  5.17it/s, loss=1.33]


Epoch 42, Loss: 1.3206


Training Prior: 100%|██████████| 39/39 [00:05<00:00,  6.56it/s, loss=1.37]


Epoch 43, Loss: 1.2998


Training Prior: 100%|██████████| 39/39 [00:07<00:00,  5.21it/s, loss=1.28]


Epoch 44, Loss: 1.2829


Training Prior: 100%|██████████| 39/39 [00:05<00:00,  6.53it/s, loss=1.2]


Epoch 45, Loss: 1.2574


Training Prior: 100%|██████████| 39/39 [00:06<00:00,  5.64it/s, loss=1.22]


Epoch 46, Loss: 1.2379


Training Prior: 100%|██████████| 39/39 [00:06<00:00,  6.05it/s, loss=1.14]


Epoch 47, Loss: 1.2023


Training Prior: 100%|██████████| 39/39 [00:05<00:00,  6.59it/s, loss=1.13]


Epoch 48, Loss: 1.1645


Training Prior: 100%|██████████| 39/39 [00:07<00:00,  5.29it/s, loss=1.1]


Epoch 49, Loss: 1.1374


Training Prior: 100%|██████████| 39/39 [00:05<00:00,  6.70it/s, loss=1.16]


Epoch 50, Loss: 1.1166
Generating new samples...
Performing interpolation...
Analyzing codebook...
Performing quantitative evaluation...

FINAL RESULTS
FID Score: 266.15

Generation completed successfully!
Results saved to: ./results
Checkpoints saved to: ./checkpoints


In [None]:
# Delete old incompatible checkpoints and start fresh
import shutil
if os.path.exists(config.checkpoint_dir):
    print("Removing old incompatible checkpoints...")
    shutil.rmtree(config.checkpoint_dir)
    os.makedirs(config.checkpoint_dir, exist_ok=True)

Removing old incompatible checkpoints...


In [None]:
ls

clustering.png         interpolation.png           reconstructions.png
codebook_usage.png     latent_clusters.png         training_history.png
final_report.json      latent_interpolation.png    tsne_codes.png
generated_emojis.png   latent_space_tsne.png
generated_samples.png  prior_training_history.png


In [None]:
# OYE IMPORTNAT :-
# use this to copy your data to your drive.
# All compute units of Colab operate in a volatile memory that will be erased once the session runtime DISCONNECTS

from google.colab import drive
drive.mount('/content/drive')

# Update all your paths to use Google Drive
config.data_dir = '/content/drive/MyDrive/emoji_project/emoji_data'
config.checkpoint_dir = '/content/drive/MyDrive/emoji_project/checkpoints'
config.results_dir = '/content/drive/MyDrive/emoji_project/results'

# Create these directories in Drive
import os
os.makedirs(config.data_dir, exist_ok=True)
os.makedirs(config.checkpoint_dir, exist_ok=True)
os.makedirs(config.results_dir, exist_ok=True)

Mounted at /content/drive


In [None]:
from google.colab import drive
import os
import shutil

# Mount Google Drive
drive.mount('/content/drive')

# Create directories in Drive (assuming these are your config paths)
# Replace these with your actual config paths if different
data_dir = '/content/drive/MyDrive/data'  # or your config.data_dir
checkpoint_dir = '/content/drive/MyDrive/checkpoints'  # or your config.checkpoint_dir
results_dir = '/content/drive/MyDrive/results'  # or your config.results_dir

# Create directories
os.makedirs(data_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(results_dir, exist_ok=True)

# Source directory
source_dir = '/content/results/'

# Copy all files and subdirectories from results to Drive results directory
if os.path.exists(source_dir):
    for item in os.listdir(source_dir):
        source_path = os.path.join(source_dir, item)
        destination_path = os.path.join(results_dir, item)

        if os.path.isfile(source_path):
            shutil.copy2(source_path, destination_path)
            print(f"Copied file: {item}")
        elif os.path.isdir(source_path):
            shutil.copytree(source_path, destination_path, dirs_exist_ok=True)
            print(f"Copied directory: {item}")

    print("Transfer completed successfully!")
else:
    print(f"Source directory {source_dir} does not exist!")

# Optional: List the transferred files to verify
print("\nTransferred files:")
for item in os.listdir(results_dir):
    item_path = os.path.join(results_dir, item)
    if os.path.isfile(item_path):
        print(f"File: {item}")
    else:
        print(f"Directory: {item}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Copied file: training_history.png
Copied file: reconstructions.png
Copied file: latent_interpolation.png
Copied file: generated_samples.png
Copied file: tsne_codes.png
Copied file: latent_clusters.png
Copied file: clustering.png
Copied file: latent_space_tsne.png
Copied file: generated_emojis.png
Copied file: final_report.json
Copied file: codebook_usage.png
Copied file: prior_training_history.png
Copied file: interpolation.png
Transfer completed successfully!

Transferred files:
File: training_history.png
File: reconstructions.png
File: latent_interpolation.png
File: generated_samples.png
File: tsne_codes.png
File: latent_clusters.png
File: clustering.png
File: latent_space_tsne.png
File: generated_emojis.png
File: final_report.json
File: codebook_usage.png
File: prior_training_history.png
File: interpolation.png
