# WrapsRL v5: StyleGAN2-ADA for Rocket League Decal Textures

This notebook implements a texture generation model for Rocket League decals using StyleGAN2-ADA PyTorch. The model is designed to generate high-quality 1024×1024 decal textures that respect the constraints of texture mapping in the game.

## Overview

Rocket League decal textures present unique challenges for image generation models:
- Specific regions must be preserved (large black spaces that should be left blank)
- Texture mapping constraints must be respected
- The model must be creative within these bounds

This notebook addresses these challenges by:
1. Setting up the appropriate environment for training on Google Colab with A100 GPU
2. Preparing data from the provided Google Drive link using TFRecords
3. Training a StyleGAN2-ADA PyTorch model with optimal parameters
4. Evaluating the model using FID and precision/recall metrics
5. Generating sample 1024×1024 decal textures

The notebook also includes integration with Weights & Biases (wandb.ai) for remote monitoring of model performance.

## 1. Environment Setup for Google Colab A100

This section sets up the environment for training on Google Colab with an A100 GPU. It includes:
- Mounting Google Drive
- Installing required dependencies
- Setting up system monitoring
- Configuring GPU settings

In [None]:
# Check if running in Colab
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Running in Google Colab, setting up environment...")
    
    # Mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Set project paths
    import os
    PROJECT_ROOT = '/content/WrapsRL_v5'
    DATA_DIR = f"{PROJECT_ROOT}/data"
    MODELS_DIR = f"{PROJECT_ROOT}/models"
    OUTPUTS_DIR = f"{PROJECT_ROOT}/outputs"
    
    # Create directories if they don't exist
    os.makedirs(PROJECT_ROOT, exist_ok=True)
    os.makedirs(DATA_DIR, exist_ok=True)
    os.makedirs(MODELS_DIR, exist_ok=True)
    os.makedirs(OUTPUTS_DIR, exist_ok=True)
    
    # Clone StyleGAN2-ADA PyTorch repository
    !git clone https://github.com/NVlabs/stylegan2-ada-pytorch.git
    %cd stylegan2-ada-pytorch
    
    # Install dependencies
    !pip install click requests tqdm pyspng ninja imageio-ffmpeg==0.4.3
    !pip install wandb  # For remote monitoring
    
    # Check GPU availability and select A100
    !nvidia-smi
else:
    print("Not running in Colab. Please run this notebook in Google Colab with A100 GPU for optimal performance.")
    PROJECT_ROOT = os.getcwd()
    DATA_DIR = os.path.join(PROJECT_ROOT, 'data')
    MODELS_DIR = os.path.join(PROJECT_ROOT, 'models')
    OUTPUTS_DIR = os.path.join(PROJECT_ROOT, 'outputs')

In [None]:
# Set up system monitoring
if IN_COLAB:
    # Install and import libraries for monitoring
    !pip install psutil gputil
    import psutil
    import GPUtil
    import time
    from IPython.display import display, HTML
    import threading
    
    # Function to monitor system resources
    def monitor_resources():
        while monitoring_active:
            # CPU usage
            cpu_percent = psutil.cpu_percent(interval=1)
            
            # Memory usage
            memory = psutil.virtual_memory()
            memory_percent = memory.percent
            
            # GPU usage
            gpus = GPUtil.getGPUs()
            if gpus:
                gpu = gpus[0]
                gpu_name = gpu.name
                gpu_load = gpu.load * 100
                gpu_memory_used = gpu.memoryUsed
                gpu_memory_total = gpu.memoryTotal
                gpu_memory_percent = (gpu_memory_used / gpu_memory_total) * 100
                
                print(f"CPU: {cpu_percent:.1f}% | Memory: {memory_percent:.1f}% | "
                      f"GPU: {gpu_name} | GPU Load: {gpu_load:.1f}% | "
                      f"GPU Memory: {gpu_memory_used:.0f}/{gpu_memory_total:.0f} MB ({gpu_memory_percent:.1f}%)")
            else:
                print(f"CPU: {cpu_percent:.1f}% | Memory: {memory_percent:.1f}% | GPU: Not available")
                
            time.sleep(10)  # Update every 10 seconds
    
    # Start monitoring in a separate thread
    monitoring_active = True
    monitor_thread = threading.Thread(target=monitor_resources)
    monitor_thread.daemon = True
    monitor_thread.start()
    
    print("System monitoring started. Updates will appear every 10 seconds.")

In [None]:
# Initialize Weights & Biases for remote monitoring
import wandb

# Initialize wandb - you'll need to log in with your API key
!wandb login

