# Dice Detection Training — GAN Augmented Dataset

## 1. Setup & Installation

In [None]:
# Install dependencies
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!sudo apt-get update && sudo apt-get install -y libavif-dev libheif-dev
!pip install roboflow matplotlib seaborn tqdm numpy pillow

# Clone and setup the project
import os
if not os.path.exists('Dice-Detection'):
    !git clone https://github.com/Adr44mo/Dice-Detection.git
os.chdir('Dice-Detection')
!pip install -e .
print('\n✓ Setup complete!')

## 2. Imports

In [None]:
import torch
import torch.utils.data as data
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json
import os

# Import custom modules
from src.dataset import DiceDetectionDataset, collate_fn
from src.model import get_fasterrcnn_model, save_model_checkpoint
from src.training import train_one_epoch, evaluate, get_optimizer, get_lr_scheduler
from src.metrics import evaluate_map, print_metrics
from src.visualization import plot_training_history
from src.aug.annotation_manager import AnnotationManager

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 3. Configuration

In [None]:
# =============================================================================
# TRAINING CONFIGURATION
# =============================================================================

# Training hyperparameters
BATCH_SIZE = 16
NUM_WORKERS = 8
NUM_EPOCHS = 5
LEARNING_RATE = 0.005

# Using the default Faster R-CNN loss

# =============================================================================
# ANNOTATION FILES  (GAN augmented dataset)
# =============================================================================
TRAIN_ANNOTATION = 'synthetic_coco_dataset/train/annotations/gan_train.coco.json'
VAL_ANNOTATION = 'synthetic_coco_dataset/train/annotations/gan_val.coco.json'
TEST_ANNOTATION = 'test_balanced.coco.json'

# =============================================================================
# DATASET PATHS
# =============================================================================
# GAN-generated images for train/val
GAN_DATASET_BASE = os.path.join('.', 'synthetic_coco_dataset')
TRAIN_PATH = os.path.join(GAN_DATASET_BASE, 'train', 'train')
VAL_PATH = os.path.join(GAN_DATASET_BASE, 'val', 'train')

# Test set: original Roboflow data (downloaded separately)
# Update this path if your Roboflow dataset is at a different location
ROBOFLOW_DATASET_PATH = 'dice-2'
TEST_PATH = os.path.join(ROBOFLOW_DATASET_PATH, 'test')

# Device
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# =============================================================================
# PRINT SUMMARY
# =============================================================================
print('='*60)
print('GAN DATASET TRAINING CONFIGURATION')
print('='*60)
print(f'\n[Annotation Files]')
print(f'  Train: {TRAIN_ANNOTATION}')
print(f'  Val:   {VAL_ANNOTATION}')
print(f'  Test:  {TEST_ANNOTATION}')
print(f'\n[Dataset Paths]')
print(f'  Train images: {TRAIN_PATH}')
print(f'  Val images:   {VAL_PATH}')
print(f'  Test images:  {TEST_PATH}')
print(f'\n[Training]')
print(f'  Batch size: {BATCH_SIZE}')
print(f'  Epochs:     {NUM_EPOCHS}')
print(f'  LR:         {LEARNING_RATE}')
print(f'  Device:     {DEVICE}')
print('='*60)

## 4. Load Annotations & Prepare Datasets

In [None]:
# Initialize Annotation Manager
anno_manager = AnnotationManager('./Annotations')

# Load annotations
train_anno = anno_manager.load_annotation_set(TRAIN_ANNOTATION)
val_anno = anno_manager.load_annotation_set(VAL_ANNOTATION)
test_anno = anno_manager.load_annotation_set(TEST_ANNOTATION)

# Print stats
for name, ann_file in [('Train', TRAIN_ANNOTATION), ('Val', VAL_ANNOTATION), ('Test', TEST_ANNOTATION)]:
    stats = anno_manager.get_dataset_stats(ann_file)
    print(f'{name}: {stats["num_images"]} images, {stats["num_annotations"]} annotations')

# Write annotations to dataset directories
import shutil
for path in [TRAIN_PATH, VAL_PATH, TEST_PATH]:
    os.makedirs(path, exist_ok=True)

for path, anno in [(TRAIN_PATH, train_anno), (VAL_PATH, val_anno), (TEST_PATH, test_anno)]:
    with open(os.path.join(path, '_annotations.coco.json'), 'w') as f:
        json.dump(anno, f)

