In [2]:
import os
import json
from os.path import join
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from sklearn.metrics import confusion_matrix

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torchvision import models
import torchvision.transforms as transforms
import warnings
warnings.filterwarnings("ignore")

In [3]:
"""
Implement and test the utilities in support of evaluating the results
from the region-by-region decisions and turning them into detections.

All rectangles are four component lists (or tuples) giving the upper
left and lower right corners of an axis-aligned rectangle.  For example, 
[2, 9, 12, 18] has upper left corner (2,9) and lower right (12, 18)

The region predictions for an image are stored in a list of dictionaries,
each giving the class, the activation and the bounding rectangle.
For example,

{
    "class": 2,
    "a":  0.67,
    "rectangle": (18, 14, 50, 75)
}

if the class is 0 this means there is no detection and the rectangle
should be ignored.  The region predictions must be turned into the
detection results by filtering those with class 0 and through non
maximum supression.  The resulting regions should be considered the
"detections" for the image.

After this, detections should be compared to the ground truth 

The ground truth regions for an image are stored as a list of dictionaries. 
Each dictionary contains the region's class and bounding rectangle.
Here is an example dictionary:

{
    "class":  3,
    "rectangle": (15, 20, 56, 65)
}

Class 0 will not appear in the ground truth.  
"""

def area(rect):
    h = rect[3] - rect[1]
    w = rect[2] - rect[0]
    return h * w


def iou(rect1, rect2):
    """
    Input: two rectangles
    Output: IOU value, which should be 0 if the rectangles do not overlap.
    """ 
    x0, y0, x1, y1 = rect1
    u0, v0, u1, v1 = rect2
    ir = (max(x0, u0), max(y0, v0), min(x1, u1), min(y1, v1))
    if ir[0] >= ir[2] or ir[1] >= ir[3]:
        return 0
    else:
        return area(ir) / (area(rect1) + area(rect2) - area(ir))


def predictions_to_detections(predictions, iou_threshold=0.5):
    """
    Input: List of region predictions

    Output: List of region predictions that are considered to be
    detection results. These are ordered by activation with all class
    0 predictions eliminated, and non-maximum suppression
    applied using the standard greedy algorithm.
    """
    predictions = [p for p in predictions if p['class'] != 0]

    # Sort by activation in descending order
    predictions.sort(key=lambda item: item['a'], reverse=True)

    detections = []  # Final selected detections
    detected_classes = set()  # Track detected classes

    while predictions:
        # Pick the highest-confidence prediction
        current = predictions.pop(0)

        # Skip if this class is already detected
        if current['class'] in detected_classes:
            continue

        # Add the selected box to final detections
        detections.append(current)
        detected_classes.add(current['class'])

        # Remove overlapping boxes of the same class (NMS)
        predictions = [p for p in predictions 
                       if p['class'] != current['class'] or iou(p['rectangle'], current['rectangle']) < iou_threshold]

    return detections





def evaluate(detections, gt_detections, iou_threshold=0.5):
    """
    Input:
    1. The detections returned by the predictions_to_detections function
    2. The list of ground truth regions, and
    3. The IOU threshold

    The calculation must compare each detection region to the ground
    truth detection regions to determine which are correct and which
    are incorrect.  Finally, it must compute the average precision for
    up to n detections.

    Returns:
    list of correct detections,
    list of incorrect detections,
    list of ground truth regions that are missed,
    AP@n value.
    """
    # Initialize variables
    gtd_num = len(gt_detections)
    b = []
    correct_dets = []
    incorrect_dets = []

    # Precompute IOUs and store the best match
    for d in detections:
        index, g, highest_iou = None, None, 0

        # Only compare if classes match
        for i, gt in enumerate(gt_detections):
            if d['class'] == gt['class']:
                curr_iou = iou(d['rectangle'], gt['rectangle'])
                if curr_iou > highest_iou:
                    highest_iou, index, g = curr_iou, i, gt

        # Check if the best IOU exceeds threshold
        if g is None or highest_iou < iou_threshold:
            b.append(0)
            incorrect_dets.append(d)
        else:
            b.append(1)
            correct_dets.append(d)
            del gt_detections[index]  # Remove matched detection

    # Missed detections are the ones still in gt_detections
    missed = gt_detections


    cumsum = np.cumsum( b )
    b = np.array(b)
    precision = cumsum/np.arange(1, len(b)+1)

    recall = cumsum / gtd_num

    recall_levels = np.linspace(0,1,11)
    max_precision_at_recall = []

    for recall_level in recall_levels:
        valid_precisions = precision[recall >= recall_level]  # Filtered precision values

        # Check if there are valid precision values before calling np.max()
        max_precision = np.max(valid_precisions) if valid_precisions.size > 0 else 0
        
        max_precision_at_recall.append(max_precision)
        
    AP_n = np.mean(max_precision_at_recall)

    return correct_dets, incorrect_dets, missed, AP_n

