# Faster R-CNN K-Fold Cross Validation Training

## Configuration
- **Model**: Faster R-CNN with ResNet50-FPN backbone
- **Optimizer**: SGD (lr=0.005, momentum=0.9, weight_decay=0.0005)
- **Batch Size**: 16 (with gradient accumulation 4 steps)
- **Effective Batch Size**: 64
- **K-Folds**: 5
- **Epochs per Fold**: 20
- **Scheduler**: StepLR (reduce lr by 50% every 10 epochs)
- **Mixed Precision**: Enabled
- **Evaluation**: Every 5 epochs

## 1. Import Libraries

In [None]:
import os
import torch
import torchvision
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torch.utils.data import Dataset, DataLoader, Subset
from PIL import Image
import yaml
import wandb
from tqdm import tqdm
import numpy as np
from torchvision.transforms import functional as F
import matplotlib.pyplot as plt
import torchvision.ops as ops
from sklearn.model_selection import StratifiedKFold
import json
from datetime import datetime
from torch.cuda.amp import GradScaler, autocast

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Combined Dataset Class for K-Fold

In [None]:
class CombinedPPEDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        
        with open(os.path.join(root_dir, 'data.yaml'), 'r') as f:
            data_config = yaml.safe_load(f)
        self.class_names = data_config['names']
        self.num_classes = data_config['nc']
        
        # Combine train_aug and valid datasets
        self.train_img_dir = os.path.join(root_dir, 'train_aug', 'images')
        self.train_label_dir = os.path.join(root_dir, 'train_aug', 'labels')
        self.valid_img_dir = os.path.join(root_dir, 'valid', 'images')
        self.valid_label_dir = os.path.join(root_dir, 'valid', 'labels')
        
        train_files = [f for f in os.listdir(self.train_img_dir) if f.endswith('.jpg')]
        valid_files = [f for f in os.listdir(self.valid_img_dir) if f.endswith('.jpg')]
        
        self.img_files = []
        self.label_files = []
        self.stratified_labels = []
        
        # Add train files
        for f in train_files:
            self.img_files.append(os.path.join(self.train_img_dir, f))
            label_path = os.path.join(self.train_label_dir, f.replace('.jpg', '.txt'))
            self.label_files.append(label_path)
            self.stratified_labels.append(self._get_stratified_label(label_path))
        
        # Add valid files
        for f in valid_files:
            self.img_files.append(os.path.join(self.valid_img_dir, f))
            label_path = os.path.join(self.valid_label_dir, f.replace('.jpg', '.txt'))
            self.label_files.append(label_path)
            self.stratified_labels.append(self._get_stratified_label(label_path))
        
        print(f"Combined dataset: {len(self.img_files)} images")
        print(f"  - Train images: {len(train_files)}")
        print(f"  - Valid images: {len(valid_files)}")
    
    def _get_stratified_label(self, label_path):
        """Get stratified label for balanced K-fold split"""
        if not os.path.exists(label_path):
            return 0
        
        with open(label_path, 'r') as f:
            lines = f.readlines()
        
        if not lines:
            return 0
        
        classes = set()
        for line in lines:
            line = line.strip()
            if line:
                class_id = int(float(line.split()[0]))
                classes.add(class_id)
        
        return hash(tuple(sorted(classes))) % 1000
        
    def __len__(self):
        return len(self.img_files)
    
    def __getitem__(self, idx):
        img_path = self.img_files[idx]
        label_path = self.label_files[idx]
        
        image = Image.open(img_path).convert('RGB')
        image_tensor = F.to_tensor(image)
        
        boxes = []
        labels = []
        
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f.readlines():
                    line = line.strip()
                    if line:
                        parts = line.split()
                        class_id = int(float(parts[0]))
                        x_center, y_center, width, height = map(float, parts[1:])
                        
                        if width <= 0 or height <= 0:
                            continue
                        
                        img_width, img_height = image.size
                        x_min = (x_center - width/2) * img_width
                        y_min = (y_center - height/2) * img_height
                        x_max = (x_center + width/2) * img_width
                        y_max = (y_center + height/2) * img_height
                        
                        x_min = max(0, min(x_min, img_width))
                        y_min = max(0, min(y_min, img_height))
                        x_max = max(0, min(x_max, img_width))
                        y_max = max(0, min(y_max, img_height))
                        
                        if x_max > x_min and y_max > y_min:
                            boxes.append([x_min, y_min, x_max, y_max])
                            labels.append(class_id + 1)
        
        if len(boxes) > 0:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
        else:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)
        
        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': torch.tensor([idx])
        }
        
        return image_tensor, target

