# Traffic Object Detection - Model Training

This notebook demonstrates how to train RetinaNet and Deformable DETR models for traffic object detection.

In [None]:
import os
import sys
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

# Add src to path
sys.path.append('../src')

from datasets import TrafficDataset, get_transforms
from models import create_model
from train import Trainer
from utils.config import load_config, create_default_config
from utils.logger import setup_logger
from utils.visualization import plot_training_curves

## Configuration

In [None]:
# Load configuration
config_path = "../configs/retinanet_config.yaml"
if os.path.exists(config_path):
    config = load_config(config_path)
else:
    print("Config file not found, creating default configuration")
    config = create_default_config()

# Override for notebook training
config.training.epochs = 10  # Shorter training for demo
config.dataset.batch_size = 4  # Smaller batch size
config.dataset.image_size = 512

print(f"Model: {config.model.name}")
print(f"Epochs: {config.training.epochs}")
print(f"Batch size: {config.dataset.batch_size}")
print(f"Learning rate: {config.training.learning_rate}")

## Dataset Preparation

In [None]:
# Dataset paths
dataset_root = "../data/traffic"
images_dir = os.path.join(dataset_root, "images")
train_annotations = os.path.join(dataset_root, "train_annotations.json")
val_annotations = os.path.join(dataset_root, "val_annotations.json")

# Check if dataset exists
if not os.path.exists(train_annotations):
    print("Warning: Training annotations not found!")
    print(f"Expected: {train_annotations}")
    print("Please run the data download script first.")
else:
    print("Dataset found!")

In [None]:
# Create datasets and data loaders
def create_data_loaders(config):
    # Get transforms
    train_transform = get_transforms(
        phase="train",
        image_size=config.dataset.image_size
    )
    val_transform = get_transforms(
        phase="val",
        image_size=config.dataset.image_size
    )
    
    # Create datasets
    train_dataset = TrafficDataset(
        root=images_dir,
        annotation_file=train_annotations,
        transform=train_transform,
        class_names=config.classes.names
    )
    
    val_dataset = TrafficDataset(
        root=images_dir,
        annotation_file=val_annotations,
        transform=val_transform,
        class_names=config.classes.names
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.dataset.batch_size,
        shuffle=True,
        num_workers=2,  # Reduced for notebook
        collate_fn=lambda x: tuple(zip(*x))
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.dataset.batch_size,
        shuffle=False,
        num_workers=2,
        collate_fn=lambda x: tuple(zip(*x))
    )
    
    return train_loader, val_loader, train_dataset, val_dataset

if os.path.exists(train_annotations):
    train_loader, val_loader, train_dataset, val_dataset = create_data_loaders(config)
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    print(f"Number of classes: {len(config.classes.names)}")

## Model Creation

In [None]:
# Create model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = create_model(
    model_name=config.model.name,
    num_classes=config.model.num_classes,
    config=config.model,
    pretrained=config.model.pretrained
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model: {type(model).__name__}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## Training Setup

In [None]:
# Create optimizer
if config.training.optimizer.lower() == "adamw":
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.training.learning_rate,
        weight_decay=config.training.weight_decay
    )
else:
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=config.training.learning_rate,
        momentum=0.9,
        weight_decay=config.training.weight_decay
    )

# Create scheduler
if config.training.scheduler.lower() == "cosine":
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config.training.epochs
    )
else:
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=30, gamma=0.1
    )

print(f"Optimizer: {type(optimizer).__name__}")
print(f"Scheduler: {type(scheduler).__name__}")

## Training

In [None]:
# Create trainer
if os.path.exists(train_annotations):
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        save_dir="../checkpoints/notebook_training"
    )
    
    print("Starting training...")
    
    # Train the model
    history = trainer.train(
        epochs=config.training.epochs,
        save_every=5
    )
    
    print("Training completed!")
else:
    print("Skipping training - dataset not available")
    history = None

## Training Results Visualization

In [None]:
if history is not None:
    # Plot training curves
    fig = plot_training_curves(
        train_losses=history['train_losses'],
        val_losses=history['val_losses'],
        learning_rates=history['learning_rates']
    )
    plt.show()
    
    # Print final metrics
    print(f"\nFinal Training Results:")
    print(f"Final training loss: {history['train_losses'][-1]:.4f}")
    print(f"Final validation loss: {history['val_losses'][-1]:.4f}")
    print(f"Best validation loss: {min(history['val_losses']):.4f}")
    print(f"Final learning rate: {history['learning_rates'][-1]:.6f}")

## Model Inference Test

In [None]:
if os.path.exists(train_annotations):
    # Test inference on a few samples
    model.eval()
    
    # Get a batch from validation set
    val_iter = iter(val_loader)
    images, targets = next(val_iter)
    
    # Move to device
    images_device = [img.to(device) for img in images]
    
    # Run inference
    with torch.no_grad():
        predictions = model.predict(images_device, score_threshold=0.3)
    
    # Print results
    for i, (pred, target) in enumerate(zip(predictions, targets)):
        print(f"\nImage {i+1}:")
        print(f"  Ground truth objects: {len(target['boxes'])}")
        print(f"  Predicted objects: {len(pred['boxes'])}")
        
        if len(pred['boxes']) > 0:
            # Show top predictions
            top_scores, top_indices = torch.topk(pred['scores'], min(3, len(pred['scores'])))
            for j, idx in enumerate(top_indices):
                class_id = pred['labels'][idx].item()
                score = pred['scores'][idx].item()
                class_name = config.classes.names[class_id] if class_id < len(config.classes.names) else f"Class_{class_id}"
                print(f"    {j+1}. {class_name}: {score:.3f}")
else:
    print("Skipping inference test - dataset not available")

## Save Model

In [None]:
# Save the trained model
output_dir = "../outputs/notebook_training"
os.makedirs(output_dir, exist_ok=True)

model_path = os.path.join(output_dir, "trained_model.pth")
torch.save(model.state_dict(), model_path)
print(f"Model saved to: {model_path}")

# Save configuration
config_save_path = os.path.join(output_dir, "config.yaml")
from utils.config import save_config
save_config(config, config_save_path)
print(f"Configuration saved to: {config_save_path}")

## Next Steps

After training, you can:

1. **Evaluate the model**: Use the evaluation notebook to assess model performance
2. **Run inference**: Test the model on new images using the inference script
3. **Compare models**: Train different architectures and compare their performance
4. **Fine-tune**: Adjust hyperparameters and retrain for better results
5. **Deploy**: Use the trained model in a production environment

For production training, use the command-line scripts with full datasets and longer training times.