In [None]:
import torch
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
import os
from PIL import Image
import numpy as np
from tqdm import tqdm

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, Subset
from pycocotools.coco import COCO
from PIL import Image
import torchvision.transforms as T


class CustomCocoDataset(Dataset):
    def __init__(self, root, annotation, transform=None):
        """
        Args:
            root (str): Path to the images folder
            annotation (str): Path to the COCO annotation JSON file
            transform (callable, optional): Optional transform to be applied on an image.
        """
        self.root = root
        self.coco = COCO(annotation)
        # Filter images to include only those with at least one annotation
        self.ids = [
            img_id
            for img_id in self.coco.imgs.keys()
            if len(self.coco.getAnnIds(imgIds=img_id)) > 0
        ]
        self.transform = transform

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index):
        coco = self.coco
        img_id = self.ids[index]
        ann_ids = coco.getAnnIds(imgIds=img_id)
        anns = coco.loadAnns(ann_ids)

        # Load image
        img_info = coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.root, img_info["file_name"])
        img = Image.open(img_path).convert("RGB")

        # Prepare targets
        boxes, labels, masks, areas, iscrowd = [], [], [], [], []
        for ann in anns:
            x, y, w, h = ann["bbox"]
            boxes.append([x, y, x + w, y + h])
            labels.append(ann["category_id"])
            masks.append(coco.annToMask(ann))
            areas.append(ann["area"])
            iscrowd.append(ann["iscrowd"])

        # Convert to tensors
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        areas = torch.as_tensor(areas, dtype=torch.float32)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)

        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks,
            "image_id": torch.tensor([img_id]),
            "area": areas,
            "iscrowd": iscrowd,
        }

        if self.transform:
            img = self.transform(img)

        return img, target


def get_transform():
    return T.Compose(
        [
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )


# Single dataset path (with only 'train' folder and annotation)
data_root = "/content/final_main_seg_dataset_just_ARAS-3/train/"
ann_file = os.path.join(data_root, "_annotations.coco.json")

# Instantiate full dataset
full_dataset = CustomCocoDataset(
    root=data_root, annotation=ann_file, transform=get_transform()
)

total_samples = len(full_dataset)

# Compute split sizes (no shuffling, preserve original order)
train_size = int(0.7 * total_samples)
valid_size = int(0.2 * total_samples)
test_size = total_samples - train_size - valid_size

# Generate index subsets
train_indices = list(range(0, train_size))
valid_indices = list(range(train_size, train_size + valid_size))
test_indices = list(range(train_size + valid_size, total_samples))

# Create subsets
train_dataset = Subset(full_dataset, train_indices)
valid_dataset = Subset(full_dataset, valid_indices)
test_dataset = Subset(full_dataset, test_indices)


# Custom collate function for batching
def collate_fn(batch):
    return tuple(zip(*batch))


# DataLoaders with no shuffle
train_loader = DataLoader(
    train_dataset, batch_size=8, shuffle=False, num_workers=2, collate_fn=collate_fn
)

valid_loader = DataLoader(
    valid_dataset, batch_size=1, shuffle=False, num_workers=2, collate_fn=collate_fn
)

test_loader = DataLoader(
    test_dataset, batch_size=1, shuffle=False, num_workers=2, collate_fn=collate_fn
)

In [None]:
# Set device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Load pre-trained Mask R-CNN with ResNet50-FPN backbone
model = maskrcnn_resnet50_fpn(pretrained=True)

# Get input features for modifying the heads
in_features = model.roi_heads.box_predictor.cls_score.in_features
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
dim_reduced = model.roi_heads.mask_predictor.conv5_mask.out_channels

# Replace the box predictor (13 classes: 12 categories + background)
num_classes = 13  # Adjust this based on your dataset
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

# Replace the mask predictor
model.roi_heads.mask_predictor = MaskRCNNPredictor(
    in_features_mask, dim_reduced, num_classes
)

# Move model to device
model.to(device)

In [None]:
# Parameters to optimize (only those requiring gradients)
params = [p for p in model.parameters() if p.requires_grad]

# Optimizer (SGD)
# optimizer = torch.optim.SGD(
#     params,
#     lr=0.005,
#     momentum=0.9,
#     weight_decay=0.0005
# )
# Alternative: Use AdamW instead of SGD (uncomment to use)
optimizer = torch.optim.AdamW(params, lr=0.0005, weight_decay=0.0005)

# Learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=3, gamma=0.1  # Reduce LR every 3 epochs  # Multiply LR by 0.1
)