## 3. Model and Helper Functions

In [None]:
def get_model(num_classes):
    """Create Faster R-CNN model"""
    model = fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

def collate_fn(batch):
    """Custom collate function"""
    images, targets = zip(*batch)
    filtered_images = []
    filtered_targets = []
    
    for img, target in zip(images, targets):
        if img.shape[0] != 3:
            continue
            
        filtered_images.append(img)
        
        if len(target['boxes']) > 0:
            valid_boxes = []
            valid_labels = []
            
            for i, box in enumerate(target['boxes']):
                x1, y1, x2, y2 = box
                if x2 > x1 and y2 > y1:
                    valid_boxes.append(box)
                    valid_labels.append(target['labels'][i])
            
            if len(valid_boxes) > 0:
                filtered_target = {
                    'boxes': torch.stack(valid_boxes),
                    'labels': torch.stack(valid_labels),
                    'image_id': target['image_id']
                }
            else:
                filtered_target = {
                    'boxes': torch.zeros((0, 4), dtype=torch.float32),
                    'labels': torch.zeros((0,), dtype=torch.int64),
                    'image_id': target['image_id']
                }
        else:
            filtered_target = {
                'boxes': torch.zeros((0, 4), dtype=torch.float32),
                'labels': torch.zeros((0,), dtype=torch.int64),
                'image_id': target['image_id']
            }
        
        filtered_targets.append(filtered_target)
    
    return list(filtered_images), list(filtered_targets)

## 4. Load Combined Dataset

In [None]:
dataset_root = "/root/ocr-project/First/Detection/ppe-dataset-clean/versions/1/ppe-detection-project-dataset-c/versions/1"

with open(os.path.join(dataset_root, 'data.yaml'), 'r') as f:
    data_config = yaml.safe_load(f)

class_names = data_config['names']
num_classes = data_config['nc']

combined_dataset = CombinedPPEDataset(dataset_root)

print(f"Classes: {num_classes} ({class_names})")
print(f"Total dataset size: {len(combined_dataset)} images")

## 5. Training Configuration

In [None]:
# K-Fold Configuration
K_FOLDS = 5
batch_size = 16
gradient_accumulation_steps = 4
num_epochs = 20
learning_rate = 0.005
momentum = 0.9
weight_decay = 0.0005
num_workers = 2
use_mixed_precision = True

print(f"Configuration:")
print(f"  - K-Folds: {K_FOLDS}")
print(f"  - Batch Size: {batch_size}")
print(f"  - Gradient Accumulation: {gradient_accumulation_steps}")
print(f"  - Effective Batch Size: {batch_size * gradient_accumulation_steps}")
print(f"  - Epochs per Fold: {num_epochs}")
print(f"  - Mixed Precision: {use_mixed_precision}")
print(f"  - Evaluation: Every 5 epochs")

## 6. Initialize Weights & Biases

