In [14]:
# Standard library imports
import json
import os
import random
import shutil
import warnings
from collections import defaultdict
from pathlib import Path

# Third-party imports
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

# Computer vision imports
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.retinanet import RetinaNetHead
from torchmetrics.detection.mean_ap import MeanAveragePrecision

# Image processing imports
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches

warnings.filterwarnings("ignore")

In [15]:
# --- Unified Configuration ---
current_dir = Path.cwd()
project_root = current_dir.parents[1]

CONFIG = {
    # --- Dataset and Environment ---
    "DATA_DIR": project_root / "data" / "object_detection_dataset",
    "IMAGES_DIR": "images",
    "ANNOTATION_FILE": "_annotations.coco.json",
    "RESULTS_DIR": "research_results",
    "DEVICE": "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu",
    
    # --- Evaluation Parameters ---
    "K_SHOTS_TO_TEST": [1, 2, 3, 5, 10],
    "N_EVAL_EPISODES": 30,
    
    # --- Fine-Tuning Parameters ---
    "FT_EPOCHS": 15,
    "FT_LEARNING_RATE": 0.001,
    "FT_BATCH_SIZE": 1,

    # --- Learning Rate Scheduler Parameters ---
    "LR_SCHEDULER_STEP_SIZE": 5,
    "LR_SCHEDULER_GAMMA": 0.5,

    # --- Visualization ---
    "VISUALIZE_SAMPLES": 30,
    "CONFIDENCE_THRESHOLD": 0.5
}

# Experiment definitions (simplified)
EXPERIMENTS = [
    {"name": "TFA_fc_FasterRCNN", "freeze": "backbone", "classifier": "fc"},
    {"name": "TFA_cos_FasterRCNN", "freeze": "backbone", "classifier": "cos", "scale": 20.0},
    {"name": "Full_Finetune_FasterRCNN", "freeze": "none", "classifier": "fc"},
    {"name": "TFA_fc_Augmented", "freeze": "backbone", "classifier": "fc", "augment": True}
]

In [16]:
# --- Simplified Data Handling ---
def get_transforms(use_augmentation=False):
    """Simple transform factory."""
    if use_augmentation:
        return A.Compose([
            A.HorizontalFlip(p=0.5),
            A.ColorJitter(brightness=0.2, contrast=0.2, p=0.5),
            ToTensorV2(),
        ], bbox_params=A.BboxParams(format='coco', label_fields=['labels']))
    return A.Compose([ToTensorV2()], bbox_params=A.BboxParams(format='coco', label_fields=['labels']))

class SimpleCocoDataset(Dataset):
    """Simplified COCO dataset with essential functionality."""
    
    def __init__(self, transforms=None):
        self.root = os.path.join(CONFIG["DATA_DIR"], CONFIG["IMAGES_DIR"])
        self.transforms = transforms
        
        # Load COCO data
        with open(os.path.join(CONFIG["DATA_DIR"], CONFIG["ANNOTATION_FILE"])) as f:
            coco_data = json.load(f)
        
        self.images = coco_data['images']
        self.annotations = coco_data['annotations']
        self.categories = coco_data['categories']
        
        # Create mappings
        self._build_mappings()
        
    def _build_mappings(self):
        """Build category and image mappings."""
        self.cat_id_to_name = {cat['id']: cat['name'] for cat in self.categories}
        self.cat_id_to_label = {cat['id']: i + 1 for i, cat in enumerate(self.categories)}
        self.label_to_name = {v: self.cat_id_to_name[k] for k, v in self.cat_id_to_label.items()}
        
        # Group annotations by image
        self.img_to_anns = defaultdict(list)
        self.cat_to_imgs = defaultdict(set)
        for ann in self.annotations:
            self.img_to_anns[ann['image_id']].append(ann)
            self.cat_to_imgs[ann['category_id']].add(ann['image_id'])
        
        self.img_ids = [img['id'] for img in self.images]
        self.id_to_idx = {img_id: i for i, img_id in enumerate(self.img_ids)}
    
    def __getitem__(self, idx):
        img_info = self.images[idx]
        img_path = os.path.join(self.root, img_info['file_name'])
        img = np.array(Image.open(img_path).convert("RGB"))
        
        # Process annotations
        anns = self.img_to_anns[img_info['id']]
        boxes, labels = self._process_annotations(anns, img.shape[:2])
        
        # Apply transforms
        if self.transforms:
            transformed = self.transforms(image=img, bboxes=boxes, labels=labels)
            img_tensor = transformed['image'] / 255.0
            boxes = self._convert_boxes(transformed['bboxes'])
            labels = torch.tensor(transformed['labels'], dtype=torch.int64)
        else:
            img_tensor = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
            boxes = self._convert_boxes(boxes)
            labels = torch.tensor(labels, dtype=torch.int64)
            
        target = {"boxes": boxes, "labels": labels}
        original_img = Image.fromarray(img)
        
        return img_tensor, target, original_img
    
    def _process_annotations(self, anns, img_shape):
        """Extract valid boxes and labels from annotations."""
        boxes, labels = [], []
        img_h, img_w = img_shape
        
        for ann in anns:
            x, y, w, h = ann['bbox']
            # Clamp to image bounds
            x_max, y_max = min(x + w, img_w), min(y + h, img_h)
            if x_max > x and y_max > y:
                boxes.append([x, y, x_max - x, y_max - y])
                labels.append(self.cat_id_to_label[ann['category_id']])
        
        # Handle empty annotations
        if not boxes:
            return self.__getitem__(random.randint(0, len(self) - 1))[:2]
            
        return boxes, labels
    
    def _convert_boxes(self, boxes):
        """Convert COCO format to PyTorch format."""
        if not boxes:
            return torch.zeros((0, 4), dtype=torch.float32)
        boxes_tensor = torch.tensor(boxes, dtype=torch.float32)
        # Convert from [x, y, w, h] to [x1, y1, x2, y2]
        boxes_tensor[:, 2] = boxes_tensor[:, 0] + boxes_tensor[:, 2]
        boxes_tensor[:, 3] = boxes_tensor[:, 1] + boxes_tensor[:, 3]
        return boxes_tensor
    
    def __len__(self):
        return len(self.images)