# Initialize a new wandb project
wandb.init(
    project="WrapsRL-v5",
    name="stylegan2-ada-rl-decals",
    config={
        "architecture": "StyleGAN2-ADA",
        "dataset": "Rocket League Decals",
        "batch_size": 4,
        "augmentation": "ada",
        "fp32": True,
        "kimg": 25000,
        "image_size": 1024,
    }
)

## 2. Data Ingestion and TFRecords Preparation

This section handles data ingestion from the provided Google Drive link and prepares TFRecords for training. It includes:
- Downloading data from Google Drive
- Processing and preparing the dataset
- Converting to TFRecords format for StyleGAN2-ADA

In [None]:
# Download data from Google Drive
if IN_COLAB:
    # The provided Google Drive link
    DRIVE_LINK = "https://drive.google.com/drive/folders/1--hKMnum6Y6vmzkDVzLvE44eYjRswvXG?usp=drive_link"
    
    # Install gdown for downloading from Google Drive
    !pip install gdown
    
    # Extract folder ID from the link
    import re
    folder_id = re.search(r'folders/([^?]+)', DRIVE_LINK).group(1)
    
    # Download the folder contents
    !gdown --folder --id {folder_id} -O {DATA_DIR}/raw
    
    print(f"Downloaded data to {DATA_DIR}/raw")
    !ls -la {DATA_DIR}/raw

In [None]:
# Process and prepare the dataset
import os
import shutil
from PIL import Image
import numpy as np
from tqdm import tqdm

# Create processed data directory
PROCESSED_DIR = os.path.join(DATA_DIR, 'processed')
os.makedirs(PROCESSED_DIR, exist_ok=True)

