<a target="_blank" href="https://colab.research.google.com/github/Reslan-Tinawi/selva-box-tree-detection/blob/main/notebooks/06_deepforest_pretrained.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Import packages

In [None]:
# detect if running in colab
try:
    import google.colab

    ! pip install torchmetrics deepforest
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

import random
from pprint import pprint

import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import torch
import wandb
from datasets import load_from_disk
from deepforest import main
from PIL.TiffImagePlugin import TiffImageFile
from torch.utils.data import DataLoader, Dataset
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision import tv_tensors
from torchvision.transforms import v2 as T
from tqdm.notebook import tqdm

# --- A100 OPTIMIZATION: ENABLE TF32 ---
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [None]:
if IN_COLAB:
    from google.colab import drive
    drive.mount("/content/drive")
    BASE_PATH = "/content/drive/MyDrive/datasets/SelvaBox/saved/"
else:
    BASE_PATH = "../data/selvabox/"

In [None]:
def setup_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(42)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load data

In [None]:
hf_test_ds = load_from_disk(BASE_PATH + "test")

In [None]:
print(f"Number of test samples: {len(hf_test_ds)}")

# Utility functions

In [None]:
def df_to_dict(df):
    boxes = df[["xmin", "ymin", "xmax", "ymax"]].to_numpy()
    labels = df["label"].to_numpy()
    scores = df["score"].to_numpy()

    return {
        "boxes": torch.tensor(boxes, dtype=torch.float32),
        "labels": torch.tensor(labels, dtype=torch.int64),
        "scores": torch.tensor(scores, dtype=torch.float32),
    }

In [None]:
def plot_image(
    img, boxes, scores=None, labels=None, class_names=None, save_path=None, show=True
):
    """
    Plots bounding boxes on an image with optional scores and labels.

    Args:
        img (np.array | torch.Tensor): Input image. Shape [H, W, C] (numpy) or [C, H, W] (torch).
        boxes (np.array | torch.Tensor): Bounding boxes [N, 4] format (xmin, ymin, xmax, ymax).
        scores (np.array | torch.Tensor, optional): Confidence scores [N]. Defaults to None.
        labels (np.array | torch.Tensor, optional): Class indices [N]. Defaults to None.
        class_names (list, optional): List of class string names. Defaults to None.
        save_path (str, optional): Path to save the figure. Defaults to None.
        show (bool, optional): Whether to display the plot. Defaults to True.
    """

    # --- 1. Data Standardization ---
    # Convert PyTorch tensors to Numpy if necessary
    if isinstance(img, torch.Tensor):
        img = img.cpu().numpy()
        # If image is [C, H, W], transpose to [H, W, C] for Matplotlib
        if img.shape[0] < img.shape[2]:
            img = img.transpose(1, 2, 0)

    if isinstance(boxes, torch.Tensor):
        boxes = boxes.cpu().numpy()

    if isinstance(scores, torch.Tensor):
        scores = scores.cpu().numpy()

    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()

    # Normalize image range if it's float 0-1, mostly for display consistency
    # (Matplotlib handles 0-1 floats or 0-255 ints, but mixing is bad)
    if img.dtype == np.float32 or img.dtype == np.float64:
        img = np.clip(img, 0, 1)

    # --- 2. Setup Figure ---
    fig, ax = plt.subplots(1, figsize=(12, 9))
    ax.imshow(img)

    # --- 3. Color Setup ---
    # If no class names provided, default to a generic list
    if class_names is None:
        if labels is not None:
            max_label = int(np.max(labels))
            class_names = [f"Class {i}" for i in range(max_label + 1)]
        else:
            class_names = ["Object"]

    # Generate distinct colors for classes
    cmap = plt.get_cmap("tab20b")
    colors = [cmap(i) for i in np.linspace(0, 1, len(class_names))]

    # --- 4. Plotting Loop ---
    for i, box in enumerate(boxes):
        xmin, ymin, xmax, ymax = box

        # Determine Label
        if labels is not None:
            cls_id = int(labels[i])
        else:
            cls_id = 0  # Default to 0 if no labels provided

        color = colors[cls_id % len(colors)]
        class_name = (
            class_names[cls_id] if cls_id < len(class_names) else f"Class {cls_id}"
        )

        # Draw Rectangle
        width = xmax - xmin
        height = ymax - ymin
        rect = patches.Rectangle(
            (xmin, ymin), width, height, linewidth=2, edgecolor=color, facecolor="none"
        )
        ax.add_patch(rect)

        # Build Text String
        display_text = class_name
        if scores is not None:
            display_text += f" {int(100 * scores[i])}%"

        # Draw Text with background
        ax.text(
            xmin,
            ymin,
            display_text,
            color="white",
            fontsize=10,
            verticalalignment="top",
            bbox={
                "color": color,
                "pad": 2,
                "alpha": 0.8,
            },  # Added alpha for better visibility
        )

    plt.axis("off")  # Hide axes ticks

    if save_path:
        plt.savefig(save_path, bbox_inches="tight")

    if show:
        plt.show()
    else:
        plt.close()

# Hyper-parameters

In [None]:
CONFIG = {
    "project_name": "selva-box-tree-detection",  # WandB project name
    "name": "deepforest-pretrained",
    "num_classes": 2,  # Background + your classes (e.g., 1 class + 1 background = 2)
    "batch_size": 16,
    "num_workers": 4,
    "device": device,
    "model_name": "deepforest_pretrained",
}

