In [None]:
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
from coco_dataset import COCODetectionDataset


In [None]:
# Set paths
train_img_dir = 'data/coco_top10_2000/data'
val_img_dir = 'data/coco_top10_2000/data'
# train_ann_file = 'data/coco_top10_2000/labels.json'
# val_ann_file = 'data/coco_top10_2000/labels.json'

train_ann_file = 'data/coco_top10_2000/labels_balanced.json'
val_ann_file = 'data/coco_top10_2000/labels_balanced.json'

num_classes = 11  # 10 classes + 1 background
num_epochs = 5
batch_size = 2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


In [None]:
def collateFn(batch):
    return tuple(zip(*batch))

transform = transforms.ToTensor()

trainData = COCODetectionDataset(train_img_dir, train_ann_file, transforms=transform)
valData = COCODetectionDataset(val_img_dir, val_ann_file, transforms=transform)

trainLoader = DataLoader(trainData, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collateFn)
valLoader = DataLoader(valData, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collateFn)

print(f'Train samples: {len(trainData)}')
print(f'Val samples: {len(valData)}')

In [None]:
def getModel(num_classes=11):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights='DEFAULT')
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

model = getModel(num_classes)
model.to(device)
print(f"Using device: {device}")


## Training

In [None]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lrScheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

for i in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, targets in tqdm(trainLoader, desc=f'Epoch {i + 1}/{num_epochs}'):
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        lossDIct = model(images, targets)
        losses = sum(loss for loss in lossDIct.values())
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        running_loss += losses.item()
    lrScheduler.step()
    print(f'Epoch {i+1}, Loss: {running_loss/len(trainLoader):.4f}')

In [None]:
import traceback
model.train()
valLoss = 0.0
numBatches = 0
with torch.no_grad():
    for i, (images, targets) in enumerate(tqdm(valLoader, desc='Validation')):
        print(f"Batch {i}: {[len(t['boxes']) for t in targets]}")
        skipBatch = False
        for t in targets:
            if len(t['boxes']) == 0 or len(t['boxes']) != len(t['labels']):
                print(f"Skipping batch {i} due to empty or mismatched targets")
                skipBatch = True
                break
        if skipBatch:
            continue
        try:
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            lossDIct = model(images, targets)
            losses = sum(loss for loss in lossDIct.values())
            valLoss += losses.item()
            numBatches += 1
        except Exception as e:
            print(f"Error in batch {i}: {e}")
            traceback.print_exc()
            print(f"Target boxes: {[t['boxes'].shape for t in targets]}")
            print(f"Target labels: {[t['labels'].shape for t in targets]}")
            continue
if numBatches > 0:
    print(f'Validation Loss: {valLoss / numBatches:.4f}')
else:
    print("No valid batches in validation set!")

## Inference Example

In [None]:
import random

classNames = valData.class_names
model.eval()
img, _ = valData[random.randint(0, len(valData) - 1)]
with torch.no_grad():
    prediction = model([img.to(device)])
boxes = prediction[0]['boxes'].cpu().numpy()
scores = prediction[0]['scores'].cpu().numpy()
labels = prediction[0]['labels'].cpu().numpy()

plt.figure(figsize=(10,10))
plt.imshow(transforms.ToPILImage()(img))
for box, score, label in zip(boxes, scores, labels):
    if score > 0.5:
        x1, y1, x2, y2 = box
        plt.gca().add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1, fill=False, edgecolor='red', linewidth=2))
        # label is 1-based, so subtract 1 for 0-based indexing
        className = classNames[label - 1] if 1 <= label <= len(classNames) else 'unknown'
        plt.text(x1, y1, f'{className}: {score:.2f}', color='white', fontsize=12, bbox=dict(facecolor='red', alpha=0.5))
plt.axis('off')
plt.show()

## Save Model

In [None]:
torch.save(model.state_dict(), 'model.pth')
print('Model saved as model.pth')

In [None]:
def loadModel(model_path, num_classes, device):
    from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
    import torchvision
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights='DEFAULT')
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

In [None]:
import matplotlib.pyplot as plt
from torchvision import transforms

# Load model
model_path = "model.pth"
num_classes = 11  # 10 classes + background
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = loadModel(model_path, num_classes, device)