In [None]:
run = wandb.init(
    project=f"ppe-detection-kfold-{K_FOLDS}",
    config={
        'k_folds': K_FOLDS,
        'batch_size': batch_size,
        'gradient_accumulation_steps': gradient_accumulation_steps,
        'effective_batch_size': batch_size * gradient_accumulation_steps,
        'use_mixed_precision': use_mixed_precision,
        'num_epochs': num_epochs,
        'learning_rate': learning_rate,
        'momentum': momentum,
        'weight_decay': weight_decay,
        'num_classes': num_classes + 1,
        'dataset_size': len(combined_dataset)
    },
    reinit=True
)

# Define metrics
wandb.define_metric("fold")
wandb.define_metric("epoch")
wandb.define_metric("train_loss", step_metric="epoch")
wandb.define_metric("val_loss", step_metric="epoch")
wandb.define_metric("learning_rate", step_metric="epoch")
wandb.define_metric("mAP", step_metric="epoch")
wandb.define_metric("precision", step_metric="epoch")
wandb.define_metric("recall", step_metric="epoch")
wandb.define_metric("f1_score", step_metric="epoch")

for class_name in class_names:
    wandb.define_metric(f"mAP_{class_name}", step_metric="epoch")

## 7. Training and Evaluation Functions

In [None]:
def train_one_epoch(model, optimizer, data_loader, device, epoch, scaler=None):
    """Train for one epoch with gradient accumulation and mixed precision"""
    model.train()
    metric_logger = {}
    
    for idx, batch in enumerate(tqdm(data_loader, desc=f'Epoch {epoch}')):
        if len(batch) == 2:
            images, targets = batch
        else:
            continue
            
        if len(images) == 0:
            continue
            
        images = [image.to(device) for image in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        # Mixed precision training
        if use_mixed_precision and scaler is not None:
            with autocast():
                loss_dict = model(images, targets)
        else:
            loss_dict = model(images, targets)
        
        if isinstance(loss_dict, dict):
            losses = sum(loss for loss in loss_dict.values())
            for k, v in loss_dict.items():
                if k not in metric_logger:
                    metric_logger[k] = []
                metric_logger[k].append(v.item())
        else:
            losses = loss_dict
        
        # Scale loss for gradient accumulation
        losses = losses / gradient_accumulation_steps
        
        # Clear gradients
        if idx % gradient_accumulation_steps == 0:
            optimizer.zero_grad()
        
        # Backward pass
        if use_mixed_precision and scaler is not None:
            scaler.scale(losses).backward()
        else:
            losses.backward()
        
        # Update weights
        if (idx + 1) % gradient_accumulation_steps == 0:
            if use_mixed_precision and scaler is not None:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
        
        # Clear memory
        del images, targets, loss_dict, losses
        if idx % 10 == 0:
            torch.cuda.empty_cache()
    
    return {k: np.mean(v) for k, v in metric_logger.items()}

def calculate_detection_metrics(predictions, targets, class_names, iou_threshold=0.5):
    """Calculate mAP, precision, recall, F1"""
    metrics = {}
    num_classes = len(class_names)
    aps = []
    precisions = []
    recalls = []
    
    for class_idx in range(num_classes):
        class_name = class_names[class_idx]
        pred_boxes = []
        pred_scores = []
        gt_boxes = []
        
        for pred in predictions:
            if 'boxes' in pred and 'labels' in pred and 'scores' in pred:
                class_mask = pred['labels'] == (class_idx + 1)
                if class_mask.any():
                    pred_boxes.append(pred['boxes'][class_mask])
                    pred_scores.append(pred['scores'][class_mask])
        
        for target in targets:
            if 'boxes' in target and 'labels' in target:
                class_mask = target['labels'] == (class_idx + 1)
                if class_mask.any():
                    gt_boxes.append(target['boxes'][class_mask])
        
        if len(pred_boxes) > 0 and len(gt_boxes) > 0:
            all_pred_boxes = torch.cat(pred_boxes, dim=0)
            all_pred_scores = torch.cat(pred_scores, dim=0)
            all_gt_boxes = torch.cat(gt_boxes, dim=0)
            
            if len(all_pred_boxes) > 0 and len(all_gt_boxes) > 0:
                ious = ops.box_iou(all_pred_boxes, all_gt_boxes)
                sorted_indices = torch.argsort(all_pred_scores, descending=True)
                sorted_ious = ious[sorted_indices]
                
                tp = torch.zeros(len(sorted_ious))
                fp = torch.zeros(len(sorted_ious))
                gt_matched = torch.zeros(len(all_gt_boxes), dtype=torch.bool)
                
                for i in range(len(sorted_ious)):
                    max_iou, max_idx = torch.max(sorted_ious[i], dim=0)
                    if max_iou >= iou_threshold and not gt_matched[max_idx]:
                        tp[i] = 1
                        gt_matched[max_idx] = True
                    else:
                        fp[i] = 1
                
                tp_cumsum = torch.cumsum(tp, dim=0)
                fp_cumsum = torch.cumsum(fp, dim=0)
                precision = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-6)
                recall = tp_cumsum / len(all_gt_boxes)
                
                # 11-point interpolation AP
                ap = 0
                for t in torch.arange(0, 1.1, 0.1):
                    if torch.sum(recall >= t) == 0:
                        p = 0
                    else:
                        p = torch.max(precision[recall >= t])
                    ap += p / 11
                
                metrics[f'mAP_{class_name}'] = ap.item()
                aps.append(ap.item())
                precisions.append(precision[-1].item() if len(precision) > 0 else 0)
                recalls.append(recall[-1].item() if len(recall) > 0 else 0)
            else:
                metrics[f'mAP_{class_name}'] = 0
                aps.append(0)
                precisions.append(0)
                recalls.append(0)
        else:
            metrics[f'mAP_{class_name}'] = 0
            aps.append(0)
            precisions.append(0)
            recalls.append(0)
    
    metrics['mAP'] = np.mean(aps) if aps else 0
    metrics['precision'] = np.mean(precisions) if precisions else 0
    metrics['recall'] = np.mean(recalls) if recalls else 0
    
    if metrics['precision'] + metrics['recall'] > 0:
        metrics['f1_score'] = 2 * (metrics['precision'] * metrics['recall']) / (metrics['precision'] + metrics['recall'])
    else:
        metrics['f1_score'] = 0
    
    return metrics

