# Import packages

In [None]:
# detect if running in colab
try:
    import google.colab

    ! pip install torchmetrics
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

import gc
import math
import random
import sys
import time
from pprint import pprint

import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import wandb
from datasets import load_from_disk
from PIL.TiffImagePlugin import TiffImageFile
from torch.utils.data import DataLoader, Dataset
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision import tv_tensors
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.transforms import v2 as T
from tqdm.notebook import tqdm

# --- A100 OPTIMIZATION: ENABLE TF32 ---
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

Collecting torchmetrics
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m62.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.15.2 torchmetrics-1.8.2


  self.setter(val)


In [None]:
if IN_COLAB:
    from google.colab import drive
    drive.mount("/content/drive")
    BASE_PATH = "/content/drive/MyDrive/datasets/SelvaBox/saved/"
else:
    BASE_PATH = "../data/selvabox/"

Mounted at /content/drive


In [None]:
def setup_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(42)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


# Load data

In [None]:
hf_train_ds = load_from_disk(BASE_PATH + "train")
hf_val_ds = load_from_disk(BASE_PATH + "validation")
hf_test_ds = load_from_disk(BASE_PATH + "test")

Loading dataset from disk:   0%|          | 0/34 [00:00<?, ?it/s]

Loading dataset from disk:   0%|          | 0/24 [00:00<?, ?it/s]

In [None]:
print(f"Number of training samples: {len(hf_train_ds)}")
print(f"Number of validation samples: {len(hf_val_ds)}")
print(f"Number of test samples: {len(hf_test_ds)}")

Number of training samples: 585
Number of validation samples: 387
Number of test samples: 1477


# Utility functions

In [None]:
class EarlyStopper:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = -float(
            "inf"
        )  # Looking for max mAP, so init with negative inf
        self.early_stop = False

    def __call__(self, current_score):
        # Logic for maximizing metric (mAP)
        if current_score > (self.best_score + self.min_delta):
            self.best_score = current_score
            self.counter = 0
        else:
            self.counter += 1
            print(
                f"   --> EarlyStopping counter: {self.counter} out of {self.patience}"
            )
            if self.counter >= self.patience:
                self.early_stop = True

In [None]:
def plot_image(
    img, boxes, scores=None, labels=None, class_names=None, save_path=None, show=True
):
    """
    Plots bounding boxes on an image with optional scores and labels.

    Args:
        img (np.array | torch.Tensor): Input image. Shape [H, W, C] (numpy) or [C, H, W] (torch).
        boxes (np.array | torch.Tensor): Bounding boxes [N, 4] format (xmin, ymin, xmax, ymax).
        scores (np.array | torch.Tensor, optional): Confidence scores [N]. Defaults to None.
        labels (np.array | torch.Tensor, optional): Class indices [N]. Defaults to None.
        class_names (list, optional): List of class string names. Defaults to None.
        save_path (str, optional): Path to save the figure. Defaults to None.
        show (bool, optional): Whether to display the plot. Defaults to True.
    """

    # --- 1. Data Standardization ---
    # Convert PyTorch tensors to Numpy if necessary
    if isinstance(img, torch.Tensor):
        img = img.cpu().numpy()
        # If image is [C, H, W], transpose to [H, W, C] for Matplotlib
        if img.shape[0] < img.shape[2]:
            img = img.transpose(1, 2, 0)

    if isinstance(boxes, torch.Tensor):
        boxes = boxes.cpu().numpy()

    if isinstance(scores, torch.Tensor):
        scores = scores.cpu().numpy()

    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()

    # Normalize image range if it's float 0-1, mostly for display consistency
    # (Matplotlib handles 0-1 floats or 0-255 ints, but mixing is bad)
    if img.dtype == np.float32 or img.dtype == np.float64:
        img = np.clip(img, 0, 1)

    # --- 2. Setup Figure ---
    fig, ax = plt.subplots(1, figsize=(12, 9))
    ax.imshow(img)

    # --- 3. Color Setup ---
    # If no class names provided, default to a generic list
    if class_names is None:
        if labels is not None:
            max_label = int(np.max(labels))
            class_names = [f"Class {i}" for i in range(max_label + 1)]
        else:
            class_names = ["Object"]

    # Generate distinct colors for classes
    cmap = plt.get_cmap("tab20b")
    colors = [cmap(i) for i in np.linspace(0, 1, len(class_names))]

    # --- 4. Plotting Loop ---
    for i, box in enumerate(boxes):
        xmin, ymin, xmax, ymax = box

        # Determine Label
        if labels is not None:
            cls_id = int(labels[i])
        else:
            cls_id = 0  # Default to 0 if no labels provided

        color = colors[cls_id % len(colors)]
        class_name = (
            class_names[cls_id] if cls_id < len(class_names) else f"Class {cls_id}"
        )

        # Draw Rectangle
        width = xmax - xmin
        height = ymax - ymin
        rect = patches.Rectangle(
            (xmin, ymin), width, height, linewidth=2, edgecolor=color, facecolor="none"
        )
        ax.add_patch(rect)

        # Build Text String
        display_text = class_name
        if scores is not None:
            display_text += f" {int(100 * scores[i])}%"

        # Draw Text with background
        ax.text(
            xmin,
            ymin,
            display_text,
            color="white",
            fontsize=10,
            verticalalignment="top",
            bbox={
                "color": color,
                "pad": 2,
                "alpha": 0.8,
            },  # Added alpha for better visibility
        )

    plt.axis("off")  # Hide axes ticks

    if save_path:
        plt.savefig(save_path, bbox_inches="tight")

    if show:
        plt.show()
    else:
        plt.close()

