# Image-GS Quick Start

Minimal setup for RTX 4090. Tested on runpod.io.

## Usage
1. Copy this notebook to your workspace (e.g., `/workspace/quick-start.ipynb`)
2. Run all cells to install dependencies and clone repository
3. Place images in `/workspace/input/`
4. Configure and train!

## Workspace Structure
```
/workspace/
├── quick-start.ipynb  (this notebook)
├── input/            (your images)
├── output/           (results)
└── image-gs/         (repository)
```

## Step 1: Install System Dependencies

In [None]:
import sys
import subprocess

print("Installing system dependencies...\n")

commands = [
    "apt-get update -qq",
    "apt-get install -y -qq build-essential git wget curl",
]

for cmd in commands:
    subprocess.run(cmd, shell=True, capture_output=True)

print("✓ System dependencies installed")

## Step 2: Upgrade pip

In [None]:
!{sys.executable} -m pip install --upgrade pip setuptools wheel -q

## Step 3: Install PyTorch (CUDA 12.1 for RTX 4090)

In [None]:
print("Installing PyTorch 2.4.1 with CUDA 12.1...\n")

!{sys.executable} -m pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121

import torch
print(f"\n✓ PyTorch {torch.__version__}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"✓ GPU: {torch.cuda.get_device_name(0)}")

## Step 4: Install Python Dependencies

In [None]:
print("Installing Python dependencies...\n")

dependencies = [
    "flip-evaluator",
    "lpips==0.1.4",
    "matplotlib==3.9.2",
    "numpy<2.1",
    "opencv-python==4.12.0.88",
    "pytorch-msssim==1.0.0",
    "scikit-image==0.24.0",
    "scipy==1.13.1",
    "torchmetrics==1.5.2",
    "jaxtyping",
    "rich>=12",
    "pyyaml==6.0",
    "ninja",
]

for dep in dependencies:
    subprocess.run([sys.executable, "-m", "pip", "install", dep, "-q"], capture_output=True)

print("✓ Python dependencies installed")

## Step 5: Clone Repository

In [None]:
import os

# Determine root workspace directory (where this notebook is located)
# On runpod.io this would be /workspace
ROOT_WORKSPACE = os.getcwd()

# Repository will be cloned here
REPO_DIR = os.path.join(ROOT_WORKSPACE, "image-gs")

if os.path.exists(REPO_DIR):
    print(f"Repository already exists at {REPO_DIR}")
    !cd {REPO_DIR} && git pull
else:
    print(f"Cloning repository to {REPO_DIR}...\n")
    !git clone https://github.com/NYU-ICL/image-gs.git {REPO_DIR}

print(f"\n✓ Repository: {REPO_DIR}")

## Step 6: Install fused-ssim

In [None]:
print("Installing fused-ssim...\n")

!{sys.executable} -m pip install git+https://github.com/rahul-goel/fused-ssim.git --no-build-isolation -q

print("✓ fused-ssim installed")

## Step 7: Install gsplat (with fix)

In [None]:
print("Installing gsplat CUDA extension...\n")
print("This will take 5-10 minutes.\n")

gsplat_dir = os.path.join(REPO_DIR, "gsplat")
os.chdir(gsplat_dir)

# Uninstall any existing installation
!{sys.executable} -m pip uninstall -y gsplat -q
!{sys.executable} -m pip cache purge -q

# Regular install (not editable) - this is the fix from step 9
!{sys.executable} -m pip install . --no-build-isolation

print("\n✓ gsplat installed")

## Step 8: Verify Installation

In [None]:
print("Verifying installation...\n")

# Force reimport
import importlib
for mod in list(sys.modules.keys()):
    if 'gsplat' in mod:
        del sys.modules[mod]

errors = []

try:
    import torch
    assert torch.cuda.is_available()
    print("✓ PyTorch with CUDA")
except Exception as e:
    errors.append(f"PyTorch: {e}")

try:
    from fused_ssim import fused_ssim
    print("✓ fused_ssim")
except Exception as e:
    errors.append(f"fused_ssim: {e}")

try:
    from gsplat import (
        project_gaussians_2d_scale_rot,
        rasterize_gaussians_no_tiles,
        rasterize_gaussians_sum,
    )
    print("✓ gsplat CUDA extensions")
except Exception as e:
    errors.append(f"gsplat: {e}")

try:
    os.chdir(REPO_DIR)
    sys.path.insert(0, REPO_DIR)
    from model import GaussianSplatting2D
    from utils.misc_utils import load_cfg
    print("✓ Image-GS modules")
except Exception as e:
    errors.append(f"Image-GS: {e}")

if errors:
    print(f"\n⚠️  {len(errors)} error(s):")
    for err in errors:
        print(f"  {err}")
else:
    print("\n✅ All components verified!")

## Step 9: Setup Workspace

Creates organized directory structure for inputs and outputs.

In [None]:
import glob
import shutil
import numpy as np
from datetime import datetime
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt

matplotlib.rcParams['text.usetex'] = False

# Get paths from global variables (set in previous cells)
# These will be set when Step 5 and Step 9 are run
def get_paths():
    """Get workspace paths, with fallback defaults. Create directories if needed."""
    global ROOT_WORKSPACE, REPO_DIR, INPUT_DIR, OUTPUT_DIR
    
    if 'ROOT_WORKSPACE' not in globals():
        ROOT_WORKSPACE = os.getcwd()
    if 'REPO_DIR' not in globals():
        REPO_DIR = os.path.join(ROOT_WORKSPACE, "image-gs")
    if 'INPUT_DIR' not in globals():
        INPUT_DIR = os.path.join(ROOT_WORKSPACE, "input")
    if 'OUTPUT_DIR' not in globals():
        OUTPUT_DIR = os.path.join(ROOT_WORKSPACE, "output")
    
    # Ensure input and output directories exist
    os.makedirs(INPUT_DIR, exist_ok=True)
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    return ROOT_WORKSPACE, REPO_DIR, INPUT_DIR, OUTPUT_DIR


def train_image_gs(input_filename, num_gaussians, max_steps, use_progressive=True):
    """
    Train Image-GS model with organized file management.
    
    Args:
        input_filename: Just the filename (e.g., "cat.png") from input/
        num_gaussians: Number of Gaussians
        max_steps: Training steps
        use_progressive: Enable progressive optimization (default: True)
    
    Returns:
        output_folder: Name of the created output folder
    """
    # Get workspace paths
    ROOT_WORKSPACE, REPO_DIR, INPUT_DIR, OUTPUT_DIR = get_paths()
    
    # Change to repository directory for training
    os.chdir(REPO_DIR)
    
    # Validate input
    input_path = os.path.join(INPUT_DIR, input_filename)
    if not os.path.exists(input_path):
        raise FileNotFoundError(f"Input image not found: {input_path}")
    
    # Create timestamped output folder name
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    base_name = os.path.splitext(input_filename)[0]
    output_folder = f"{base_name}-{num_gaussians}-{max_steps}-{timestamp}"
    output_path = os.path.join(OUTPUT_DIR, output_folder)
    os.makedirs(output_path, exist_ok=True)
    os.makedirs(os.path.join(output_path, "other"), exist_ok=True)
    
    # Copy input to media/images/ (relative to repo)
    media_input_path = os.path.join(REPO_DIR, "media", "images", input_filename)
    os.makedirs(os.path.join(REPO_DIR, "media", "images"), exist_ok=True)
    shutil.copy2(input_path, media_input_path)
    
    # Setup temporary experiment name
    temp_exp_name = f"temp/{output_folder}"
    
    # Build command with optional progressive optimization flag
    prog_flag = "" if use_progressive else "--disable_prog_optim"
    
    # Run training
    cmd = f"""
    {sys.executable} main.py \
      --input_path="images/{input_filename}" \
      --exp_name="{temp_exp_name}" \
      --num_gaussians={num_gaussians} \
      --max_steps={max_steps} \
      --quantize \
      {prog_flag} \
      --device="cuda:0"
    """
    
    print("=" * 80)
    print(f"🚀 TRAINING: {input_filename}")
    print("=" * 80)
    print(f"Gaussians:   {num_gaussians}")
    print(f"Steps:       {max_steps}")
    print(f"Progressive: {use_progressive}")
    print(f"Output:      output/{output_folder}/")
    print(f"Time est:    ~{max_steps * 0.002:.1f}-{max_steps * 0.005:.1f} minutes")
    print("=" * 80)
    print()
    
    os.system(cmd)
    
    # Organize outputs
    result_base = os.path.join(REPO_DIR, "results", temp_exp_name)
    run_dirs = [d for d in os.listdir(result_base) if os.path.isdir(os.path.join(result_base, d))]
    latest_run = sorted(run_dirs)[-1]
    result_dir = os.path.join(result_base, latest_run)
    
    # Copy key files to output folder
    # 1. Model checkpoint
    ckpt_dir = os.path.join(result_dir, "checkpoints")
    ckpt_files = glob.glob(os.path.join(ckpt_dir, "ckpt_step-*.pt"))
    if ckpt_files:
        latest_ckpt = sorted(ckpt_files)[-1]
        shutil.copy2(latest_ckpt, os.path.join(output_path, "model.pt"))
    
    # 2. Rendered image
    renders = glob.glob(os.path.join(result_dir, "render_res-*.jpg"))
    if renders:
        shutil.copy2(renders[0], os.path.join(output_path, "rendered.jpg"))
    
    # 3. Ground truth image
    gts = glob.glob(os.path.join(result_dir, "gt_res-*.jpg"))
    if gts:
        shutil.copy2(gts[0], os.path.join(output_path, "other", "ground_truth.jpg"))
    
    # 4. Training log
    log_file = os.path.join(result_dir, "log_train.txt")
    if os.path.exists(log_file):
        shutil.copy2(log_file, os.path.join(output_path, "other", "log_train.txt"))
    
    # 5. Copy all other files to "other" subdirectory
    for item in os.listdir(result_dir):
        item_path = os.path.join(result_dir, item)
        if os.path.isfile(item_path):
            # Skip files we already copied
            if item not in ["log_train.txt"] and not item.startswith("render_res-") and not item.startswith("gt_res-"):
                shutil.copy2(item_path, os.path.join(output_path, "other", item))
    
    print()
    print("=" * 80)
    print("✅ TRAINING COMPLETE")
    print("=" * 80)
    print(f"📁 Output folder: output/{output_folder}/")
    print()
    
    return output_folder


def view_results(output_folder):
    """
    View and analyze results from a training run.
    
    Args:
        output_folder: Name of folder in output/ (e.g., "cat-10000-5000-20251027_143052")
    """
    # Get workspace paths
    ROOT_WORKSPACE, REPO_DIR, INPUT_DIR, OUTPUT_DIR = get_paths()
    
    output_path = os.path.join(OUTPUT_DIR, output_folder)
    
    if not os.path.exists(output_path):
        raise FileNotFoundError(f"Output folder not found: {output_path}")
    
    # Load files
    model_path = os.path.join(output_path, "model.pt")
    rendered_path = os.path.join(output_path, "rendered.jpg")
    gt_path = os.path.join(output_path, "other", "ground_truth.jpg")
    
    # Parse config from folder name
    parts = output_folder.rsplit("-", 3)
    if len(parts) >= 3:
        base_name = parts[0]
        num_gaussians = int(parts[1])
        max_steps = int(parts[2])
    else:
        base_name = output_folder
        num_gaussians = "?"
        max_steps = "?"
    
    # File sizes
    model_size = os.path.getsize(model_path) if os.path.exists(model_path) else None
    
    # Load images
    gt_img = np.array(Image.open(gt_path)).astype(np.float32) / 255.0
    render_img = np.array(Image.open(rendered_path)).astype(np.float32) / 255.0
    
    # Calculate difference
    diff = np.abs(gt_img - render_img)
    diff_gray = np.mean(diff, axis=2)
    
    # Statistics
    mean_diff = np.mean(diff_gray)
    max_diff = np.max(diff_gray)
    std_diff = np.std(diff_gray)
    
    # Image info
    height, width, channels = gt_img.shape
    total_pixels = width * height
    gt_size = os.path.getsize(gt_path)
    render_size = os.path.getsize(rendered_path)
    
    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    axes[0].imshow(gt_img)
    axes[0].set_title("Ground Truth", fontsize=14, fontweight='bold')
    axes[0].axis('off')
    
    axes[1].imshow(render_img)
    axes[1].set_title("2D Gaussians (Rendered)", fontsize=14, fontweight='bold')
    axes[1].axis('off')
    
    im = axes[2].imshow(diff_gray, cmap='hot', vmin=0, vmax=0.2)
    axes[2].set_title(f"Difference Map\nMean: {mean_diff:.4f} ({mean_diff*100:.2f}%)", fontsize=14, fontweight='bold')
    axes[2].axis('off')
    
    cbar = plt.colorbar(im, ax=axes[2], fraction=0.046, pad=0.04)
    cbar.set_label('Absolute Difference', rotation=270, labelpad=20)
    
    plt.tight_layout()
    
    # Save visualization
    summary_img_path = os.path.join(output_path, "summary.png")
    plt.savefig(summary_img_path, dpi=150, bbox_inches='tight')
    
    plt.show()
    
    # Format helpers
    def fmt_size(size_bytes):
        if size_bytes is None:
            return "N/A"
        if size_bytes < 1024:
            return f"{size_bytes} B"
        elif size_bytes < 1024 * 1024:
            return f"{size_bytes / 1024:.2f} KB"
        else:
            return f"{size_bytes / (1024 * 1024):.2f} MB"
    
    def fmt_ratio(num, denom):
        if num is None or denom is None:
            return "N/A"
        return f"{num / denom:.2f}x"
    
    # Build summary text with tables
    summary_lines = []
    summary_lines.append("=" * 100)
    summary_lines.append("IMAGE-GS TRAINING SUMMARY")
    summary_lines.append("=" * 100)
    summary_lines.append("")
    
    # Table 1: Training Configuration
    summary_lines.append("TRAINING CONFIGURATION")
    summary_lines.append("-" * 100)
    summary_lines.append(f"{'Output Folder':<25} {'Base Name':<20} {'Gaussians':<15} {'Steps':<15}")
    summary_lines.append(f"{output_folder:<25} {base_name:<20} {num_gaussians:<15} {max_steps:<15}")
    summary_lines.append("")
    
    # Table 2: Image & File Information
    summary_lines.append("IMAGE & FILE INFORMATION")
    summary_lines.append("-" * 100)
    summary_lines.append(f"{'Metric':<30} {'Value':<30} {'Metric':<30} {'Value':<30}")
    summary_lines.append(f"{'Resolution':<30} {f'{width} x {height} px':<30} {'Total Pixels':<30} {f'{total_pixels:,}':<30}")
    summary_lines.append(f"{'Channels':<30} {channels:<30} {'Ground Truth Size':<30} {fmt_size(gt_size):<30}")
    summary_lines.append(f"{'Model Size':<30} {fmt_size(model_size):<30} {'Rendered Size':<30} {fmt_size(render_size):<30}")
    summary_lines.append("")
    
    # Table 3: Compression Analysis
    if model_size:
        uncompressed_size = total_pixels * channels
        bpp = (model_size * 8) / total_pixels
        compression_vs_gt = gt_size / model_size
        compression_vs_raw = uncompressed_size / model_size
        
        summary_lines.append("COMPRESSION ANALYSIS")
        summary_lines.append("-" * 100)
        summary_lines.append(f"{'Metric':<40} {'Value':<30} {'Note':<30}")
        summary_lines.append(f"{'Compression vs Original (JPG)':<40} {fmt_ratio(gt_size, model_size):<30} {'Model is {:.1f}% of original'.format((model_size/gt_size)*100):<30}")
        summary_lines.append(f"{'Compression vs Raw (uncompressed)':<40} {fmt_ratio(uncompressed_size, model_size):<30} {f'{fmt_size(uncompressed_size)} -> {fmt_size(model_size)}':<30}")
        summary_lines.append(f"{'Bits Per Pixel (bpp)':<40} {f'{bpp:.4f} bpp':<30} {'Lower is better':<30}")
        summary_lines.append("")
    
    # Table 4: Quality Metrics
    pix_1pct = np.sum(diff_gray > 0.01) / diff_gray.size * 100
    pix_5pct = np.sum(diff_gray > 0.05) / diff_gray.size * 100
    pix_10pct = np.sum(diff_gray > 0.10) / diff_gray.size * 100
    
    summary_lines.append("QUALITY METRICS")
    summary_lines.append("-" * 100)
    summary_lines.append(f"{'Metric':<30} {'Value':<30} {'Metric':<30} {'Value':<30}")
    summary_lines.append(f"{'Mean Difference':<30} {f'{mean_diff:.6f} ({mean_diff*100:.2f}%)':<30} {'Max Difference':<30} {f'{max_diff:.6f} ({max_diff*100:.2f}%)':<30}")
    summary_lines.append(f"{'Std Deviation':<30} {f'{std_diff:.6f}':<30} {'Pixels > 1% diff':<30} {f'{pix_1pct:.2f}%':<30}")
    summary_lines.append(f"{'Pixels > 5% diff':<30} {f'{pix_5pct:.2f}%':<30} {'Pixels > 10% diff':<30} {f'{pix_10pct:.2f}%':<30}")
    summary_lines.append("")
    
    # Files saved
    summary_lines.append("FILES SAVED")
    summary_lines.append("-" * 100)
    summary_lines.append(f"{'File':<25} {'Description':<75}")
    summary_lines.append(f"{'📄 summary.txt':<25} {'This summary (text format)':<75}")
    summary_lines.append(f"{'📊 summary.png':<25} {'Visual comparison (3-panel image)':<75}")
    summary_lines.append(f"{'🧠 model.pt':<25} {'Trained 2D Gaussian model (PyTorch checkpoint)':<75}")
    summary_lines.append(f"{'🖼️  rendered.jpg':<25} {'Rendered output from model':<75}")
    summary_lines.append(f"{'📁 other/':<25} {'Training logs and additional files':<75}")
    summary_lines.append("=" * 100)
    
    summary_text = "\n".join(summary_lines)
    print(summary_text)
    
    # Save summary text
    summary_txt_path = os.path.join(output_path, "summary.txt")
    with open(summary_txt_path, 'w') as f:
        f.write(summary_text)
    
    print()
    print(f"💾 Saved: summary.txt")
    print(f"💾 Saved: summary.png")
    print()
    print(f"📁 Full path: {output_path}")


def upscale_render(output_folder, render_height):
    """
    Render at higher resolution using trained model.
    
    Args:
        output_folder: Name of folder in output/ with trained model
        render_height: Target height in pixels
    """
    # Get workspace paths
    ROOT_WORKSPACE, REPO_DIR, INPUT_DIR, OUTPUT_DIR = get_paths()
    
    # Change to repository directory
    os.chdir(REPO_DIR)
    
    output_path = os.path.join(OUTPUT_DIR, output_folder)
    if not os.path.exists(output_path):
        raise FileNotFoundError(f"Output folder not found: {output_path}")
    
    # Parse config from folder name
    parts = output_folder.rsplit("-", 3)
    base_name = parts[0]
    num_gaussians = int(parts[1])
    timestamp = parts[3]
    
    # Find original input file
    input_files = glob.glob(os.path.join(INPUT_DIR, f"{base_name}.*"))
    if not input_files:
        raise FileNotFoundError(f"Original input not found for: {base_name}")
    input_filename = os.path.basename(input_files[0])
    
    # Ensure input is in media/images/
    media_input_path = os.path.join(REPO_DIR, "media", "images", input_filename)
    if not os.path.exists(media_input_path):
        shutil.copy2(input_files[0], media_input_path)
    
    # Setup experiment name (should match training)
    temp_exp_name = f"temp/{output_folder}"
    
    # Run upscale render
    cmd = f"""
    {sys.executable} main.py \
      --input_path="images/{input_filename}" \
      --exp_name="{temp_exp_name}" \
      --num_gaussians={num_gaussians} \
      --quantize \
      --eval \
      --render_height={render_height} \
      --device="cuda:0"
    """
    
    print("=" * 80)
    print(f"🔍 UPSCALE RENDER: {output_folder}")
    print("=" * 80)
    print(f"Target height: {render_height}px")
    print("=" * 80)
    print()
    
    os.system(cmd)
    
    # Find and copy upscaled render
    result_base = os.path.join(REPO_DIR, "results", temp_exp_name)
    run_dirs = [d for d in os.listdir(result_base) if os.path.isdir(os.path.join(result_base, d))]
    latest_run = sorted(run_dirs)[-1]
    eval_dir = os.path.join(result_base, latest_run, "eval")
    
    if os.path.exists(eval_dir):
        upscaled_renders = glob.glob(os.path.join(eval_dir, "render_*.jpg"))
        if upscaled_renders:
            upscaled_name = f"rendered_{render_height}px.jpg"
            shutil.copy2(upscaled_renders[0], os.path.join(output_path, upscaled_name))
            print()
            print("=" * 80)
            print("✅ UPSCALE COMPLETE")
            print("=" * 80)
            print(f"💾 Saved: {upscaled_name}")
            print(f"📁 Location: output/{output_folder}/")
            print("=" * 80)
        else:
            print("⚠️  Warning: Could not find upscaled render")
    else:
        print("⚠️  Warning: Eval directory not found")

print("✅ Functions loaded!")

## Step 11: Configuration

Set your input image and training parameters.

In [None]:
# Place your input image in: /workspace/input/
# Example: /workspace/input/cat.png

INPUT_FILENAME = "cat.png"  # Just the filename from input/

# Training parameters
NUM_GAUSSIANS = 10000       # More = better quality (5k-30k recommended)
MAX_STEPS = 5000            # More = better convergence (3k-10k recommended)
USE_PROGRESSIVE = True      # Enable progressive optimization (recommended)

print("=" * 80)
print("CONFIGURATION")
print("=" * 80)
print(f"Input:       {os.path.join(ROOT_WORKSPACE, 'input', INPUT_FILENAME)}")
print(f"Output:      {os.path.join(ROOT_WORKSPACE, 'output')}/{INPUT_FILENAME.split('.')[0]}-{NUM_GAUSSIANS}-{MAX_STEPS}-[timestamp]/")
print(f"Gaussians:   {NUM_GAUSSIANS}")
print(f"Steps:       {MAX_STEPS}")
print(f"Progressive: {USE_PROGRESSIVE}")
print("=" * 80)

## Step 12: Training

Train the model and save to organized output folder.

In [None]:
output_folder = train_image_gs(
    input_filename=INPUT_FILENAME,
    num_gaussians=NUM_GAUSSIANS,
    max_steps=MAX_STEPS,
    use_progressive=USE_PROGRESSIVE
)

print(f"✅ Training complete!")
print(f"📁 Output: {os.path.join(ROOT_WORKSPACE, 'output', output_folder)}")

## Step 13: View Results

Display comparison visualization and save summary files.

In [None]:
# Use the output_folder from training
view_results(output_folder)

# Or manually specify a folder:
# view_results("cat-10000-5000-20251027_143052")

## Step 14: Optional - Upscale Render

Render at higher resolution using the trained model.

In [None]:
# Example: Render at 2x resolution
# upscale_render(
#     output_folder=output_folder,  # or specify manually: "cat-10000-5000-20251027_143052"
#     render_height=1024
# )

## Optional: Upscale Render

Render at higher resolution using the trained model.

In [None]:
# Example: Render at 2x resolution
# upscale_render(
#     input_image=INPUT_IMAGE,
#     exp_name=EXP_NAME,
#     num_gaussians=NUM_GAUSSIANS,
#     render_height=1024  # Adjust as needed
# )