In [4]:
def test_iou():
    """
    Use this function for you own testing of your IOU function
    """
    # should be .370
    rect1 = (0, 5, 11, 15)
    rect2 = (2, 9, 12, 18)
    res = iou(rect1, rect2)
    print(f"iou for {rect1} {rect2} is {res:1.2f}")

    # should be 0
    rect1 = (2, -3, 11, 4)
    res = iou(rect1, rect2)
    print(f"iou for {rect1} {rect2} is {res:1.2f}")

    # should be 0.2
    rect1 = (3, 12, 9, 15)
    res = iou(rect1, rect2)
    print(f"iou for {rect1} {rect2} is {res:1.2f}")

test_iou()

iou for (0, 5, 11, 15) (2, 9, 12, 18) is 0.37
iou for (2, -3, 11, 4) (2, 9, 12, 18) is 0.00
iou for (3, 12, 9, 15) (2, 9, 12, 18) is 0.20


In [5]:
def test_evaluation_code(in_json_file):
    with open(in_json_file, "r") as in_fp:
        data = json.load(in_fp)
    
    region_predictions = data["region_predictions"]
    gt_detections = data["gt_detections"]

    detections = predictions_to_detections(region_predictions)
    print(f"DETECTIONS: count = {len(detections)}")
    if len(detections) >= 2:
        print(f"DETECTIONS: first activation {detections[0]['a']:.2f}" )
        print(f"DETECTIONS: last activation {detections[-1]['a']:.2f}")
    elif len(detections) == 1:
        print(f"DETECTIONS: only activation {detections[0]['a']:.2f}")
    else:
        print(f"DETECTIONS: no activations")

    correct, incorrect, missed, ap = evaluate(detections, gt_detections)

    print(f"AP: num correct {len(correct)}")
    if len(correct) > 0:
        print(f"AP: first correct activation {correct[0]['a']:.2f}")

    print(f"AP: num incorrect {len(incorrect)}")
    if len(incorrect) > 0:
        print(f"AP: first incorrect activation {incorrect[0]['a']:.2f}")

    print(f"AP: num ground truth missed {len(missed)}")
    print(f"AP: final AP value {ap:1.3f}")


In [6]:
test_evaluation_code('eval_test1.json')

DETECTIONS: count = 2
DETECTIONS: first activation 0.90
DETECTIONS: last activation 0.70
AP: num correct 1
AP: first correct activation 0.90
AP: num incorrect 1
AP: first incorrect activation 0.70
AP: num ground truth missed 2
AP: final AP value 0.364


In [7]:
test_evaluation_code('eval_test2.json')

DETECTIONS: count = 5
DETECTIONS: first activation 0.94
DETECTIONS: last activation 0.55
AP: num correct 4
AP: first correct activation 0.90
AP: num incorrect 1
AP: first incorrect activation 0.94
AP: num ground truth missed 1
AP: final AP value 0.655


In [8]:
test_evaluation_code('eval_test3.json')

DETECTIONS: count = 1
DETECTIONS: only activation 0.94
AP: num correct 0
AP: num incorrect 1
AP: first incorrect activation 0.94
AP: num ground truth missed 1
AP: final AP value 0.000


In [9]:
test_evaluation_code('eval_test4.json')

DETECTIONS: count = 11
DETECTIONS: first activation 0.89
DETECTIONS: last activation 0.65
AP: num correct 10
AP: first correct activation 0.89
AP: num incorrect 1
AP: first incorrect activation 0.88
AP: num ground truth missed 1
AP: final AP value 0.835


