In [1]:
# Cell 1: Mount Drive and install dependencies
from google.colab import drive
drive.mount('/content/drive')

# Install required packages (should mostly be pre-installed)
!pip install torch torchvision torchaudio
!pip install matplotlib numpy pillow scikit-image opencv-python

Mounted at /content/drive


In [2]:
# Cell 2: Create directories
import os
from pathlib import Path

# Create project structure
project_dir = Path('/content/drive/MyDrive/ResearchProject')
project_dir.mkdir(parents=True, exist_ok=True)

# Create subdirectories
(project_dir / 'checkpoints').mkdir(exist_ok=True)
(project_dir / 'logs').mkdir(exist_ok=True)

print("Project structure ready!")
print(f"Project dir: {project_dir}")

Project structure ready!
Project dir: /content/drive/MyDrive/ResearchProject


In [3]:
# Cell 3: Import and setup
import sys
sys.path.append(str('/content/drive/MyDrive/ResearchProject'))

import torch
from unet_denoiser import BlindVideoDenoiserUNet
from dataloader import BlindDenoiseDataset
from unet_denoiser_training import train, TemporalDataLoader, create_train_val_split

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
print(f"GPU: {torch.cuda.get_device_name(0) if device == 'cuda' else 'CPU'}")

# Check GPU memory (important!)
if device == 'cuda':
    print(f"GPU Memory Available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

ModuleNotFoundError: No module named 'unet_denoiser'

In [None]:
# Cell 4: Create train/val split from DAVIS dataset
davis_root = '/content/drive/MyDrive/ResearchProject/DAVISDataset'

# Create train/val split (80/20 split, videos are not mixed between train and val)
train_dataset, val_dataset = create_train_val_split(
    davis_root_dir=davis_root,
    val_split=0.2,  # 20% of videos for validation
    seed=42
)

# Create temporal loaders with resizing
batch_size = 8  # Adjust based on GPU memory
resize_to = (384, 384)  # Resize all frames to this size. Change if needed for your GPU memory

train_loader = TemporalDataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    resize_to=resize_to
)
val_loader = TemporalDataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    resize_to=resize_to
)

print(f"\nTrain loader batches per epoch: {len(train_loader)}")
print(f"Val loader batches per epoch: {len(val_loader)}")
print(f"Frame resolution: {resize_to}")

Total videos: 150
Train videos: 120
Val videos: 30
Train frames: 8714
Val frames: 2017

Train loader batches per epoch: 1089
Val loader batches per epoch: 252
Frame resolution: (384, 384)


In [None]:
# Cell 5: Initialize model and train
model = BlindVideoDenoiserUNet(
    in_channels=9,
    out_channels=3,
    base_channels=64,
    num_stages=3
)

print(f"Total model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Start training with your choice of loss
trained_model, logger = train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=100,
    initial_lr=1e-3,
    device=device,
    checkpoint_dir=str(project_dir / 'checkpoints'),
    log_dir=str(project_dir / 'logs'),
    loss_type='combined',  # 'l1', 'l2', or 'combined'
    loss_alpha=0.7  # Only used if loss_type='combined'. 0.7 = 70% L1, 30% L2
)

Total model parameters: 4,541,184
Epoch    1 | Train Loss: 0.064194 | Val Loss: 0.088483 | LR: 1.00e-03
  → Best model saved! (Val Loss: 0.088483)
Epoch    2 | Train Loss: 0.048143 | Val Loss: 0.050756 | LR: 1.00e-03
  → Best model saved! (Val Loss: 0.050756)
Epoch    3 | Train Loss: 0.046503 | Val Loss: 0.041066 | LR: 9.99e-04
  → Best model saved! (Val Loss: 0.041066)
Epoch    4 | Train Loss: 0.044451 | Val Loss: 0.049860 | LR: 9.98e-04
Epoch    5 | Train Loss: 0.043551 | Val Loss: 0.040767 | LR: 9.96e-04
  → Best model saved! (Val Loss: 0.040767)


In [1]:
# Cell 6: Plot training curves (run this periodically or after training)
from unet_denoiser_training import TrainingLogger
import matplotlib.pyplot as plt

logger = TrainingLogger(log_dir=str(project_dir / 'logs'))
logger.plot_metrics()  # This will save and display the curves

ModuleNotFoundError: No module named 'unet_denoiser_training'

In [None]:
# Cell 7: Load best model for inference
checkpoint = torch.load(str(project_dir / 'checkpoints' / 'best_model.pt'), map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Best model loaded from epoch {checkpoint['epoch']}")
print(f"Best val loss: {checkpoint['val_loss']:.6f}")

Best model loaded from epoch 1
Best val loss: 0.001884
