In [1]:
# Importing libraries
import os
import json
import torch
import torchvision
from PIL import Image
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor
from torchvision.models.detection import fasterrcnn_resnet50_fpn

In [None]:
# Dataset Class
class CocoFasterRCNNDataset(Dataset):
    def __init__(self, root, annotation_file, transforms=None):
        self.root = root
        self.transforms = transforms
        with open(annotation_file, 'r') as f:
            self.coco = json.load(f)
        self.image_map = {img['id']: img for img in self.coco['images']}
        self.image_list = list(self.image_map.values())
        self.annotations = {}
        for ann in self.coco['annotations']:
            self.annotations.setdefault(ann['image_id'], []).append(ann)
    def __len__(self):
        return len(self.image_list)
    def __getitem__(self, idx):
        img_info = self.image_list[idx]
        img_path = os.path.join(self.root, img_info['file_name'])
        img = Image.open(img_path).convert("RGB")
        image_id = img_info['id']
        anns = self.annotations.get(image_id, [])
        boxes, labels, areas, iscrowd = [], [], [], []
        for ann in anns:
            x, y, w, h = ann['bbox']
            boxes.append([x, y, x + w, y + h])
            labels.append(ann['category_id'])
            areas.append(ann['area'])
            iscrowd.append(ann.get('iscrowd', 0))
        target = {
            'boxes': torch.tensor(boxes, dtype=torch.float32),
            'labels': torch.tensor(labels, dtype=torch.int64),
            'image_id': torch.tensor([image_id]),
            'area': torch.tensor(areas, dtype=torch.float32),
            'iscrowd': torch.tensor(iscrowd, dtype=torch.uint8)
        }
        if self.transforms:
            img = self.transforms(img)
        return img, target

In [2]:
# Main Training Script
if __name__ == '__main__':
    # Paths — updated to point directly to folders containing images
    train_dir = r'dataset\train'
    valid_dir = r'dataset\valid'
    # Load datasets
    train_dataset = CocoFasterRCNNDataset(
        root=train_dir,
        annotation_file=os.path.join(train_dir, '_annotations.fixed.json'),
        transforms=ToTensor()
    )
    valid_dataset = CocoFasterRCNNDataset(
        root=valid_dir,
        annotation_file=os.path.join(valid_dir, '_annotations.fixed.json'),
        transforms=ToTensor()
    )
    # Data loaders
    train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
    valid_loader = DataLoader(valid_dataset, batch_size=2, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))
    # Model setup
    num_classes = 2  # 1 class (camouflage) + background
    model = fasterrcnn_resnet50_fpn(weights="DEFAULT")
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)
    # Optimizer
    optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=1e-4, weight_decay=1e-4)
    # Training loop
    num_epochs = 10
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        for images, targets in train_loader:
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            total_loss += losses.item()
        print(f"\n[Epoch {epoch+1}/{num_epochs}] Training Loss: {total_loss:.4f}")
        # Basic Validation
        model.eval()
        print(f"[Epoch {epoch+1}] Validation Prediction Summary:")
        with torch.no_grad():
            for i, (images, targets) in enumerate(valid_loader):
                images = [img.to(device) for img in images]
                outputs = model(images)
                for j, output in enumerate(outputs):
                    boxes = output['boxes'].cpu().numpy()
                    scores = output['scores'].cpu().numpy()
                    print(f"  Val sample {i*len(images)+j+1}: {len(boxes)} boxes, top score = {scores[0]:.2f}" if len(scores) > 0 else "  No detections")
    # Saving model
    torch.save(model.state_dict(), 'fasterrcnn_camouflage.pth')
    print("Model saved")


[Epoch 1/10] Training Loss: 238.7373
[Epoch 1] Validation Prediction Summary:
  Val sample 1: 17 boxes, top score = 0.92
  Val sample 2: 19 boxes, top score = 0.97
  Val sample 3: 12 boxes, top score = 0.94
  Val sample 4: 8 boxes, top score = 0.92
  Val sample 5: 16 boxes, top score = 0.89
  Val sample 6: 17 boxes, top score = 0.83
  Val sample 7: 7 boxes, top score = 0.96
  Val sample 8: 20 boxes, top score = 0.87
  Val sample 9: 17 boxes, top score = 0.91
  Val sample 10: 25 boxes, top score = 0.86
  Val sample 11: 44 boxes, top score = 0.89
  Val sample 12: 17 boxes, top score = 0.90
  Val sample 13: 17 boxes, top score = 0.78
  Val sample 14: 23 boxes, top score = 0.56
  Val sample 15: 25 boxes, top score = 0.89
  Val sample 16: 12 boxes, top score = 0.86
  Val sample 17: 19 boxes, top score = 0.86
  Val sample 18: 35 boxes, top score = 0.80
  Val sample 19: 17 boxes, top score = 0.81
  Val sample 20: 19 boxes, top score = 0.89
  Val sample 21: 15 boxes, top score = 0.86
  Val sa