In [None]:
wandb.init(
    project=CONFIG["project_name"],
    name=CONFIG["name"],
    config=CONFIG,
)

# Custom dataset

In [None]:
# inspired from: https://docs.pytorch.org/tutorials/intermediate/torchvision_tutorial.html
class SelvaBoxDataset(Dataset):
    def __init__(self, hf_dataset, n_classes=1, transforms=None):
        self.dataset = hf_dataset
        self.n_classes = n_classes
        self.transforms = transforms

    def __getitem__(self, index):
        sample = self.dataset[index]
        image: TiffImageFile = sample["image"]
        annotations_dict = sample["annotations"]

        if image.mode != "RGB":
            image = image.convert("RGB")

        # PIL returns (Width, Height)
        w, h = image.size

        image = tv_tensors.Image(image)

        # number of objects/trees in the image
        num_objs = len(annotations_dict["bbox"])

        target = {
            "boxes": tv_tensors.BoundingBoxes(
                data=annotations_dict["bbox"],
                format="XYWH",  # COCO format
                canvas_size=(h, w),
            ),
            "labels": torch.zeros(
                (num_objs,), dtype=torch.int64
            ),  # all trees have label 0
            "image_id": torch.tensor(
                index
            ),  # TODO: is this necessary? when moving data to GPU, it expects a tensor
            "area": torch.tensor(annotations_dict["area"], dtype=torch.float32),
            "iscrowd": torch.tensor(annotations_dict["iscrowd"], dtype=torch.int64),
        }

        if self.transforms:
            image, target = self.transforms(image, target)

        if target["boxes"].shape[0] == 0:
            target["boxes"] = torch.zeros((0, 4), dtype=torch.float32)
            target["labels"] = torch.zeros((0,), dtype=torch.int64)
            target["area"] = torch.zeros((0,), dtype=torch.float32)
            target["iscrowd"] = torch.zeros((0,), dtype=torch.int64)

        return image, target

    def __len__(self):
        return len(self.dataset)

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))

In [None]:
transforms = T.Compose(
    [
        T.ConvertBoundingBoxFormat(format="XYXY"),  # Convert COCO format to xyxy
        T.ToDtype(torch.float, scale=True),
        T.ToPureTensor(),
    ]
)

In [None]:
test_dataset = SelvaBoxDataset(hf_test_ds, transforms=transforms)

In [None]:
test_loader = DataLoader(
    test_dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=CONFIG["num_workers"],
    pin_memory=True,
    persistent_workers=True,
)

# Model definition

In [None]:
def get_model():
    model = main.deepforest()

    # Load a pretrained tree detection model from Hugging Face
    model.load_model(model_name="weecology/deepforest-tree", revision="main")

    return model

# Model evaluation

In [None]:
model = get_model()
model.to(device)

In [None]:
# Initialize the metric
metric = MeanAveragePrecision(
    box_format="xyxy",
    iou_type="bbox",
    max_detection_thresholds=[1, 100, 400],
    class_metrics=True,
)

# Assuming test_loader is defined
for images, targets in tqdm(test_loader, desc="Testing"):
    images = list(image.to(device) for image in images)

    # Forward pass
    images_batch = torch.stack(images).to(device)
    predictions = model.predict_batch(images_batch)

    # Move to CPU (torchmetrics handles CPU/GPU, but consistency is good)
    predictions = [df_to_dict(pred) for pred in predictions]

    # Update the metric with this batch
    # targets need to be a list of dicts on the same device as predictions
    # If targets are on GPU, move to CPU to match predictions
    targets_cpu = [{k: v.cpu() for k, v in t.items()} for t in targets]

    metric.update(predictions, targets_cpu)

    # Clear GPU cache to prevent OOM errors
    torch.cuda.empty_cache()

# Compute the final metrics over the whole dataset
results = metric.compute()

# Print results
print(f"mAP (IoU=0.50:0.95): {results['map']:.4f}")
print(f"mAP (IoU=0.50): {results['map_50']:.4f}")
print(f"mAP (IoU=0.75): {results['map_75']:.4f}")

pprint(results)

In [None]:
wandb.log(results)

In [None]:
wandb.finish()

# Visualize results

In [None]:
# visualize prediction and ground truth on some test images side by side

# Get a batch from the test set
test_iter = iter(test_loader)
images, targets = next(test_iter)

images = list(img.to(device) for img in images)

# Forward pass
images_batch = torch.stack(images).to(device)
predictions = model.predict_batch(images_batch)

predictions = [df_to_dict(pred) for pred in predictions]

images = [img.cpu() for img in images]
targets = [{k: v.cpu() for k, v in t.items()} for t in targets]

for i in range(len(images)):
    img = images[i]
    pred = predictions[i]
    target = targets[i]

    # keep only predictions with score > 0.5
    keep_idxs = pred["scores"] > 0.5
    pred["boxes"] = pred["boxes"][keep_idxs]
    pred["scores"] = pred["scores"][keep_idxs]
    pred["labels"] = pred["labels"][keep_idxs]

    print(f"Image {i + 1} Predictions:")
    plot_image(
        img,
        boxes=pred["boxes"],
        scores=pred["scores"],
        labels=pred["labels"],
        class_names=["tree"],
        show=True,
    )

    print(f"Image {i + 1} Ground Truth:")
    plot_image(
        img,
        boxes=target["boxes"],
        labels=target["labels"],
        class_names=["tree"],
        show=True,
    )