In [None]:
from tqdm import tqdm

# Assume model, train_loader, optimizer, lr_scheduler, and device are already defined
num_epochs = 10

for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    total_loss = 0

    # Wrap train_loader with tqdm for a progress bar
    for images, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        # Move data to the appropriate device (e.g., GPU)
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # Forward pass
        loss_dict = model(images, targets)  # Returns a dictionary of losses
        losses = sum(loss for loss in loss_dict.values())

        # Backward pass and optimization
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        total_loss += losses.item()

    # Update learning rate (if using a scheduler)
    lr_scheduler.step()

    # Print average loss for the epoch
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")

In [None]:
# Save the model state dictionary
torch.save(model.state_dict(), "/content/drive/MyDrive/maskrcnn_finetuned.pth")
torch.save(model.state_dict(), "maskrcnn_finetuned.pth")
print("Model saved as 'maskrcnn_finetuned.pth'")

In [None]:
def evaluate_model(model, test_loader, test_ann_file, device):
    model.eval()
    coco_gt = COCO(test_ann_file)  # Ground truth annotations

    # Prepare predictions in COCO format
    coco_dt = []
    with torch.no_grad():
        for images, targets in tqdm(test_loader, desc="Evaluating"):
            images = list(img.to(device) for img in images)
            outputs = model(images)

            for i, output in enumerate(outputs):
                image_id = targets[i]["image_id"].item()
                boxes = output["boxes"].cpu().numpy()
                scores = output["scores"].cpu().numpy()
                labels = output["labels"].cpu().numpy()
                masks = output["masks"].cpu().numpy()

                for box, score, label, mask in zip(boxes, scores, labels, masks):
                    # Convert box from [x_min, y_min, x_max, y_max] to [x, y, w, h]
                    x_min, y_min, x_max, y_max = box
                    x, y, w, h = x_min, y_min, x_max - x_min, y_max - y_min

                    # Convert mask to RLE (Run-Length Encoding) for COCO
                    mask = mask[
                        0
                    ]  # Mask R-CNN outputs [N, 1, H, W], take first channel
                    mask = (mask > 0.5).astype(np.uint8)  # Binarize mask
                    rle = coco_gt.maskUtils.encode(
                        np.array(mask[:, :, None], order="F", dtype="uint8")
                    )[0]
                    rle["counts"] = rle["counts"].decode(
                        "utf-8"
                    )  # Convert bytes to string

                    # Add prediction to list
                    coco_dt.append(
                        {
                            "image_id": int(image_id),
                            "category_id": int(label),
                            "bbox": [float(x), float(y), float(w), float(h)],
                            "score": float(score),
                            "segmentation": rle,
                        }
                    )

    # Load predictions into COCO format
    coco_dt = coco_gt.loadRes(coco_dt)

    # Run COCO evaluation
    coco_eval = COCOeval(coco_gt, coco_dt, "segm")  # Use 'segm' for mask evaluation
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()

    # Extract mAP@0.5 and mAP@0.5:0.95
    mAP_50 = coco_eval.stats[1]  # AP at IoU=0.5
    mAP_50_95 = coco_eval.stats[0]  # AP at IoU=0.5:0.95

    print(f"mAP@0.5: {mAP_50:.4f}")
    print(f"mAP@0.5:0.95: {mAP_50_95:.4f}")


# Run evaluation
evaluate_model(model, test_loader, test_ann, device)