# Function to process images
def process_images(input_dir, output_dir, target_size=(1024, 1024)):
    os.makedirs(output_dir, exist_ok=True)
    
    # Get all image files
    image_files = [f for f in os.listdir(input_dir) 
                  if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
    
    print(f"Processing {len(image_files)} images...")
    
    for img_file in tqdm(image_files):
        try:
            # Open and process image
            img_path = os.path.join(input_dir, img_file)
            img = Image.open(img_path)
            
            # Resize to target size if needed
            if img.size != target_size:
                img = img.resize(target_size, Image.LANCZOS)
            
            # Save as PNG
            output_path = os.path.join(output_dir, os.path.splitext(img_file)[0] + '.png')
            img.save(output_path, 'PNG')
            
        except Exception as e:
            print(f"Error processing {img_file}: {e}")
    
    print(f"Processed images saved to {output_dir}")

# Process the raw images
RAW_DIR = os.path.join(DATA_DIR, 'raw')
if os.path.exists(RAW_DIR):
    process_images(RAW_DIR, PROCESSED_DIR)
else:
    print(f"Raw data directory {RAW_DIR} not found. Please ensure data is downloaded correctly.")

In [None]:
# Convert processed images to StyleGAN2-ADA dataset format
if IN_COLAB and os.path.exists(PROCESSED_DIR):
    # Create dataset directory
    DATASET_DIR = os.path.join(DATA_DIR, 'dataset')
    os.makedirs(DATASET_DIR, exist_ok=True)
    
    # Use StyleGAN2-ADA's dataset_tool.py to create the dataset
    !python stylegan2-ada-pytorch/dataset_tool.py create_from_images \
        {DATASET_DIR}/rocket_league_decals {PROCESSED_DIR}
    
    print(f"Dataset created at {DATASET_DIR}/rocket_league_decals")
    
    # Generate dataset statistics
    import json
    
    # Count images
    num_images = len([f for f in os.listdir(PROCESSED_DIR) 
                     if f.lower().endswith('.png')])
    
    # Calculate average file size
    total_size = sum(os.path.getsize(os.path.join(PROCESSED_DIR, f)) 
                     for f in os.listdir(PROCESSED_DIR) 
                     if f.lower().endswith('.png'))
    avg_size = total_size / num_images if num_images > 0 else 0
    
    # Save statistics
    stats = {
        "num_images": num_images,
        "image_size": "1024x1024",
        "total_size_mb": total_size / (1024 * 1024),
        "avg_size_kb": avg_size / 1024
    }
    
    with open(os.path.join(DATA_DIR, 'dataset_stats.json'), 'w') as f:
        json.dump(stats, f, indent=4)
    
    print("Dataset statistics:")
    print(json.dumps(stats, indent=4))
    
    # Log to wandb
    wandb.log({"dataset_stats": stats})

## 3. Training with StyleGAN2-ADA PyTorch

This section sets up and runs the training process using StyleGAN2-ADA PyTorch. It includes:
- Selecting an appropriate pretrained model as a starting point
- Configuring training parameters
- Running the training process
- Monitoring training progress with wandb

In [None]:
# Download pretrained models
if IN_COLAB:
    # Create models directory
    os.makedirs(MODELS_DIR, exist_ok=True)
    
    # Download recommended pretrained models
    pretrained_models = {
        "ffhq": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl",
        "metfaces": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl",
        "afhqv2": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqv2.pkl",
        "brecahad": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/brecahad.pkl"
    }
    
    for name, url in pretrained_models.items():
        output_path = os.path.join(MODELS_DIR, f"{name}.pkl")
        !wget {url} -O {output_path}
        print(f"Downloaded {name} model to {output_path}")
    
    # List available models
    print("\nAvailable pretrained models:")
    !ls -la {MODELS_DIR}

### Pretrained Model Selection

Based on our research, we recommend the following pretrained models as potential kickoff points for Rocket League decal textures:

1. **BreCaHAD (512x512)**: This medical dataset model has learned to generate images with specific structural constraints, which might transfer well to the constraints of texture mapping in Rocket League decals.

2. **FFHQ (1024x1024)**: While designed for faces, this model has learned complex patterns and details that could transfer well to decal generation. The high resolution matches our target 1024x1024 output.

3. **MetFaces (1024x1024)**: This model was trained using transfer learning from FFHQ with ADA, demonstrating the effectiveness of the ADA approach for limited data scenarios.

For this project, we'll use the **BreCaHAD** model as our primary kickoff point, as it's likely to handle the structural constraints of decal textures well. We'll also experiment with the other models to compare results.

In [None]:
# Set up training parameters
if IN_COLAB:
    # Select pretrained model to use
    PRETRAINED_MODEL = os.path.join(MODELS_DIR, "brecahad.pkl")
    
    # Set up training parameters
    DATASET_PATH = os.path.join(DATA_DIR, 'dataset/rocket_league_decals')
    OUTPUT_DIR = os.path.join(OUTPUTS_DIR, 'training-runs')
    
    # Training parameters as specified in requirements
    GPUS = 1
    BATCH_SIZE = 4
    AUG = "ada"
    KIMG = 25000
    FP32 = True
    
    # Create a custom training script with wandb integration
    with open('train_with_wandb.py', 'w') as f:
        f.write("""
import os
import sys
import json
import wandb
import subprocess
import time

# Function to parse training stats
def parse_training_stats(stats_file):
    if not os.path.exists(stats_file):
        return None
    
    with open(stats_file, 'r') as f:
        lines = f.readlines()
    
    if not lines:
        return None
    
    # Parse the latest stats
    try:
        latest_stats = json.loads(lines[-1])
        return latest_stats
    except:
        return None

# Start the training process
cmd = [
    'python', 'stylegan2-ada-pytorch/train.py',
    f'--outdir={sys.argv[1]}',
    f'--data={sys.argv[2]}',
    f'--gpus={sys.argv[3]}',
    f'--batch={sys.argv[4]}',
    f'--aug={sys.argv[5]}',
    f'--resume={sys.argv[6]}',
    f'--kimg={sys.argv[7]}',
]

if sys.argv[8] == 'True':
    cmd.append('--fp32')

# Start the process
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)

# Monitor the output directory for training stats
output_dir = sys.argv[1]
last_log_time = 0

while process.poll() is None:
    # Check for output from the process
    output = process.stdout.readline()
    if output:
        print(output.strip())
    
    # Check for training stats every 60 seconds
    current_time = time.time()
    if current_time - last_log_time > 60:
        # Look for the training_stats.jsonl file in subdirectories
        for root, dirs, files in os.walk(output_dir):
            if 'training_stats.jsonl' in files:
                stats_file = os.path.join(root, 'training_stats.jsonl')
                stats = parse_training_stats(stats_file)
                
                if stats:
                    # Log to wandb
                    wandb_stats = {}
                    
                    # Extract relevant metrics
                    if 'tick' in stats:
                        wandb_stats['tick'] = stats['tick']
                    if 'kimg' in stats:
                        wandb_stats['kimg'] = stats['kimg']
                    
                    # Loss metrics
                    for key in stats:
                        if key.startswith('Loss/'):
                            wandb_stats[key] = stats[key]
                    
                    # FID score if available
                    if 'fid50k_full' in stats:
                        wandb_stats['fid50k_full'] = stats['fid50k_full']
                    
                    # Log to wandb
                    wandb.log(wandb_stats)
                    
                    # Also look for generated images
                    fakes_file = os.path.join(root, 'fakes.png')
                    if os.path.exists(fakes_file):
                        wandb.log({"generated_samples": wandb.Image(fakes_file)})
                
                last_log_time = current_time
                break
    
    # Sleep briefly to avoid high CPU usage
    time.sleep(1)

# Process any remaining output
for output in process.stdout:
    print(output.strip())

# Final status
exit_code = process.wait()
print(f"Training process exited with code {exit_code}")
        """)
    
    print("Created custom training script with wandb integration")

In [None]:
# Run the training process
if IN_COLAB and os.path.exists(DATASET_PATH) and os.path.exists(PRETRAINED_MODEL):
    print("Starting StyleGAN2-ADA training...")
    
    # Run the custom training script
    !python train_with_wandb.py \
        {OUTPUT_DIR} \
        {DATASET_PATH} \
        {GPUS} \
        {BATCH_SIZE} \
        {AUG} \
        {PRETRAINED_MODEL} \
        {KIMG} \
        {str(FP32)}
    
    print("Training complete!")
else:
    print("Cannot start training. Please ensure dataset and pretrained model are available.")

## 4. Evaluation

This section evaluates the trained model using FID and precision/recall metrics. It includes:
- Computing FID score
- Computing precision and recall
- Visualizing the results

In [None]:
# Evaluate the trained model
if IN_COLAB:
    # Find the latest network pickle
    import glob
    
    # Look for the latest network pickle in the output directory
    network_pickles = glob.glob(f"{OUTPUT_DIR}/**/network-snapshot-*.pkl", recursive=True)
    if network_pickles:
        # Sort by modification time (newest first)
        network_pickles.sort(key=os.path.getmtime, reverse=True)
        latest_pickle = network_pickles[0]
        print(f"Found latest network pickle: {latest_pickle}")
        
        # Compute metrics
        print("Computing FID and precision/recall metrics...")
        !python stylegan2-ada-pytorch/calc_metrics.py \
            --metrics=fid50k_full,pr50k3_full \
            --network={latest_pickle} \
            --data={DATASET_PATH} \
            --gpus={GPUS}
        
        # Parse and log metrics
        metrics_file = os.path.join(os.path.dirname(latest_pickle), 'metric-fid50k_full.jsonl')
        pr_metrics_file = os.path.join(os.path.dirname(latest_pickle), 'metric-pr50k3_full.jsonl')
        
        if os.path.exists(metrics_file):
            with open(metrics_file, 'r') as f:
                lines = f.readlines()
                if lines:
                    fid_metric = json.loads(lines[-1])
                    print(f"FID score: {fid_metric['results']['fid50k_full']}")
                    wandb.log({"final_fid": fid_metric['results']['fid50k_full']})
        
        if os.path.exists(pr_metrics_file):
            with open(pr_metrics_file, 'r') as f:
                lines = f.readlines()
                if lines:
                    pr_metric = json.loads(lines[-1])
                    precision = pr_metric['results']['pr50k3_full_precision']
                    recall = pr_metric['results']['pr50k3_full_recall']
                    print(f"Precision: {precision}, Recall: {recall}")
                    wandb.log({"final_precision": precision, "final_recall": recall})
    else:
        print("No trained network pickles found.")

## 5. Sample Generation

This section generates sample 1024×1024 decal textures using the trained model. It includes:
- Generating random samples
- Visualizing the generated textures
- Saving the results

In [None]:
# Generate samples
if IN_COLAB and 'latest_pickle' in locals() and os.path.exists(latest_pickle):
    # Create samples directory
    SAMPLES_DIR = os.path.join(OUTPUTS_DIR, 'samples')
    os.makedirs(SAMPLES_DIR, exist_ok=True)
    
    # Generate random samples
    print("Generating random samples...")
    !python stylegan2-ada-pytorch/generate.py \
        --outdir={SAMPLES_DIR} \
        --trunc=0.7 \
        --seeds=0-9 \
        --network={latest_pickle}
    
    # Display the generated samples
    from IPython.display import Image, display
    import matplotlib.pyplot as plt
    
    sample_files = sorted(glob.glob(f"{SAMPLES_DIR}/seed*.png"))
    
    if sample_files:
        plt.figure(figsize=(20, 10))
        for i, sample_file in enumerate(sample_files[:10]):
            plt.subplot(2, 5, i+1)
            img = plt.imread(sample_file)
            plt.imshow(img)
            plt.axis('off')
            plt.title(f"Sample {i+1}")
        plt.tight_layout()
        plt.savefig(os.path.join(SAMPLES_DIR, 'samples_grid.png'))
        plt.show()
        
        # Log samples to wandb
        wandb.log({"generated_samples_grid": wandb.Image(os.path.join(SAMPLES_DIR, 'samples_grid.png'))})
        
        # Log individual samples
        for i, sample_file in enumerate(sample_files[:10]):
            wandb.log({f"sample_{i+1}": wandb.Image(sample_file)})
    else:
        print("No samples generated.")
else:
    print("Cannot generate samples. Please ensure trained model is available.")

## 6. Summary and Reporting

This section generates a markdown summary of the project, including data statistics, training curves, and final metric scores.

In [None]:
# Generate summary report
if IN_COLAB:
    # Create report directory
    REPORT_DIR = os.path.join(OUTPUTS_DIR, 'report')
    os.makedirs(REPORT_DIR, exist_ok=True)
    
    # Generate markdown report
    report_path = os.path.join(REPORT_DIR, 'summary_report.md')
    
    with open(report_path, 'w') as f:
        f.write("# WrapsRL v5 Project Summary Report\n\n")
        
        # Dataset statistics
        f.write("## Dataset Statistics\n\n")
        stats_file = os.path.join(DATA_DIR, 'dataset_stats.json')
        if os.path.exists(stats_file):
            with open(stats_file, 'r') as sf:
                stats = json.load(sf)
                f.write(f"- Number of images: {stats['num_images']}\n")
                f.write(f"- Image size: {stats['image_size']}\n")
                f.write(f"- Total dataset size: {stats['total_size_mb']:.2f} MB\n")
                f.write(f"- Average image size: {stats['avg_size_kb']:.2f} KB\n\n")
        else:
            f.write("Dataset statistics not available.\n\n")
        
        # Training configuration
        f.write("## Training Configuration\n\n")
        f.write(f"- Pretrained model: {os.path.basename(PRETRAINED_MODEL)}\n")
        f.write(f"- GPUs: {GPUS}\n")
        f.write(f"- Batch size: {BATCH_SIZE}\n")
        f.write(f"- Augmentation: {AUG}\n")
        f.write(f"- Training duration: {KIMG} kimgs\n")
        f.write(f"- FP32: {FP32}\n\n")
        
        # Evaluation metrics
        f.write("## Evaluation Metrics\n\n")
        if 'latest_pickle' in locals() and os.path.exists(latest_pickle):
            metrics_file = os.path.join(os.path.dirname(latest_pickle), 'metric-fid50k_full.jsonl')
            pr_metrics_file = os.path.join(os.path.dirname(latest_pickle), 'metric-pr50k3_full.jsonl')
            
            if os.path.exists(metrics_file):
                with open(metrics_file, 'r') as mf:
                    lines = mf.readlines()
                    if lines:
                        fid_metric = json.loads(lines[-1])
                        f.write(f"- FID score: {fid_metric['results']['fid50k_full']}\n")
            
            if os.path.exists(pr_metrics_file):
                with open(pr_metrics_file, 'r') as pf:
                    lines = pf.readlines()
                    if lines:
                        pr_metric = json.loads(lines[-1])
                        precision = pr_metric['results']['pr50k3_full_precision']
                        recall = pr_metric['results']['pr50k3_full_recall']
                        f.write(f"- Precision: {precision}\n")
                        f.write(f"- Recall: {recall}\n\n")
        else:
            f.write("Evaluation metrics not available.\n\n")
        
        # Sample images
        f.write("## Sample Images\n\n")
        samples_grid = os.path.join(SAMPLES_DIR, 'samples_grid.png')
        if os.path.exists(samples_grid):
            f.write(f"![Generated Samples]({samples_grid})\n\n")
        else:
            f.write("Sample images not available.\n\n")
        
        # Wandb link
        f.write("## Remote Monitoring\n\n")
        f.write(f"Training progress and metrics can be monitored remotely at: {wandb.run.get_url()}\n")
    
    print(f"Summary report generated at {report_path}")
    
    # Display the report
    with open(report_path, 'r') as f:
        report_content = f.read()
    
    from IPython.display import Markdown
    display(Markdown(report_content))
    
    # Log report to wandb
    wandb.save(report_path)

In [None]:
# Clean up and finish
if IN_COLAB:
    # Stop monitoring
    if 'monitoring_active' in globals():
        monitoring_active = False
        if 'monitor_thread' in globals() and monitor_thread.is_alive():
            monitor_thread.join(timeout=1)
    
    # Finish wandb run
    wandb.finish()
    
    print("WrapsRL v5 project completed successfully!")