# Hyper-parameters

In [None]:
CONFIG = {
    "project_name": "selva-box-tree-detection",  # WandB project name
    "name": "vanilla-fasterrcnn-experiment",
    "num_classes": 2,  # Background + your classes (e.g., 1 class + 1 background = 2)
    "batch_size": 16,
    "num_workers": 4,
    "num_epochs": 20,
    "learning_rate": 0.005,
    "momentum": 0.9,
    "weight_decay": 0.0005,
    "step_size": 3,  # Scheduler step size
    "gamma": 0.1,  # Scheduler gamma
    "patience": 5,  # Early stopping patience
    "device": device,
    "model_name": "fasterrcnn_resnet50_fpn",
}

In [30]:
wandb.init(
    project=CONFIG["project_name"],
    name=CONFIG["name"],
    config=CONFIG,
)

# Custom dataset

In [None]:
# inspired from: https://docs.pytorch.org/tutorials/intermediate/torchvision_tutorial.html
class SelvaBoxDataset(Dataset):
    def __init__(self, hf_dataset, n_classes=1, transforms=None):
        self.dataset = hf_dataset
        self.n_classes = n_classes
        self.transforms = transforms

    def __getitem__(self, index):
        sample = self.dataset[index]
        image: TiffImageFile = sample["image"]
        annotations_dict = sample["annotations"]

        if image.mode != "RGB":
            image = image.convert("RGB")

        # PIL returns (Width, Height)
        w, h = image.size

        image = tv_tensors.Image(image)

        # number of objects/trees in the image
        num_objs = len(annotations_dict["bbox"])

        target = {
            "boxes": tv_tensors.BoundingBoxes(
                data=annotations_dict["bbox"],
                format="XYWH",  # COCO format
                canvas_size=(h, w),
            ),
            "labels": torch.ones((num_objs,), dtype=torch.int64),
            "image_id": torch.tensor(
                index
            ),  # TODO: is this necessary? when moving data to GPU, it expects a tensor
            "area": torch.tensor(annotations_dict["area"], dtype=torch.float32),
            "iscrowd": torch.tensor(annotations_dict["iscrowd"], dtype=torch.int64),
        }

        if self.transforms:
            image, target = self.transforms(image, target)

        return image, target

    def __len__(self):
        return len(self.dataset)

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))

In [None]:
transforms = T.Compose(
    [
        T.ConvertBoundingBoxFormat(format="XYXY"),  # Convert COCO format to xyxy
        T.ToDtype(torch.float, scale=True),
        T.ToPureTensor(),
    ]
)

In [None]:
train_dataset = SelvaBoxDataset(hf_train_ds, transforms=transforms)
val_dataset = SelvaBoxDataset(hf_val_ds, transforms=transforms)
test_dataset = SelvaBoxDataset(hf_test_ds, transforms=transforms)

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=CONFIG["num_workers"],
    pin_memory=True,
    persistent_workers=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=CONFIG["num_workers"],
    pin_memory=True,
    persistent_workers=True,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=CONFIG["num_workers"],
    pin_memory=True,
    persistent_workers=True,
)