## Report mAP per class

In [None]:
!pip install torchmetrics

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, Subset
from pycocotools.coco import COCO
from PIL import Image
import torchvision.transforms as T
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchmetrics.detection import MeanAveragePrecision
import numpy as np
from tqdm import tqdm
import cv2


# -------------------- COCO Dataset Definition --------------------
class CustomCocoDataset(Dataset):
    def __init__(self, root, annotation, transform=None):
        self.root = root
        self.coco = COCO(annotation)
        self.ids = [
            img_id
            for img_id in self.coco.imgs.keys()
            if len(self.coco.getAnnIds(imgIds=img_id)) > 0
        ]
        self.transform = transform

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index):
        img_id = self.ids[index]
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)

        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.root, img_info["file_name"])
        img = Image.open(img_path).convert("RGB")

        boxes, labels, masks, areas, iscrowd = [], [], [], [], []
        for ann in anns:
            x, y, w, h = ann["bbox"]
            boxes.append([x, y, x + w, y + h])
            labels.append(ann["category_id"])
            masks.append(self.coco.annToMask(ann))
            areas.append(ann["area"])
            iscrowd.append(ann["iscrowd"])

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        areas = torch.as_tensor(areas, dtype=torch.float32)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)

        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks,
            "image_id": torch.tensor([img_id]),
            "area": areas,
            "iscrowd": iscrowd,
        }

        if self.transform:
            img = self.transform(img)
        return img, target


