# Image-GS Quick Start
Minimal setup for RTX 4090. Tested on runpod.io.

## Step 1: Install System Dependencies

In [None]:
import sys
import subprocess

print("Installing system dependencies...\n")

commands = [
    "apt-get update -qq",
    "apt-get install -y -qq build-essential git wget curl",
]

for cmd in commands:
    subprocess.run(cmd, shell=True, capture_output=True)

print("✓ System dependencies installed")

## Step 2: Upgrade pip

In [None]:
!{sys.executable} -m pip install --upgrade pip setuptools wheel -q

## Step 3: Install PyTorch (CUDA 12.1 for RTX 4090)

In [None]:
print("Installing PyTorch 2.4.1 with CUDA 12.1...\n")

!{sys.executable} -m pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121

import torch
print(f"\n✓ PyTorch {torch.__version__}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"✓ GPU: {torch.cuda.get_device_name(0)}")

## Step 4: Install Python Dependencies

In [None]:
print("Installing Python dependencies...\n")

dependencies = [
    "flip-evaluator",
    "lpips==0.1.4",
    "matplotlib==3.9.2",
    "numpy<2.1",
    "opencv-python==4.12.0.88",
    "pytorch-msssim==1.0.0",
    "scikit-image==0.24.0",
    "scipy==1.13.1",
    "torchmetrics==1.5.2",
    "jaxtyping",
    "rich>=12",
    "pyyaml==6.0",
    "ninja",
]

for dep in dependencies:
    subprocess.run([sys.executable, "-m", "pip", "install", dep, "-q"], capture_output=True)

print("✓ Python dependencies installed")

## Step 5: Clone Repository

In [None]:
import os

REPO_DIR = "/workspace/image-gs"

if os.path.exists(REPO_DIR):
    print(f"Repository already exists at {REPO_DIR}")
    !cd {REPO_DIR} && git pull
else:
    print(f"Cloning repository to {REPO_DIR}...\n")
    !git clone https://github.com/NYU-ICL/image-gs.git {REPO_DIR}

os.chdir(REPO_DIR)
print(f"\n✓ Working directory: {os.getcwd()}")

## Step 6: Install fused-ssim

In [None]:
print("Installing fused-ssim...\n")

!{sys.executable} -m pip install git+https://github.com/rahul-goel/fused-ssim.git --no-build-isolation -q

print("✓ fused-ssim installed")

## Step 7: Install gsplat (with fix)

In [None]:
print("Installing gsplat CUDA extension...\n")
print("This will take 5-10 minutes.\n")

os.chdir("/workspace/image-gs/gsplat")

# Uninstall any existing installation
!{sys.executable} -m pip uninstall -y gsplat -q
!{sys.executable} -m pip cache purge -q

# Regular install (not editable) - this is the fix from step 9
!{sys.executable} -m pip install . --no-build-isolation

os.chdir("/workspace/image-gs")
print("\n✓ gsplat installed")

## Step 8: Verify Installation

In [None]:
print("Verifying installation...\n")

# Force reimport
import importlib
for mod in list(sys.modules.keys()):
    if 'gsplat' in mod:
        del sys.modules[mod]

errors = []

try:
    import torch
    assert torch.cuda.is_available()
    print("✓ PyTorch with CUDA")
except Exception as e:
    errors.append(f"PyTorch: {e}")

try:
    from fused_ssim import fused_ssim
    print("✓ fused_ssim")
except Exception as e:
    errors.append(f"fused_ssim: {e}")

try:
    from gsplat import (
        project_gaussians_2d_scale_rot,
        rasterize_gaussians_no_tiles,
        rasterize_gaussians_sum,
    )
    print("✓ gsplat CUDA extensions")
except Exception as e:
    errors.append(f"gsplat: {e}")

try:
    sys.path.insert(0, os.getcwd())
    from model import GaussianSplatting2D
    from utils.misc_utils import load_cfg
    print("✓ Image-GS modules")
except Exception as e:
    errors.append(f"Image-GS: {e}")

if errors:
    print(f"\n⚠️  {len(errors)} error(s):")
    for err in errors:
        print(f"  {err}")
else:
    print("\n✅ All components verified!")

## Step 9: Define Functions

In [None]:
import glob
import numpy as np
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt

matplotlib.rcParams['text.usetex'] = False

def train_image_gs(input_image, exp_name, num_gaussians, max_steps):
    """
    Train Image-GS model.
    
    Args:
        input_image: Path relative to media/ (e.g., "images/cat.png")
        exp_name: Experiment name (e.g., "test/cat")
        num_gaussians: Number of Gaussians
        max_steps: Training steps
    """
    os.chdir("/workspace/image-gs")
    
    cmd = f"""
    {sys.executable} main.py \
      --input_path="{input_image}" \
      --exp_name="{exp_name}" \
      --num_gaussians={num_gaussians} \
      --max_steps={max_steps} \
      --quantize \
      --device="cuda:0"
    """
    
    print("🚀 Training...")
    print(f"Estimated time: ~{max_steps * 0.002:.1f}-{max_steps * 0.005:.1f} minutes\n")
    
    os.system(cmd)
    
    print(f"\n✅ Complete! Results: results/{exp_name}/")


def upscale_render(input_image, exp_name, num_gaussians, render_height):
    """
    Render at higher resolution (upscaling).
    
    Args:
        input_image: Original input path
        exp_name: Experiment name from training
        num_gaussians: Number of Gaussians used
        render_height: Target height in pixels
    """
    os.chdir("/workspace/image-gs")
    
    cmd = f"""
    {sys.executable} main.py \
      --input_path="{input_image}" \
      --exp_name="{exp_name}" \
      --num_gaussians={num_gaussians} \
      --quantize \
      --eval \
      --render_height={render_height} \
      --device="cuda:0"
    """
    
    print(f"🔍 Rendering at {render_height}px height...\n")
    
    os.system(cmd)
    
    print(f"\n✅ Upscaled render complete! Check results/{exp_name}/.../eval/")


def view_results(exp_name, input_image, num_gaussians, max_steps, save_results=True):
    """
    View and analyze results.
    
    Args:
        exp_name: Experiment name
        input_image: Original input path
        num_gaussians: Number of Gaussians used
        max_steps: Steps used
        save_results: Save comparison image and summary
    """
    result_base = f"results/{exp_name}"
    
    run_dirs = [d for d in os.listdir(result_base) if os.path.isdir(os.path.join(result_base, d))]
    latest_run = sorted(run_dirs)[-1]
    result_dir = os.path.join(result_base, latest_run)
    
    if save_results:
        output_dir = os.path.join(result_dir, "analysis")
        os.makedirs(output_dir, exist_ok=True)
    
    renders = glob.glob(os.path.join(result_dir, "render_res-*.jpg"))
    gts = glob.glob(os.path.join(result_dir, "gt_res-*.jpg"))
    
    # Find checkpoint
    ckpt_dir = os.path.join(result_dir, "checkpoints")
    ckpt_files = glob.glob(os.path.join(ckpt_dir, "ckpt_step-*.pt"))
    if ckpt_files:
        latest_ckpt = sorted(ckpt_files)[-1]
        model_size = os.path.getsize(latest_ckpt)
    else:
        model_size = None
        latest_ckpt = None
    
    # Load images
    gt_img = np.array(Image.open(gts[0])).astype(np.float32) / 255.0
    render_img = np.array(Image.open(renders[0])).astype(np.float32) / 255.0
    
    # Calculate difference
    diff = np.abs(gt_img - render_img)
    diff_gray = np.mean(diff, axis=2)
    
    # Statistics
    mean_diff = np.mean(diff_gray)
    max_diff = np.max(diff_gray)
    
    # File info
    gt_size = os.path.getsize(gts[0])
    render_size = os.path.getsize(renders[0])
    height, width, channels = gt_img.shape
    total_pixels = width * height
    
    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    axes[0].imshow(gt_img)
    axes[0].set_title("Ground Truth", fontsize=14)
    axes[0].axis('off')
    
    axes[1].imshow(render_img)
    axes[1].set_title("2D Gaussians", fontsize=14)
    axes[1].axis('off')
    
    im = axes[2].imshow(diff_gray, cmap='hot', vmin=0, vmax=1)
    axes[2].set_title(f"Difference\n(Mean: {mean_diff:.4f})", fontsize=14)
    axes[2].axis('off')
    
    cbar = plt.colorbar(im, ax=axes[2], fraction=0.046, pad=0.04)
    cbar.set_label('Difference', rotation=270, labelpad=20)
    
    plt.tight_layout()
    
    if save_results:
        comparison_path = os.path.join(output_dir, "comparison.png")
        plt.savefig(comparison_path, dpi=150, bbox_inches='tight')
        print(f"💾 Saved: {comparison_path}")
    
    plt.show()
    
    # Format helpers
    def fmt_size(size_bytes):
        if size_bytes is None:
            return "N/A"
        if size_bytes < 1024 * 1024:
            return f"{size_bytes / 1024:.2f} KB"
        return f"{size_bytes / (1024 * 1024):.2f} MB"
    
    def fmt_ratio(num, denom):
        if num is None or denom is None:
            return "N/A"
        return f"{num / denom:.2f}x"
    
    # Print summary
    print("\n" + "=" * 80)
    print("SUMMARY")
    print("=" * 80)
    print(f"Input:              {input_image}")
    print(f"Experiment:         {exp_name}")
    print(f"Gaussians:          {num_gaussians}")
    print(f"Steps:              {max_steps}")
    print()
    print(f"Resolution:         {width}x{height} ({total_pixels:,} pixels)")
    print(f"Original size:      {fmt_size(gt_size)}")
    print(f"Model size:         {fmt_size(model_size)}")
    print(f"Compression:        {fmt_ratio(gt_size, model_size)}")
    print()
    print(f"Mean difference:    {mean_diff:.6f} ({mean_diff*100:.4f}%)")
    print(f"Max difference:     {max_diff:.6f} ({max_diff*100:.4f}%)")
    print(f">5% diff pixels:    {np.sum(diff_gray > 0.05) / diff_gray.size * 100:.2f}%")
    print("=" * 80)
    
    if save_results:
        summary_path = os.path.join(output_dir, "summary.txt")
        with open(summary_path, 'w') as f:
            f.write(f"Experiment: {exp_name}\n")
            f.write(f"Input: {input_image}\n")
            f.write(f"Gaussians: {num_gaussians}\n")
            f.write(f"Steps: {max_steps}\n\n")
            f.write(f"Resolution: {width}x{height}\n")
            f.write(f"Original: {fmt_size(gt_size)}\n")
            f.write(f"Model: {fmt_size(model_size)}\n")
            f.write(f"Compression: {fmt_ratio(gt_size, model_size)}\n\n")
            f.write(f"Mean diff: {mean_diff:.6f}\n")
            f.write(f"Max diff: {max_diff:.6f}\n")
        print(f"💾 Saved: {summary_path}")

print("✅ Functions loaded!")

## Configuration

In [None]:
# Input image (relative to /workspace/image-gs/media/)
INPUT_IMAGE = "images/your_image.png"

# Experiment name (results saved to results/{EXP_NAME}/)
EXP_NAME = "test/demo"

# Training parameters
NUM_GAUSSIANS = 10000  # More = better quality (5k-30k recommended)
MAX_STEPS = 5000       # More = better convergence (3k-10k recommended)

print("=" * 60)
print("CONFIGURATION")
print("=" * 60)
print(f"Input:       media/{INPUT_IMAGE}")
print(f"Output:      results/{EXP_NAME}/")
print(f"Gaussians:   {NUM_GAUSSIANS}")
print(f"Steps:       {MAX_STEPS}")
print("=" * 60)

## Training

In [None]:
train_image_gs(
    input_image=INPUT_IMAGE,
    exp_name=EXP_NAME,
    num_gaussians=NUM_GAUSSIANS,
    max_steps=MAX_STEPS
)

## View Results

In [None]:
view_results(
    exp_name=EXP_NAME,
    input_image=INPUT_IMAGE,
    num_gaussians=NUM_GAUSSIANS,
    max_steps=MAX_STEPS,
    save_results=True
)

## Optional: Upscale Render

Render at higher resolution using the trained model.

In [None]:
# Example: Render at 2x resolution
# upscale_render(
#     input_image=INPUT_IMAGE,
#     exp_name=EXP_NAME,
#     num_gaussians=NUM_GAUSSIANS,
#     render_height=1024  # Adjust as needed
# )