def collate_fn(batch):
    return tuple(zip(*batch))

In [17]:
# --- Simplified Model Components ---
class CosineSimilarityClassifier(nn.Module):
    """Cosine similarity classifier for few-shot learning."""
    def __init__(self, in_features, out_features, scale=20.0):
        super().__init__()
        self.scale = scale
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        nn.init.kaiming_uniform_(self.weight)

    def forward(self, x):
        return self.scale * F.linear(F.normalize(x), F.normalize(self.weight))

class CosineFastRCNNPredictor(FastRCNNPredictor):
    """Fast R-CNN predictor with cosine similarity."""
    def __init__(self, in_channels, num_classes, scale=20.0):
        super().__init__(in_channels, num_classes)
        self.cls_score = CosineSimilarityClassifier(in_channels, num_classes, scale)

def create_model(num_classes, exp_config):
    """Create and configure model based on experiment config."""
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(weights='DEFAULT')
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    num_classes_with_bg = num_classes + 1
    
    # Configure classifier
    if exp_config.get('classifier') == 'cos':
        scale = exp_config.get('scale', 20.0)
        model.roi_heads.box_predictor = CosineFastRCNNPredictor(in_features, num_classes_with_bg, scale)
    else:
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes_with_bg)
    
    # Apply freezing strategy
    if exp_config.get('freeze') == 'backbone':
        for param in model.backbone.parameters():
            param.requires_grad = False
        print("Backbone frozen")
    else:
        print("Model fully trainable")
        
    return model

In [18]:
# --- Simplified Helper Functions ---
def create_episode_split(dataset, k_shot):
    """Create support and query splits for few-shot learning."""
    support_img_ids = set()
    
    # Sample k images per category
    for cat_id in dataset.cat_to_imgs.keys():
        img_ids = list(dataset.cat_to_imgs[cat_id])
        random.shuffle(img_ids)
        support_img_ids.update(img_ids[:min(k_shot, len(img_ids))])
    
    # Create indices
    support_indices = [dataset.id_to_idx[img_id] for img_id in support_img_ids]
    query_indices = [i for i, img_id in enumerate(dataset.img_ids) if img_id not in support_img_ids]
    
    return (torch.utils.data.Subset(dataset, support_indices), 
            torch.utils.data.Subset(dataset, query_indices))

def evaluate_model(model, query_loader):
    """Evaluate model and return mAP metrics."""
    model.eval()
    metric = MeanAveragePrecision(box_format='xyxy').to(CONFIG['DEVICE'])
    
    with torch.no_grad():
        for images, targets, _ in query_loader:
            images = [img.to(CONFIG['DEVICE']) for img in images]
            targets = [{k: v.to(CONFIG['DEVICE']) for k, v in t.items()} for t in targets]
            predictions = model(images)
            metric.update(predictions, targets)
    
    results = metric.compute()
    return {
        'map': results['map'].item(),
        'map_50': results['map_50'].item(), 
        'map_75': results['map_75'].item()
    }

