# Imports


In [None]:
import os
import json
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
from torchvision import transforms
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
from torchmetrics.detection.mean_ap import MeanAveragePrecision
import torchvision.transforms as T
from tqdm import tqdm
import glob
from pathlib import Path
import random
import shutil
import copy
#from kaggle_secrets import UserSecretsClient
from PIL import Image, ImageDraw, ImageFont

VERBOSE = True

# Helper Functions

In [None]:
class InflammatoryCellsDataset(Dataset):
    def __init__(self, root_dir, annotations_dir, splits, transforms=None):
        self.root_dir = Path(root_dir)
        self.annotations_dir = Path(annotations_dir)
        self.splits = splits
        self.transforms = transforms

        self.image_paths = []
        self.annotation_paths = []

        for split in splits:
            image_dir = self.root_dir / split / "images"
            ann_dir = self.annotations_dir / split / "annotations"

            for image_path in image_dir.glob("*.png"):
                ann_path = ann_dir / image_path.name.replace(".png", ".json")
                if ann_path.exists():
                    # Load the annotation and check for non-empty bbox list
                    with open(ann_path) as f:
                        ann_data = json.load(f)
                    if not ann_data:  # skip empty annotation files
                        continue
                        
                    self.image_paths.append(image_path)
                    self.annotation_paths.append(ann_path)

        assert len(self.image_paths) == len(self.annotation_paths), "Mismatch in images and annotations"

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image
        img_path = self.image_paths[idx]
        ann_path = self.annotation_paths[idx]
        image = Image.open(img_path).convert("RGB")

        # Load annotations
        with open(ann_path) as f:
            ann_data = json.load(f)

        boxes = []
        labels = []

        for item in ann_data:
            boxes.append(item["bbox"])
            labels.append(1)  # Binary label: 1 = inflammatory cell

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)

        target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": torch.tensor([idx])
        }

        if self.transforms:
            image = self.transforms(image)

        return image, target

def collate_fn(batch):
    return tuple(zip(*batch))


def get_transform(train):
    transforms = []
    transforms.append(T.PILToTensor())
    transforms.append(T.ConvertImageDtype(torch.float32))
    if train:
        # add data augmentation
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)


def get_model(num_classes=2):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

# subsample train, val, and test datasets for given fraction
def subsample_dataset(dataset, fraction):
    indices = torch.randperm(len(dataset))[:int(len(dataset) * fraction)]
    return torch.utils.data.Subset(dataset, indices)

In [None]:
def box_iou(boxes1, boxes2):
    """
    Compute IoU between two sets of boxes.
    boxes1, boxes2: [N,4] tensors (xmin, ymin, xmax, ymax)
    Returns IoU matrix [N, M]
    """
    area1 = (boxes1[:, 2] - boxes1[:, 0]).clamp(min=0) * (boxes1[:, 3] - boxes1[:, 1]).clamp(min=0)
    area2 = (boxes2[:, 2] - boxes2[:, 0]).clamp(min=0) * (boxes2[:, 3] - boxes2[:, 1]).clamp(min=0)

    lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
    rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]

    wh = (rb - lt).clamp(min=0)  # [N,M,2]
    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]

    union = area1[:, None] + area2 - inter
    iou = inter / union
    return iou

