# Pokemon Sprite Generator - Three-Stage Training Pipeline

## Overview
This notebook implements a complete three-stage training pipeline for generating Pokemon sprites from text descriptions using a Stable Diffusion-like architecture.

### Architecture Components:
- **Stage 1**: VAE (Variational Autoencoder) with perceptual loss
- **Stage 2**: U-Net diffusion model for denoising
- **Stage 3**: Text encoder fine-tuning with frozen VAE and U-Net

### Key Features:
- 🎨 Text-to-image generation for Pokemon sprites
- 📊 TensorBoard logging for monitoring training progress
- 🔄 Modular three-stage training approach
- 🚀 Optimized for Kaggle environment with GPU acceleration

**Note**: This notebook is designed to work with the complete repository uploaded to Kaggle. Make sure to upload the entire `pokemon-sprite-generator` folder to your Kaggle dataset.

# 1. Install Dependencies and Setup Environment

First, let's install the required packages and set up the Python environment for the training pipeline.

In [None]:
# Install required packages
import subprocess
import sys

def install_package(package):
    """Install a package using pip"""
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])

# Install required packages
packages = [
    'torch',
    'torchvision',
    'transformers',
    'accelerate',
    'tensorboard',
    'pillow',
    'matplotlib',
    'seaborn',
    'tqdm',
    'pyyaml',
    'numpy',
    'pandas'
]

print("Installing required packages...")
for package in packages:
    try:
        install_package(package)
        print(f"✅ {package} installed successfully")
    except Exception as e:
        print(f"❌ Failed to install {package}: {e}")

print("\n🎉 All packages installed!")