# Model definition

In [None]:
def get_model(num_classes):
    # Load a model pre-trained on COCO
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")

    # Replace the classifier with a new one, that has num_classes
    in_features = model.roi_heads.box_predictor.cls_score.in_features

    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return model

# Model training

In [None]:
def train_one_epoch(model, optimizer, data_loader, device, epoch):
    model.train()

    running_loss = 0.0
    running_cls_loss = 0.0
    running_box_reg_loss = 0.0

    for i, (images, targets) in tqdm(
        enumerate(data_loader), total=len(data_loader), desc=f"Training Epoch {epoch}"
    ):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()

        # Forward pass
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        loss_value = losses.item()

        # Specific losses for detailed logging
        cls_loss = loss_dict["loss_classifier"].item()
        box_reg_loss = loss_dict["loss_box_reg"].item()

        if not math.isfinite(loss_value):
            print(f"Loss is {loss_value}, stopping training")
            print(
                f"Loss Dict: {loss_dict}"
            )  # Added print to see which specific loss failed
            sys.exit(1)

        losses.backward()

        # gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        running_loss += loss_value
        running_cls_loss += cls_loss
        running_box_reg_loss += box_reg_loss

        if i % 10 == 0:
            wandb.log(
                {
                    "train/batch_loss": loss_value,
                    "train/batch_cls_loss": cls_loss,
                    "train/batch_box_loss": box_reg_loss,
                }
            )

    epoch_loss = running_loss / len(data_loader)
    return epoch_loss


@torch.no_grad()
def evaluate_map(model, data_loader, device):
    model.eval()
    metric = MeanAveragePrecision(box_format="xyxy", iou_type="bbox")

    for images, targets in tqdm(data_loader, desc="Validating"):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        predictions = model(images)

        predictions = [{k: v.cpu() for k, v in p.items()} for p in predictions]
        targets_cpu = [{k: v.cpu() for k, v in t.items()} for t in targets]

        metric.update(predictions, targets_cpu)

    results = metric.compute()
    return results

In [None]:
model = get_model(CONFIG["num_classes"])
model.to(device)

Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth


100%|██████████| 160M/160M [00:00<00:00, 234MB/s]


FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(

In [None]:
# 4. Optimizer & Scheduler
params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.SGD(
    params,
    lr=CONFIG["learning_rate"],
    momentum=CONFIG["momentum"],
    weight_decay=CONFIG["weight_decay"],
)

lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=CONFIG["step_size"], gamma=CONFIG["gamma"]
)

early_stopper = EarlyStopper(patience=CONFIG["patience"])

In [None]:
print("Starting training...")

best_map = 0.0

# Memory tracking
torch.cuda.reset_peak_memory_stats()
start_train_time = time.time()

for epoch in tqdm(range(CONFIG["num_epochs"]), desc="Overall Training Progress"):
    epoch_start = time.time()

    # --- Train ---
    avg_train_loss = train_one_epoch(
        model, optimizer, train_loader, device, epoch
    )

    # --- Update Learning Rate ---
    lr_scheduler.step()
    curr_lr = optimizer.param_groups[0]["lr"]

    # --- Validation (Loss Proxy) ---
    val_metrics = evaluate_map(model, val_loader, device)
    val_map_50 = val_metrics["map_50"].item()
    val_map = val_metrics["map"].item()

    epoch_end = time.time()
    epoch_duration = (epoch_end - epoch_start) / 60
    peak_mem = torch.cuda.max_memory_allocated() / 1024 / 1024

    print(
        f"Epoch [{epoch + 1}/{CONFIG['num_epochs']}] "
        f"Train Loss: {avg_train_loss:.4f} | "
        f"Val mAP: {val_map:.4f} | "
        f"Val mAP_50: {val_map_50:.4f} | "
        f"Time: {epoch_duration:.1f}m | "
        f"Peak Mem: {peak_mem:.0f} MB"
    )

    # --- Logging ---
    wandb.log(
        {
            "epoch": epoch + 1,
            "train/epoch_loss": avg_train_loss,
            "val/mAP": val_map,
            "val/mAP_50": val_map_50,
            "learning_rate": curr_lr,
            "system/peak_mem_mb": peak_mem,
        }
    )

    # --- Save Best Model ---
    if val_map > best_map:
        best_map = val_map
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "map": best_map,
                "config": CONFIG,
            },
            "baseline_faster_r_cnn.pth",
        )
        print(f"--> New Best Model Saved (mAP: {best_map:.4f})")

    # 6. Early Stopping
    early_stopper(val_map)
    if early_stopper.early_stop:
        print("--> Early stopping triggered.")
        break

    # 7. Flush Memory
    del avg_train_loss, val_metrics
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()


