# UVCGAN Training on Google Colab

This notebook trains UVCGAN on BRATS19 dataset using Colab's free GPU.

**Before starting:**
1. Enable GPU: Runtime → Change runtime type → Hardware accelerator: GPU
2. Upload your BRATS19 dataset to Google Drive or prepare to upload it here


## Step 1: Install Dependencies


## Step 0: Quick Pipeline Test (Optional - Recommended First!)


In [None]:
# Quick test to verify pipeline works before full training
# This runs 1 epoch each of pretraining and training (~5-10 minutes)
# Uncomment to run:

# !python scripts/brats19/quick_test.py \
#     --pretrain-epochs 1 \
#     --train-epochs 1 \
#     --batch-size 4

print("Quick test cell ready. Uncomment the command above to run a test.")
print("This verifies the entire pipeline works before committing GPU hours to full training.")
print("\nOr use the simpler test script:")
print("  python scripts/brats19/quick_test.py")


In [None]:
# Install PyTorch with CUDA support
# Using CUDA 11.8 for compatibility (works with most Colab GPUs)
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# Install other dependencies
!pip install numpy pillow matplotlib tqdm gitpython

# Verify PyTorch installation
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")


## Step 2: Clone and Install Repositories


## Step 2.5: Apply Fixes (Scheduler, Grayscale, Config)


In [None]:
# Clone uvcgan4slats
!git clone https://github.com/LS4GAN/uvcgan4slats.git
%cd uvcgan4slats

# Clone and install toytools
!git clone https://github.com/LS4GAN/toytools
%cd toytools
!pip install -e .
%cd ..

# Install uvcgan4slats
!pip install -e .


## Step 3: Mount Google Drive (Optional - if dataset is on Drive)


In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Copy dataset from Drive (modify path as needed)
# !cp -r /content/drive/MyDrive/brats19 /content/brats19


## Step 4: Set Environment Variables


In [None]:
import os

# Set data and output directories
os.environ['UVCGAN_DATA'] = '/content'
os.environ['UVCGAN_OUTDIR'] = '/content/outputs'

# Verify
print(f"Data directory: {os.environ['UVCGAN_DATA']}")
print(f"Output directory: {os.environ['UVCGAN_OUTDIR']}")

# Check if dataset exists
if os.path.exists('/content/brats19'):
    print("✓ Dataset found")
    !ls -la /content/brats19
else:
    print("✗ Dataset not found. Please upload or mount your dataset.")


## Step 5: Verify GPU Availability


In [None]:
import torch

print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")


## Step 6: Run Pretraining (Optional but Recommended)


In [None]:
# Pretrain generators
# Adjust batch_size based on GPU memory (32 or 16 if OOM errors)
!python scripts/brats19/pretrain_brats19.py --gen uvcgan --batch_size 32


## Step 7: Run Main Training


In [None]:
# Train translation model
# If you skipped pretraining, add --no-pretrain flag
!python scripts/brats19/train_brats19.py --gen uvcgan --labmda-cycle 1.0 --lr-gen 1e-5 --lr-disc 5e-5


## Step 9: Visualize Results (After Training)


In [None]:
# Visualize results from a specific checkpoint
import os
import glob
from IPython.display import Image, display

# Find the latest checkpoint directory
checkpoint_base = '/content/outputs/brats19'
checkpoint_dirs = [d for d in os.listdir(checkpoint_base) if d.startswith('model_m')]

if checkpoint_dirs:
    # Use the training checkpoint (not pretrain)
    train_checkpoints = [d for d in checkpoint_dirs if 'train' in d]
    if train_checkpoints:
        checkpoint_dir = os.path.join(checkpoint_base, sorted(train_checkpoints)[-1])
        print(f"Using checkpoint: {checkpoint_dir}")
        
        # Evaluate and visualize
        import subprocess
        result = subprocess.run([
            'python', 'scripts/brats19/eval_and_visualize.py',
            checkpoint_dir,
            '--n-samples', '10',
            '--split', 'test'
        ], capture_output=True, text=True)
        
        print(result.stdout)
        if result.stderr:
            print("Errors:", result.stderr)
        
        # Display sample visualizations
        vis_dir = os.path.join(checkpoint_dir, 'visualizations', 'fake_vs_real')
        if os.path.exists(vis_dir):
            images = sorted(glob.glob(os.path.join(vis_dir, '*.png')))
            print(f"\nFound {len(images)} visualization images")
            
            for img_path in images[:5]:  # Show first 5
                print(f"\n{os.path.basename(img_path)}:")
                display(Image(img_path))
        else:
            print("Visualizations not found. Make sure eval_and_visualize.py completed successfully.")
    else:
        print("No training checkpoints found. Run training first.")
else:
    print("No checkpoints found. Run training first.")


## Step 8: Save Results to Google Drive


In [None]:
# Save outputs to Drive
!mkdir -p /content/drive/MyDrive/uvcgan_outputs
!cp -r /content/outputs /content/drive/MyDrive/uvcgan_outputs/
print("Results saved to Google Drive!")
