# Image-GS Quick Start

Minimal setup for RTX 4090. Tested on runpod.io.

## Usage
1. Copy this notebook to your workspace (e.g., `/workspace/quick-start.ipynb`)
2. Run all cells to install dependencies and clone repository
3. Place images in `/workspace/input/`
4. Configure and train!

## Workspace Structure
```
/workspace/
├── quick-start.ipynb  (this notebook)
├── input/            (your images)
├── output/           (results)
└── image-gs/         (repository)
```

## Step 1: Install System Dependencies

In [None]:
import sys
import subprocess

print("Installing system dependencies...\n")

commands = [
    "apt-get update -qq",
    "apt-get install -y -qq build-essential git wget curl",
]

for cmd in commands:
    subprocess.run(cmd, shell=True, capture_output=True)

print("✓ System dependencies installed")

## Step 2: Upgrade pip

In [None]:
!{sys.executable} -m pip install --upgrade pip setuptools wheel -q

## Step 3: Install PyTorch (CUDA 12.1 for RTX 4090)

In [None]:
print("Installing PyTorch 2.4.1 with CUDA 12.1...\n")

!{sys.executable} -m pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121

import torch
print(f"\n✓ PyTorch {torch.__version__}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"✓ GPU: {torch.cuda.get_device_name(0)}")

## Step 4: Install Python Dependencies

In [None]:
print("Installing Python dependencies...\n")

dependencies = [
    "flip-evaluator",
    "lpips==0.1.4",
    "matplotlib==3.9.2",
    "numpy<2.1",
    "opencv-python==4.12.0.88",
    "pytorch-msssim==1.0.0",
    "scikit-image==0.24.0",
    "scipy==1.13.1",
    "torchmetrics==1.5.2",
    "jaxtyping",
    "rich>=12",
    "pyyaml==6.0",
    "ninja",
]

for dep in dependencies:
    subprocess.run([sys.executable, "-m", "pip", "install", dep, "-q"], capture_output=True)

print("✓ Python dependencies installed")

## Step 5: Clone Repository

In [None]:
import os

# Determine root workspace directory (where this notebook is located)
# On runpod.io this would be /workspace
ROOT_WORKSPACE = os.getcwd()

# Repository will be cloned here
REPO_DIR = os.path.join(ROOT_WORKSPACE, "image-gs")

if os.path.exists(REPO_DIR):
    print(f"Repository already exists at {REPO_DIR}")
    !cd {REPO_DIR} && git pull
else:
    print(f"Cloning repository to {REPO_DIR}...\n")
    !git clone https://github.com/NYU-ICL/image-gs.git {REPO_DIR}

print(f"\n✓ Repository: {REPO_DIR}")

## Step 6: Install fused-ssim

In [None]:
print("Installing fused-ssim...\n")

!{sys.executable} -m pip install git+https://github.com/rahul-goel/fused-ssim.git --no-build-isolation -q

print("✓ fused-ssim installed")

## Step 7: Install gsplat (with fix)

In [None]:
print("Installing gsplat CUDA extension...\n")
print("This will take 5-10 minutes.\n")

gsplat_dir = os.path.join(REPO_DIR, "gsplat")
os.chdir(gsplat_dir)

# Uninstall any existing installation
!{sys.executable} -m pip uninstall -y gsplat -q
!{sys.executable} -m pip cache purge -q

# Regular install (not editable) - this is the fix from step 9
!{sys.executable} -m pip install . --no-build-isolation

print("\n✓ gsplat installed")

## Step 8: Verify Installation

In [None]:
print("Verifying installation...\n")

# Force reimport
import importlib
for mod in list(sys.modules.keys()):
    if 'gsplat' in mod:
        del sys.modules[mod]

errors = []

try:
    import torch
    assert torch.cuda.is_available()
    print("✓ PyTorch with CUDA")
except Exception as e:
    errors.append(f"PyTorch: {e}")

try:
    from fused_ssim import fused_ssim
    print("✓ fused_ssim")
except Exception as e:
    errors.append(f"fused_ssim: {e}")

try:
    from gsplat import (
        project_gaussians_2d_scale_rot,
        rasterize_gaussians_no_tiles,
        rasterize_gaussians_sum,
    )
    print("✓ gsplat CUDA extensions")
except Exception as e:
    errors.append(f"gsplat: {e}")

try:
    os.chdir(REPO_DIR)
    sys.path.insert(0, REPO_DIR)
    from model import GaussianSplatting2D
    from utils.misc_utils import load_cfg
    print("✓ Image-GS modules")
except Exception as e:
    errors.append(f"Image-GS: {e}")

if errors:
    print(f"\n⚠️  {len(errors)} error(s):")
    for err in errors:
        print(f"  {err}")
else:
    print("\n✅ All components verified!")

## Step 9: Setup Workspace

Creates organized directory structure for inputs and outputs.

In [None]:
import glob
import shutil
import numpy as np
from datetime import datetime
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt

matplotlib.rcParams['text.usetex'] = False

# Create workspace structure at root level
INPUT_DIR = os.path.join(ROOT_WORKSPACE, "input")
OUTPUT_DIR = os.path.join(ROOT_WORKSPACE, "output")

os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("✅ Workspace structure created:")
print(f"   📁 {ROOT_WORKSPACE}/")
print(f"      ├── quick-start.ipynb")
print(f"      ├── input/   (place your images here)")
print(f"      ├── output/  (results will be saved here)")
print(f"      └── image-gs/  (repository)")

## Step 10: Function Definitions

Helper functions for training, viewing, and analyzing results.

In [None]:
def get_paths():
    """Get workspace paths, with fallback defaults. Create directories if needed."""
    global ROOT_WORKSPACE, REPO_DIR, INPUT_DIR, OUTPUT_DIR
    
    if 'ROOT_WORKSPACE' not in globals():
        ROOT_WORKSPACE = os.getcwd()
    if 'REPO_DIR' not in globals():
        REPO_DIR = os.path.join(ROOT_WORKSPACE, "image-gs")
    if 'INPUT_DIR' not in globals():
        INPUT_DIR = os.path.join(ROOT_WORKSPACE, "input")
    if 'OUTPUT_DIR' not in globals():
        OUTPUT_DIR = os.path.join(ROOT_WORKSPACE, "output")
    
    # Ensure input and output directories exist
    os.makedirs(INPUT_DIR, exist_ok=True)
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    return ROOT_WORKSPACE, REPO_DIR, INPUT_DIR, OUTPUT_DIR

In [None]:
def _setup_training_environment(input_filename, num_gaussians, max_steps):
    """Setup training environment: validate input, create directories, copy files."""
    ROOT_WORKSPACE, REPO_DIR, INPUT_DIR, OUTPUT_DIR = get_paths()
    
    # Change to repository directory
    os.chdir(REPO_DIR)
    
    # Validate input
    input_path = os.path.join(INPUT_DIR, input_filename)
    if not os.path.exists(input_path):
        raise FileNotFoundError(f"Input image not found: {input_path}")
    
    # Create timestamped output folder
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    base_name = os.path.splitext(input_filename)[0]
    output_folder = f"{base_name}-{num_gaussians}-{max_steps}-{timestamp}"
    output_path = os.path.join(OUTPUT_DIR, output_folder)
    
    os.makedirs(output_path, exist_ok=True)
    os.makedirs(os.path.join(output_path, "other"), exist_ok=True)
    
    # Copy input to media/images/
    media_input_path = os.path.join(REPO_DIR, "media", "images", input_filename)
    os.makedirs(os.path.join(REPO_DIR, "media", "images"), exist_ok=True)
    shutil.copy2(input_path, media_input_path)
    
    return output_folder, output_path

In [None]:
def _build_training_command(input_filename, output_folder, num_gaussians, max_steps, use_progressive):
    """Build the training command string."""
    prog_flag = "" if use_progressive else "--disable_prog_optim"
    temp_exp_name = f"temp/{output_folder}"
    
    cmd = f"""
    {sys.executable} main.py \
      --input_path="images/{input_filename}" \
      --exp_name="{temp_exp_name}" \
      --num_gaussians={num_gaussians} \
      --max_steps={max_steps} \
      --quantize \
      {prog_flag} \
      --device="cuda:0"
    """
    
    return cmd

In [None]:
def _run_training(cmd, input_filename, num_gaussians, max_steps, use_progressive, output_folder):
    """Print training info and execute training command."""
    ROOT_WORKSPACE, _, _, _ = get_paths()
    
    print("=" * 80)
    print(f"🚀 TRAINING: {input_filename}")
    print("=" * 80)
    print(f"Gaussians:   {num_gaussians}")
    print(f"Steps:       {max_steps}")
    print(f"Progressive: {use_progressive}")
    print(f"Output:      output/{output_folder}/")
    print(f"Time est:    ~{max_steps * 0.002:.1f}-{max_steps * 0.005:.1f} minutes")
    print("=" * 80)
    print()
    
    os.system(cmd)
    
    print()
    print("=" * 80)
    print("✅ TRAINING COMPLETE")
    print("=" * 80)
    print(f"📁 Output folder: output/{output_folder}/")
    print()

In [None]:
def _organize_training_outputs(output_folder, output_path):
    """Copy training outputs to organized output folder."""
    _, REPO_DIR, _, _ = get_paths()
    
    temp_exp_name = f"temp/{output_folder}"
    result_base = os.path.join(REPO_DIR, "results", temp_exp_name)
    run_dirs = [d for d in os.listdir(result_base) if os.path.isdir(os.path.join(result_base, d))]
    latest_run = sorted(run_dirs)[-1]
    result_dir = os.path.join(result_base, latest_run)
    
    # 1. Model checkpoint
    ckpt_dir = os.path.join(result_dir, "checkpoints")
    ckpt_files = glob.glob(os.path.join(ckpt_dir, "ckpt_step-*.pt"))
    if ckpt_files:
        latest_ckpt = sorted(ckpt_files)[-1]
        shutil.copy2(latest_ckpt, os.path.join(output_path, "model.pt"))
    
    # 2. Rendered image
    renders = glob.glob(os.path.join(result_dir, "render_res-*.jpg"))
    if renders:
        shutil.copy2(renders[0], os.path.join(output_path, "rendered.jpg"))
    
    # 3. Ground truth image
    gts = glob.glob(os.path.join(result_dir, "gt_res-*.jpg"))
    if gts:
        shutil.copy2(gts[0], os.path.join(output_path, "other", "ground_truth.jpg"))
    
    # 4. Training log
    log_file = os.path.join(result_dir, "log_train.txt")
    if os.path.exists(log_file):
        shutil.copy2(log_file, os.path.join(output_path, "other", "log_train.txt"))
    
    # 5. Metrics CSV
    metrics_csv = os.path.join(result_dir, "metrics.csv")
    if os.path.exists(metrics_csv):
        shutil.copy2(metrics_csv, os.path.join(output_path, "metrics.csv"))
    
    # 6. Copy all other files to "other" subdirectory
    for item in os.listdir(result_dir):
        item_path = os.path.join(result_dir, item)
        if os.path.isfile(item_path):
            # Skip files we already copied
            if item not in ["log_train.txt", "metrics.csv"] and not item.startswith("render_res-") and not item.startswith("gt_res-"):
                shutil.copy2(item_path, os.path.join(output_path, "other", item))

In [None]:
def train_image_gs(input_filename, num_gaussians, max_steps, use_progressive=True):
    """
    Train Image-GS model with organized file management.
    
    Args:
        input_filename: Just the filename (e.g., "cat.png") from input/
        num_gaussians: Number of Gaussians
        max_steps: Training steps
        use_progressive: Enable progressive optimization (default: True)
    
    Returns:
        output_folder: Name of the created output folder
    """
    output_folder, output_path = _setup_training_environment(input_filename, num_gaussians, max_steps)
    cmd = _build_training_command(input_filename, output_folder, num_gaussians, max_steps, use_progressive)
    _run_training(cmd, input_filename, num_gaussians, max_steps, use_progressive, output_folder)
    _organize_training_outputs(output_folder, output_path)
    
    return output_folder

In [None]:
def _load_training_results(output_folder):
    """Load training results and parse configuration."""
    _, _, _, OUTPUT_DIR = get_paths()
    
    output_path = os.path.join(OUTPUT_DIR, output_folder)
    if not os.path.exists(output_path):
        raise FileNotFoundError(f"Output folder not found: {output_path}")
    
    # Parse config from folder name
    parts = output_folder.rsplit("-", 3)
    if len(parts) >= 3:
        base_name = parts[0]
        num_gaussians = int(parts[1])
        max_steps = int(parts[2])
    else:
        base_name = output_folder
        num_gaussians = "?"
        max_steps = "?"
    
    # File paths
    model_path = os.path.join(output_path, "model.pt")
    rendered_path = os.path.join(output_path, "rendered.jpg")
    gt_path = os.path.join(output_path, "other", "ground_truth.jpg")
    
    # Load images
    gt_img = np.array(Image.open(gt_path)).astype(np.float32) / 255.0
    render_img = np.array(Image.open(rendered_path)).astype(np.float32) / 255.0
    
    # File sizes
    model_size = os.path.getsize(model_path) if os.path.exists(model_path) else None
    gt_size = os.path.getsize(gt_path)
    render_size = os.path.getsize(rendered_path)
    
    return {
        'output_path': output_path,
        'output_folder': output_folder,
        'base_name': base_name,
        'num_gaussians': num_gaussians,
        'max_steps': max_steps,
        'gt_img': gt_img,
        'render_img': render_img,
        'model_size': model_size,
        'gt_size': gt_size,
        'render_size': render_size
    }

In [None]:
def _calculate_quality_metrics(gt_img, render_img):
    """Calculate quality metrics between ground truth and rendered images."""
    # Calculate difference
    diff = np.abs(gt_img - render_img)
    diff_gray = np.mean(diff, axis=2)
    
    # Basic statistics
    mean_diff = np.mean(diff_gray)
    max_diff = np.max(diff_gray)
    std_diff = np.std(diff_gray)
    
    # Pixel-level analysis
    pix_1pct = np.sum(diff_gray > 0.01) / diff_gray.size * 100
    pix_5pct = np.sum(diff_gray > 0.05) / diff_gray.size * 100
    pix_10pct = np.sum(diff_gray > 0.10) / diff_gray.size * 100
    
    return {
        'diff': diff,
        'diff_gray': diff_gray,
        'mean_diff': mean_diff,
        'max_diff': max_diff,
        'std_diff': std_diff,
        'pix_1pct': pix_1pct,
        'pix_5pct': pix_5pct,
        'pix_10pct': pix_10pct
    }

In [None]:
def _create_results_visualization(gt_img, render_img, diff_gray, mean_diff):
    """Create 3-panel comparison visualization."""
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    axes[0].imshow(gt_img)
    axes[0].set_title("Ground Truth", fontsize=14, fontweight='bold')
    axes[0].axis('off')
    
    axes[1].imshow(render_img)
    axes[1].set_title("2D Gaussians (Rendered)", fontsize=14, fontweight='bold')
    axes[1].axis('off')
    
    im = axes[2].imshow(diff_gray, cmap='hot', vmin=0, vmax=0.2)
    axes[2].set_title(f"Difference Map\nMean: {mean_diff:.4f} ({mean_diff*100:.2f}%)", fontsize=14, fontweight='bold')
    axes[2].axis('off')
    
    cbar = plt.colorbar(im, ax=axes[2], fraction=0.046, pad=0.04)
    cbar.set_label('Absolute Difference', rotation=270, labelpad=20)
    
    plt.tight_layout()
    
    return fig

In [None]:
def _format_summary_text(results_data, metrics_data):
    """Format summary text with tables."""
    # Helper functions
    def fmt_size(size_bytes):
        if size_bytes is None:
            return "N/A"
        if size_bytes < 1024:
            return f"{size_bytes} B"
        elif size_bytes < 1024 * 1024:
            return f"{size_bytes / 1024:.2f} KB"
        else:
            return f"{size_bytes / (1024 * 1024):.2f} MB"
    
    def fmt_ratio(num, denom):
        if num is None or denom is None:
            return "N/A"
        return f"{num / denom:.2f}x"
    
    # Extract data
    output_folder = results_data['output_folder']
    base_name = results_data['base_name']
    num_gaussians = results_data['num_gaussians']
    max_steps = results_data['max_steps']
    gt_img = results_data['gt_img']
    model_size = results_data['model_size']
    gt_size = results_data['gt_size']
    render_size = results_data['render_size']
    
    mean_diff = metrics_data['mean_diff']
    max_diff = metrics_data['max_diff']
    std_diff = metrics_data['std_diff']
    pix_1pct = metrics_data['pix_1pct']
    pix_5pct = metrics_data['pix_5pct']
    pix_10pct = metrics_data['pix_10pct']
    
    # Image info
    height, width, channels = gt_img.shape
    total_pixels = width * height
    
    # Build summary
    summary_lines = []
    summary_lines.append("=" * 100)
    summary_lines.append("IMAGE-GS TRAINING SUMMARY")
    summary_lines.append("=" * 100)
    summary_lines.append("")
    
    # Table 1: Training Configuration
    summary_lines.append("TRAINING CONFIGURATION")
    summary_lines.append("-" * 100)
    summary_lines.append(f"{'Output Folder':<25} {'Base Name':<20} {'Gaussians':<15} {'Steps':<15}")
    summary_lines.append(f"{output_folder:<25} {base_name:<20} {num_gaussians:<15} {max_steps:<15}")
    summary_lines.append("")
    
    # Table 2: Image & File Information
    summary_lines.append("IMAGE & FILE INFORMATION")
    summary_lines.append("-" * 100)
    summary_lines.append(f"{'Metric':<30} {'Value':<30} {'Metric':<30} {'Value':<30}")
    summary_lines.append(f"{'Resolution':<30} {f'{width} x {height} px':<30} {'Total Pixels':<30} {f'{total_pixels:,}':<30}")
    summary_lines.append(f"{'Channels':<30} {channels:<30} {'Ground Truth Size':<30} {fmt_size(gt_size):<30}")
    summary_lines.append(f"{'Model Size':<30} {fmt_size(model_size):<30} {'Rendered Size':<30} {fmt_size(render_size):<30}")
    summary_lines.append("")
    
    # Table 3: Compression Analysis
    if model_size:
        uncompressed_size = total_pixels * channels
        bpp = (model_size * 8) / total_pixels
        compression_vs_gt = gt_size / model_size
        compression_vs_raw = uncompressed_size / model_size
        
        summary_lines.append("COMPRESSION ANALYSIS")
        summary_lines.append("-" * 100)
        summary_lines.append(f"{'Metric':<40} {'Value':<30} {'Note':<30}")
        summary_lines.append(f"{'Compression vs Original (JPG)':<40} {fmt_ratio(gt_size, model_size):<30} {'Model is {:.1f}% of original'.format((model_size/gt_size)*100):<30}")
        summary_lines.append(f"{'Compression vs Raw (uncompressed)':<40} {fmt_ratio(uncompressed_size, model_size):<30} {f'{fmt_size(uncompressed_size)} -> {fmt_size(model_size)}':<30}")
        summary_lines.append(f"{'Bits Per Pixel (bpp)':<40} {f'{bpp:.4f} bpp':<30} {'Lower is better':<30}")
        summary_lines.append("")
    
    # Table 4: Quality Metrics
    summary_lines.append("QUALITY METRICS")
    summary_lines.append("-" * 100)
    summary_lines.append(f"{'Metric':<30} {'Value':<30} {'Metric':<30} {'Value':<30}")
    summary_lines.append(f"{'Mean Difference':<30} {f'{mean_diff:.6f} ({mean_diff*100:.2f}%)':<30} {'Max Difference':<30} {f'{max_diff:.6f} ({max_diff*100:.2f}%)':<30}")
    summary_lines.append(f"{'Std Deviation':<30} {f'{std_diff:.6f}':<30} {'Pixels > 1% diff':<30} {f'{pix_1pct:.2f}%':<30}")
    summary_lines.append(f"{'Pixels > 5% diff':<30} {f'{pix_5pct:.2f}%':<30} {'Pixels > 10% diff':<30} {f'{pix_10pct:.2f}%':<30}")
    summary_lines.append("")
    
    # Files saved
    summary_lines.append("FILES SAVED")
    summary_lines.append("-" * 100)
    summary_lines.append(f"{'File':<25} {'Description':<75}")
    summary_lines.append(f"{'📄 summary.txt':<25} {'This summary (text format)':<75}")
    summary_lines.append(f"{'📊 summary.png':<25} {'Visual comparison (3-panel image)':<75}")
    summary_lines.append(f"{'📈 metrics.csv':<25} {'Training metrics over iterations (CSV format)':<75}")
    summary_lines.append(f"{'📉 metrics_plot.png':<25} {'Training metrics visualization (6-panel plot)':<75}")
    summary_lines.append(f"{'🧠 model.pt':<25} {'Trained 2D Gaussian model (PyTorch checkpoint)':<75}")
    summary_lines.append(f"{'🖼️  rendered.jpg':<25} {'Rendered output from model':<75}")
    summary_lines.append(f"{'📁 other/':<25} {'Training logs and additional files':<75}")
    summary_lines.append("=" * 100)
    
    return "\n".join(summary_lines)

In [None]:
def _save_summary_files(output_path, fig, summary_text):
    """Save visualization and summary text to files."""
    # Save visualization
    summary_img_path = os.path.join(output_path, "summary.png")
    fig.savefig(summary_img_path, dpi=150, bbox_inches='tight')
    plt.show()
    
    # Save summary text
    summary_txt_path = os.path.join(output_path, "summary.txt")
    with open(summary_txt_path, 'w') as f:
        f.write(summary_text)
    
    print(summary_text)
    print()
    print(f"💾 Saved: summary.txt")
    print(f"💾 Saved: summary.png")
    print()
    print(f"📁 Full path: {output_path}")

In [None]:
def view_results(output_folder):
    """
    View and analyze results from a training run.
    
    Args:
        output_folder: Name of folder in output/ (e.g., "cat-10000-5000-20251027_143052")
    """
    results_data = _load_training_results(output_folder)
    metrics_data = _calculate_quality_metrics(results_data['gt_img'], results_data['render_img'])
    fig = _create_results_visualization(results_data['gt_img'], results_data['render_img'], 
                                       metrics_data['diff_gray'], metrics_data['mean_diff'])
    summary_text = _format_summary_text(results_data, metrics_data)
    _save_summary_files(results_data['output_path'], fig, summary_text)

In [None]:
def upscale_render(output_folder, render_height):
    """
    Render at higher resolution using trained model.
    
    Args:
        output_folder: Name of folder in output/ with trained model
        render_height: Target height in pixels
    """
    ROOT_WORKSPACE, REPO_DIR, INPUT_DIR, OUTPUT_DIR = get_paths()
    os.chdir(REPO_DIR)
    
    output_path = os.path.join(OUTPUT_DIR, output_folder)
    if not os.path.exists(output_path):
        raise FileNotFoundError(f"Output folder not found: {output_path}")
    
    # Parse config
    parts = output_folder.rsplit("-", 3)
    base_name = parts[0]
    num_gaussians = int(parts[1])
    
    # Find original input
    input_files = glob.glob(os.path.join(INPUT_DIR, f"{base_name}.*"))
    if not input_files:
        raise FileNotFoundError(f"Original input not found for: {base_name}")
    input_filename = os.path.basename(input_files[0])
    
    # Ensure input is in media/images/
    media_input_path = os.path.join(REPO_DIR, "media", "images", input_filename)
    if not os.path.exists(media_input_path):
        shutil.copy2(input_files[0], media_input_path)
    
    # Run upscale
    temp_exp_name = f"temp/{output_folder}"
    cmd = f"""
    {sys.executable} main.py \
      --input_path="images/{input_filename}" \
      --exp_name="{temp_exp_name}" \
      --num_gaussians={num_gaussians} \
      --quantize \
      --eval \
      --render_height={render_height} \
      --device="cuda:0"
    """
    
    print("=" * 80)
    print(f"🔍 UPSCALE RENDER: {output_folder}")
    print("=" * 80)
    print(f"Target height: {render_height}px")
    print("=" * 80)
    print()
    
    os.system(cmd)
    
    # Copy upscaled render
    result_base = os.path.join(REPO_DIR, "results", temp_exp_name)
    run_dirs = [d for d in os.listdir(result_base) if os.path.isdir(os.path.join(result_base, d))]
    latest_run = sorted(run_dirs)[-1]
    eval_dir = os.path.join(result_base, latest_run, "eval")
    
    if os.path.exists(eval_dir):
        upscaled_renders = glob.glob(os.path.join(eval_dir, "render_*.jpg"))
        if upscaled_renders:
            upscaled_name = f"rendered_{render_height}px.jpg"
            shutil.copy2(upscaled_renders[0], os.path.join(output_path, upscaled_name))
            print()
            print("=" * 80)
            print("✅ UPSCALE COMPLETE")
            print("=" * 80)
            print(f"💾 Saved: {upscaled_name}")
            print(f"📁 Location: output/{output_folder}/")
            print("=" * 80)
        else:
            print("⚠️  Warning: Could not find upscaled render")
    else:
        print("⚠️  Warning: Eval directory not found")

In [None]:
def load_metrics_csv(output_folder):
    """
    Load training metrics from CSV file.
    
    Args:
        output_folder: Name of folder in output/ with training results
    
    Returns:
        pandas.DataFrame: Training metrics with columns:
            step, total_loss, l1_loss, l2_loss, ssim_loss, psnr, ssim,
            num_gaussians, num_bytes, render_time_accum, total_time_accum
    """
    import pandas as pd
    
    _, _, _, OUTPUT_DIR = get_paths()
    
    output_path = os.path.join(OUTPUT_DIR, output_folder)
    if not os.path.exists(output_path):
        raise FileNotFoundError(f"Output folder not found: {output_path}")
    
    metrics_csv_path = os.path.join(output_path, "metrics.csv")
    if not os.path.exists(metrics_csv_path):
        raise FileNotFoundError(f"Metrics CSV not found: {metrics_csv_path}")
    
    # Load CSV
    df = pd.read_csv(metrics_csv_path)
    
    # Convert empty strings to NaN for numeric columns
    numeric_cols = ['total_loss', 'l1_loss', 'l2_loss', 'ssim_loss']
    for col in numeric_cols:
        df[col] = pd.to_numeric(df[col], errors='coerce')
    
    return df

In [None]:
def plot_training_metrics(output_folder, save_plot=True, show_plot=True):
    """
    Plot comprehensive training metrics over iterations.
    
    Creates a multi-panel visualization showing:
    - Loss curves (total and components)
    - Quality metrics (PSNR, SSIM)
    - Model growth (size, gaussian count)
    - Timing information
    
    Args:
        output_folder: Name of folder in output/ with training results
        save_plot: Save plot as metrics_plot.png (default: True)
        show_plot: Display plot inline (default: True)
    """
    import pandas as pd
    
    # Load metrics
    df = load_metrics_csv(output_folder)
    
    _, _, _, OUTPUT_DIR = get_paths()
    output_path = os.path.join(OUTPUT_DIR, output_folder)
    
    # Parse config from folder name
    parts = output_folder.rsplit("-", 3)
    if len(parts) >= 3:
        base_name = parts[0]
        num_gaussians = int(parts[1])
        max_steps = int(parts[2])
    else:
        base_name = output_folder
        num_gaussians = "?"
        max_steps = "?"
    
    # Create figure with subplots
    fig = plt.figure(figsize=(20, 12))
    gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.25)
    
    # Title
    fig.suptitle(f'Training Metrics: {base_name} (G={num_gaussians}, Steps={max_steps})', 
                 fontsize=16, fontweight='bold', y=0.995)
    
    # 1. Loss curves (top left)
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.plot(df['step'], df['total_loss'], 'k-', linewidth=2, label='Total Loss')
    if df['l1_loss'].notna().any():
        ax1.plot(df['step'], df['l1_loss'], '--', alpha=0.7, label='L1 Loss')
    if df['l2_loss'].notna().any():
        ax1.plot(df['step'], df['l2_loss'], '--', alpha=0.7, label='L2 Loss')
    if df['ssim_loss'].notna().any():
        ax1.plot(df['step'], df['ssim_loss'], '--', alpha=0.7, label='SSIM Loss')
    ax1.set_xlabel('Training Step', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title('Loss Components', fontsize=14, fontweight='bold')
    ax1.legend(loc='best')
    ax1.grid(True, alpha=0.3)
    
    # 2. PSNR (top right)
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.plot(df['step'], df['psnr'], 'b-', linewidth=2)
    ax2.set_xlabel('Training Step', fontsize=12)
    ax2.set_ylabel('PSNR (dB)', fontsize=12)
    ax2.set_title('Peak Signal-to-Noise Ratio', fontsize=14, fontweight='bold')
    ax2.grid(True, alpha=0.3)
    # Add final value annotation
    final_psnr = df['psnr'].iloc[-1]
    ax2.axhline(y=final_psnr, color='b', linestyle='--', alpha=0.3)
    ax2.text(0.98, 0.02, f'Final: {final_psnr:.2f} dB', 
             transform=ax2.transAxes, ha='right', va='bottom',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # 3. SSIM (middle left)
    ax3 = fig.add_subplot(gs[1, 0])
    ax3.plot(df['step'], df['ssim'], 'g-', linewidth=2)
    ax3.set_xlabel('Training Step', fontsize=12)
    ax3.set_ylabel('SSIM', fontsize=12)
    ax3.set_title('Structural Similarity Index', fontsize=14, fontweight='bold')
    ax3.grid(True, alpha=0.3)
    # Add final value annotation
    final_ssim = df['ssim'].iloc[-1]
    ax3.axhline(y=final_ssim, color='g', linestyle='--', alpha=0.3)
    ax3.text(0.98, 0.02, f'Final: {final_ssim:.4f}', 
             transform=ax3.transAxes, ha='right', va='bottom',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # 4. Model size (middle right)
    ax4 = fig.add_subplot(gs[1, 1])
    # Convert bytes to KB
    df['size_kb'] = df['num_bytes'] / 1024
    ax4.plot(df['step'], df['size_kb'], 'r-', linewidth=2)
    ax4.set_xlabel('Training Step', fontsize=12)
    ax4.set_ylabel('Model Size (KB)', fontsize=12)
    ax4.set_title('Model Size Growth', fontsize=14, fontweight='bold')
    ax4.grid(True, alpha=0.3)
    # Add final value annotation
    final_size = df['size_kb'].iloc[-1]
    ax4.text(0.98, 0.98, f'Final: {final_size:.2f} KB', 
             transform=ax4.transAxes, ha='right', va='top',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # 5. Gaussian count (bottom left)
    ax5 = fig.add_subplot(gs[2, 0])
    ax5.plot(df['step'], df['num_gaussians'], 'm-', linewidth=2)
    ax5.set_xlabel('Training Step', fontsize=12)
    ax5.set_ylabel('Number of Gaussians', fontsize=12)
    ax5.set_title('Gaussian Count', fontsize=14, fontweight='bold')
    ax5.grid(True, alpha=0.3)
    # Add final value annotation
    final_gaussians = df['num_gaussians'].iloc[-1]
    ax5.text(0.98, 0.98, f'Final: {int(final_gaussians):,}', 
             transform=ax5.transAxes, ha='right', va='top',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # 6. Timing (bottom right)
    ax6 = fig.add_subplot(gs[2, 1])
    ax6.plot(df['step'], df['total_time_accum'], 'c-', linewidth=2, label='Total Time')
    ax6.plot(df['step'], df['render_time_accum'], 'orange', linewidth=2, label='Render Time')
    ax6.set_xlabel('Training Step', fontsize=12)
    ax6.set_ylabel('Time (seconds)', fontsize=12)
    ax6.set_title('Accumulated Time', fontsize=14, fontweight='bold')
    ax6.legend(loc='best')
    ax6.grid(True, alpha=0.3)
    # Add final value annotation
    final_time = df['total_time_accum'].iloc[-1]
    ax6.text(0.98, 0.02, f'Total: {final_time:.1f}s ({final_time/60:.1f}m)', 
             transform=ax6.transAxes, ha='right', va='bottom',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Save plot
    if save_plot:
        plot_path = os.path.join(output_path, "metrics_plot.png")
        fig.savefig(plot_path, dpi=150, bbox_inches='tight')
        print(f"💾 Saved: metrics_plot.png")
    
    # Show plot
    if show_plot:
        plt.show()
    else:
        plt.close(fig)
    
    if save_plot:
        print(f"📁 Location: {output_path}")

In [None]:
print("✅ Functions loaded!")

## Step 11: Configuration

Set your input image and training parameters.

In [None]:
# Place your input image in: /workspace/input/
# Example: /workspace/input/cat.png

INPUT_FILENAME = "cat.png"  # Just the filename from input/

# Training parameters - use lists for batch training
# Single value: GAUSSIANS = [5000]
# Multiple values: GAUSSIANS = [500, 1000, 2000, 5000, 10000]
GAUSSIANS = [5000]          # Number of Gaussians (5k-30k recommended)
STEPS = [3500]              # Training steps (3k-10k recommended)
USE_PROGRESSIVE = True      # Enable progressive optimization (recommended)

# Calculate total training runs
total_runs = len(GAUSSIANS) * len(STEPS)

print("=" * 80)
print("CONFIGURATION")
print("=" * 80)
print(f"Input:       {os.path.join(ROOT_WORKSPACE, 'input', INPUT_FILENAME)}")
print(f"Output:      {os.path.join(ROOT_WORKSPACE, 'output')}/")
print(f"Gaussians:   {GAUSSIANS}")
print(f"Steps:       {STEPS}")
print(f"Progressive: {USE_PROGRESSIVE}")
print(f"Total runs:  {total_runs} ({'x'.join([str(len(GAUSSIANS)), str(len(STEPS))])} combinations)")
print("=" * 80)

## Step 12: Training

Train the model and save to organized output folder.

In [None]:
import itertools
from datetime import datetime

# Generate all combinations
combinations = list(itertools.product(GAUSSIANS, STEPS))
total_combinations = len(combinations)

print("=" * 80)
print(f"BATCH TRAINING: {total_combinations} combination(s)")
print("=" * 80)
print()

# Store output folders for later viewing
output_folders = []

# Train each combination
for idx, (num_gaussians, max_steps) in enumerate(combinations, 1):
    print(f"{'═' * 80}")
    print(f"TRAINING {idx}/{total_combinations}")
    print(f"{'═' * 80}")
    print(f"Parameters: Gaussians={num_gaussians}, Steps={max_steps}")
    print()
    
    # Estimate total time
    est_time_min = max_steps * 0.002
    est_time_max = max_steps * 0.005
    
    # Train
    start_time = datetime.now()
    output_folder = train_image_gs(
        input_filename=INPUT_FILENAME,
        num_gaussians=num_gaussians,
        max_steps=max_steps,
        use_progressive=USE_PROGRESSIVE
    )
    end_time = datetime.now()
    elapsed = (end_time - start_time).total_seconds() / 60
    
    output_folders.append(output_folder)
    
    print(f"⏱️  Elapsed time: {elapsed:.2f} minutes")
    print(f"📁 Output: {output_folder}")
    print()

print("=" * 80)
print("✅ ALL TRAINING COMPLETE")
print("=" * 80)
print(f"Total runs:     {total_combinations}")
print(f"Output folders: {len(output_folders)}")
print()
print("Output folders list:")
for i, folder in enumerate(output_folders, 1):
    print(f"  {i}. {folder}")
print("=" * 80)

## Step 13: View Results

Display comparison visualization and save summary files.

In [None]:
# View results from batch training
# Options:
# 1. View all results: set VIEW_ALL = True
# 2. View specific result: set VIEW_ALL = False and specify VIEW_INDEX

VIEW_ALL = True           # View all results from batch training
VIEW_INDEX = 0            # Index to view (0-based, used when VIEW_ALL=False)

if VIEW_ALL:
    print("=" * 80)
    print(f"VIEWING ALL RESULTS ({len(output_folders)} total)")
    print("=" * 80)
    print()
    
    for idx, folder in enumerate(output_folders, 1):
        print(f"\n{'═' * 80}")
        print(f"RESULT {idx}/{len(output_folders)}: {folder}")
        print(f"{'═' * 80}\n")
        view_results(folder)
        print("\n")
else:
    if 0 <= VIEW_INDEX < len(output_folders):
        folder = output_folders[VIEW_INDEX]
        print(f"Viewing result {VIEW_INDEX + 1}/{len(output_folders)}: {folder}\n")
        view_results(folder)
    else:
        print(f"❌ Error: VIEW_INDEX={VIEW_INDEX} is out of range (0-{len(output_folders)-1})")
        print(f"Available folders:")
        for i, folder in enumerate(output_folders):
            print(f"  {i}. {folder}")

# Or manually specify a folder:
# view_results("cat-10000-5000-20251027_143052")

## Step 13b: Compare Batch Results (Optional)

Compare metrics across all trained models.

In [None]:
def _collect_batch_metrics(output_folders):
    """Collect metrics from all batch training results."""
    _, _, _, OUTPUT_DIR = get_paths()
    
    results = []
    
    for folder in output_folders:
        output_path = os.path.join(OUTPUT_DIR, folder)
        
        # Parse config
        parts = folder.rsplit("-", 3)
        if len(parts) < 3:
            continue
        
        base_name = parts[0]
        num_gaussians = int(parts[1])
        max_steps = int(parts[2])
        
        # File paths
        model_path = os.path.join(output_path, "model.pt")
        rendered_path = os.path.join(output_path, "rendered.jpg")
        gt_path = os.path.join(output_path, "other", "ground_truth.jpg")
        
        if not all([os.path.exists(model_path), os.path.exists(rendered_path), os.path.exists(gt_path)]):
            continue
        
        # Get sizes
        model_size = os.path.getsize(model_path)
        gt_size = os.path.getsize(gt_path)
        
        # Load images and calculate metrics
        gt_img = np.array(Image.open(gt_path)).astype(np.float32) / 255.0
        render_img = np.array(Image.open(rendered_path)).astype(np.float32) / 255.0
        
        diff = np.abs(gt_img - render_img)
        diff_gray = np.mean(diff, axis=2)
        
        mean_diff = np.mean(diff_gray)
        max_diff = np.max(diff_gray)
        
        # Calculate compression
        height, width, channels = gt_img.shape
        total_pixels = width * height
        bpp = (model_size * 8) / total_pixels
        compression_ratio = gt_size / model_size
        
        results.append({
            'folder': folder,
            'gaussians': num_gaussians,
            'steps': max_steps,
            'model_size_kb': model_size / 1024,
            'compression': compression_ratio,
            'bpp': bpp,
            'mean_diff': mean_diff,
            'max_diff': max_diff
        })
    
    # Sort by gaussians, then steps
    results.sort(key=lambda x: (x['gaussians'], x['steps']))
    
    return results

In [None]:
def _print_comparison_table(results):
    """Print comparison table for batch results."""
    print("=" * 120)
    print("BATCH TRAINING COMPARISON")
    print("=" * 120)
    print(f"{'Gaussians':<12} {'Steps':<8} {'Model Size':<14} {'Compression':<14} {'BPP':<10} {'Mean Diff':<14} {'Max Diff':<14}")
    print("-" * 120)
    
    for r in results:
        print(f"{r['gaussians']:<12} {r['steps']:<8} {r['model_size_kb']:>10.2f} KB  {r['compression']:>10.2f}x  {r['bpp']:>8.4f}  {r['mean_diff']:>10.6f}  {r['max_diff']:>10.6f}")
    
    print("=" * 120)

In [None]:
def _print_highlights(results):
    """Print highlights showing best compression and quality."""
    best_compression = max(results, key=lambda x: x['compression'])
    best_quality = min(results, key=lambda x: x['mean_diff'])
    
    print()
    print("HIGHLIGHTS:")
    print(f"  Best Compression: G={best_compression['gaussians']}, S={best_compression['steps']} -> {best_compression['compression']:.2f}x")
    print(f"  Best Quality:     G={best_quality['gaussians']}, S={best_quality['steps']} -> Mean Diff={best_quality['mean_diff']:.6f}")
    print("=" * 120)

In [None]:
def compare_batch_results(output_folders):
    """Compare metrics across all trained models."""
    results = _collect_batch_metrics(output_folders)
    
    if not results:
        print("No results to compare")
        return
    
    _print_comparison_table(results)
    _print_highlights(results)
    
    return results

In [None]:
# Run comparison if we have batch results
if len(output_folders) > 0:
    comparison_results = compare_batch_results(output_folders)
else:
    print("No batch results available. Run the training cell first.")

## Step 13c: View Training Metrics (Optional)

Plot detailed training metrics showing loss, quality, and size evolution over iterations.

In [None]:
# Plot training metrics for batch results
# Options:
# 1. Plot all results: set PLOT_ALL = True
# 2. Plot specific result: set PLOT_ALL = False and specify PLOT_INDEX

PLOT_ALL = True           # Plot all results from batch training
PLOT_INDEX = 0            # Index to plot (0-based, used when PLOT_ALL=False)

if PLOT_ALL:
    print("=" * 80)
    print(f"PLOTTING METRICS FOR ALL RESULTS ({len(output_folders)} total)")
    print("=" * 80)
    print()
    
    for idx, folder in enumerate(output_folders, 1):
        print(f"\n{'═' * 80}")
        print(f"METRICS {idx}/{len(output_folders)}: {folder}")
        print(f"{'═' * 80}\n")
        try:
            plot_training_metrics(folder)
        except FileNotFoundError as e:
            print(f"⚠️  Warning: {e}")
        print("\n")
else:
    if 0 <= PLOT_INDEX < len(output_folders):
        folder = output_folders[PLOT_INDEX]
        print(f"Plotting metrics {PLOT_INDEX + 1}/{len(output_folders)}: {folder}\n")
        plot_training_metrics(folder)
    else:
        print(f"❌ Error: PLOT_INDEX={PLOT_INDEX} is out of range (0-{len(output_folders)-1})")
        print(f"Available folders:")
        for i, folder in enumerate(output_folders):
            print(f"  {i}. {folder}")

# Or manually specify a folder:
# plot_training_metrics("cat-10000-5000-20251027_143052")

## Step 14: Optional - Upscale Render

Render at higher resolution using the trained model.

In [None]:
# Example: Render at 2x resolution
# upscale_render(
#     output_folder=output_folder,  # or specify manually: "cat-10000-5000-20251027_143052"
#     render_height=1024
# )

## Optional: Upscale Render

Render at higher resolution using the trained model.

In [None]:
# Example: Render at 2x resolution
# upscale_render(
#     input_image=INPUT_IMAGE,
#     exp_name=EXP_NAME,
#     num_gaussians=NUM_GAUSSIANS,
#     render_height=1024  # Adjust as needed
# )