# Faster R-CNN Training for PPE Detection

## Configuration
- **Model**: Faster R-CNN with ResNet50-FPN backbone
- **Optimizer**: SGD (lr=0.005, momentum=0.9, weight_decay=0.0005)
- **Batch Size**: 16
- **Epochs**: 100
- **Scheduler**: StepLR (reduce lr by 50% every 20 epochs)
- **Dataset**: PPE Detection Dataset

## 1. Import Libraries

In [None]:
import os
import torch
import torchvision
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import yaml
import wandb
from tqdm import tqdm
import numpy as np
from torchvision.transforms import functional as F
import matplotlib.pyplot as plt
import torchvision.ops as ops

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Dataset Class

In [None]:
class PPEDataset(Dataset):
    def __init__(self, root_dir, split='train'):
        self.root_dir = root_dir
        
        with open(os.path.join(root_dir, 'data.yaml'), 'r') as f:
            data_config = yaml.safe_load(f)
        self.class_names = data_config['names']
        self.num_classes = data_config['nc']
        
        if split == 'train':
            self.img_dir = os.path.join(root_dir, 'train_aug', 'images')
            self.label_dir = os.path.join(root_dir, 'train_aug', 'labels')
        else:
            self.img_dir = os.path.join(root_dir, 'valid', 'images')
            self.label_dir = os.path.join(root_dir, 'valid', 'labels')
        
        self.img_files = [f for f in os.listdir(self.img_dir) if f.endswith('.jpg')]
        
    def __len__(self):
        return len(self.img_files)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_files[idx])
        image = Image.open(img_path).convert('RGB')
        
        label_path = os.path.join(self.label_dir, self.img_files[idx].replace('.jpg', '.txt'))
        
        boxes = []
        labels = []
        
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f.readlines():
                    line = line.strip()
                    if line:
                        parts = line.split()
                        class_id = int(float(parts[0]))
                        x_center, y_center, width, height = map(float, parts[1:])
                        
                        if width <= 0 or height <= 0:
                            continue
                        
                        img_width, img_height = image.size
                        x_min = (x_center - width/2) * img_width
                        y_min = (y_center - height/2) * img_height
                        x_max = (x_center + width/2) * img_width
                        y_max = (y_center + height/2) * img_height
                        
                        x_min = max(0, min(x_min, img_width))
                        y_min = max(0, min(y_min, img_height))
                        x_max = max(0, min(x_max, img_width))
                        y_max = max(0, min(y_max, img_height))
                        
                        if x_max > x_min and y_max > y_min:
                            boxes.append([x_min, y_min, x_max, y_max])
                            labels.append(class_id + 1)
        
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        
        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': torch.tensor([idx])
        }
        
        image_tensor = F.to_tensor(image)
        return image_tensor, target

## 3. Model and Helper Functions

In [None]:
def get_model(num_classes):
    """Create Faster R-CNN model with custom number of classes"""
    model = fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

def collate_fn(batch):
    """Custom collate function to filter invalid boxes"""
    images, targets = zip(*batch)
    filtered_images = []
    filtered_targets = []
    
    for img, target in zip(images, targets):
        if len(target['boxes']) > 0:
            valid_boxes = []
            valid_labels = []
            
            for i, box in enumerate(target['boxes']):
                x1, y1, x2, y2 = box
                if x2 > x1 and y2 > y1:
                    valid_boxes.append(box)
                    valid_labels.append(target['labels'][i])
            
            if len(valid_boxes) > 0:
                filtered_target = {
                    'boxes': torch.stack(valid_boxes),
                    'labels': torch.stack(valid_labels),
                    'image_id': target['image_id']
                }
                filtered_images.append(img)
                filtered_targets.append(filtered_target)
    
    return list(filtered_images), list(filtered_targets)

## 4. Load Dataset

In [None]:
dataset_root = "/root/ocr-project/First/Detection/ppe-dataset-clean/versions/1/ppe-detection-project-dataset-c/versions/1"

with open(os.path.join(dataset_root, 'data.yaml'), 'r') as f:
    data_config = yaml.safe_load(f)

class_names = data_config['names']
num_classes = data_config['nc']

train_dataset = PPEDataset(dataset_root, split='train')
val_dataset = PPEDataset(dataset_root, split='val')

print(f"Classes: {num_classes} ({class_names})")
print(f"Training: {len(train_dataset)} images")
print(f"Validation: {len(val_dataset)} images")

## 5. Initialize Weights & Biases

In [None]:
# Training configuration
batch_size = 16
num_epochs = 100
learning_rate = 0.005
momentum = 0.9
weight_decay = 0.0005