def train_model(model, support_loader, exp_config):
    """Train model on support set."""
    # Setup optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=CONFIG["FT_LEARNING_RATE"], 
                               momentum=0.9, weight_decay=0.0005)
    
    # Setup scheduler if needed
    scheduler = None
    if exp_config.get("use_scheduler"):
        scheduler = StepLR(optimizer, step_size=CONFIG["LR_SCHEDULER_STEP_SIZE"], 
                          gamma=CONFIG["LR_SCHEDULER_GAMMA"])
    
    # Training loop
    model.train()
    if exp_config.get('freeze') == 'backbone':
        model.backbone.eval()
    
    for epoch in range(CONFIG["FT_EPOCHS"]):
        for images, targets, _ in support_loader:
            images = [img.to(CONFIG['DEVICE']) for img in images]
            targets = [{k: v.to(CONFIG['DEVICE']) for k, v in t.items()} for t in targets]
            
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
        
        if scheduler:
            scheduler.step()

def run_single_experiment(exp_config, dataset):
    """Run a complete experiment for one configuration."""
    exp_name = exp_config['name']
    results_dir = f"{CONFIG['RESULTS_DIR']}/{exp_name}"
    os.makedirs(results_dir, exist_ok=True)
    
    print(f"\n{'='*50}\nRUNNING: {exp_name}\n{'='*50}")
    
    # Configure dataset transforms
    use_augment = exp_config.get('augment', False)
    dataset.transforms = get_transforms(use_augment)
    
    results = defaultdict(list)
    num_classes = len(dataset.categories)
    
    for k in CONFIG["K_SHOTS_TO_TEST"]:
        print(f"\n--- K={k} shots ---")
        episode_results = defaultdict(list)
        
        for episode in tqdm(range(CONFIG["N_EVAL_EPISODES"]), desc=f"K={k}"):
            # Create model and splits
            model = create_model(num_classes, exp_config).to(CONFIG['DEVICE'])
            support_set, query_set = create_episode_split(dataset, k)
            
            if not support_set or not query_set:
                continue
                
            # Train and evaluate
            support_loader = DataLoader(support_set, batch_size=CONFIG["FT_BATCH_SIZE"], 
                                      shuffle=True, collate_fn=collate_fn)
            query_loader = DataLoader(query_set, batch_size=4, shuffle=False, collate_fn=collate_fn)
            
            train_model(model, support_loader, exp_config)
            metrics = evaluate_model(model, query_loader)
            
            for metric, value in metrics.items():
                episode_results[metric].append(value)
        
        # Aggregate results
        if episode_results['map']:
            results['k'].append(k)
            for metric in ['map', 'map_50', 'map_75']:
                values = episode_results[metric]
                results[f'{metric}_mean'].append(np.mean(values))
                results[f'{metric}_std'].append(np.std(values))
    
    # Save results
    with open(f"{results_dir}/results.json", 'w') as f:
        json.dump(results, f, indent=2)
    
    return results

In [None]:
# --- Simplified Main Execution ---
def plot_results(results_dir, experiments):
    """Generate comparison plots for all metrics."""
    metrics = {'map': 'mAP (IoU .50-.95)', 'map_50': 'mAP@.50', 'map_75': 'mAP@.75'}
    
    for metric, display_name in metrics.items():
        plt.figure(figsize=(12, 8))
        
        for exp in experiments:
            results_file = f"{results_dir}/{exp['name']}/results.json"
            if os.path.exists(results_file):
                with open(results_file) as f:
                    data = json.load(f)
                if f'{metric}_mean' in data:
                    plt.errorbar(data['k'], data[f'{metric}_mean'], yerr=data[f'{metric}_std'],
                               fmt='-o', capsize=5, label=exp['name'], markersize=8)
        
        plt.title(f'Few-Shot Object Detection Comparison - {display_name}')
        plt.xlabel('Number of Support Shots (K)')
        plt.ylabel(display_name)
        plt.xticks(CONFIG["K_SHOTS_TO_TEST"])
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(f'{results_dir}/comparison_{metric}.png', dpi=300, bbox_inches='tight')
        plt.show()

def main():
    """Main execution function."""
    # Setup
    results_dir = CONFIG["RESULTS_DIR"]
    if os.path.exists(results_dir):
        shutil.rmtree(results_dir)
    os.makedirs(results_dir)
    
    # Load dataset
    dataset = SimpleCocoDataset()
    print(f"Dataset: {len(dataset)} images, {len(dataset.categories)} classes")
    print(f"Device: {CONFIG['DEVICE']}")
    
    # Run all experiments
    all_results = {}
    for exp_config in EXPERIMENTS:
        results = run_single_experiment(exp_config, dataset)
        all_results[exp_config['name']] = results
    
    # Generate plots
    plot_results(results_dir, EXPERIMENTS)
    print(f"\n✅ All experiments completed! Results saved to '{results_dir}/'")

# Run experiments
if __name__ == '__main__':
    main()

Dataset: 298 images, 7 classes
Device: mps

RUNNING: TFA_fc_FasterRCNN

--- K=1 shots ---


K=1:   0%|          | 0/30 [00:00<?, ?it/s]

Backbone frozen
