# CombNet Pretraining on Google Colab

This notebook trains the VectorCombinationUNet (CombNet) to fix boundary divergence issues in div-free vector field inpainting.

**Steps:**
1. Upload your project code (275 KB ZIP!)
2. Generate training data on Colab GPU (~3 min)
3. Train model (100 epochs, ~2-3 hours on T4 GPU)
4. Download the trained model

**What to upload:**
- Just the project ZIP file: `diffusion_tiny_v2.zip` (275 KB)
- No ocean data needed - we generate synthetic div-free fields!

## 1. Setup Environment

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
else:
    print("⚠️ WARNING: No GPU detected! Go to Runtime > Change runtime type > Select GPU")

In [None]:
# Install dependencies (PyTorch is pre-installed on Colab)
!pip install -q tqdm pyyaml scipy matplotlib

## 2. Upload Project ZIP (Only 275 KB!)

In [None]:
# Upload your tiny project ZIP
from google.colab import files
import zipfile
import os

print("Please upload diffusion_tiny_v2.zip (275 KB)")
uploaded = files.upload()

# Extract the ZIP file
for filename in uploaded.keys():
    if filename.endswith('.zip'):
        print(f"Extracting {filename}...")
        with zipfile.ZipFile(filename, 'r') as zip_ref:
            zip_ref.extractall('.')
        print("Extraction complete!")
        
# Navigate to project directory
%cd diffusionInpaintingVectorFields
!pwd

## 3. Verify Project Structure

In [None]:
# Verify key files exist
import os

required_files = [
    'data.yaml',
    'scripts/pretrain_combnet.py',
    'scripts/generate_combnet_data.py',
    'ddpm/vector_combination/combiner_unet.py',
    'ddpm/vector_combination/combination_loss.py',
]

print("Checking required files:")
all_good = True
for f in required_files:
    exists = os.path.exists(f)
    status = "✅" if exists else "❌"
    print(f"{status} {f}")
    if not exists:
        all_good = False
        
if all_good:
    print("\n✅ All required files present!")
else:
    print("\n❌ Some files are missing. Please check your upload.")

## 4. Generate Training Data

Creates 20,000 training samples from synthetic div-free fields. Takes ~2-3 minutes on GPU.

**No ocean data needed!** We generate pure synthetic fields.

**Note**: Reduced to 20k samples (from 40k) to avoid OOM on Colab free tier.

In [None]:
# Generate training data (completely synthetic!)
!python scripts/generate_combnet_data.py

In [None]:
# Verify dataset was created
import torch
dataset_path = "results/combnet_training_data.pt"

if os.path.exists(dataset_path):
    dataset = torch.load(dataset_path, weights_only=False)
    print("✅ Training dataset generated successfully!")
    print(f"\nDataset contents:")
    for key, val in dataset.items():
        print(f"  {key}: {val.shape}")
    print(f"\nTotal samples: {len(dataset['known'])}")
    
    # Show some statistics
    print(f"\nSample statistics:")
    print(f"  Field range: [{dataset['known'].min():.3f}, {dataset['known'].max():.3f}]")
    print(f"  Mask coverage: {dataset['mask'].mean():.1%}")
else:
    print("❌ Dataset generation failed!")

## 5. Run Training

Train for 100 epochs (~2-3 hours on T4 GPU, ~1 hour on A100).

In [None]:
# Run training with tee to save logs
!python scripts/pretrain_combnet.py --epochs 100 --batch_size 32 --lr 0.001 2>&1 | tee combnet_colab_training.log

## 6. Monitor Training Progress

Run this cell while training to check progress.

In [None]:
# Check last few lines of log
!tail -30 combnet_colab_training.log

## 7. Plot Training Loss

In [None]:
import torch
import matplotlib.pyplot as plt

# Load the saved model to get training history
checkpoint = torch.load('ddpm/Trained_Models/pretrained_combnet.pt', map_location='cpu', weights_only=False)

if 'train_losses' in checkpoint:
    losses = checkpoint['train_losses']
    
    plt.figure(figsize=(10, 5))
    plt.plot(losses, linewidth=2)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.title('CombNet Training Loss', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig('training_loss.png', dpi=150)
    plt.show()
    
    print(f"\nTraining Summary:")
    print(f"  Total epochs: {len(losses)}")
    print(f"  Initial loss: {losses[0]:.6f}")
    print(f"  Final loss: {losses[-1]:.6f}")
    print(f"  Best loss: {checkpoint['loss']:.6f}")
    print(f"  Improvement: {(1 - losses[-1]/losses[0])*100:.1f}%")
else:
    print("No training history found in checkpoint")

## 8. Download Trained Model

In [None]:
from google.colab import files

# Download the trained model
print("Downloading trained model...")
files.download('ddpm/Trained_Models/pretrained_combnet.pt')

# Also download the training log
print("Downloading training log...")
files.download('combnet_colab_training.log')

# Download the loss plot
if os.path.exists('training_loss.png'):
    print("Downloading loss plot...")
    files.download('training_loss.png')

print("\n✅ Downloads complete!")
print("\nPlace the pretrained_combnet.pt file in:")
print("  ddpm/Trained_Models/pretrained_combnet.pt")
print("\nThen set in data.yaml:")
print("  use_comb_net: auto  # or 'yes' to force using it")

## 9. Quick Model Test (Optional)

Test that the model loads correctly.

In [None]:
import torch
from ddpm.vector_combination.combiner_unet import VectorCombinationUNet

# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VectorCombinationUNet(n_channels=4, n_classes=2).to(device)

checkpoint = torch.load('ddpm/Trained_Models/pretrained_combnet.pt', weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print("✅ Model loaded successfully!")
print(f"\nModel config:")
for key, value in checkpoint['config'].items():
    print(f"  {key}: {value}")

# Test forward pass
test_input = torch.randn(1, 4, 64, 128).to(device)  # [batch, channels, H, W]
with torch.no_grad():
    output = model(test_input)
    
print(f"\nTest forward pass:")
print(f"  Input shape: {test_input.shape}")
print(f"  Output shape: {output.shape}")
print(f"  Output range: [{output.min():.4f}, {output.max():.4f}]")
print("\n✅ Model is working correctly!")

## Notes

**Free Tier Limitations:**
- Sessions timeout after ~12 hours of inactivity
- You may be disconnected if idle
- The model is saved every epoch, so you can resume with `--resume` flag

**To Resume Training:**
```python
!python scripts/pretrain_combnet.py --epochs 100 --batch_size 32 --resume
```

**Batch Size Tips:**
- T4 GPU (free tier): Use batch_size=16 or 32
- A100 GPU (Colab Pro): Can use batch_size=64 or higher
- If you get OOM errors, reduce batch size to 8

**Key Improvements:**
- Only 275 KB upload (vs 2.6 GB!)
- No ocean data needed - fully synthetic training data
- Generates 40k samples in ~3 minutes on GPU
- Total time: ~2-3 hours on free T4 GPU