# SV-SCN Training Notebook

Train the Single-View Shape Completion Network on Google Colab.

**Requirements:**
- GPU runtime (T4/V100/A100)
- ~50GB disk space for data
- 6-12 hours for full training

**Quick Start:**
1. Enable GPU: Runtime → Change runtime type → GPU
2. Run all cells in order
3. Checkpoints saved to Google Drive

## 1. Setup Environment

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Mount Google Drive for persistent storage
from google.colab import drive
drive.mount('/content/drive')

# Create project directory in Drive
!mkdir -p /content/drive/MyDrive/svscn
!mkdir -p /content/drive/MyDrive/svscn/checkpoints
!mkdir -p /content/drive/MyDrive/svscn/data

In [None]:
# Clone repository (or upload your code)
# Option 1: From GitHub
# !git clone https://github.com/yourusername/frozo-3d-model.git

# Option 2: Upload from local
# Use Colab file upload or copy from Drive
!mkdir -p /content/frozo-3d-model
%cd /content/frozo-3d-model

In [None]:
# Install dependencies
!pip install torch>=2.0.0 numpy>=1.24.0 open3d>=0.17.0 trimesh>=4.0.0 \
    objaverse>=0.1.7 tqdm>=4.65.0 tensorboard>=2.14.0 scipy

## 2. Prepare Dataset

In [None]:
import os
import sys
sys.path.insert(0, '/content/frozo-3d-model')

# Check if data already exists in Drive
DATA_DIR = '/content/drive/MyDrive/svscn/data/combined'

if os.path.exists(f'{DATA_DIR}/dataset_metadata.json'):
    print('Dataset found in Drive!')
else:
    print('Dataset not found. Will prepare...')

In [None]:
# Option A: Use placeholder data for quick testing
USE_PLACEHOLDER = True  # Set to False for real data

if USE_PLACEHOLDER:
    from svscn.data.shapenet import download_shapenet_sample
    from pathlib import Path
    
    LOCAL_DATA = Path('/content/data')
    download_shapenet_sample(LOCAL_DATA / 'shapenet')
    print('Placeholder data created!')

In [None]:
# Option B: Download Objaverse data (takes ~1-2 hours)
# Only run if USE_PLACEHOLDER = False

if not USE_PLACEHOLDER:
    from svscn.data.dataset_manager import prepare_combined_dataset
    from pathlib import Path
    
    prepare_combined_dataset(
        output_dir=Path(DATA_DIR),
        config=None  # Uses default config
    )

In [None]:
# Preprocess meshes to point clouds
from svscn.data.preprocess import process_dataset
from pathlib import Path

if USE_PLACEHOLDER:
    input_dir = Path('/content/data/shapenet')
    output_dir = Path('/content/data/processed')
else:
    input_dir = Path(DATA_DIR) / 'raw'
    output_dir = Path(DATA_DIR) / 'pointclouds'

process_dataset(input_dir, output_dir, num_points=8192)
print(f'Point clouds saved to {output_dir}')

In [None]:
# Generate training pairs (partial + full)
from svscn.data.augment import process_to_training_data
from pathlib import Path

if USE_PLACEHOLDER:
    full_dir = Path('/content/data/processed')
    train_dir = Path('/content/data/training')
else:
    full_dir = Path(DATA_DIR) / 'pointclouds'
    train_dir = Path(DATA_DIR) / 'training'

num_pairs = process_to_training_data(
    full_clouds_dir=full_dir,
    output_dir=train_dir,
    views_per_object=3
)
print(f'Created {num_pairs} training pairs')

## 3. Create Data Loaders

In [None]:
from svscn.data.dataset import FurnitureDataset, create_data_loaders
from pathlib import Path
import torch

if USE_PLACEHOLDER:
    TRAIN_DATA = Path('/content/data/training')
else:
    TRAIN_DATA = Path(DATA_DIR) / 'training'

# Create loaders
train_dataset = FurnitureDataset(TRAIN_DATA, split='train')
val_dataset = FurnitureDataset(TRAIN_DATA, split='val')

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=32, shuffle=True, num_workers=2
)
val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=32, shuffle=False, num_workers=2
)

print(f'Train samples: {len(train_dataset)}')
print(f'Val samples: {len(val_dataset)}')

In [None]:
# Visualize a sample
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

sample = train_dataset[0]
partial = sample['partial'].numpy()
full = sample['full'].numpy()

fig = plt.figure(figsize=(12, 5))