In [10]:
'''
Skeleton model class. You will have to implement the classification and regression layers,
along with the forward method.
'''

class RCNN(nn.Module):
    def __init__(self):
        super(RCNN, self).__init__()

        # Pretrained backbone. If you are on the cci machine then this will not be able to automatically download
        #  the pretrained weights. You will have to download them locally then copy them over.
        #  During the local download it should tell you where torch is downloading the weights to, then copy them to 
        #  ~/.cache/torch/checkpoints/ on the supercomputer.
        resnet = models.resnet18(pretrained=True)

        # Remove the last fc layer of the pretrained network.
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])

        # Freeze backbone weights. 
        for param in self.backbone.parameters():
            param.requires_grad = False

        # TODO: Implement the fully connected layers for classification and regression.
        feature_dim = 512
        # Classification layer (num_classes + 1 for 'nothing' class)
        self.classifier = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 5))
        
        # Regression layer (4 coordinates for each class)
        self.regressor = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 4 * 5))

    def forward(self, x):
        x = self.backbone(x)
        x = torch.flatten(x, start_dim=1)
        
        class_logits = self.classifier(x)
        bbox_reg = self.regressor(x).view(-1, self.num_classes, 4)
        return class_logits, bbox_reg
    
    def compute_loss(self, class_logits, bbox_reg, gt_classes, gt_bboxs):
        classification_loss_fn = nn.CrossEntropyLoss()
        classification_loss = classification_loss_fn(class_logits, gt_classes)
        
        regression_loss = 0
        batch_size = class_logits.size(0)
        for i in range(batch_size):
            gt_class = gt_classes[i]
            if gt_class != 0:  # Only compute regression loss for non-'nothing' classes
                predicted_bbox = bbox_reg[i, gt_class - 1]  # -1 because class 0 is 'nothing'
                gt_bbox = gt_bboxs[i]
                regression_loss += torch.sum((predicted_bbox - gt_bbox) ** 2)
        
        total_loss = classification_loss + (0.1 * regression_loss)  # Weighted sum
        return total_loss

In [11]:
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]


# Dictionaries mapping class labels to names.
LABELS_TO_NAMES = {0: 'nothing',
                   1: 'bicycle',
                   2: 'car',
                   3: 'motorbike',
                   4: 'person',}


LABELS_TO_NAMES_LARGE = {0: 'nothing',
                         1: 'aeroplane',
                         2: 'bicycle',
                         3: 'bird',
                         4: 'boat',
                         5: 'bottle',
                         6: 'bus',
                         7: 'car',
                         8: 'cat',
                         9: 'chair',
                         10: 'cow',
                         11: 'diningtable',
                         12: 'dog',
                         13: 'horse',
                         14: 'motorbike',
                         15: 'person',
                         16: 'pottedplant',
                         17: 'sheep',
                         18: 'sofa',
                         19: 'train',
                         20: 'tvmonitor'}


