In [None]:
%pip install -U albumentations

import torch
import torchvision
from pycocotools.cocoeval import COCOeval
import numpy as np
from torch.utils.data import DataLoader
import os
from tqdm import tqdm
import copy
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights
from torchvision.models.detection.retinanet import RetinaNetClassificationHead
from functools import partial
import torch.nn as nn
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.amp import GradScaler, autocast
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import random
import seaborn as sns

In [3]:
# Define the CocoDataset class
class CocoDataset(torchvision.datasets.CocoDetection):
    def __init__(self, root, annFile, transform=None):
        super(CocoDataset, self).__init__(root, annFile)
        self.transform = transform

    def __getitem__(self, index):
        img, target = super(CocoDataset, self).__getitem__(index)
        image_id = self.ids[index]
        if self.transform is not None:
            transformed = self.transform(
                image=np.array(img),
                bboxes=[ann["bbox"] for ann in target],
                labels=[ann["category_id"] for ann in target],
            )
            img = transformed["image"]
            target = [
                {"bbox": bbox, "category_id": label, "image_id": image_id}
                for bbox, label in zip(transformed["bboxes"], transformed["labels"])
            ]
        else:
            for ann in target:
                ann["image_id"] = image_id
        return img, target

In [4]:
# Define Albumentations transformations
train_transform = A.Compose(
    [
        A.Resize(640, 640),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Rotate(limit=20),
        A.RandomBrightnessContrast(p=0.5),
        A.RandomGamma(p=0.5),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ],
    bbox_params=A.BboxParams(format="coco", label_fields=["labels"]),
)

val_transform = A.Compose(
    [
        A.Resize(640, 640),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ],
    bbox_params=A.BboxParams(format="coco", label_fields=["labels"]),
)

In [None]:
# Create datasets
dataset_path = "/kaggle/input/oil-palm-bunch-3910"

train_dataset = CocoDataset(
    root=os.path.join(dataset_path, "train"),
    annFile=os.path.join(dataset_path, "train", "_annotations.coco.json"),
    transform=train_transform,
)

val_dataset = CocoDataset(
    root=os.path.join(dataset_path, "valid"),
    annFile=os.path.join(dataset_path, "valid", "_annotations.coco.json"),
    transform=val_transform,
)

test_dataset = CocoDataset(
    root=os.path.join(dataset_path, "test"),
    annFile=os.path.join(dataset_path, "test", "_annotations.coco.json"),
    transform=val_transform,
)

# Create data loaders
def collate_fn(batch):
    images = []
    targets = []

    for image, target in batch:
        images.append(image)

        boxes = []
        labels = []
        for ann in target:
            bbox = ann["bbox"]
            boxes.append([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]])
            labels.append(ann["category_id"])

        target_dict = {
            "boxes": torch.FloatTensor(boxes),
            "labels": torch.LongTensor(labels),
            "image_id": torch.tensor(
                [ann["image_id"] for ann in target][0] if target else 0
            ),
        }
        targets.append(target_dict)

    return images, targets


train_loader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=4,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=4,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
)

In [None]:
def visualize_samples(dataset, num_samples=5):
    fig, axs = plt.subplots(1, num_samples, figsize=(20, 5))
    category_ids = dataset.coco.getCatIds()
    category_names = {cat['id']: cat['name'] for cat in dataset.coco.loadCats(category_ids)}
    
    for i in range(num_samples):
        category_id = category_ids[i % len(category_ids)]
        img_ids = dataset.coco.getImgIds(catIds=[category_id])
        img_id = random.choice(img_ids)
        img_info = dataset.coco.loadImgs(img_id)[0]
        ann_ids = dataset.coco.getAnnIds(imgIds=img_id, catIds=[category_id])
        anns = dataset.coco.loadAnns(ann_ids)
        
        image = dataset.coco.loadImgs(img_id)[0]
        image = dataset.coco.loadImgs(img_id)[0]
        img_path = os.path.join(dataset.root, image['file_name'])
        img = plt.imread(img_path)
        
        axs[i].imshow(img)
        axs[i].axis('off')
        
        for ann in anns:
            bbox = ann['bbox']
            category_name = category_names[ann['category_id']]
            rect = patches.Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], linewidth=2, edgecolor='r', facecolor='none')
            axs[i].add_patch(rect)
            axs[i].text(bbox[0], bbox[1] - 10, category_name, color='red', fontsize=12, backgroundcolor='white')
        
        axs[i].set_title(f"({chr(97 + i)})", fontsize=16)
    
    plt.show()