ax1 = fig.add_subplot(121, projection='3d')
ax1.scatter(partial[:, 0], partial[:, 1], partial[:, 2], s=1, c='blue')
ax1.set_title(f'Partial ({partial.shape[0]} points)')

ax2 = fig.add_subplot(122, projection='3d')
ax2.scatter(full[:, 0], full[:, 1], full[:, 2], s=0.5, c='green')
ax2.set_title(f'Full ({full.shape[0]} points)')

plt.tight_layout()
plt.show()

## 4. Initialize Model

In [None]:
import torch
from svscn.models import SVSCN

# Create model
model = SVSCN(
    num_classes=3,
    input_points=2048,
    output_points=8192
)

# Move to GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Count parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Model: SVSCN')
print(f'Parameters: {num_params:,}')
print(f'Device: {device}')

## 5. Training

In [None]:
from svscn.training import Trainer
from pathlib import Path

# Checkpoint directory - save to Drive for persistence
CHECKPOINT_DIR = Path('/content/drive/MyDrive/svscn/checkpoints')
LOG_DIR = Path('/content/logs')

# Create trainer
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    checkpoint_dir=CHECKPOINT_DIR,
    log_dir=LOG_DIR,
    device=device
)

print('Trainer initialized!')

In [None]:
# TensorBoard (run in separate cell)
%load_ext tensorboard
%tensorboard --logdir /content/logs

In [None]:
# Train!
# For placeholder data, use fewer epochs
EPOCHS = 10 if USE_PLACEHOLDER else 150

summary = trainer.train(epochs=EPOCHS)

print('\nTraining Complete!')
print(f'Best validation loss: {summary["best_val_loss"]:.6f}')

## 6. Evaluate Results

In [None]:
# Load best model
best_ckpt = CHECKPOINT_DIR / 'best.pt'
checkpoint = torch.load(best_ckpt)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f'Loaded best model (epoch {checkpoint["epoch"]})')

In [None]:
# Visualize predictions
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Get a sample
sample = val_dataset[0]
partial = sample['partial'].unsqueeze(0).to(device)
full = sample['full'].numpy()
class_id = sample['class_id'].unsqueeze(0).to(device)

# Predict
with torch.no_grad():
    pred = model(partial, class_id)
pred = pred.cpu().numpy()[0]

# Plot
fig = plt.figure(figsize=(15, 5))

ax1 = fig.add_subplot(131, projection='3d')
ax1.scatter(partial.cpu().numpy()[0, :, 0], 
            partial.cpu().numpy()[0, :, 1], 
            partial.cpu().numpy()[0, :, 2], s=1, c='blue')
ax1.set_title('Input (Partial)')

ax2 = fig.add_subplot(132, projection='3d')
ax2.scatter(pred[:, 0], pred[:, 1], pred[:, 2], s=0.5, c='red')
ax2.set_title('Prediction')

ax3 = fig.add_subplot(133, projection='3d')
ax3.scatter(full[:, 0], full[:, 1], full[:, 2], s=0.5, c='green')
ax3.set_title('Ground Truth')

plt.tight_layout()
plt.show()

In [None]:
# Compute metrics on validation set
from svscn.models.losses import chamfer_distance, coverage_ratio

total_cd = 0
total_coverage = 0
num_samples = 0

model.eval()
with torch.no_grad():
    for batch in val_loader:
        partial = batch['partial'].to(device)
        full = batch['full'].to(device)
        class_id = batch['class_id'].to(device)
        
        pred = model(partial, class_id)
        
        cd = chamfer_distance(pred, full, reduce='none')
        cov = coverage_ratio(pred, full)
        
        total_cd += cd.sum().item()
        total_coverage += cov.sum().item()
        num_samples += len(batch['partial'])

avg_cd = total_cd / num_samples
avg_cov = total_coverage / num_samples

print(f'Validation Metrics:')
print(f'  Average Chamfer Distance: {avg_cd:.6f}')
print(f'  Average Coverage: {avg_cov:.4f}')

## 7. Export Model

In [None]:
# Save final model for inference
from datetime import datetime

version = datetime.now().strftime('%Y%m%d')
export_path = CHECKPOINT_DIR / f'sv_scn_v0.1.0_{version}.pt'

torch.save({
    'model_state_dict': model.state_dict(),
    'num_classes': 3,
    'input_points': 2048,
    'output_points': 8192,
    'version': 'sv_scn_v0.1.0',
    'metrics': {
        'chamfer_distance': avg_cd,
        'coverage': avg_cov
    }
}, export_path)

print(f'Model exported to: {export_path}')

In [None]:
# Download model from Colab
from google.colab import files
files.download(str(export_path))