class HW5Dataset(Dataset):
    def __init__(self, data_root, json_file, candidate_region_size=224, cache_dir=None):
        with open(json_file, 'r') as f:
            data_dict = json.load(f)

        self.data_root = data_root
        self.candidate_region_size = candidate_region_size
        self.cache_dir = cache_dir  # Directory to cache cropped images

        self.images = []
        self.candidate_bboxes = torch.empty((0, 4), dtype=int)
        self.ground_truth_bboxes = torch.empty((0, 4), dtype=int)
        self.ground_truth_classes = torch.empty(0, dtype=int)

        for key, values in data_dict.items():
            for val in values:
                self.images.append(key)
                self.candidate_bboxes = torch.cat((self.candidate_bboxes, torch.tensor(val['bbox']).unsqueeze(0)))
                self.ground_truth_bboxes = torch.cat((self.ground_truth_bboxes, torch.tensor(val['gt_bbox']).unsqueeze(0)))
                self.ground_truth_classes = torch.cat((self.ground_truth_classes, torch.tensor(val['class']).unsqueeze(0)))

        # Define transform for candidate regions (resize, tensor conversion, and normalization)
        self.transform = transforms.Compose([transforms.Resize((candidate_region_size, candidate_region_size)),
                                             transforms.ToTensor(),
                                             transforms.Normalize(mean=MEAN, std=STD)])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Load image.
        image_path = join(self.data_root, self.images[idx])
        image = Image.open(image_path)

        # Get candidate bounding box for the current image
        candidate_bbox = self.candidate_bboxes[idx, :]

        # Define a unique cache filename based on the image and bounding box
        cache_filename = f"{self.images[idx]}_{candidate_bbox[0]}_{candidate_bbox[1]}_{candidate_bbox[2]}_{candidate_bbox[3]}.pt"
        cache_filepath = join(self.cache_dir, cache_filename) if self.cache_dir else None

        # If cache_dir is provided and cached cropped image exists, load it
        if self.cache_dir and os.path.exists(cache_filepath):
            candidate_region = torch.load(cache_filepath)
        else:
            # Crop the image to the candidate region
            candidate_region = image.crop((candidate_bbox[0].item(), candidate_bbox[1].item(), candidate_bbox[2].item(), candidate_bbox[3].item()))

            # Transform (resize, tensor, normalize)
            candidate_region = self.transform(candidate_region)

            # Cache the cropped region for future access
            if self.cache_dir:
                os.makedirs(self.cache_dir, exist_ok=True)
                torch.save(candidate_region, cache_filepath)

        # Resize and normalize the ground truth bounding box
        gt_bbox = self.ground_truth_bboxes[idx, :]
        x_scale = self.candidate_region_size / (candidate_bbox[2] - candidate_bbox[0])
        y_scale = self.candidate_region_size / (candidate_bbox[3] - candidate_bbox[1])

        resized_gt_x0 = (gt_bbox[0] - candidate_bbox[0]) * x_scale / self.candidate_region_size
        resized_gt_y0 = (gt_bbox[1] - candidate_bbox[1]) * y_scale / self.candidate_region_size
        resized_gt_x1 = (gt_bbox[2] - candidate_bbox[0]) * x_scale / self.candidate_region_size
        resized_gt_y1 = (gt_bbox[3] - candidate_bbox[1]) * y_scale / self.candidate_region_size

        resized_gt_bbox = torch.tensor([resized_gt_x0, resized_gt_y0, resized_gt_x1, resized_gt_y1])

        return candidate_region, resized_gt_bbox, self.ground_truth_classes[idx]


class HW5DatasetTest(Dataset):
    """
    Dataset for Test.
    Input:
        data_root - path to the test image directory
        json_file - path to test.json
    Returns:
        image - numpy array A x B x 3 (RGB)
        candidate_regions - NUM_CANDIDATE_REGIONS x 3 x M x M tensor
        candidate_bboxes - all candidate bounding boxes for an image 
        ground_truth_bboxes - all ground truth bounding boxes for an image
        ground_truth_classes - all ground truth classes for an image
    """
    def __init__(self, data_root, json_file, candidate_region_size=224):
        with open(json_file, 'r') as f:
            data_dict = json.load(f)

        self.data_root = data_root

        self.images = []
        self.candidate_bboxes = []
        self.ground_truth_bboxes = []
        self.ground_truth_classes = []
        for key, values in data_dict.items():
            self.images.append(key)

            bboxes = torch.empty((len(values['candidate_bboxes']), 4), dtype=int)
            for i, bbox in enumerate(values['candidate_bboxes']):
                bboxes[i, :] = torch.tensor(bbox)
            self.candidate_bboxes.append(bboxes)

            labels = torch.empty((len(values['gt_bboxes'])), dtype=int)
            bboxes = torch.empty((len(values['gt_bboxes']), 4), dtype=int)
            for i, bbox in enumerate(values['gt_bboxes']):
                bboxes[i, :] = torch.tensor(bbox['bbox'])
                labels[i] = bbox['class']
            self.ground_truth_bboxes.append(bboxes)
            self.ground_truth_classes.append(labels)

        self.candidate_region_size = candidate_region_size

        # Transform to resize, convert to tensor, and normalize.
        self.transform = transforms.Compose([transforms.Resize((candidate_region_size, candidate_region_size)),
                                             transforms.ToTensor(),
                                             transforms.Normalize(mean=MEAN, std=STD)])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Load image.
        image_path = join(self.data_root, self.images[idx])
        image = Image.open(image_path)

        # Apply transform to resize and normalize the candidate images.
        idx_candidate_bboxes = self.candidate_bboxes[idx]
        candidate_regions = torch.empty((len(idx_candidate_bboxes), 3, self.candidate_region_size, self.candidate_region_size))
        for i, bbox in enumerate(idx_candidate_bboxes):
            candidate_region = image.crop((bbox[0].item(), bbox[1].item(), bbox[2].item(), bbox[3].item()))
            candidate_region = self.transform(candidate_region)
            candidate_regions[i] = candidate_region

        return np.array(image), candidate_regions, self.candidate_bboxes[idx], self.ground_truth_bboxes[idx], self.ground_truth_classes[idx]