def evaluate_precision_recall(outputs, targets, iou_threshold=0.5):
    """
    Compute TP, FP, FN for a batch of predictions and targets.
    outputs: list of dicts with keys ['boxes', 'labels', 'scores']
    targets: list of dicts with keys ['boxes', 'labels']
    """

    TP = 0
    FP = 0
    FN = 0

    for preds, targs in zip(outputs, targets):
        pred_boxes = preds['boxes']
        pred_scores = preds['scores']
        pred_labels = preds['labels']

        true_boxes = targs['boxes']
        true_labels = targs['labels']

        if len(pred_boxes) == 0:
            # No predictions, all true boxes are false negatives
            FN += len(true_boxes)
            continue

        if len(true_boxes) == 0:
            # No ground truth, all preds are false positives
            FP += len(pred_boxes)
            continue

        # Filter predictions by score threshold (optional, e.g., 0.5)
        score_thresh = 0.5
        keep = pred_scores >= score_thresh
        pred_boxes = pred_boxes[keep]
        pred_labels = pred_labels[keep]

        if len(pred_boxes) == 0:
            FN += len(true_boxes)
            continue

        # Compute IoU matrix between preds and true boxes
        ious = box_iou(pred_boxes, true_boxes)  # [num_pred, num_true]

        # Match preds to true boxes by IoU > threshold and label match
        matched_gt = set()
        matched_pred = set()

        for pred_idx in range(ious.shape[0]):
            for gt_idx in range(ious.shape[1]):
                if ious[pred_idx, gt_idx] >= iou_threshold and pred_labels[pred_idx] == true_labels[gt_idx]:
                    if gt_idx not in matched_gt and pred_idx not in matched_pred:
                        matched_gt.add(gt_idx)
                        matched_pred.add(pred_idx)

        TP += len(matched_pred)
        FP += len(pred_boxes) - len(matched_pred)
        FN += len(true_boxes) - len(matched_gt)

    precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
    recall = TP / (TP + FN) if (TP + FN) > 0 else 0.0
    return precision, recall

# Data prep

In [None]:
root_dir = "/kaggle/input/patches-with-annotations/patches_with_annotations"
test_dir = "/kaggle/input/patches/patches_newest"
annotations_dir = "/kaggle/input/annotations/json_mm"

# Prepare combined train splits
train_splits = ['pas-original', 'pas-diagnostic']
full_train = InflammatoryCellsDataset(root_dir, root_dir, splits=train_splits, transforms=get_transform(True))
# 80/20 train/val split
train_size = int(0.8 * len(full_train))
val_size = len(full_train) - train_size
train_ds, val_ds = random_split(full_train, [train_size, val_size])

# PAS-CPG test set
test_ds = InflammatoryCellsDataset(root_dir, annotations_dir, splits=['cpg_test'], transforms=get_transform(False))

In [None]:
fraction = 1  # 100% of the dataset
train_ds = subsample_dataset(train_ds, fraction)  # 10% of training data
val_ds = subsample_dataset(val_ds, fraction)    # 10% of validation data
test_ds = subsample_dataset(test_ds, fraction)  # 10% of test data

# DataLoaders
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds, batch_size=4, shuffle=False, collate_fn=collate_fn)
test_loader  = DataLoader(test_ds, batch_size=4, shuffle=False, collate_fn=collate_fn)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = get_model(num_classes=2)
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

if VERBOSE:
    print(f"Training on {len(train_ds)} samples (80% split from: {train_splits}")
    print(f"Validating on {len(val_ds)} samples (20% split from: {train_splits})")
    #print(f"Evaluating on {len(test_ds)} samples from split: pas-cpg")

# Model Training

In [None]:
def train_model(model, optimizer, train_loader, val_loader, device, num_epochs=10, lr_scheduler=None):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_map = 0.0

    model.to(device)

    # Record training history
    history = {
        "train_loss": [],
        "val_precision": [],
        "val_recall": []
    }

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print("-" * 20)

        # ---------- Training ----------
        model.train()
        total_train_loss = 0.0

        for images, targets in tqdm(train_loader, desc="Training"):
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            loss_dict = model(images, targets)
            total_loss = sum(loss for loss in loss_dict.values())

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            total_train_loss += total_loss.item()

        avg_train_loss = total_train_loss / len(train_loader)
        history["train_loss"].append(avg_train_loss)
        print(f"Train Loss: {avg_train_loss:.4f}")

        if lr_scheduler:
            lr_scheduler.step()

        # ---------- Validation ----------
        model.eval()
        val_precision = 0.0
        val_recall = 0.0
        num_batches = 0
        
        with torch.no_grad():
            for images, targets in val_loader:
                images = list(img.to(device) for img in images)
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
                outputs = model(images)  # list of dicts
        
                precision, recall = evaluate_precision_recall(outputs, targets)
                val_precision += precision
                val_recall += recall
                num_batches += 1
        
        avg_precision = val_precision / num_batches
        avg_recall = val_recall / num_batches
        history["val_precision"].append(avg_precision)
        history["val_recall"].append(avg_recall)
        
        print(f"Validation Precision: {avg_precision:.4f}, Recall: {avg_recall:.4f}")

    return model, history