print('\n✓ Annotations copied to dataset directories')

## 5. Create Datasets & Data Loaders

In [None]:
# Create datasets (no augmentation — GAN data is already augmented)
train_dataset = DiceDetectionDataset(
    root_dir=TRAIN_PATH,
    annotation_file='_annotations.coco.json',
    split='train'
 )

val_dataset = DiceDetectionDataset(
    root_dir=VAL_PATH,
    annotation_file='_annotations.coco.json',
    split='val'
 )

# Test dataset
has_test_set = os.path.exists(TEST_PATH) and os.path.exists(os.path.join(TEST_PATH, '_annotations.coco.json'))
if has_test_set:
    test_dataset = DiceDetectionDataset(
        root_dir=TEST_PATH,
        annotation_file='_annotations.coco.json',
        split='test'
    )
else:
    test_dataset = None

print(f'Training dataset:   {len(train_dataset)} images')
print(f'Validation dataset: {len(val_dataset)} images')
if has_test_set:
    print(f'Test dataset:       {len(test_dataset)} images')
else:
    print('Test dataset:       Not available (will use validation for evaluation)')
print(f'Number of classes:  {train_dataset.num_classes}')

# Create data loaders
train_loader = data.DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    collate_fn=collate_fn
 )

val_loader = data.DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    collate_fn=collate_fn
 )

if has_test_set:
    test_loader = data.DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        collate_fn=collate_fn
    )
else:
    test_loader = None

print(f'\nTraining batches:   {len(train_loader)}')
print(f'Validation batches: {len(val_loader)}')
if has_test_set:
    print(f'Test batches:       {len(test_loader)}')

## 6. Initialize Model

In [None]:
# Create Faster R-CNN model
model = get_fasterrcnn_model(
    num_classes=train_dataset.num_classes,
    pretrained=True,
    trainable_backbone_layers=3
)
model.to(DEVICE)

# Setup optimizer and scheduler
optimizer = get_optimizer(model, lr=LEARNING_RATE)
lr_scheduler = get_lr_scheduler(optimizer, step_size=3, gamma=0.1)

print(f'Model initialized on {DEVICE}')
print('  Using default Faster R-CNN detection loss')

## 7. Training Loop

In [None]:
# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'learning_rate': []
}

best_val_loss = float('inf')

# Checkpoint directory
CHECKPOINT_DIR = 'checkpoints_gan'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
config_str = 'gan'

print('Starting training...')
print(f'Checkpoint directory: {CHECKPOINT_DIR}\n')

for epoch in range(NUM_EPOCHS):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}")
    print(f"{'='*60}")

    # Train
    train_metrics = train_one_epoch(
        model, optimizer, train_loader, DEVICE, epoch + 1
    )

    # Evaluate
    val_metrics = evaluate(model, val_loader, DEVICE)

    # Update learning rate
    lr_scheduler.step()

    # Record history
    history['train_loss'].append(train_metrics['loss'])
    history['val_loss'].append(val_metrics['val_loss'])
    history['learning_rate'].append(optimizer.param_groups[0]['lr'])

    # Print summary
    print(f"\nEpoch {epoch + 1} Summary:")
    print(f"  Train Loss: {train_metrics['loss']:.4f}")
    print(f"  Val Loss: {val_metrics['val_loss']:.4f}")
    print(f"  Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    print(f"  Time: {train_metrics['time']:.2f}s")

    # Save best model
    if val_metrics['val_loss'] < best_val_loss:
        best_val_loss = val_metrics['val_loss']
        checkpoint_path = os.path.join(CHECKPOINT_DIR, f'best_model_{config_str}.pth')
        save_model_checkpoint(
            model, optimizer, epoch + 1, val_metrics['val_loss'],
            checkpoint_path,
            additional_info={
                'train_loss': train_metrics['loss'],
                'config': config_str
            }
        )
        print(f'  ✓ New best model saved!')

    # Save latest checkpoint
    latest_path = os.path.join(CHECKPOINT_DIR, f'latest_model_{config_str}.pth')
    save_model_checkpoint(
        model, optimizer, epoch + 1, val_metrics['val_loss'],
        latest_path
    )

print('\n' + '='*60)
print('Training completed!')
print('='*60)

## 8. Plot Training History

In [None]:
plot_training_history({
    'Training Loss': history['train_loss'],
    'Validation Loss': history['val_loss'],
    'Learning Rate': history['learning_rate']
})