# -------------------- Transforms --------------------
def get_transform():
    return T.Compose(
        [
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )


# -------------------- Data Splitting (70/20/10) --------------------
data_root = "train/"
ann_file = os.path.join(data_root, "_annotations.coco.json")
full_dataset = CustomCocoDataset(
    root=data_root, annotation=ann_file, transform=get_transform()
)
total = len(full_dataset)

train_end = int(0.7 * total)
valid_end = train_end + int(0.2 * total)

train_dataset = Subset(full_dataset, list(range(0, train_end)))
valid_dataset = Subset(full_dataset, list(range(train_end, valid_end)))
test_dataset = Subset(full_dataset, list(range(valid_end, total)))

# Collate fn
collate_fn = lambda batch: tuple(zip(*batch))

train_loader = DataLoader(
    train_dataset, batch_size=8, shuffle=False, num_workers=2, collate_fn=collate_fn
)
valid_loader = DataLoader(
    valid_dataset, batch_size=1, shuffle=False, num_workers=2, collate_fn=collate_fn
)
test_loader = DataLoader(
    test_dataset, batch_size=1, shuffle=False, num_workers=2, collate_fn=collate_fn
)

# -------------------- Model Setup --------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = maskrcnn_resnet50_fpn(pretrained=False)
num_classes = 13  # 12 classes + background
model.roi_heads.box_predictor = FastRCNNPredictor(1024, num_classes)
model.roi_heads.mask_predictor = MaskRCNNPredictor(256, 256, num_classes)
model.load_state_dict(
    torch.load("/content/maskrcnn_finetuned.pth", map_location=device)
)
model.to(device)


# -------------------- Evaluation --------------------
def evaluate_model(model, data_loader, coco_annotation, data_root, output_folder):
    os.makedirs(output_folder, exist_ok=True)
    model.eval()

    # Metrics
    metas = {
        "mask_full": {"iou_type": "segm", "iou_thresholds": None},
        "mask_50": {"iou_type": "segm", "iou_thresholds": [0.5]},
        "box_full": {"iou_type": "bbox", "iou_thresholds": None},
        "box_50": {"iou_type": "bbox", "iou_thresholds": [0.5]},
    }
    metrics = {}
    for key, cfg in metas.items():
        metrics[key] = MeanAveragePrecision(
            box_format="xyxy",
            iou_type=cfg["iou_type"],
            iou_thresholds=cfg["iou_thresholds"],
            class_metrics=True,
        ).to("cpu")

    coco_gt = COCO(coco_annotation)

    # Visual params (using original fixed class colors and names)
    class_colors = {
        1: (178, 129, 241),
        2: (111, 205, 151),
        3: (188, 114, 0),
        4: (54, 49, 173),
        5: (104, 34, 237),
        6: (30, 145, 246),
        7: (153, 62, 98),
        8: (85, 154, 0),
        9: (220, 196, 71),
        10: (115, 229, 223),
        11: (0, 128, 128),
        12: (0, 22, 103),
    }
    class_names = {
        1: "Cannula",
        2: "Cap-Cystotome",
        3: "Cap-Forceps",
        4: "Cornea",
        5: "Forceps",
        6: "I-A-Handpiece",
        7: "Lens-Injector",
        8: "Phaco-Handpiece",
        9: "Primary-Knife",
        10: "Pupil",
        11: "Second-Instrument",
        12: "Secondary-Knife",
    }

    with torch.no_grad()():
        for images, targets in tqdm(data_loader, desc="Eval"):
            imgs = [img.to(device) for img in images]
            outputs = model(imgs)

            for out, tgt in zip(outputs, targets):
                img_id = tgt["image_id"].item()
                img_info = coco_gt.loadImgs(img_id)[0]
                img_path = os.path.join(data_root, img_info["file_name"])
                img = cv2.imread(img_path)
                vis = img.copy()

                # Filter preds
                keep = out["labels"] != 0
                preds = {
                    "boxes": out["boxes"][keep].cpu(),
                    "scores": out["scores"][keep].cpu(),
                    "labels": (out["labels"][keep] - 1).cpu(),
                }
                # For masks
                pm = out["masks"][keep].cpu().squeeze(1) > 0.5
                preds_mask = {**preds, "masks": pm}

                # GT
                gt = {
                    "boxes": tgt["boxes"].cpu(),
                    "labels": (tgt["labels"] - 1).cpu(),
                    "masks": tgt["masks"].cpu().bool(),
                }

                # Update metrics
                metrics["box_full"].update([preds], [gt])
                metrics["box_50"].update([preds], [gt])
                metrics["mask_full"].update([preds_mask], [gt])
                metrics["mask_50"].update([preds_mask], [gt])

                # Optional: save visualizations
                for box, label, mask in zip(
                    out["boxes"][keep], out["labels"][keep], pm
                ):
                    c = class_colors[int(label)]
                    x1, y1, x2, y2 = map(int, box)
                    vis[mask.numpy()] = cv2.addWeighted(
                        vis[mask.numpy()], 0.5, np.full_like(vis, c), 0.5, 0
                    )
                    cv2.rectangle(vis, (x1, y1), (x2, y2), c, 2)
                    cv2.putText(
                        vis,
                        cat_info[int(label)],
                        (x1, y1 - 5),
                        cv2.FONT_HERSHEY_SIMPLEX,
                        0.5,
                        c,
                        1,
                    )
                cv2.imwrite(os.path.join(output_folder, img_info["file_name"]), vis)

    # Compute and print
    results = {k: m.compute() for k, m in metrics.items()}
    # Per-class and overall report
    valid_ids = coco_gt.getCatIds()
    names = cat_info

    print("\n====== Per-class mAP ======")
    for cid in valid_ids:
        idx = cid - 1
        name = names[cid]
        print(
            f"{name}: Box@50 {results['box_50']['map_per_class'][idx]:.4f}, Box@full {results['box_full']['map_per_class'][idx]:.4f} | "
            f"Mask@50 {results['mask_50']['map_per_class'][idx]:.4f}, Mask@full {results['mask_full']['map_per_class'][idx]:.4f}"
        )

    print("\n====== Overall mAP ======")
    for k in ["box_50", "box_full", "mask_50", "mask_full"]:
        print(f"{k}: {results[k]['map'].item():.4f}")


# -------------------- Execute Evaluation --------------------
output_folder = "output"
evaluate_model(model, test_loader, ann_file, data_root, output_folder)