In [12]:
def train_model(model, train_loader, val_loader, num_epochs=10, learning_rate=0.001, device='cuda'):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    best_val_loss = float('inf')
    train_losses, val_losses = [], []
    train_ious, val_ious = [], []
    
    # Initialize lists for confusion matrices
    all_train_preds, all_train_labels = [], []
    all_val_preds, all_val_labels = [], []
    
    for epoch in range(num_epochs):
        print(f"Training epoch {epoch+1}")
        
        # Clear prediction lists at start of each epoch
        epoch_train_preds, epoch_train_labels = [], []
        epoch_val_preds, epoch_val_labels = [], []
        
        # Training phase
        print(".train()")
        model.train()
        epoch_train_loss, epoch_train_iou = 0, []
        
        for batch, (images, gt_bboxes, gt_classes) in enumerate(train_loader):
            print("Converting to device")
            images, gt_bboxes, gt_classes = images.to(device), gt_bboxes.to(device, dtype=torch.float32), gt_classes.to(device, dtype=torch.long)
            
            # Forward pass
            class_logits, bbox_reg = model(images)
            
            # Compute loss
            loss = model.compute_loss(class_logits, bbox_reg, gt_classes, gt_bboxes)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_train_loss += loss.item()
            
            # Get predictions for confusion matrix
            _, pred_classes = torch.max(class_logits, 1)
            epoch_train_preds.extend(pred_classes.cpu().numpy())
            epoch_train_labels.extend(gt_classes.cpu().numpy())
            
            # Calculate IoU for correct classifications
            correct_mask = (pred_classes == gt_classes) & (gt_classes != 0)
            if correct_mask.any():
                for i in torch.where(correct_mask)[0]:
                    pred_bbox = bbox_reg[i, gt_classes[i]-1].detach().cpu().numpy()
                    true_bbox = gt_bboxes[i].detach().cpu().numpy()
                    epoch_train_iou.append(iou(pred_bbox, true_bbox))
        
        # Validation phase
        print(f"Beginning validation for epoch {epoch+1}")
        model.eval()
        epoch_val_loss, epoch_val_iou = 0, []
        
        with torch.no_grad():
            for images, gt_bboxes, gt_classes in val_loader:
                images, gt_bboxes, gt_classes = images.to(device), gt_bboxes.to(device, dtype=torch.float32), gt_classes.to(device, dtype=torch.long)
                
                # Forward pass
                class_logits, bbox_reg = model(images)
                loss = model.compute_loss(class_logits, bbox_reg, gt_classes, gt_bboxes)
                epoch_val_loss += loss.item()
                
                # Get predictions for confusion matrix
                _, pred_classes = torch.max(class_logits, 1)
                epoch_val_preds.extend(pred_classes.cpu().numpy())
                epoch_val_labels.extend(gt_classes.cpu().numpy())
                
                # Calculate IoU for correct classifications
                correct_mask = (pred_classes == gt_classes) & (gt_classes != 0)
                if correct_mask.any():
                    for i in torch.where(correct_mask)[0]:
                        pred_bbox = bbox_reg[i, gt_classes[i]-1].cpu().numpy()
                        true_bbox = gt_bboxes[i].cpu().numpy()
                        epoch_val_iou.append(iou(pred_bbox, true_bbox))
        
        # Store predictions for final confusion matrix
        all_train_preds.extend(epoch_train_preds)
        all_train_labels.extend(epoch_train_labels)
        all_val_preds.extend(epoch_val_preds)
        all_val_labels.extend(epoch_val_labels)
        
        # Calculate epoch metrics
        avg_train_loss = epoch_train_loss / len(train_loader)
        avg_val_loss = epoch_val_loss / len(val_loader)
        avg_train_iou = np.mean(epoch_train_iou) if epoch_train_iou else 0
        avg_val_iou = np.mean(epoch_val_iou) if epoch_val_iou else 0
        
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        train_ious.append(avg_train_iou)
        val_ious.append(avg_val_iou)
        
        # Model checkpointing based on validation loss
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_model.pth')
        
        # Print training and validation results for each epoch
        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}')
        print(f'Train IoU: {avg_train_iou:.4f} | Val IoU: {avg_val_iou:.4f}')
        print('-' * 50)
    
    # Plot training curves
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(train_ious, label='Train IoU')
    plt.plot(val_ious, label='Val IoU')
    plt.xlabel('Epoch')
    plt.ylabel('IoU')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    # Confusion matrix plotting function
    def plot_confusion_matrix(cm, classes, title='Confusion Matrix'):
        plt.figure(figsize=(10, 8))
        plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
        plt.title(title)
        plt.colorbar()
        tick_marks = np.arange(len(classes))
        plt.xticks(tick_marks, classes, rotation=45)
        plt.yticks(tick_marks, classes)
        
        thresh = cm.max() / 2.
        for i in range(cm.shape[0]):
            for j in range(cm.shape[1]):
                plt.text(j, i, format(cm[i, j], 'd'),
                         horizontalalignment="center",
                         color="white" if cm[i, j] > thresh else "black")
        
        plt.ylabel('True label')
        plt.xlabel('Predicted label')
        plt.tight_layout()
        plt.show()
    
    # Training confusion matrix
    train_cm = confusion_matrix(all_train_labels, all_train_preds)
    plot_confusion_matrix(train_cm, list(LABELS_TO_NAMES.values()), 'Training Confusion Matrix')
    
    # Validation confusion matrix
    val_cm = confusion_matrix(all_val_labels, all_val_preds)
    plot_confusion_matrix(val_cm, list(LABELS_TO_NAMES.values()), 'Validation Confusion Matrix')
    
    return model


