# PokéGAN Training on Google Colab

This notebook trains a DCGAN to generate Pokémon images using Google Colab's GPU.


## Step 1: Install Dependencies


In [None]:
%pip install torch torchvision torchmetrics[image] pyyaml matplotlib scipy tensorboard datasets torch-fidelity kaggle


## Step 2: Clone Repository

This step automatically clones the project repository from GitHub.


In [None]:
import os

repo_name = 'CSC487-Project'
repo_url = 'https://github.com/BraedenAlonge/CSC487-Project.git'

# Clone or pull repository
if not os.path.exists(repo_name):
    print("Cloning repository...")
    !git clone {repo_url}
else:
    print("Repository already exists. Updating...")
    %cd {repo_name}
    !git pull
    %cd ..

# Move into project directory
if repo_name in os.listdir('.'):
    %cd {repo_name}
    
print(f"Current directory: {os.getcwd()}")


## Step 3: Verify GPU


In [None]:
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


## Step 4: Mount Google Drive (Optional - for saving checkpoints)


In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
# Step 5: Download Dataset (Kaggle Method)
import os
import shutil
import random
from google.colab import files

print("--- Upload Kaggle JSON ---")
print("Please upload your kaggle.json file (from Kaggle Account -> API):")
uploaded = files.upload()

# Setup Kaggle Auth
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/ 2>/dev/null
!chmod 600 ~/.kaggle/kaggle.json

print("Downloading from Kaggle...")
!kaggle datasets download -d noodulz/pokemon-dataset-1000 --force

# Clean previous data
if os.path.exists('data/pokemon-dataset-1000'):
    shutil.rmtree('data/pokemon-dataset-1000')
!mkdir -p data

print("Extracting dataset...")
if os.path.exists('pokemon-dataset-1000.zip'):
    # Unzip to a temporary location first to inspect structure
    temp_extract_dir = 'data/temp_extract'
    if os.path.exists(temp_extract_dir): shutil.rmtree(temp_extract_dir)
    !unzip -q pokemon-dataset-1000.zip -d {temp_extract_dir}
    
    print("Organizing dataset...")
    # Target directories
    base_data_dir = 'data/pokemon-dataset-1000'
    train_dir = os.path.join(base_data_dir, 'train')
    val_dir = os.path.join(base_data_dir, 'val')
    test_dir = os.path.join(base_data_dir, 'test')
    
    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(val_dir, exist_ok=True)
    os.makedirs(test_dir, exist_ok=True)
    
    # Find ALL images recursively
    all_images = []
    for root, dirs, files in os.walk(temp_extract_dir):
        for file in files:
            if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                all_images.append(os.path.join(root, file))
    
    print(f"Found {len(all_images)} images total.")
    
    # Shuffle and Split
    random.shuffle(all_images)
    train_split = int(0.9 * len(all_images))
    val_split = int(0.05 * len(all_images))
    
    train_imgs = all_images[:train_split]
    val_imgs = all_images[train_split:train_split+val_split]
    test_imgs = all_images[train_split+val_split:]
    
    print("Moving files to train/val/test folders...")
    # Helper to move files
    def move_files(file_list, target_folder):
        for src in file_list:
            dst = os.path.join(target_folder, os.path.basename(src))
            # Handle duplicate filenames if flattened
            if os.path.exists(dst):
                base, ext = os.path.splitext(os.path.basename(src))
                dst = os.path.join(target_folder, f"{base}_{random.randint(0,9999)}{ext}")
            shutil.move(src, dst)
            
    move_files(train_imgs, train_dir)
    move_files(val_imgs, val_dir)
    move_files(test_imgs, test_dir)
    
    # Cleanup temp
    shutil.rmtree(temp_extract_dir)
    
    print(f"✓ Dataset prepared!")
    print(f"  Train: {len(os.listdir(train_dir))}")
    print(f"  Val: {len(os.listdir(val_dir))}")
    print(f"  Test: {len(os.listdir(test_dir))}")
else:
    print("Error: pokemon-dataset-1000.zip not found! Upload failed?")


In [None]:
# This cell verifies and repairs the project structure to prevent ImportErrors.
# It automatically creates 'data/__init__.py' and 'data/pokemon_dataset.py' if they are missing.

import yaml
import os
import sys

# Get current directory and ensure it's in Python path
project_dir = os.getcwd()
if project_dir not in sys.path:
    sys.path.insert(0, project_dir)

print(f"Project directory: {project_dir}")

# 1. FIX: Ensure data package has __init__.py
data_init_path = os.path.join(project_dir, 'data', '__init__.py')
os.makedirs(os.path.join(project_dir, 'data'), exist_ok=True)
if not os.path.exists(data_init_path):
    print("Creating missing data/__init__.py...")
    with open(data_init_path, 'w') as f:
        f.write("from .pokemon_dataset import PokemonDataset\n")
        f.write("__all__ = ['PokemonDataset']\n")