# Initialize Wandb
run = wandb.init(
    project="ppe-detection-fasterrcnn",
    config={
        'batch_size': batch_size,
        'num_epochs': num_epochs,
        'learning_rate': learning_rate,
        'momentum': momentum,
        'weight_decay': weight_decay,
        'num_classes': num_classes + 1,
    },
    reinit=True
)

# Define metrics
wandb.define_metric("epoch")
wandb.define_metric("train_loss", step_metric="epoch")
wandb.define_metric("val_loss", step_metric="epoch")
wandb.define_metric("learning_rate", step_metric="epoch")
wandb.define_metric("mAP", step_metric="epoch")
wandb.define_metric("precision", step_metric="epoch")
wandb.define_metric("recall", step_metric="epoch")
wandb.define_metric("f1_score", step_metric="epoch")

for class_name in class_names:
    wandb.define_metric(f"mAP_{class_name}", step_metric="epoch")

## 6. Create Data Loaders

In [None]:
train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=2,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=2,
    collate_fn=collate_fn
)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

## 7. Create Model and Optimizer

In [None]:
model = get_model(num_classes + 1)
model.to(device)

optimizer = torch.optim.SGD(
    model.parameters(), 
    lr=learning_rate,
    momentum=momentum,
    weight_decay=weight_decay
)

# Learning rate scheduler: reduce LR by 50% every 20 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

print(f"Model created and moved to {device}")

## 8. Training and Evaluation Functions

In [None]:
def train_one_epoch(model, optimizer, data_loader, device, epoch):
    """Train for one epoch"""
    model.train()
    metric_logger = {}
    
    for idx, batch in enumerate(tqdm(data_loader, desc=f'Epoch {epoch}')):
        if len(batch) == 2:
            images, targets = batch
        else:
            continue
            
        if len(images) == 0:
            continue
            
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        loss_dict = model(images, targets)
        
        if isinstance(loss_dict, dict):
            losses = sum(loss for loss in loss_dict.values())
            for k, v in loss_dict.items():
                if k not in metric_logger:
                    metric_logger[k] = []
                metric_logger[k].append(v.item())
        else:
            losses = loss_dict
        
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
    
    return {k: np.mean(v) for k, v in metric_logger.items()}

def calculate_detection_metrics(predictions, targets, class_names, iou_threshold=0.5):
    """Calculate mAP, precision, recall, and F1 score"""
    metrics = {}
    num_classes = len(class_names)
    aps = []
    precisions = []
    recalls = []
    
    for class_idx in range(num_classes):
        class_name = class_names[class_idx]
        pred_boxes = []
        pred_scores = []
        gt_boxes = []
        
        for pred in predictions:
            if 'boxes' in pred and 'labels' in pred and 'scores' in pred:
                class_mask = pred['labels'] == (class_idx + 1)
                if class_mask.any():
                    pred_boxes.append(pred['boxes'][class_mask])
                    pred_scores.append(pred['scores'][class_mask])
        
        for target in targets:
            if 'boxes' in target and 'labels' in target:
                class_mask = target['labels'] == (class_idx + 1)
                if class_mask.any():
                    gt_boxes.append(target['boxes'][class_mask])
        
        if len(pred_boxes) > 0 and len(gt_boxes) > 0:
            all_pred_boxes = torch.cat(pred_boxes, dim=0)
            all_pred_scores = torch.cat(pred_scores, dim=0)
            all_gt_boxes = torch.cat(gt_boxes, dim=0)
            
            if len(all_pred_boxes) > 0 and len(all_gt_boxes) > 0:
                ious = ops.box_iou(all_pred_boxes, all_gt_boxes)
                sorted_indices = torch.argsort(all_pred_scores, descending=True)
                sorted_ious = ious[sorted_indices]
                
                tp = torch.zeros(len(sorted_ious))
                fp = torch.zeros(len(sorted_ious))
                gt_matched = torch.zeros(len(all_gt_boxes), dtype=torch.bool)
                
                for i in range(len(sorted_ious)):
                    max_iou, max_idx = torch.max(sorted_ious[i], dim=0)
                    if max_iou >= iou_threshold and not gt_matched[max_idx]:
                        tp[i] = 1
                        gt_matched[max_idx] = True
                    else:
                        fp[i] = 1
                
                tp_cumsum = torch.cumsum(tp, dim=0)
                fp_cumsum = torch.cumsum(fp, dim=0)
                precision = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-6)
                recall = tp_cumsum / len(all_gt_boxes)
                
                # Calculate AP using 11-point interpolation
                ap = 0
                for t in torch.arange(0, 1.1, 0.1):
                    if torch.sum(recall >= t) == 0:
                        p = 0
                    else:
                        p = torch.max(precision[recall >= t])
                    ap += p / 11
                
                metrics[f'mAP_{class_name}'] = ap.item()
                aps.append(ap.item())
                precisions.append(precision[-1].item() if len(precision) > 0 else 0)
                recalls.append(recall[-1].item() if len(recall) > 0 else 0)
            else:
                metrics[f'mAP_{class_name}'] = 0
                aps.append(0)
                precisions.append(0)
                recalls.append(0)
        else:
            metrics[f'mAP_{class_name}'] = 0
            aps.append(0)
            precisions.append(0)
            recalls.append(0)
    
    metrics['mAP'] = np.mean(aps) if aps else 0
    metrics['precision'] = np.mean(precisions) if precisions else 0
    metrics['recall'] = np.mean(recalls) if recalls else 0
    
    if metrics['precision'] + metrics['recall'] > 0:
        metrics['f1_score'] = 2 * (metrics['precision'] * metrics['recall']) / (metrics['precision'] + metrics['recall'])
    else:
        metrics['f1_score'] = 0
    
    return metrics