# Visualize random 5 samples from different categories in the training dataset
visualize_samples(train_dataset, num_samples=5)

In [None]:
# Initialize model
num_classes = len(train_dataset.coco.getCatIds()) + 1

# Define the model
def load_model(num_classes=num_classes):
    model = retinanet_resnet50_fpn_v2(
        weights=RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
    )
    num_anchors = model.head.classification_head.num_anchors
    classification_head = RetinaNetClassificationHead(
        in_channels=256,
        num_anchors=num_anchors,
        num_classes=num_classes,
        norm_layer=partial(torch.nn.GroupNorm, 32)
    )
    classification_head.cls_logits = nn.Sequential(
        classification_head.conv,
        nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.Dropout(p=0.5),
        classification_head.cls_logits
    )
    model.head.classification_head = classification_head
    return model


# Load the model
model = load_model()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Optimize anchor boxes
model.anchor_generator.sizes = ((16, 32, 64, 128, 256, 512),)
model.anchor_generator.aspect_ratios = ((0.5, 1.0, 2.0),) * len(model.anchor_generator.sizes)

# Training setup
num_epochs = 50
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)

In [None]:
# Metrics calculation
def calculate_detection_metrics(coco_eval):
    # Get precision and recall arrays from COCO evaluation
    precision = coco_eval.eval['precision']
    recall = coco_eval.eval['recall']
    
    # Get precision and recall at IoU=0.5 (first index)
    precision_at_iou50 = precision[0, :, :, 0, -1]  # IoU=0.5, all categories, all areas
    recall_at_iou50 = recall[0, :, 0]  # IoU=0.5, all categories
    
    # Calculate mean precision and recall
    mean_precision = np.mean(precision_at_iou50[precision_at_iou50 > -1])
    mean_recall = np.mean(recall_at_iou50[recall_at_iou50 > -1])
    
    # Calculate True Positives (TP), False Positives (FP), and False Negatives (FN)
    TP = np.sum(precision_at_iou50 > 0)  # Number of true positives
    FP = np.sum(precision_at_iou50 == 0)  # Number of false positives
    FN = np.sum(recall_at_iou50 == 0)  # Number of false negatives
    
    # Calculate accuracy
    accuracy = TP / (TP + FP + FN + 1e-6)
    
    # Calculate F1 score
    if mean_precision + mean_recall > 0:
        f1_score = 2 * (mean_precision * mean_recall) / (mean_precision + mean_recall)
    else:
        f1_score = 0.0
    
    return accuracy, f1_score

# Define function to calculate metrics
def calculate_metrics(model, data_loader, device):
    model.eval()
    coco_gt = data_loader.dataset.coco
    coco_dt = []
    
    with torch.no_grad():
        progress_bar = tqdm(data_loader, desc="Validation", unit="batch")
        
        for images, targets in progress_bar:
            images = [img.to(device) for img in images]
            outputs = model(images)
            
            for i, output in enumerate(outputs):
                boxes = output['boxes'].cpu()
                scores = output['scores'].cpu() 
                labels = output['labels'].cpu()
                
                image_id = targets[i]['image_id'].item()
                
                for box, score, label in zip(boxes, scores, labels):
                    coco_dt.append({
                        'image_id': image_id,
                        'category_id': label.item(),
                        'bbox': [box[0].item(), box[1].item(), 
                                box[2].item() - box[0].item(),
                                box[3].item() - box[1].item()],
                        'score': score.item()
                    })
    
    if len(coco_dt) == 0:
        return 0.0, 0.0, 0.0, 0.0, 0.0
    
    coco_eval = COCOeval(coco_gt, coco_gt.loadRes(coco_dt), 'bbox')
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()
    
    # Extract metrics
    mAP = coco_eval.stats[0]  # mAP@IoU=0.50:0.95
    precision = coco_eval.stats[1]  # mAP@IoU=0.50
    recall = coco_eval.stats[8]  # AR@IoU=0.50:0.95
    
    # Calculate accuracy and F1 using the new method
    accuracy, f1_score = calculate_detection_metrics(coco_eval)
    
    return mAP, precision, recall, f1_score, accuracy