def evaluate(model, data_loader, device, class_names):
    """Simple evaluation (loss only)"""
    model.eval()
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            if len(batch) == 2:
                images, targets = batch
            else:
                continue
                
            if len(images) == 0:
                continue
                
            images = [image.to(device) for image in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            model.train()
            if use_mixed_precision:
                with autocast():
                    loss_dict = model(images, targets)
            else:
                loss_dict = model(images, targets)
            model.eval()
            
            if isinstance(loss_dict, dict):
                losses = sum(loss for loss in loss_dict.values())
            else:
                losses = loss_dict
                
            total_loss += losses.item()
            num_batches += 1
            
            del images, targets, loss_dict, losses
            if num_batches % 25 == 0:
                torch.cuda.empty_cache()
    
    metrics = {'val_loss': total_loss / num_batches if num_batches > 0 else 0}
    for class_name in class_names:
        metrics[f'mAP_{class_name}'] = 0.0
    metrics.update({'mAP': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0})
    
    return metrics

def evaluate_full(model, data_loader, device, class_names):
    """Full evaluation with metrics (run every 5 epochs)"""
    print("  Running full evaluation...")
    model.eval()
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Full evaluation"):
            if len(batch) == 2:
                images, targets = batch
            else:
                continue
                
            if len(images) == 0:
                continue
                
            images = [image.to(device) for image in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            predictions = model(images)
            all_predictions.extend(predictions)
            all_targets.extend(targets)
            
            del images, targets, predictions
            torch.cuda.empty_cache()
    
    metrics = calculate_detection_metrics(all_predictions, all_targets, class_names)
    del all_predictions, all_targets
    torch.cuda.empty_cache()
    
    return metrics

## 8. K-Fold Cross Validation Training

In [None]:
stratified_kfold = StratifiedKFold(n_splits=K_FOLDS, shuffle=True, random_state=42)
indices = list(range(len(combined_dataset)))
stratified_labels = combined_dataset.stratified_labels

all_fold_results = []

print("Starting Stratified K-Fold Cross Validation...")
print("=" * 60)

for fold, (train_idx, val_idx) in enumerate(stratified_kfold.split(indices, stratified_labels)):
    print(f"\nFold {fold+1}/{K_FOLDS}")
    print("-" * 40)
    
    # Create datasets for this fold
    train_subset = Subset(combined_dataset, train_idx)
    val_subset = Subset(combined_dataset, val_idx)
    
    print(f"Train samples: {len(train_subset)}")
    print(f"Validation samples: {len(val_subset)}")
    
    # Create data loaders
    train_loader = DataLoader(
        train_subset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=num_workers,
        collate_fn=collate_fn
    )
    
    val_loader = DataLoader(
        val_subset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=num_workers,
        collate_fn=collate_fn
    )
    
    # Create model and optimizer
    model = get_model(num_classes + 1)
    model.to(device)
    
    optimizer = torch.optim.SGD(
        model.parameters(), 
        lr=learning_rate,
        momentum=momentum,
        weight_decay=weight_decay
    )
    
    # Scheduler: reduce LR by 50% every 10 epochs
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
    scaler = GradScaler() if use_mixed_precision else None
    
    # Training loop for this fold
    fold_val_losses = []
    fold_metrics = []
    
    for epoch in range(num_epochs):
        train_metrics = train_one_epoch(model, optimizer, train_loader, device, epoch+1, scaler)
        
        # Evaluate every 5 epochs or at the end
        if (epoch + 1) % 5 == 0 or epoch == num_epochs - 1:
            val_metrics = evaluate_full(model, val_loader, device, class_names)
            print(f"  Epoch {epoch+1}/{num_epochs}: Train Loss: {sum(train_metrics.values()):.4f}, mAP: {val_metrics['mAP']:.3f}")
        else:
            val_metrics = evaluate(model, val_loader, device, class_names)
            print(f"  Epoch {epoch+1}/{num_epochs}: Train Loss: {sum(train_metrics.values()):.4f}, Val Loss: {val_metrics['val_loss']:.4f}")
        
        lr_scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        
        total_train_loss = sum(train_metrics.values())
        fold_val_losses.append(val_metrics.get('val_loss', 0))
        fold_metrics.append(val_metrics)
        
        # Log to wandb
        log_dict = {
            'fold': fold+1,
            'epoch': epoch+1,
            'train_loss': total_train_loss,
            'val_loss': val_metrics.get('val_loss', 0),
            'learning_rate': current_lr,
            'mAP': val_metrics['mAP'],
            'precision': val_metrics['precision'],
            'recall': val_metrics['recall'],
            'f1_score': val_metrics['f1_score']
        }
        
        for class_name in class_names:
            log_dict[f'mAP_{class_name}'] = val_metrics[f'mAP_{class_name}']
        
        wandb.log(log_dict)
        
        del train_metrics, val_metrics
        torch.cuda.empty_cache()
    
    # Save model for this fold
    os.makedirs('kfold-model', exist_ok=True)
    model_path = f'kfold-model/fasterrcnn_ppe_fold{fold+1}.pth'
    torch.save(model.state_dict(), model_path)
    wandb.save(model_path)
    
    # Store results
    fold_result = {
        'fold': fold+1,
        'train_samples': len(train_subset),
        'val_samples': len(val_subset),
        'final_metrics': fold_metrics[-1],
        'model_path': model_path
    }
    all_fold_results.append(fold_result)
    
    print(f"  Fold {fold+1} completed - Final mAP: {fold_metrics[-1]['mAP']:.3f}")
    
    del model, optimizer, lr_scheduler, scaler
    torch.cuda.empty_cache()

print("\nK-Fold Cross Validation completed!")
print("=" * 60)

## 9. Aggregate Results and Visualization

In [None]:
# Calculate overall metrics
metric_keys = ['mAP', 'precision', 'recall', 'f1_score']
overall_metrics = {}

for metric in metric_keys:
    values = [result['final_metrics'][metric] for result in all_fold_results]
    overall_metrics[f'{metric}_mean'] = np.mean(values)
    overall_metrics[f'{metric}_std'] = np.std(values)

# Per-class mAP
for class_name in class_names:
    values = [result['final_metrics'][f'mAP_{class_name}'] for result in all_fold_results]
    overall_metrics[f'mAP_{class_name}_mean'] = np.mean(values)
    overall_metrics[f'mAP_{class_name}_std'] = np.std(values)

# Log overall results
wandb.log(overall_metrics)

# Print results
print(f"\nOverall Performance (Mean ± Std):")
print(f"  - mAP: {overall_metrics['mAP_mean']:.3f} ± {overall_metrics['mAP_std']:.3f}")
print(f"  - Precision: {overall_metrics['precision_mean']:.3f} ± {overall_metrics['precision_std']:.3f}")
print(f"  - Recall: {overall_metrics['recall_mean']:.3f} ± {overall_metrics['recall_std']:.3f}")
print(f"  - F1 Score: {overall_metrics['f1_score_mean']:.3f} ± {overall_metrics['f1_score_std']:.3f}")

print(f"\nPer-Class mAP:")
for class_name in class_names:
    mean_val = overall_metrics[f'mAP_{class_name}_mean']
    std_val = overall_metrics[f'mAP_{class_name}_std']
    print(f"  - {class_name}: {mean_val:.3f} ± {std_val:.3f}")

In [None]:
# Visualization
fig = plt.figure(figsize=(15, 5))

# mAP per fold
ax1 = plt.subplot(1, 3, 1)
fold_maps = [result['final_metrics']['mAP'] for result in all_fold_results]
ax1.bar(range(1, K_FOLDS+1), fold_maps, color='skyblue')
ax1.set_title('mAP per Fold', fontsize=14, fontweight='bold')
ax1.set_xlabel('Fold')
ax1.set_ylabel('mAP')

# Per-class mAP
ax2 = plt.subplot(1, 3, 2)
class_means = [overall_metrics[f'mAP_{c}_mean'] for c in class_names]
class_stds = [overall_metrics[f'mAP_{c}_std'] for c in class_names]
ax2.bar(class_names, class_means, yerr=class_stds, capsize=5, color='lightgreen')
ax2.set_title('Per-Class mAP', fontsize=14, fontweight='bold')
ax2.set_ylabel('mAP')
ax2.tick_params(axis='x', rotation=45)

# Overall metrics
ax3 = plt.subplot(1, 3, 3)
metrics_names = ['mAP', 'Precision', 'Recall', 'F1 Score']
metrics_means = [overall_metrics[f'{m.lower()}_mean'] for m in metrics_names]
ax3.bar(metrics_names, metrics_means, color='gold')
ax3.set_title('Overall Metrics', fontsize=14, fontweight='bold')
ax3.set_ylabel('Score')
ax3.tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

wandb.log({"kfold_results": wandb.Image(fig)})

## 10. Save Results

In [None]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_file = f"kfold_results_{timestamp}.json"

results_summary = {
    'k_folds': K_FOLDS,
    'total_samples': len(combined_dataset),
    'overall_metrics': overall_metrics,
    'fold_results': all_fold_results,
    'class_names': class_names,
    'timestamp': timestamp
}

with open(results_file, 'w') as f:
    json.dump(results_summary, f, indent=2)

wandb.save(results_file)
wandb.finish()

print(f"Results saved to: {results_file}")