# wandb.finish()
print("Training complete.")

Starting training...


Overall Training Progress:   0%|          | 0/20 [00:00<?, ?it/s]

Training Epoch 0:   0%|          | 0/37 [00:00<?, ?it/s]

Validating:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch [1/20] Train Loss: 2.7365 | Val mAP: 0.0457 | Val mAP_50: 0.1302 | Time: 3.7m | Peak Mem: 19347 MB
--> New Best Model Saved (mAP: 0.0457)


Training Epoch 1:   0%|          | 0/37 [00:00<?, ?it/s]

Validating:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch [2/20] Train Loss: 1.7214 | Val mAP: 0.0662 | Val mAP_50: 0.1610 | Time: 2.7m | Peak Mem: 19315 MB
--> New Best Model Saved (mAP: 0.0662)


Training Epoch 2:   0%|          | 0/37 [00:00<?, ?it/s]

Validating:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch [3/20] Train Loss: 1.6033 | Val mAP: 0.0950 | Val mAP_50: 0.2135 | Time: 2.7m | Peak Mem: 19091 MB
--> New Best Model Saved (mAP: 0.0950)


Training Epoch 3:   0%|          | 0/37 [00:00<?, ?it/s]

Validating:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch [4/20] Train Loss: 1.5566 | Val mAP: 0.0886 | Val mAP_50: 0.2011 | Time: 2.7m | Peak Mem: 19597 MB
   --> EarlyStopping counter: 1 out of 5


Training Epoch 4:   0%|          | 0/37 [00:00<?, ?it/s]

Validating:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch [5/20] Train Loss: 1.5476 | Val mAP: 0.0952 | Val mAP_50: 0.2142 | Time: 2.7m | Peak Mem: 19631 MB
--> New Best Model Saved (mAP: 0.0952)


Training Epoch 5:   0%|          | 0/37 [00:00<?, ?it/s]

Validating:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch [6/20] Train Loss: 1.5414 | Val mAP: 0.0929 | Val mAP_50: 0.2070 | Time: 2.7m | Peak Mem: 19271 MB
   --> EarlyStopping counter: 1 out of 5


Training Epoch 6:   0%|          | 0/37 [00:00<?, ?it/s]

Validating:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch [7/20] Train Loss: 1.5412 | Val mAP: 0.0955 | Val mAP_50: 0.2137 | Time: 2.7m | Peak Mem: 19573 MB
--> New Best Model Saved (mAP: 0.0955)


Training Epoch 7:   0%|          | 0/37 [00:00<?, ?it/s]

Validating:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch [8/20] Train Loss: 1.5380 | Val mAP: 0.0962 | Val mAP_50: 0.2150 | Time: 2.7m | Peak Mem: 19409 MB
--> New Best Model Saved (mAP: 0.0962)


Training Epoch 8:   0%|          | 0/37 [00:00<?, ?it/s]

Validating:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch [9/20] Train Loss: 1.5396 | Val mAP: 0.0963 | Val mAP_50: 0.2154 | Time: 2.7m | Peak Mem: 19298 MB
--> New Best Model Saved (mAP: 0.0963)


Training Epoch 9:   0%|          | 0/37 [00:00<?, ?it/s]

Validating:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch [10/20] Train Loss: 1.5398 | Val mAP: 0.0964 | Val mAP_50: 0.2156 | Time: 2.7m | Peak Mem: 19482 MB
--> New Best Model Saved (mAP: 0.0964)


Training Epoch 10:   0%|          | 0/37 [00:00<?, ?it/s]