# Define label smoothing function
def smooth_labels(labels, num_classes, smoothing=0.1):
    confidence = 1.0 - smoothing
    label_shape = torch.Size((labels.size(0), num_classes))
    with torch.no_grad():
        smooth_labels = torch.full(size=label_shape, fill_value=smoothing / (num_classes - 1), device=labels.device)
        smooth_labels.scatter_(1, labels.data.unsqueeze(1), confidence)
    return smooth_labels

# Initialize GradScaler
scaler = GradScaler()

# Training loop
metrics_history = {
    "train_loss": [],
    "mAP": [],
    "precision": [],
    "recall": [],
    "f1_score": [],
    "accuracy": [],
}

best_map = 0
best_model_wts = copy.deepcopy(model.state_dict())
patience = 5
patience_counter = 0
accumulation_steps = 4

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    optimizer.zero_grad()

    progress_bar = tqdm(
        train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch"
    )

    for i, (images, targets) in enumerate(progress_bar):
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        with autocast(device_type="cuda" if torch.cuda.is_available() else "cpu"):
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            loss = losses / accumulation_steps

        # Apply label smoothing
        for target in targets:
            target["labels"] = smooth_labels(target["labels"], num_classes)

        scaler.scale(loss).backward()

        if (i + 1) % accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        train_loss += losses.item()
        progress_bar.set_postfix(loss=losses.item())

    scheduler.step()

    avg_loss = train_loss / len(train_loader)
    mAP, precision, recall, f1_score, accuracy = calculate_metrics(model, val_loader, device)

    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"Training Loss: {avg_loss:.4f}")
    print(f"mAP: {mAP:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1_score:.4f}")
    print(f"Accuracy: {accuracy:.4f}\n")

    metrics_history["train_loss"].append(avg_loss)
    metrics_history["mAP"].append(mAP)
    metrics_history["precision"].append(precision)
    metrics_history["recall"].append(recall)
    metrics_history["f1_score"].append(f1_score)
    metrics_history["accuracy"].append(accuracy)

    if mAP > best_map:
        best_map = mAP
        best_model_wts = copy.deepcopy(model.state_dict())
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print("Early stopping triggered")
        break

model.load_state_dict(best_model_wts)
torch.save(model.state_dict(), "retinanet_resnet50_best_model.pth")

In [None]:
# Final evaluation on test set
test_mAP, test_precision, test_recall, test_f1, test_accuracy = calculate_metrics(model, test_loader, device)

print("\nTest Set Results:")
print(f"mAP: {test_mAP:.4f}")
print(f"Precision: {test_precision:.4f}")
print(f"Recall: {test_recall:.4f}")
print(f"F1 Score: {test_f1:.4f}")
print(f"Accuracy: {test_accuracy:.4f}")