In [13]:
def visualize_detections(image, detections, gt_detections, class_names):
    """
    Visualize detections and ground truth on an image
    :param image: PIL Image or numpy array
    :param detections: List of detection dictionaries with 'rectangle', 'class', 'a'
    :param gt_detections: List of ground truth dictionaries with 'rectangle', 'class'
    :param class_names: Dictionary mapping class IDs to names
    """
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    fig, ax = plt.subplots(1, figsize=(12, 12))
    ax.imshow(image)
    
    # Draw ground truth boxes (yellow)
    for gt in gt_detections:
        rect = gt['rectangle']
        box = patches.Rectangle((rect[0], rect[1]), rect[2]-rect[0], rect[3]-rect[1],
                               linewidth=2, edgecolor='y', facecolor='none')
        ax.add_patch(box)
        plt.text(rect[0], rect[1]-5, f"GT: {class_names[gt]['class']}",
                color='y', fontsize=10, bbox=dict(facecolor='black', alpha=0.5))
    
    # Draw detections (green=correct, red=incorrect)
    for det in detections:
        rect = det['rectangle']
        class_id = det['class']
        
        # Check if detection matches any ground truth
        is_correct = False
        for gt in gt_detections:
            if gt['class'] == class_id and iou(rect, gt['rectangle']) >= 0.5:
                is_correct = True
                break
        
        color = 'g' if is_correct else 'r'
        box = patches.Rectangle((rect[0], rect[1]), rect[2]-rect[0], rect[3]-rect[1],
                               linewidth=2, edgecolor=color, facecolor='none')
        ax.add_patch(box)
        plt.text(rect[0], rect[1]-5, f"{class_names[class_id]} ({det['a']:.2f})",
                color=color, fontsize=10, bbox=dict(facecolor='black', alpha=0.5))
    
    plt.axis('off')
    plt.show()