In [None]:
# Optional LR scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

model, history = train_model(
    model=model,
    optimizer=optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    num_epochs=10,
    lr_scheduler=lr_scheduler
)

### Automatically saving to public dataset (for Kaggle)

In [None]:
# # Setup Kaggle API token from input dataset (uploaded file)
# os.makedirs("/root/.config/kaggle", exist_ok=True)
# shutil.copy("/kaggle/input/kagglekey/kaggle.json", "/root/.config/kaggle/kaggle.json")
# os.chmod("/root/.config/kaggle/kaggle.json", 0o600)
# print("✅ Kaggle API token configured.")

# # Save your model + training history together
# output_dir = "saved_model_new"
# os.makedirs(output_dir, exist_ok=True)
# model_path = os.path.join(output_dir, "FasterRCNN_test.pth")

# torch.save({
#     'model_state_dict': model.state_dict(),
#     'optimizer_state_dict': optimizer.state_dict(),
#     'history': history
# }, model_path)

# print(f"Model + history saved to: {model_path}")

# # Create metadata file for Kaggle dataset
# dataset_name = "AIMI-project"
# username = "luukneervens"

# metadata = {
#     "title": dataset_name,
#     "id": f"{username}/{dataset_name}",
#     "licenses": [{"name": "CC0-1.0"}]
# }

# with open(os.path.join(output_dir, "dataset-metadata.json"), "w") as f:
#     json.dump(metadata, f)

# # Upload or update dataset on Kaggle
# !kaggle datasets create -p saved_model_new -u

### Plot training loss

In [None]:
# plt.figure(figsize=(12, 5))
# plt.plot(history['train_loss'], label='Train Loss')
# plt.xlabel('Epoch')
# plt.ylabel('Loss')
# plt.title('Training Loss Over Epochs')

# Model Evaluation

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = get_model(num_classes=2)
model.to(device)
model.load_state_dict(torch.load("/kaggle/input/fasterrcnn1/pytorch/default/1/FasterRCNN1.pth"))

In [None]:
def iou(boxA, boxB):
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    interW = max(0, xB - xA)
    interH = max(0, yB - yA)
    interArea = interW * interH
    boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
    boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
    unionArea = boxAArea + boxBArea - interArea
    return interArea / unionArea if unionArea > 0 else 0

to_tensor = transforms.ToTensor()

def visualize_and_save(img: Image.Image, gt_boxes, pred_boxes, 
                       out_path: str, 
                       gt_color=(0,255,0), pred_color=(255,0,0), 
                       thickness=2):
    """
    Draws gt_boxes and pred_boxes onto img and saves to out_path.
    gt_boxes, pred_boxes: lists of [xmin,ymin,xmax,ymax]
    """
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    draw = ImageDraw.Draw(img)

    # Optional: load a default font for labels
    try:
        font = ImageFont.load_default()
    except:
        font = None

    # Draw GT boxes
    for box in gt_boxes:
        draw.rectangle(box, outline=gt_color, width=thickness)
        if font:
            draw.text((box[0], box[1]-10), "GT", fill=gt_color, font=font)

    # Draw Predicted boxes
    for box in pred_boxes:
        draw.rectangle(box, outline=pred_color, width=thickness)
        if font:
            draw.text((box[0], box[3]+2), "PRED", fill=pred_color, font=font)

    img.save(out_path)