In [None]:
def calculate_confusion_matrix(model, data_loader, device, confidence_threshold=0.5, iou_threshold=0.5):
    model.eval()
    categories = data_loader.dataset.coco.cats
    num_classes = len(categories)
    
    # Initialize confusion matrix
    confusion_mat = np.zeros((num_classes, num_classes))
    
    with torch.no_grad():
        for images, targets in tqdm(data_loader, desc="Calculating Confusion Matrix"):
            images = [img.to(device) for img in images]
            
            # Get predictions
            outputs = model(images)
            
            # Process each image in the batch
            for img_idx, output in enumerate(outputs):
                # Get ground truth boxes and labels
                gt_boxes = targets[img_idx]['boxes'].cpu().numpy()
                gt_labels = targets[img_idx]['labels'].cpu().numpy()
                
                # Get predicted boxes, scores, and labels
                pred_boxes = output['boxes'].cpu().numpy()
                pred_scores = output['scores'].cpu().numpy()
                pred_labels = output['labels'].cpu().numpy()
                
                # Filter predictions by confidence threshold
                mask = pred_scores >= confidence_threshold
                pred_boxes = pred_boxes[mask]
                pred_labels = pred_labels[mask]
                
                # Calculate IoU between all predicted and ground truth boxes
                for gt_idx, gt_box in enumerate(gt_boxes):
                    gt_label = gt_labels[gt_idx]
                    
                    if len(pred_boxes) > 0:
                        # Calculate IoU for all predictions
                        ious = bbox_iou(gt_box, pred_boxes)
                        max_iou_idx = np.argmax(ious)
                        
                        if ious[max_iou_idx] >= iou_threshold:
                            # True Positive
                            pred_label = pred_labels[max_iou_idx]
                            confusion_mat[gt_label-1][pred_label-1] += 1
                        else:
                            # False Negative
                            confusion_mat[gt_label-1][-1] += 1
                    else:
                        # False Negative
                        confusion_mat[gt_label-1][-1] += 1
                
                # Count False Positives
                if len(pred_boxes) > 0:
                    for pred_label, pred_box in zip(pred_labels, pred_boxes):
                        if not np.any(bbox_iou(pred_box, gt_boxes) >= iou_threshold):
                            confusion_mat[-1][pred_label-1] += 1
    
    return confusion_mat

def bbox_iou(box1, box2):
    """
    Calculate IoU between box1 and box2 or arrays of boxes
    """
    if box2.ndim == 1:
        box2 = box2[np.newaxis, :]
    
    # Calculate intersection
    x1 = np.maximum(box1[0], box2[:, 0])
    y1 = np.maximum(box1[1], box2[:, 1])
    x2 = np.minimum(box1[2], box2[:, 2])
    y2 = np.minimum(box1[3], box2[:, 3])
    
    intersection = np.maximum(0, x2 - x1) * np.maximum(0, y2 - y1)
    
    # Calculate areas
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])
    
    # Calculate union
    union = box1_area + box2_area - intersection
    
    return intersection / (union + 1e-7)