Validating:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch [11/20] Train Loss: 1.5396 | Val mAP: 0.0964 | Val mAP_50: 0.2157 | Time: 2.7m | Peak Mem: 19531 MB
--> New Best Model Saved (mAP: 0.0964)


Training Epoch 11:   0%|          | 0/37 [00:00<?, ?it/s]

Validating:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch [12/20] Train Loss: 1.5392 | Val mAP: 0.0965 | Val mAP_50: 0.2157 | Time: 2.7m | Peak Mem: 19576 MB
--> New Best Model Saved (mAP: 0.0965)


Training Epoch 12:   0%|          | 0/37 [00:00<?, ?it/s]

Validating:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch [13/20] Train Loss: 1.5376 | Val mAP: 0.0965 | Val mAP_50: 0.2157 | Time: 2.7m | Peak Mem: 19090 MB
--> New Best Model Saved (mAP: 0.0965)


Training Epoch 13:   0%|          | 0/37 [00:00<?, ?it/s]

Validating:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch [14/20] Train Loss: 1.5340 | Val mAP: 0.0964 | Val mAP_50: 0.2157 | Time: 2.7m | Peak Mem: 19538 MB
   --> EarlyStopping counter: 1 out of 5


Training Epoch 14:   0%|          | 0/37 [00:00<?, ?it/s]

Validating:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch [15/20] Train Loss: 1.5377 | Val mAP: 0.0965 | Val mAP_50: 0.2158 | Time: 2.7m | Peak Mem: 19149 MB
--> New Best Model Saved (mAP: 0.0965)


Training Epoch 15:   0%|          | 0/37 [00:00<?, ?it/s]

Validating:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch [16/20] Train Loss: 1.5406 | Val mAP: 0.0965 | Val mAP_50: 0.2157 | Time: 2.7m | Peak Mem: 19444 MB
   --> EarlyStopping counter: 1 out of 5


Training Epoch 16:   0%|          | 0/37 [00:00<?, ?it/s]

Validating:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch [17/20] Train Loss: 1.5377 | Val mAP: 0.0965 | Val mAP_50: 0.2158 | Time: 2.7m | Peak Mem: 19090 MB
--> New Best Model Saved (mAP: 0.0965)


Training Epoch 17:   0%|          | 0/37 [00:00<?, ?it/s]

Validating:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch [18/20] Train Loss: 1.5372 | Val mAP: 0.0965 | Val mAP_50: 0.2158 | Time: 2.7m | Peak Mem: 19294 MB
   --> EarlyStopping counter: 1 out of 5


Training Epoch 18:   0%|          | 0/37 [00:00<?, ?it/s]

Validating:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch [19/20] Train Loss: 1.5395 | Val mAP: 0.0965 | Val mAP_50: 0.2157 | Time: 2.7m | Peak Mem: 19221 MB
   --> EarlyStopping counter: 2 out of 5


Training Epoch 19:   0%|          | 0/37 [00:00<?, ?it/s]

Validating:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch [20/20] Train Loss: 1.5388 | Val mAP: 0.0965 | Val mAP_50: 0.2157 | Time: 2.7m | Peak Mem: 18973 MB
   --> EarlyStopping counter: 3 out of 5


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
learning_rate,██▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
system/peak_mem_mb,▅▅▂██▄▇▆▄▆▇▇▂▇▃▆▂▄▄▁
train/batch_box_loss,█▄▄▇▄▄▂▅▃▃▃▃▁▄▃▃▄▄▃▂▃▃▂▃▂▃▄▄▃▄▄▄▄▂▂▃▃▃▄▂
train/batch_cls_loss,██▆▆▃▂▂▂▁▂▁▂▁▁▂▂▂▁▁▂▁▁▂▂▂▂▂▂▂▂▁▂▂▁▂▂▁▂▂▁
train/batch_loss,█▄▄▄▂▂▂▁▂▂▂▂▁▁▂▂▁▂▁▂▁▁▁▂▂▁▂▁▂▂▁▂▁▂▂▁▁▂▂▂
train/epoch_loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/mAP,▁▄█▇████████████████
val/mAP_50,▁▄█▇█▇██████████████