def evaluate(model, data_loader, device, class_names):
    """Evaluate model on validation set"""
    model.eval()
    total_loss = 0
    num_batches = 0
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            if len(batch) == 2:
                images, targets = batch
            else:
                continue
                
            if len(images) == 0:
                continue
                
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            predictions = model(images)
            
            # Calculate loss
            model.train()
            loss_dict = model(images, targets)
            model.eval()
            
            if isinstance(loss_dict, dict):
                losses = sum(loss for loss in loss_dict.values())
            else:
                losses = loss_dict
                
            total_loss += losses.item()
            num_batches += 1
            
            all_predictions.extend(predictions)
            all_targets.extend(targets)
    
    metrics = calculate_detection_metrics(all_predictions, all_targets, class_names)
    metrics['val_loss'] = total_loss / num_batches if num_batches > 0 else 0
    
    return metrics

## 9. Training Loop

In [None]:
train_losses = []
val_losses = []
learning_rates = []

print("Starting training...")
print("=" * 50)
print(f"Training: {num_epochs} epochs, batch size {batch_size}")
print(f"Learning rate will reduce by 50% every 20 epochs")

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 30)
    
    train_metrics = train_one_epoch(model, optimizer, train_loader, device, epoch+1)
    val_metrics = evaluate(model, val_loader, device, class_names)
    
    lr_scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    
    total_train_loss = sum(train_metrics.values())
    train_losses.append(total_train_loss)
    val_losses.append(val_metrics['val_loss'])
    learning_rates.append(current_lr)
    
    # Log to wandb
    log_dict = {
        'epoch': epoch+1,
        'train_loss': total_train_loss,
        'val_loss': val_metrics['val_loss'],
        'learning_rate': current_lr,
        'mAP': val_metrics['mAP'],
        'precision': val_metrics['precision'],
        'recall': val_metrics['recall'],
        'f1_score': val_metrics['f1_score']
    }
    
    for class_name in class_names:
        log_dict[f'mAP_{class_name}'] = val_metrics[f'mAP_{class_name}']
    
    wandb.log(log_dict)
    
    print(f"Results:")
    print(f"  - Train Loss: {total_train_loss:.4f}")
    print(f"  - Val Loss: {val_metrics['val_loss']:.4f}")
    print(f"  - Learning Rate: {current_lr:.6f}")
    print(f"  - mAP: {val_metrics['mAP']:.3f}")
    print(f"  - Precision: {val_metrics['precision']:.3f}")
    print(f"  - Recall: {val_metrics['recall']:.3f}")
    print(f"  - F1 Score: {val_metrics['f1_score']:.3f}")

print("\nTraining completed!")
print("=" * 50)

## 10. Visualize Training Progress

In [None]:
fig = plt.figure(figsize=(15, 5))

# Loss plot
ax1 = plt.subplot(1, 3, 1)
epochs = range(1, len(train_losses) + 1)
ax1.plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
ax1.plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2)
ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Learning rate
ax2 = plt.subplot(1, 3, 2)
ax2.plot(epochs, learning_rates, 'g-', linewidth=2)
ax2.set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Learning Rate')
ax2.grid(True, alpha=0.3)

# Per-class mAP
ax3 = plt.subplot(1, 3, 3)
last_val_metrics = evaluate(model, val_loader, device, class_names)
class_maps = [last_val_metrics[f'mAP_{class_name}'] for class_name in class_names]
ax3.bar(class_names, class_maps, color='skyblue')
ax3.set_title('mAP per Class', fontsize=14, fontweight='bold')
ax3.set_ylabel('mAP')
ax3.tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

wandb.log({"training_progress": wandb.Image(fig)})

## 11. Save Model

In [None]:
model_path = 'fasterrcnn_ppe_model.pth'
torch.save(model.state_dict(), model_path)
wandb.save(model_path)
wandb.finish()

print(f"Model saved to: {model_path}")