def plot_confusion_matrix(confusion_mat, categories):
    """
    Plot confusion matrix with proper labels and styling
    """
    plt.figure(figsize=(12, 10))
    
    # Get category names
    category_names = [cat['name'] for cat in categories.values()]
    category_names.append('Background')  # Add background class
    
    # Create heatmap
    sns.heatmap(
        confusion_mat,
        annot=True,
        fmt='.0f',
        cmap='Blues',
        xticklabels=category_names,
        yticklabels=category_names
    )
    
    plt.title('Object Detection Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('Ground Truth')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

# Calculate and plot confusion matrix
def analyze_detection_performance(model, data_loader, device):
    # Calculate confusion matrix
    confusion_mat = calculate_confusion_matrix(
        model, 
        data_loader, 
        device,
        confidence_threshold=0.5,
        iou_threshold=0.5
    )
    
    # Plot confusion matrix
    plot_confusion_matrix(confusion_mat, data_loader.dataset.coco.cats)
    
    # Calculate per-class metrics
    num_classes = len(data_loader.dataset.coco.cats)
    
    print("\nPer-class Performance Metrics:")
    print("-----------------------------")
    
    for cat_id, cat_info in data_loader.dataset.coco.cats.items():
        idx = cat_id - 1  # Adjust for 0-based indexing
        
        # Calculate metrics
        tp = confusion_mat[idx][idx]
        fp = np.sum(confusion_mat[:, idx]) - tp
        fn = np.sum(confusion_mat[idx, :]) - tp
        
        precision = tp / (tp + fp + 1e-7)
        recall = tp / (tp + fn + 1e-7)
        f1 = 2 * (precision * recall) / (precision + recall + 1e-7)
        
        print(f"\nClass: {cat_info['name']}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"F1-Score: {f1:.4f}")

# Add this after your model evaluation code
print("\nAnalyzing detection performance...")
analyze_detection_performance(model, test_loader, device)

In [None]:
def plot_metrics(metrics_history):
    epochs = range(1, len(metrics_history["train_loss"]) + 1)

    plt.figure(figsize=(15, 10))

    # Plot training loss
    plt.subplot(2, 2, 1)
    plt.plot(epochs, metrics_history["train_loss"], "b-", label="Training Loss")
    plt.title("Training Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.text(-0.1, 1.1, "(a)", transform=plt.gca().transAxes, fontsize=16, fontweight='bold')

    # Plot mAP
    plt.subplot(2, 2, 2)
    plt.plot(epochs, metrics_history["mAP"], "r-", label="mAP")
    plt.title("Mean Average Precision")
    plt.xlabel("Epochs")
    plt.ylabel("mAP")
    plt.legend()
    plt.text(-0.1, 1.1, "(b)", transform=plt.gca().transAxes, fontsize=16, fontweight='bold')

    # Plot Precision and Recall
    plt.subplot(2, 2, 3)
    plt.plot(epochs, metrics_history["precision"], "g-", label="Precision")
    plt.plot(epochs, metrics_history["recall"], "y-", label="Recall")
    plt.title("Precision and Recall")
    plt.xlabel("Epochs")
    plt.ylabel("Score")
    plt.legend()
    plt.text(-0.1, 1.1, "(c)", transform=plt.gca().transAxes, fontsize=16, fontweight='bold')

    # Plot F1 Score
    plt.subplot(2, 2, 4)
    plt.plot(epochs, metrics_history["f1_score"], "m-", label="F1 Score")
    plt.title("F1 Score")
    plt.xlabel("Epochs")
    plt.ylabel("Score")
    plt.legend()
    plt.text(-0.1, 1.1, "(d)", transform=plt.gca().transAxes, fontsize=16, fontweight='bold')

    plt.tight_layout()
    plt.show()

# Plot the metrics
plot_metrics(metrics_history)

# Prediction Visualization

In [None]:
# Import necessary libraries
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
import os
import random
import matplotlib.patches as patches
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch.nn as nn
import torchvision.ops as ops
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights
from torchvision.models.detection.retinanet import RetinaNetClassificationHead
from functools import partial
import torch.nn as nn


# Define the CocoDataset class
class CocoDataset(torchvision.datasets.CocoDetection):
    def __init__(self, root, annFile, transform=None):
        super(CocoDataset, self).__init__(root, annFile)
        self.transform = transform

    def __getitem__(self, index):
        img, target = super(CocoDataset, self).__getitem__(index)
        image_id = self.ids[index]
        if self.transform is not None:
            transformed = self.transform(
                image=np.array(img),
                bboxes=[ann["bbox"] for ann in target],
                labels=[ann["category_id"] for ann in target],
            )
            img = transformed["image"]
            target = [
                {"bbox": bbox, "category_id": label, "image_id": image_id}
                for bbox, label in zip(transformed["bboxes"], transformed["labels"])
            ]
        else:
            for ann in target:
                ann["image_id"] = image_id
        return img, target


# Define Albumentations transformations
val_transform = A.Compose(
    [
        A.Resize(640, 640),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ],
    bbox_params=A.BboxParams(format="coco", label_fields=["labels"]),
)

# Create datasets
dataset_path = "/kaggle/input/oil-palm-bunch-3910"

test_dataset = CocoDataset(
    root=os.path.join(dataset_path, "test"),
    annFile=os.path.join(dataset_path, "test", "_annotations.coco.json"),
    transform=val_transform,
)


# Create data loaders
def collate_fn(batch):
    images = []
    targets = []

    for image, target in batch:
        images.append(image)

        boxes = []
        labels = []
        for ann in target:
            bbox = ann["bbox"]
            boxes.append([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]])
            labels.append(ann["category_id"])

        target_dict = {
            "boxes": torch.FloatTensor(boxes),
            "labels": torch.LongTensor(labels),
            "image_id": torch.tensor(
                [ann["image_id"] for ann in target][0] if target else 0
            ),
        }
        targets.append(target_dict)

    return images, targets


test_loader = DataLoader(
    test_dataset,
    batch_size=4,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
)

num_classes = len(test_dataset.coco.getCatIds()) + 1


# Define the model
def load_model(num_classes=num_classes):
    model = retinanet_resnet50_fpn_v2(
        weights=RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
    )
    num_anchors = model.head.classification_head.num_anchors
    classification_head = RetinaNetClassificationHead(
        in_channels=256,
        num_anchors=num_anchors,
        num_classes=num_classes,
        norm_layer=partial(torch.nn.GroupNorm, 32)
    )
    classification_head.cls_logits = nn.Sequential(
        classification_head.conv,
        nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.Dropout(p=0.5),
        classification_head.cls_logits
    )
    model.head.classification_head = classification_head
    return model


# Load the best model
model = load_model(num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(
    torch.load("retinanet_resnet50_best_model.pth", weights_only=True, map_location=device)
)
model.to(device)
model.eval()


# Function to denormalize images
def denormalize(image, mean, std):
    image = image.permute(1, 2, 0).cpu().numpy()
    image = (image * std + mean) * 255
    image = image.astype(np.uint8)
    return image


# Function to visualize predictions
def visualize_predictions(
    dataset, model, device, num_samples=5, confidence_threshold=0.5, iou_threshold=0.5
):
    fig, axs = plt.subplots(1, num_samples, figsize=(20, 5))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])

    category_ids = dataset.coco.getCatIds()
    selected_images = []

    for i in range(num_samples):
        category_id = category_ids[i % len(category_ids)]
        img_ids = dataset.coco.getImgIds(catIds=[category_id])
        img_id = random.choice(img_ids)
        selected_images.append(img_id)

    for i, img_id in enumerate(selected_images):
        img_info = dataset.coco.loadImgs(img_id)[0]
        ann_ids = dataset.coco.getAnnIds(imgIds=img_id)
        anns = dataset.coco.loadAnns(ann_ids)
        img_path = os.path.join(dataset.root, img_info['file_name'])
        img = plt.imread(img_path)
        
        image, target = dataset[dataset.ids.index(img_id)]
        image = image.to(device)

        with torch.no_grad():
            output = model([image])[0]

        # Apply NMS
        keep = ops.nms(output["boxes"], output["scores"], iou_threshold)
        output["boxes"] = output["boxes"][keep]
        output["scores"] = output["scores"][keep]
        output["labels"] = output["labels"][keep]

        image = denormalize(image, mean, std)
        axs[i].imshow(image)
        axs[i].axis("off")

        # Plot ground truth
        for ann in target:
            bbox = ann["bbox"]
            category_id = ann["category_id"]
            category_name = dataset.coco.cats[category_id]["name"]
            rect = patches.Rectangle(
                (bbox[0], bbox[1]),
                bbox[2],
                bbox[3],
                linewidth=2,
                edgecolor="g",
                facecolor="none",
                label="Ground Truth",
            )
            axs[i].add_patch(rect)
            axs[i].text(
                bbox[0],
                bbox[1] + bbox[3] + 10,
                category_name,
                color="green",
                fontsize=12,
                backgroundcolor="white",
            )

        # Plot predictions
        for box, score, label in zip(
            output["boxes"], output["scores"], output["labels"]
        ):
            if score >= confidence_threshold:
                bbox = box.cpu().numpy()
                category_name = dataset.coco.cats[label.item()]["name"]
                rect = patches.Rectangle(
                    (bbox[0], bbox[1]),
                    bbox[2] - bbox[0],
                    bbox[3] - bbox[1],
                    linewidth=2,
                    edgecolor="r",
                    facecolor="none",
                    label="Prediction",
                )
                axs[i].add_patch(rect)
                axs[i].text(
                    bbox[0],
                    bbox[1] - 10,
                    f"{category_name}: {score:.2f}",
                    color="red",
                    fontsize=12,
                    backgroundcolor="white",
                )

        # Add legend
        handles, labels = axs[i].get_legend_handles_labels()
        by_label = dict(zip(labels, handles))
        axs[i].legend(
            by_label.values(), by_label.keys(), loc="upper right", fontsize=10
        )

        # Add subplot tag
        axs[i].set_title(f"({chr(97 + i)})", fontsize=16)

    plt.tight_layout()
    plt.show()

# Visualize predictions with 5 random samples from the validation dataset with a confidence threshold of 0.5
visualize_predictions(
    test_dataset,
    model,
    device,
    num_samples=5,
    confidence_threshold=0.5,
    iou_threshold=0.5,
)