In [2]:
import torch
import torchvision
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim


In [3]:
# Device 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 


In [6]:
# 1. Load Pretrained Model 
# ----------------------------- 
model =  torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) 




In [8]:
# Freeze backbone layers 
for param in model.backbone.body.parameters(): 
    param.requires_grad = False 


In [10]:
# Replace the detection head (for custom classes) 
num_classes = 3 # background + 2 object classes 
in_features = model.roi_heads.box_predictor.cls_score.in_features 
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 
model = model.to(device) 


In [12]:
# 2. Dummy Dataset (for testing)
# -----------------------------
class DummyDataset(Dataset):
    def __init__(self, n=10):
        self.n = n

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        
        # Fake image (RGB, 3x224x224)
        img = torch.rand(3, 224, 224)
        
        # Fake bounding box and label
        target = {
            "boxes": torch.tensor([[30, 40, 180, 200]], dtype=torch.float32),
            "labels": torch.tensor([1], dtype=torch.int64)
        }
        return img, target

# DataLoader with collate_fn for object detection
train_loader = DataLoader(
    DummyDataset(8),
    batch_size=2,
    shuffle=True,
    collate_fn=lambda x: tuple(zip(*x))
)


In [26]:
# 3. Train Only New Head 
# ----------------------------- 
params = [p for p in model.parameters() if p.requires_grad] 
optimizer = optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005) 
num_epochs = 3 
for epoch in range(num_epochs): 
    model.train() 
    total_loss = 0 
    for images, targets in train_loader: 
        images = [img.to(device) for img in images] 
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 
        loss_dict = model(images, targets) 
        losses = sum(loss for loss in loss_dict.values()) 
        optimizer.zero_grad() 
        losses.backward() 
        optimizer.step() 
        total_loss += losses.item() 
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.4f}") 
print(" Training pipeline works (with dummy data).")


Epoch [1/3], Loss: 0.4574
Epoch [2/3], Loss: 0.4051
Epoch [3/3], Loss: 0.4377
 Training pipeline works (with dummy data).