def evaluate_model(model, test_loader, class_names, device='cuda'):
    model.eval()
    model.to(device)
    
    aps = []
    class_stats = {class_id: {'tp': 0, 'fp': 0, 'fn': 0} 
                  for class_id in class_names if class_id != 0}
    
    # Specific test images to visualize
    visualize_indices = [0, 15, 27, 29, 38, 74, 77, 87, 92, 145]  # 0-based
    
    with torch.no_grad():
        for idx, (image, candidate_regions, candidate_bboxes, gt_bboxes, gt_classes) in enumerate(test_loader):
            if idx not in visualize_indices:
                continue
            
            # Move data to device
            candidate_regions = candidate_regions.to(device)
            
            # Get predictions
            class_logits, bbox_reg = model(candidate_regions)
            
            # Convert to detections
            _, pred_classes = torch.max(class_logits, 1)
            pred_scores = torch.softmax(class_logits, 1).max(1)[0]
            
            # Prepare detections for NMS
            detections = []
            for i in range(len(pred_classes)):
                if pred_classes[i] != 0:  # Skip 'nothing' class
                    class_id = pred_classes[i].item()
                    score = pred_scores[i].item()
                    # Convert bbox from relative to absolute coordinates
                    rel_bbox = bbox_reg[i, class_id-1].cpu().numpy()  # -1 because class 0 is nothing
                    abs_bbox = [
                        rel_bbox[0] * (candidate_bboxes[i][2] - candidate_bboxes[i][0]) + candidate_bboxes[i][0],
                        rel_bbox[1] * (candidate_bboxes[i][3] - candidate_bboxes[i][1]) + candidate_bboxes[i][1],
                        rel_bbox[2] * (candidate_bboxes[i][2] - candidate_bboxes[i][0]) + candidate_bboxes[i][0],
                        rel_bbox[3] * (candidate_bboxes[i][3] - candidate_bboxes[i][1]) + candidate_bboxes[i][1]
                    ]
                    detections.append({
                        'class': class_id,
                        'a': score,
                        'rectangle': abs_bbox
                    })
            
            # Apply NMS
            final_detections = predictions_to_detections(detections)
            
            # Prepare ground truth
            gt_detections = [{
                'class': gt_classes[i].item(),
                'rectangle': gt_bboxes[i].tolist()
            } for i in range(len(gt_classes))]
            
            # Evaluate
            correct_dets, incorrect_dets, missed_dets, ap = evaluate(final_detections, gt_detections)
            aps.append(ap)
            
            # Update class statistics
            for det in correct_dets:
                class_stats[det['class']]['tp'] += 1
            for det in incorrect_dets:
                class_stats[det['class']]['fp'] += 1
            for gt in missed_dets:
                class_stats[gt['class']]['fn'] += 1
            
            # Visualize
            print(f"Image {idx+1} AP: {ap:.4f}")
            visualize_detections(image[0], final_detections, gt_detections, class_names)
    
    # Calculate mAP and class statistics
    mAP = np.mean(aps)
    print(f"\nMean Average Precision (mAP): {mAP:.4f}")
    
    print("\nClass-wise Statistics:")
    for class_id, stats in class_stats.items():
        precision = stats['tp'] / (stats['tp'] + stats['fp']) if (stats['tp'] + stats['fp']) > 0 else 0
        recall = stats['tp'] / (stats['tp'] + stats['fn']) if (stats['tp'] + stats['fn']) > 0 else 0
        print(f"{class_names[class_id]}:")
        print(f"  Precision: {precision:.2%}")
        print(f"  Recall: {recall:.2%}")
        print(f"  TP: {stats['tp']}, FP: {stats['fp']}, FN: {stats['fn']}")
    
    return mAP

In [None]:
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load datasets
    train_dataset = HW5Dataset("../../../../Documents/hw5_2024_data/train", "train.json")
    val_dataset = HW5Dataset("../../../../Documents/hw5_2024_data/valid", "valid.json")
    test_dataset = HW5DatasetTest("../../../../Documents/hw5_2024_data/test", "test.json")

    print("Loaded data sets")
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)
    
    print("Loaded data sets into DataLoader")
    
    # Initialize and train model
    model = RCNN()  # 4 classes + nothing
    model = train_model(model, train_loader, val_loader, num_epochs=10, device=device)
    
    # Evaluate on test set
    evaluate_model(model, test_loader, LABELS_TO_NAMES, device=device)

Loaded data sets
Loaded data sets into DataLoader
Training epoch 1
.train()