def evaluate_on_patches(model, images_dir, anns_dir, output_vis_dir, iou_threshold=0.5, score_thresh=0.5, device="cuda", fraction=0.5):
    model.to(device).eval()

    TP = FP = FN = 0
    image_paths = glob.glob(os.path.join(Path(images_dir), "*.png"))
    print(f"Found {len(image_paths)} patches, testing on random subset of {int(fraction*len(image_paths))} patches")

    # take random subset of all test patches
    random.shuffle(image_paths)
    image_paths = image_paths[:int(len(image_paths) * fraction)]

    with torch.no_grad():
        for img_path in tqdm(image_paths):
            patient_id = '_'.join(Path(img_path).stem.split('_')[:2])
            #print(f'processing image: {patient_id}...')

            # 1) Load and preprocess
            img = Image.open(img_path).convert("RGB")
            img_t = to_tensor(img).to(device)            # C×H×W tensor in [0,1]
            preds = model([img_t])[0]                   # dict with 'boxes', 'scores', 'labels'

            # 2) Filter out low‐confidence preds
            keep = preds["scores"] > score_thresh
            pred_boxes = preds["boxes"][keep].cpu().tolist()

            # 3) Load GT boxes
            stem = Path(img_path).stem
            ann_path = Path(anns_dir) / f"{stem}.json"
            with open(ann_path, 'r') as f:
                gt_data = json.load(f)
            if isinstance(gt_data, dict) and "annotations" in gt_data:
                gt_boxes = [ann["bbox"] for ann in gt_data["annotations"]]
            else:
                gt_boxes = [ann["bbox"] for ann in gt_data]

            # 4) Match preds → GT
            matched_gt = set()
            for pb in pred_boxes:
                best_iou, best_j = 0, -1
                for j, gb in enumerate(gt_boxes):
                    if j in matched_gt: 
                        continue
                    score = iou(pb, gb)
                    if score > best_iou:
                        best_iou, best_j = score, j

                if best_iou >= iou_threshold:
                    TP += 1
                    matched_gt.add(best_j)
                else:
                    FP += 1

            # 5) Any unmatched GT → FN
            FN += (len(gt_boxes) - len(matched_gt))

            out_file = os.path.join(output_vis_dir, f"{stem}_vis.png")
            visualize_and_save(
                img.copy(),        # copy so we don’t alter img in memory
                gt_boxes, 
                pred_boxes, 
                out_file
            )

    # 6) Compute metrics
    precision = TP / (TP + FP) if TP + FP > 0 else 0.0
    recall    = TP / (TP + FN) if TP + FN > 0 else 0.0
    f1_score  = 2 * (precision * recall) / (precision + recall) \
                    if (precision + recall) else 0.0

    print(f"--- Evaluation over {len(image_paths)} patches ---")
    print(f"TP: {TP}, FP: {FP}, FN: {FN}")
    print(f"Precision: {precision:.4f}, Recall: {recall:.4f}")
    print(f"F1 Score : {f1_score:.4f}")

    return {"TP": TP, "FP": FP, "FN": FN, "precision": precision, "recall": recall, "f1": f1_score}

In [None]:
# Point to the correct absolute Kaggle paths
images_dir = "/kaggle/input/test-patches-with-annotations/test_patches_with_annotations/cpg/images"
anns_dir   = "/kaggle/input/test-patches-with-annotations/test_patches_with_annotations/cpg/annotations"
output_vis_dir = "detection_visuals"


# Run evaluation
metrics = evaluate_on_patches(
    model,
    images_dir,
    anns_dir,
    output_vis_dir,
    iou_threshold=0.5,
    score_thresh=0.5,
    device=device,  # or "cpu" if you prefer
    fraction=0.5
)

In [None]:
# Zip test detection output images for download (for Kaggle)
#shutil.make_archive("detection_visuals1.0", 'zip', "detection_visuals")