# Check GPU availability
import torch
print(f"\nGPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    print("Running on CPU")

# 2. Mount Kaggle Dataset and Configure Paths

Set up the workspace by mounting the Pokemon dataset and configuring paths for the training pipeline.

In [None]:
import os
import shutil
from pathlib import Path

# Configure paths for Kaggle environment
KAGGLE_INPUT_PATH = "/kaggle/input"
KAGGLE_WORKING_PATH = "/kaggle/working"

# Repository path (adjust this based on your dataset name)
REPO_NAME = "pokemon-sprite-generator"  # Change this to your actual dataset name
REPO_PATH = Path(KAGGLE_INPUT_PATH) / REPO_NAME

# Check if repository exists
if REPO_PATH.exists():
    print(f"✅ Repository found at: {REPO_PATH}")
    
    # Copy repository to working directory for modifications
    WORKING_REPO_PATH = Path(KAGGLE_WORKING_PATH) / REPO_NAME
    if WORKING_REPO_PATH.exists():
        shutil.rmtree(WORKING_REPO_PATH)
    
    shutil.copytree(REPO_PATH, WORKING_REPO_PATH)
    print(f"📂 Repository copied to: {WORKING_REPO_PATH}")
    
    # Change to working directory
    os.chdir(WORKING_REPO_PATH)
    print(f"📁 Current working directory: {os.getcwd()}")
    
    # Add src to Python path
    import sys
    sys.path.insert(0, str(WORKING_REPO_PATH / "src"))
    
    # List repository contents
    print("\n📋 Repository contents:")
    for item in sorted(WORKING_REPO_PATH.iterdir()):
        if item.is_dir():
            print(f"  📁 {item.name}/")
        else:
            print(f"  📄 {item.name}")
            
else:
    print(f"❌ Repository not found at: {REPO_PATH}")
    print("Please make sure you've uploaded the entire pokemon-sprite-generator repository as a Kaggle dataset.")
    print("\n💡 To upload the repository:")
    print("1. Zip the entire pokemon-sprite-generator folder")
    print("2. Upload it as a dataset on Kaggle")
    print("3. Update the REPO_NAME variable above with your dataset name")

# Configure experiment directory
EXPERIMENT_DIR = Path(KAGGLE_WORKING_PATH) / "experiments"
EXPERIMENT_DIR.mkdir(exist_ok=True)

print(f"\n🧪 Experiment directory: {EXPERIMENT_DIR}")

# Update config paths for Kaggle environment
CONFIG_PATH = WORKING_REPO_PATH / "config" / "train_config.yaml"
print(f"⚙️  Config file: {CONFIG_PATH}")

# Check if required files exist
required_files = [
    "train_3stage.py",
    "config/train_config.yaml",
    "data/pokemon.csv",
    "data/small_images"
]

print("\n🔍 Checking required files:")
for file_path in required_files:
    full_path = WORKING_REPO_PATH / file_path
    if full_path.exists():
        print(f"  ✅ {file_path}")
    else:
        print(f"  ❌ {file_path} - NOT FOUND")

In [None]:
# Update configuration for Kaggle environment
import yaml

# Read current config
with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

# Update paths for Kaggle environment
config['experiment_dir'] = str(EXPERIMENT_DIR)
config['data']['csv_path'] = str(WORKING_REPO_PATH / "data" / "pokemon.csv")
config['data']['image_dir'] = str(WORKING_REPO_PATH / "data" / "small_images")

# Reduce batch size and epochs for Kaggle environment (optional)
config['data']['batch_size'] = 8  # Reduce for memory efficiency
config['data']['num_workers'] = 2  # Reduce for Kaggle

# Adjust epochs for faster training on Kaggle
config['training']['vae_epochs'] = 5
config['training']['diffusion_epochs'] = 5
config['training']['final_epochs'] = 3

# Save updated config
with open(CONFIG_PATH, 'w') as f:
    yaml.dump(config, f, default_flow_style=False)

print("✅ Configuration updated for Kaggle environment")
print(f"📊 Batch size: {config['data']['batch_size']}")
print(f"🔄 VAE epochs: {config['training']['vae_epochs']}")
print(f"🔄 Diffusion epochs: {config['training']['diffusion_epochs']}")
print(f"🔄 Final epochs: {config['training']['final_epochs']}")
print(f"📁 Experiment directory: {config['experiment_dir']}")
print(f"📁 Data directory: {config['data']['image_dir']}")

# Display updated config
print("\n⚙️  Updated Configuration:")
print(yaml.dump(config, default_flow_style=False))

# 3. Load and Inspect Training Configuration

Load the training configuration and display key parameters for each training stage.

In [None]:
# Load training configuration
from src.training.vae_trainer import load_config
import pandas as pd

# Load configuration
config = load_config(str(CONFIG_PATH))

# Display key configuration parameters
print("🔧 Model Configuration:")
print(f"  • BERT model: {config['model']['bert_model']}")
print(f"  • Text embedding dimension: {config['model']['text_embedding_dim']}")
print(f"  • Latent dimension: {config['model']['latent_dim']}")
print(f"  • Number of timesteps: {config['model']['num_timesteps']}")

print("\n📊 Data Configuration:")
print(f"  • Batch size: {config['data']['batch_size']}")
print(f"  • Image size: {config['data']['image_size']}")
print(f"  • Number of workers: {config['data']['num_workers']}")

print("\n🎯 Training Configuration:")
print(f"  • VAE epochs: {config['training']['vae_epochs']}")
print(f"  • Diffusion epochs: {config['training']['diffusion_epochs']}")
print(f"  • Final epochs: {config['training']['final_epochs']}")

print("\n⚖️  Loss Configuration:")
print(f"  • Reconstruction weight: {config['training']['reconstruction_weight']}")
print(f"  • Perceptual weight: {config['training']['perceptual_weight']}")
print(f"  • KL weight: {config['training']['kl_weight']}")

print("\n🔄 KL Annealing Configuration:")
print(f"  • KL anneal start: {config['training']['kl_anneal_start']}")
print(f"  • KL anneal end: {config['training']['kl_anneal_end']}")
print(f"  • KL weight start: {config['training']['kl_weight_start']}")
print(f"  • KL weight end: {config['training']['kl_weight_end']}")

print("\n🎲 Optimization Configuration:")
print(f"  • Learning rate: {config['optimization']['learning_rate']}")
print(f"  • Beta1: {config['optimization']['beta1']}")
print(f"  • Beta2: {config['optimization']['beta2']}")
print(f"  • Weight decay: {config['optimization']['weight_decay']}")

# Calculate estimated training time
total_epochs = config['training']['vae_epochs'] + config['training']['diffusion_epochs'] + config['training']['final_epochs']
print(f"\n⏱️  Estimated Total Training:")
print(f"  • Total epochs: {total_epochs}")
print(f"  • Estimated time: ~{total_epochs * 10} minutes (approximate)")

# Display paths
print(f"\n📁 File Paths:")
print(f"  • Config: {CONFIG_PATH}")
print(f"  • Data CSV: {config['data']['csv_path']}")
print(f"  • Images: {config['data']['image_dir']}")
print(f"  • Experiments: {config['experiment_dir']}")

# 4. Dataset Statistics and Visualization

Analyze dataset statistics, visualize Pokemon types distribution, and show sample images with descriptions.

In [None]:
# Analyze dataset statistics
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import numpy as np
from pathlib import Path
from src.data import get_dataset_statistics

# Load dataset statistics
print("📊 Computing dataset statistics...")
stats = get_dataset_statistics(
    config['data']['csv_path'],
    config['data']['image_dir']
)

print("\n📈 Dataset Statistics:")
print(f"  • Total samples: {stats['total_samples']}")
print(f"  • Average description length: {stats['avg_description_length']:.1f} words")
print(f"  • Description length std: {stats['description_length_std']:.1f}")

# Load full dataset for visualization
# Handle different encodings like the dataset classes
try:
    df = pd.read_csv(config['data']['csv_path'], sep='\t', encoding='utf-16')
except UnicodeDecodeError:
    try:
        df = pd.read_csv(config['data']['csv_path'], sep='\t', encoding='utf-8')
    except UnicodeDecodeError:
        df = pd.read_csv(config['data']['csv_path'], sep='\t', encoding='latin-1')

print(f"\n📝 Dataset columns: {list(df.columns)}")
print(f"📋 First few rows:")
print(df.head())

# Visualize Pokemon types distribution
if 'primary_type' in df.columns:
    plt.figure(figsize=(12, 6))
    type_counts = df['primary_type'].value_counts().head(15)
    
    plt.subplot(1, 2, 1)
    sns.barplot(x=type_counts.values, y=type_counts.index)
    plt.title('Top 15 Pokemon Types Distribution')
    plt.xlabel('Count')
    
    plt.subplot(1, 2, 2)
    plt.pie(type_counts.values, labels=type_counts.index, autopct='%1.1f%%')
    plt.title('Pokemon Types Distribution')
    
    plt.tight_layout()
    plt.show()

# Show sample images with descriptions
print("\n🖼️  Sample Pokemon Images:")
sample_indices = np.random.choice(len(df), min(6, len(df)), replace=False)

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i, idx in enumerate(sample_indices):
    try:
        # Load image
        image_path = Path(config['data']['image_dir']) / f"{df.iloc[idx]['national_number']:03d}.png"
        if image_path.exists():
            img = Image.open(image_path)
            axes[i].imshow(img)
            axes[i].set_title(f"#{df.iloc[idx]['national_number']} - {df.iloc[idx]['english_name']}")
            axes[i].axis('off')
            
            # Add description if available
            if 'description' in df.columns and pd.notna(df.iloc[idx]['description']):
                description = df.iloc[idx]['description'][:100] + "..." if len(df.iloc[idx]['description']) > 100 else df.iloc[idx]['description']
                axes[i].text(0.5, -0.1, description, ha='center', va='top', 
                            transform=axes[i].transAxes, fontsize=8, wrap=True)
        else:
            axes[i].text(0.5, 0.5, f"Image not found\n{image_path}", ha='center', va='center')
            axes[i].set_title(f"#{df.iloc[idx]['national_number']} - {df.iloc[idx]['english_name']}")
    except Exception as e:
        axes[i].text(0.5, 0.5, f"Error loading image: {e}", ha='center', va='center')
        axes[i].set_title(f"Error - Index {idx}")
    
    axes[i].axis('off')

plt.tight_layout()
plt.show()

# Display description length distribution
if 'description' in df.columns:
    desc_lengths = df['description'].str.split().str.len()
    plt.figure(figsize=(10, 6))
    plt.hist(desc_lengths, bins=30, alpha=0.7, edgecolor='black')
    plt.title('Distribution of Description Lengths')
    plt.xlabel('Number of Words')
    plt.ylabel('Frequency')
    plt.axvline(desc_lengths.mean(), color='red', linestyle='--', label=f'Mean: {desc_lengths.mean():.1f}')
    plt.legend()
    plt.show()

print(f"\n✅ Dataset analysis complete!")
print(f"Ready to start training with {len(df)} Pokemon samples.")

# 5. Stage 1: VAE Training with Perceptual Loss

Execute the first stage of training - VAE (Variational Autoencoder) with perceptual loss for learning to encode/decode Pokemon images.

In [None]:
# Stage 1: VAE Training
import subprocess
import sys
import time

print("🚀 Starting Stage 1: VAE Training with Perceptual Loss")
print("=" * 60)

# Prepare training command
train_cmd = [
    sys.executable, "train_3stage.py",
    "--config", str(CONFIG_PATH),
    "--stage", "1",
    "--experiment-name", "kaggle_pokemon_3stage"
]

print(f"💻 Training command: {' '.join(train_cmd)}")
print(f"📁 Working directory: {os.getcwd()}")

# Start training
start_time = time.time()
try:
    # Run training process
    result = subprocess.run(train_cmd, capture_output=True, text=True, cwd=os.getcwd())
    
    print("📋 Training Output:")
    print(result.stdout)
    
    if result.stderr:
        print("⚠️  Training Errors/Warnings:")
        print(result.stderr)
    
    if result.returncode == 0:
        print("✅ Stage 1 training completed successfully!")
        
        # Check if checkpoint was created
        vae_checkpoint_path = Path(config['experiment_dir']) / "kaggle_pokemon_3stage_vae" / "checkpoints" / "vae_best_model.pth"
        if vae_checkpoint_path.exists():
            print(f"🎯 VAE checkpoint saved at: {vae_checkpoint_path}")
            print(f"📏 Checkpoint size: {vae_checkpoint_path.stat().st_size / 1024 / 1024:.1f} MB")
        else:
            print("❌ VAE checkpoint not found!")
            
    else:
        print(f"❌ Stage 1 training failed with return code: {result.returncode}")
        
except Exception as e:
    print(f"❌ Error during Stage 1 training: {e}")
    
finally:
    end_time = time.time()
    training_time = end_time - start_time
    print(f"⏱️  Stage 1 training time: {training_time:.1f} seconds ({training_time/60:.1f} minutes)")

# Display training logs location
logs_dir = Path(config['experiment_dir']) / "kaggle_pokemon_3stage_vae" / "logs"
if logs_dir.exists():
    print(f"\n📊 Training logs available at: {logs_dir}")
    print("You can view tensorboard logs by running:")
    print(f"tensorboard --logdir {logs_dir}")
else:
    print("⚠️  No training logs found")

# 6. Stage 2: U-Net Diffusion Training

Execute the second stage of training - U-Net diffusion model for denoising using the trained VAE from Stage 1.

In [None]:
# Stage 2: U-Net Diffusion Training
print("🚀 Starting Stage 2: U-Net Diffusion Training")
print("=" * 60)

# Check if VAE checkpoint exists
vae_checkpoint_path = Path(config['experiment_dir']) / "kaggle_pokemon_3stage_vae" / "checkpoints" / "vae_best_model.pth"

if not vae_checkpoint_path.exists():
    print("❌ VAE checkpoint not found! Please run Stage 1 first.")
    print(f"Expected path: {vae_checkpoint_path}")
else:
    print(f"✅ VAE checkpoint found: {vae_checkpoint_path}")
    
    # Prepare training command
    train_cmd = [
        sys.executable, "train_3stage.py",
        "--config", str(CONFIG_PATH),
        "--stage", "2",
        "--experiment-name", "kaggle_pokemon_3stage",
        "--vae-checkpoint", str(vae_checkpoint_path)
    ]
    
    print(f"💻 Training command: {' '.join(train_cmd)}")
    
    # Start training
    start_time = time.time()
    try:
        # Run training process
        result = subprocess.run(train_cmd, capture_output=True, text=True, cwd=os.getcwd())
        
        print("📋 Training Output:")
        print(result.stdout)
        
        if result.stderr:
            print("⚠️  Training Errors/Warnings:")
            print(result.stderr)
        
        if result.returncode == 0:
            print("✅ Stage 2 training completed successfully!")
            
            # Check if checkpoint was created
            diffusion_checkpoint_path = Path(config['experiment_dir']) / "kaggle_pokemon_3stage_diffusion" / "checkpoints" / "diffusion_best_model.pth"
            if diffusion_checkpoint_path.exists():
                print(f"🎯 Diffusion checkpoint saved at: {diffusion_checkpoint_path}")
                print(f"📏 Checkpoint size: {diffusion_checkpoint_path.stat().st_size / 1024 / 1024:.1f} MB")
            else:
                print("❌ Diffusion checkpoint not found!")
                
        else:
            print(f"❌ Stage 2 training failed with return code: {result.returncode}")
            
    except Exception as e:
        print(f"❌ Error during Stage 2 training: {e}")
        
    finally:
        end_time = time.time()
        training_time = end_time - start_time
        print(f"⏱️  Stage 2 training time: {training_time:.1f} seconds ({training_time/60:.1f} minutes)")

    # Display training logs location
    logs_dir = Path(config['experiment_dir']) / "kaggle_pokemon_3stage_diffusion" / "logs"
    if logs_dir.exists():
        print(f"\n📊 Training logs available at: {logs_dir}")
    else:
        print("⚠️  No training logs found")

# 7. Stage 3: Final Training (Text Encoder Fine-tuning)

Execute the final stage of training - fine-tune the text encoder with frozen VAE and U-Net models.

In [None]:
# Stage 3: Final Training (Text Encoder Fine-tuning)
print("🚀 Starting Stage 3: Final Training (Text Encoder Fine-tuning)")
print("=" * 60)

# Check if required checkpoints exist
vae_checkpoint_path = Path(config['experiment_dir']) / "kaggle_pokemon_3stage_vae" / "checkpoints" / "vae_best_model.pth"
diffusion_checkpoint_path = Path(config['experiment_dir']) / "kaggle_pokemon_3stage_diffusion" / "checkpoints" / "diffusion_best_model.pth"

missing_checkpoints = []
if not vae_checkpoint_path.exists():
    missing_checkpoints.append(f"VAE checkpoint: {vae_checkpoint_path}")
if not diffusion_checkpoint_path.exists():
    missing_checkpoints.append(f"Diffusion checkpoint: {diffusion_checkpoint_path}")

if missing_checkpoints:
    print("❌ Required checkpoints not found:")
    for checkpoint in missing_checkpoints:
        print(f"  • {checkpoint}")
    print("Please run previous stages first.")
else:
    print(f"✅ VAE checkpoint found: {vae_checkpoint_path}")
    print(f"✅ Diffusion checkpoint found: {diffusion_checkpoint_path}")
    
    # Prepare training command
    train_cmd = [
        sys.executable, "train_3stage.py",
        "--config", str(CONFIG_PATH),
        "--stage", "3",
        "--experiment-name", "kaggle_pokemon_3stage",
        "--vae-checkpoint", str(vae_checkpoint_path),
        "--diffusion-checkpoint", str(diffusion_checkpoint_path)
    ]
    
    print(f"💻 Training command: {' '.join(train_cmd)}")
    
    # Start training
    start_time = time.time()
    try:
        # Run training process
        result = subprocess.run(train_cmd, capture_output=True, text=True, cwd=os.getcwd())
        
        print("📋 Training Output:")
        print(result.stdout)
        
        if result.stderr:
            print("⚠️  Training Errors/Warnings:")
            print(result.stderr)
        
        if result.returncode == 0:
            print("✅ Stage 3 training completed successfully!")
            
            # Check if checkpoint was created
            final_checkpoint_path = Path(config['experiment_dir']) / "kaggle_pokemon_3stage_final" / "checkpoints" / "final_best_model.pth"
            if final_checkpoint_path.exists():
                print(f"🎯 Final checkpoint saved at: {final_checkpoint_path}")
                print(f"📏 Checkpoint size: {final_checkpoint_path.stat().st_size / 1024 / 1024:.1f} MB")
            else:
                print("❌ Final checkpoint not found!")
                
        else:
            print(f"❌ Stage 3 training failed with return code: {result.returncode}")
            
    except Exception as e:
        print(f"❌ Error during Stage 3 training: {e}")
        
    finally:
        end_time = time.time()
        training_time = end_time - start_time
        print(f"⏱️  Stage 3 training time: {training_time:.1f} seconds ({training_time/60:.1f} minutes)")

    # Display training logs location
    logs_dir = Path(config['experiment_dir']) / "kaggle_pokemon_3stage_final" / "logs"
    if logs_dir.exists():
        print(f"\n📊 Training logs available at: {logs_dir}")
    else:
        print("⚠️  No training logs found")

print("\n🎉 All training stages completed!")
print("Your Pokemon sprite generator is ready for inference!")

# 8. Monitor Training Progress and Logs

Display training logs, loss curves, and generated sample images throughout the training process.

In [None]:
# Monitor training progress and logs
import glob
import re
from datetime import datetime

experiment_base = Path(config['experiment_dir'])
stages = ['vae', 'diffusion', 'final']

print("📊 Training Progress Summary")
print("=" * 60)

for stage in stages:
    stage_dir = experiment_base / f"kaggle_pokemon_3stage_{stage}"
    
    if stage_dir.exists():
        print(f"\n🔍 {stage.upper()} Stage:")
        
        # Check for checkpoints
        checkpoint_dir = stage_dir / "checkpoints"
        if checkpoint_dir.exists():
            checkpoints = list(checkpoint_dir.glob("*.pth"))
            if checkpoints:
                print(f"  ✅ Checkpoints found: {len(checkpoints)}")
                for checkpoint in checkpoints:
                    size_mb = checkpoint.stat().st_size / 1024 / 1024
                    print(f"    • {checkpoint.name}: {size_mb:.1f} MB")
            else:
                print("  ❌ No checkpoints found")
        else:
            print("  ❌ No checkpoint directory found")
        
        # Check for logs
        log_dir = stage_dir / "logs"
        if log_dir.exists():
            log_files = list(log_dir.glob("**/*"))
            if log_files:
                print(f"  📋 Log files found: {len(log_files)}")
                print(f"    Location: {log_dir}")
            else:
                print("  ❌ No log files found")
        else:
            print("  ❌ No log directory found")
        
        # Check for sample images
        sample_dir = stage_dir / "samples"
        if sample_dir.exists():
            sample_files = list(sample_dir.glob("*.png"))
            if sample_files:
                print(f"  🖼️  Sample images found: {len(sample_files)}")
                
                # Display latest sample images
                sample_files.sort(key=lambda x: x.stat().st_mtime, reverse=True)
                latest_samples = sample_files[:3]  # Show 3 latest samples
                
                if latest_samples:
                    print(f"  📸 Latest sample images:")
                    fig, axes = plt.subplots(1, min(3, len(latest_samples)), figsize=(12, 4))
                    if len(latest_samples) == 1:
                        axes = [axes]
                    
                    for i, sample_path in enumerate(latest_samples):
                        try:
                            img = Image.open(sample_path)
                            axes[i].imshow(img)
                            axes[i].set_title(f"{sample_path.name}")
                            axes[i].axis('off')
                        except Exception as e:
                            axes[i].text(0.5, 0.5, f"Error: {e}", ha='center', va='center')
                            axes[i].set_title(f"Error loading {sample_path.name}")
                    
                    plt.tight_layout()
                    plt.show()
            else:
                print("  ❌ No sample images found")
        else:
            print("  ❌ No sample directory found")
    else:
        print(f"\n❌ {stage.upper()} Stage: Directory not found")

# Display recent log entries for debugging
print("\n📋 Recent Training Logs:")
print("=" * 60)

for stage in stages:
    stage_dir = experiment_base / f"kaggle_pokemon_3stage_{stage}"
    log_dir = stage_dir / "logs"
    
    if log_dir.exists():
        # Find the most recent log file
        log_files = list(log_dir.glob("**/*"))
        if log_files:
            print(f"\n📂 {stage.upper()} Stage Logs:")
            print(f"  Log directory: {log_dir}")
            
            # You can add more specific log parsing here if needed
            # For now, just show the directory structure
            for log_file in sorted(log_files)[:5]:  # Show first 5 log files
                if log_file.is_file():
                    size_kb = log_file.stat().st_size / 1024
                    print(f"    • {log_file.name}: {size_kb:.1f} KB")

print("\n✅ Training monitoring complete!")
print("\n💡 To view detailed training logs, you can:")
print("1. Check the tensorboard logs in each stage's logs directory")
print("2. Examine the checkpoint files for model weights")
print("3. Look at sample images generated during training")

# 9. Save and Export Model Checkpoints

Save final model checkpoints and prepare them for download or further use in inference.

In [None]:
# Save and export model checkpoints
import zipfile
import shutil
from datetime import datetime

print("💾 Preparing Model Checkpoints for Export")
print("=" * 60)

# Create export directory
export_dir = Path(KAGGLE_WORKING_PATH) / "pokemon_generator_export"
export_dir.mkdir(exist_ok=True)

# Copy essential files
essential_files = [
    "train_3stage.py",
    "generate_pokemon_3stage.py",
    "config/train_config.yaml",
    "requirements.txt",
    "README.md"
]

print("📋 Copying essential files...")
for file_path in essential_files:
    src = WORKING_REPO_PATH / file_path
    if src.exists():
        dst = export_dir / file_path
        dst.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy2(src, dst)
        print(f"  ✅ {file_path}")
    else:
        print(f"  ❌ {file_path} - not found")

# Copy src directory
src_dir = WORKING_REPO_PATH / "src"
if src_dir.exists():
    dst_src = export_dir / "src"
    if dst_src.exists():
        shutil.rmtree(dst_src)
    shutil.copytree(src_dir, dst_src)
    print("  ✅ src/")

# Copy trained model checkpoints
checkpoint_info = []
model_checkpoints = [
    ("vae", "kaggle_pokemon_3stage_vae/checkpoints/vae_best_model.pth"),
    ("diffusion", "kaggle_pokemon_3stage_diffusion/checkpoints/diffusion_best_model.pth"),
    ("final", "kaggle_pokemon_3stage_final/checkpoints/final_best_model.pth")
]

print("\n🎯 Copying trained model checkpoints...")
checkpoints_dir = export_dir / "checkpoints"
checkpoints_dir.mkdir(exist_ok=True)

for model_name, checkpoint_path in model_checkpoints:
    src_checkpoint = Path(config['experiment_dir']) / checkpoint_path
    if src_checkpoint.exists():
        dst_checkpoint = checkpoints_dir / f"{model_name}_best_model.pth"
        shutil.copy2(src_checkpoint, dst_checkpoint)
        size_mb = dst_checkpoint.stat().st_size / 1024 / 1024
        checkpoint_info.append(f"{model_name}: {size_mb:.1f} MB")
        print(f"  ✅ {model_name}_best_model.pth ({size_mb:.1f} MB)")
    else:
        print(f"  ❌ {model_name}_best_model.pth - not found")

# Create inference script
inference_script = export_dir / "inference_example.py"
inference_code = '''#!/usr/bin/env python3
"""
Example inference script for Pokemon sprite generation.
"""

import sys
from pathlib import Path
import torch
from PIL import Image

# Add src to path
sys.path.append(str(Path(__file__).parent / "src"))

from generate_pokemon_3stage import PokemonGenerator

def generate_pokemon(description, output_path="generated_pokemon.png"):
    """Generate a Pokemon sprite from text description."""
    
    # Paths to trained models
    vae_checkpoint = "checkpoints/vae_best_model.pth"
    diffusion_checkpoint = "checkpoints/diffusion_best_model.pth"
    config_path = "config/train_config.yaml"
    
    # Check if files exist
    for path in [vae_checkpoint, diffusion_checkpoint, config_path]:
        if not Path(path).exists():
            print(f"Error: {path} not found!")
            return None
    
    # Initialize generator
    generator = PokemonGenerator(
        vae_checkpoint_path=vae_checkpoint,
        diffusion_checkpoint_path=diffusion_checkpoint,
        config_path=config_path,
        device="cuda" if torch.cuda.is_available() else "cpu"
    )
    
    # Generate image
    image = generator.generate([description], num_inference_steps=50)[0]
    
    # Save image
    generator.save_image(image, output_path)
    print(f"Generated Pokemon saved to: {output_path}")
    
    return output_path

if __name__ == "__main__":
    # Example usage
    descriptions = [
        "A fire-type Pokemon with orange flames and wings",
        "A water-type Pokemon with blue scales and fins",
        "An electric-type Pokemon with yellow fur and lightning bolts"
    ]
    
    for i, desc in enumerate(descriptions):
        output_file = f"generated_pokemon_{i+1}.png"
        generate_pokemon(desc, output_file)
'''

with open(inference_script, 'w') as f:
    f.write(inference_code)

print(f"  ✅ inference_example.py created")

# Create README for export
readme_content = f'''# Pokemon Sprite Generator - Trained Models

This package contains the trained Pokemon sprite generator models and inference code.

## Training Information
- Training Date: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
- Training Platform: Kaggle
- Model Architecture: 3-Stage VAE + U-Net Diffusion

## Model Checkpoints
{chr(10).join(f"- {info}" for info in checkpoint_info)}

## Quick Start

1. Install requirements:
```bash
pip install torch torchvision transformers pillow pyyaml numpy
```

2. Run inference:
```python
python inference_example.py
```

3. Or use the generator directly:
```python
from generate_pokemon_3stage import PokemonGenerator

generator = PokemonGenerator(
    vae_checkpoint_path="checkpoints/vae_best_model.pth",
    diffusion_checkpoint_path="checkpoints/diffusion_best_model.pth",
    config_path="config/train_config.yaml"
)

image = generator.generate(["A fire-type Pokemon with orange flames"])
```

## Files Structure
- `checkpoints/`: Trained model weights
- `src/`: Source code for models and training
- `config/`: Configuration files
- `inference_example.py`: Example inference script
- `generate_pokemon_3stage.py`: Main generation script
- `train_3stage.py`: Training script

## Training Stages
1. **VAE Stage**: Variational Autoencoder with perceptual loss
2. **Diffusion Stage**: U-Net denoising model
3. **Final Stage**: Text encoder fine-tuning

Enjoy generating Pokemon sprites! 🎮✨
'''

readme_path = export_dir / "README.md"
with open(readme_path, 'w') as f:
    f.write(readme_content)

print(f"  ✅ README.md created")

# Create zip file for download
zip_path = Path(KAGGLE_WORKING_PATH) / "pokemon_generator_trained.zip"
print(f"\n📦 Creating zip file: {zip_path}")

with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for root, dirs, files in os.walk(export_dir):
        for file in files:
            file_path = Path(root) / file
            arc_path = file_path.relative_to(export_dir)
            zipf.write(file_path, arc_path)

zip_size_mb = zip_path.stat().st_size / 1024 / 1024
print(f"✅ Zip file created: {zip_size_mb:.1f} MB")

# Summary
print("\n🎉 Export Complete!")
print("=" * 60)
print(f"📁 Export directory: {export_dir}")
print(f"📦 Zip file: {zip_path}")
print(f"💾 Total size: {zip_size_mb:.1f} MB")
print("\nYou can now download the zip file containing:")
print("• Trained model checkpoints")
print("• Complete source code")
print("• Configuration files")
print("• Example inference script")
print("• Documentation")

# Display final directory structure
print(f"\n📋 Export Contents:")
for root, dirs, files in os.walk(export_dir):
    level = root.replace(str(export_dir), '').count(os.sep)
    indent = ' ' * 2 * level
    print(f"{indent}{os.path.basename(root)}/")
    subindent = ' ' * 2 * (level + 1)
    for file in files:
        file_path = Path(root) / file
        size_kb = file_path.stat().st_size / 1024
        print(f"{subindent}{file} ({size_kb:.1f} KB)")

print("\n🚀 Ready for Pokemon generation!")

# 10. Test the Trained Model

Test the complete trained model by generating Pokemon sprites from text descriptions.

In [None]:
# Test the trained model
try:
    # Import the generation script
    from generate_pokemon_3stage import PokemonGenerator
    
    print("🧪 Testing Trained Pokemon Generator")
    print("=" * 60)
    
    # Check if all required checkpoints exist
    required_checkpoints = [
        Path(config['experiment_dir']) / "kaggle_pokemon_3stage_vae" / "checkpoints" / "vae_best_model.pth",
        Path(config['experiment_dir']) / "kaggle_pokemon_3stage_diffusion" / "checkpoints" / "diffusion_best_model.pth"
    ]
    
    all_checkpoints_exist = all(checkpoint.exists() for checkpoint in required_checkpoints)
    
    if not all_checkpoints_exist:
        print("❌ Some required checkpoints are missing:")
        for checkpoint in required_checkpoints:
            status = "✅" if checkpoint.exists() else "❌"
            print(f"  {status} {checkpoint}")
        print("Please ensure all training stages completed successfully.")
    else:
        print("✅ All required checkpoints found!")
        
        # Initialize the generator
        print("\n🚀 Initializing Pokemon Generator...")
        generator = PokemonGenerator(
            vae_checkpoint_path=str(required_checkpoints[0]),
            diffusion_checkpoint_path=str(required_checkpoints[1]),
            config_path=str(CONFIG_PATH),
            device="cuda" if torch.cuda.is_available() else "cpu"
        )
        
        # Test descriptions
        test_descriptions = [
            "A fire-type Pokemon with orange flames and wings",
            "A water-type Pokemon with blue scales and fins", 
            "An electric-type Pokemon with yellow fur and lightning bolts",
            "A grass-type Pokemon with green leaves and vines",
            "A psychic-type Pokemon with purple aura and mystic powers"
        ]
        
        print(f"\n🎨 Generating {len(test_descriptions)} Pokemon sprites...")
        
        # Generate images
        fig, axes = plt.subplots(1, len(test_descriptions), figsize=(20, 4))
        if len(test_descriptions) == 1:
            axes = [axes]
        
        for i, description in enumerate(test_descriptions):
            print(f"  🎯 Generating: {description}")
            
            try:
                # Generate image
                images = generator.generate([description], num_inference_steps=25)
                
                # Display image
                if images and len(images) > 0:
                    # The generator returns PIL Images, so we can display them directly
                    image = images[0]
                    
                    # Handle both PIL Images and tensors
                    if hasattr(image, 'size'):  # PIL Image
                        axes[i].imshow(image)
                    else:  # Tensor
                        if image.dim() == 3:
                            image_array = image.permute(1, 2, 0).cpu().numpy()
                            image_array = np.clip(image_array, 0, 1)
                        else:
                            image_array = image.cpu().numpy()
                        axes[i].imshow(image_array)
                    
                    axes[i].set_title(f"Pokemon #{i+1}", fontsize=10)
                    axes[i].axis('off')
                    
                    # Add description below image
                    wrapped_desc = description[:30] + "\n" + description[30:] if len(description) > 30 else description
                    axes[i].text(0.5, -0.15, wrapped_desc, ha='center', va='top', 
                                transform=axes[i].transAxes, fontsize=8, wrap=True)
                    
                    print(f"    ✅ Generated successfully")
                    
                else:
                    axes[i].text(0.5, 0.5, "Generation\nFailed", ha='center', va='center')
                    axes[i].set_title(f"Pokemon #{i+1} - Error")
                    axes[i].axis('off')
                    print(f"    ❌ Generation failed")
                    
            except Exception as e:
                axes[i].text(0.5, 0.5, f"Error:\n{str(e)[:50]}...", ha='center', va='center')
                axes[i].set_title(f"Pokemon #{i+1} - Error")
                axes[i].axis('off')
                print(f"    ❌ Error: {e}")
        
        plt.tight_layout()
        plt.show()
        
        print("\n🎉 Generation test completed!")
        print("The trained model is ready for use!")
        
        # Save one example image
        try:
            example_images = generator.generate(["A cute electric-type Pokemon with yellow fur"], num_inference_steps=50)
            if example_images and len(example_images) > 0:
                # Save PIL Image directly
                example_image = example_images[0]
                if hasattr(example_image, 'save'):  # PIL Image
                    example_image.save("final_test_pokemon.png")
                    print(f"💾 Example image saved as: final_test_pokemon.png")
                else:
                    print("❌ Generated image is not in a saveable format")
        except Exception as e:
            print(f"❌ Error saving example image: {e}")
            
except ImportError as e:
    print(f"❌ Import error: {e}")
    print("This might be due to missing dependencies or incomplete training.")
    print("Please ensure all training stages completed successfully.")
except Exception as e:
    print(f"❌ Error during testing: {e}")
    import traceback
    traceback.print_exc()

print("\n🏁 Training and Testing Complete!")
print("=" * 60)
print("🎮 Your Pokemon Sprite Generator is ready to use!")
print("📦 Download the zip file created in the previous section")
print("🚀 Use the inference_example.py script for generating new Pokemon")
print("✨ Have fun creating your own Pokemon sprites!")