# 2. FIX: Ensure pokemon_dataset.py exists (in case clone failed)
dataset_py_path = os.path.join(project_dir, 'data', 'pokemon_dataset.py')
if not os.path.exists(dataset_py_path):
    print("⚠ data/pokemon_dataset.py missing! Creating default version...")
    # Write the content of pokemon_dataset.py here directly
    code = """
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

class PokemonDataset(Dataset):
    def __init__(self, root_dir, transform=None, augment=False):
        self.root_dir = root_dir
        self.image_paths = []
        for root, dirs, files in os.walk(root_dir):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.image_paths.append(os.path.join(root, file))
        print(f'Found {len(self.image_paths)} images in {root_dir}')
        
        base_transforms = [
            transforms.Resize((64, 64)),
            transforms.Lambda(self._rgba_to_rgb),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
        if augment:
            aug = [
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ColorJitter(0.1, 0.1, 0.1, 0.05)
            ]
            self.transform = transforms.Compose(aug + base_transforms)
        else:
            self.transform = transforms.Compose(base_transforms)
            
        if transform: self.transform = transform

    def _rgba_to_rgb(self, img):
        if img.mode == 'RGBA':
            bg = Image.new('RGB', img.size, (255, 255, 255))
            bg.paste(img, mask=img.split()[3])
            return bg
        return img.convert('RGB')

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        try:
            img = Image.open(self.image_paths[idx])
            return self.transform(img)
        except Exception as e:
            print(f'Error loading {self.image_paths[idx]}: {e}')
            return torch.zeros(3, 64, 64)
"""
    with open(dataset_py_path, 'w') as f:
        f.write(code)
    print("✓ Created data/pokemon_dataset.py")

# Verify imports
try:
    import data
    from data import PokemonDataset
    print("✓ Imported PokemonDataset successfully")
except ImportError as e:
    print(f"⚠ Import Error: {e}")

# Read baseline config
with open('configs/baseline.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Update paths for Colab
config['data']['train_dir'] = f'{project_dir}/data/pokemon-dataset-1000/train'
config['data']['val_dir'] = f'{project_dir}/data/pokemon-dataset-1000/val'
config['data']['test_dir'] = f'{project_dir}/data/pokemon-dataset-1000/test'

# Save Colab config
os.makedirs('configs', exist_ok=True)
with open('configs/colab.yaml', 'w') as f:
    yaml.dump(config, f)

print("Configuration updated for Colab!")


## Step 7: Fix Python Path (If Import Errors Occur)

**Run this if you get import errors:**


In [None]:
# Fix Python path to ensure imports work
import sys
import os

# Add current directory to Python path
current_dir = os.getcwd()
if current_dir not in sys.path:
    sys.path.insert(0, current_dir)

# Verify data module can be imported
try:
    from data import PokemonDataset
    print("✓ Successfully imported PokemonDataset")
except ImportError:
    # Try alternative import
    try:
        from data.pokemon_dataset import PokemonDataset
        print("✓ Successfully imported PokemonDataset (alternative method)")
    except ImportError as e:
        print(f"⚠ Import error: {e}")
        print("   Make sure you're in the project root directory!")
        print(f"   Current directory: {current_dir}")
        print(f"   Files in current dir: {os.listdir('.')[:10]}")

print(f"\nPython path: {sys.path[:3]}...")


## Step 8: Test Setup (Optional)

**Note:** This step checks if everything is set up correctly. If the dataset check fails, that's OK - the dataset will be verified when training starts.


In [None]:
# CRITICAL: Fix Python path before running test_setup.py
import sys
import os
current_dir = os.getcwd()
if current_dir not in sys.path:
    sys.path.insert(0, current_dir)

# Test import first
print("Testing import before running test_setup.py...")
try:
    from data import PokemonDataset
    print("✓ Import successful! Running test_setup.py...\n")
except ImportError as e:
    print(f"⚠ Import failed: {e}")
    print("Trying alternative...")
    try:
        from data.pokemon_dataset import PokemonDataset
        print("✓ Import successful (alternative)! Running test_setup.py...\n")
    except ImportError as e2:
        print(f"✗ Import still failing: {e2}")
        print(f"Current dir: {current_dir}")
        print(f"Files: {os.listdir('.')[:10]}")
        print("\nSkipping test_setup.py - will verify during training")
        import sys
        sys.exit(0)

# Run test setup
!python test_setup.py 2>&1 || echo "Test completed (some warnings OK)"


## Step 9: Run Training


In [None]:
!python train.py --config configs/colab.yaml


## Step 9: View Training Progress (Optional)


In [None]:
# View TensorBoard
%load_ext tensorboard
%tensorboard --logdir logs


## Step 10: Evaluate Model


In [None]:
!python eval.py --checkpoint checkpoints/baseline.pt --config configs/colab.yaml


## Step 11: Download Results (Optional)


In [None]:
from google.colab import files

# Download best model
files.download('checkpoints/best_model.pt')

# Download sample images
# files.download('outputs/epoch_0_fake.png')