0,1
epoch,20.0
learning_rate,0.0
system/peak_mem_mb,18973.31445
train/batch_box_loss,0.40632
train/batch_cls_loss,0.3635
train/batch_loss,1.53133
train/epoch_loss,1.53881
val/mAP,0.09647
val/mAP_50,0.21575


Training complete.


# Model evaluation

In [None]:
model = get_model(CONFIG["num_classes"])

In [None]:
# load best model
checkpoint = torch.load("baseline_faster_r_cnn.pth")

model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)

FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(

In [None]:
# Initialize the metric
metric = MeanAveragePrecision(box_format="xyxy", iou_type="bbox", class_metrics=True)

model.eval()

# Assuming test_loader is defined
for images, targets in tqdm(test_loader, desc="Testing"):
    images = list(image.to(device) for image in images)

    # Forward pass
    with torch.no_grad():
        predictions = model(images)

    # Move to CPU (torchmetrics handles CPU/GPU, but consistency is good)
    predictions = [{k: v.cpu() for k, v in p.items()} for p in predictions]

    # Update the metric with this batch
    # targets need to be a list of dicts on the same device as predictions
    # If targets are on GPU, move to CPU to match predictions
    targets_cpu = [{k: v.cpu() for k, v in t.items()} for t in targets]

    metric.update(predictions, targets_cpu)

    # Clear GPU cache to prevent OOM errors
    torch.cuda.empty_cache()

# Compute the final metrics over the whole dataset
results = metric.compute()

# Print results
print(f"mAP (IoU=0.50:0.95): {results['map']:.4f}")
print(f"mAP (IoU=0.50): {results['map_50']:.4f}")
print(f"mAP (IoU=0.75): {results['map_75']:.4f}")

# Since it is single class, you can also look at 'map_per_class'
# to ensure it matches the overall map
pprint(results)

Testing:   0%|          | 0/93 [00:00<?, ?it/s]

mAP (IoU=0.50:0.95): 0.0928
mAP (IoU=0.50): 0.2107
mAP (IoU=0.75): 0.0680
{'classes': tensor(1, dtype=torch.int32),
 'map': tensor(0.0928),
 'map_50': tensor(0.2107),
 'map_75': tensor(0.0680),
 'map_large': tensor(0.1295),
 'map_medium': tensor(0.0334),
 'map_per_class': tensor(0.0928),
 'map_small': tensor(0.0035),
 'mar_1': tensor(0.0054),
 'mar_10': tensor(0.0416),
 'mar_100': tensor(0.1565),
 'mar_100_per_class': tensor(0.1565),
 'mar_large': tensor(0.1894),
 'mar_medium': tensor(0.1101),
 'mar_small': tensor(0.0024)}


In [31]:
wandb.log(results)

In [32]:
wandb.finish()

0,1
classes,▁
map,▁
map_50,▁
map_75,▁
map_large,▁
map_medium,▁
map_per_class,▁
map_small,▁
mar_1,▁
mar_10,▁

0,1
classes,1
map,0.09281
map_50,0.21065
map_75,0.06797
map_large,0.12947
map_medium,0.03342
map_per_class,0.09281
map_small,0.00348
mar_1,0.00543
mar_10,0.04162


# Visualize results

In [None]:
# visualize prediction and ground truth on some test images side by side
model.eval()
for images, targets in tqdm(test_loader, desc="Visualizing Predictions"):
    images = list(image.to(device) for image in images)

    # Forward pass
    with torch.no_grad():
        predictions = model(images)

    for i in range(len(images)):
        img = images[i].cpu()
        gt_boxes = targets[i]["boxes"].cpu()
        pred_boxes = predictions[i]["boxes"].cpu()
        pred_scores = predictions[i]["scores"].cpu()

        # Filter predictions with a threshold (e.g., 0.5)
        threshold = 0.5
        keep_idxs = pred_scores >= threshold
        pred_boxes = pred_boxes[keep_idxs]
        pred_scores = pred_scores[keep_idxs]

        # Plot ground truth
        print("Ground Truth:")
        plot_image(
            img,
            boxes=gt_boxes,
            class_names=["Tree"],
            show=True,
        )

        # Plot predictions
        print("Predictions:")
        plot_image(
            img,
            boxes=pred_boxes,
            scores=pred_scores,
            class_names=["Tree"],
            show=True,
        )

    break  # Remove this break to visualize more batches