In [1]:
#Let's import all the necessary libraries
import numpy as np
import torch
import torchvision
from torchvision import transforms as T
from PIL import Image
import requests
import cv2
from torch.utils.data import Dataset, random_split, DataLoader
from pycocotools.coco import COCO

import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina' # For high res images

In [33]:
#Model Initiation
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=None)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes=5)
#model.train()

In [35]:
# Load your custom-trained weights
model.load_state_dict(torch.load("trained_maskrcnn.pth"))


<All keys matched successfully>

In [None]:
model.train()

In [None]:
class COCODataset(Dataset):
    def __init__(self, root, annFile, transforms=None):
        self.root = root  
        self.coco = COCO(annFile)  # Load COCO JSON annotations
        self.ids = list(self.coco.imgs.keys())  # Image IDs
        self.transforms = transforms

    def __len__(self): # Function that returns the length of the class, Without this, the code will be unable to get COCODataset length in the training 
            return len(self.ids)  
        
    def __getitem__(self, index):
        img_id = self.ids[index]
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        annotations = self.coco.loadAnns(ann_ids)
        image_info = self.coco.imgs[img_id]

        # Load Image
        img_path = os.path.join(self.root, image_info["file_name"])
        img = Image.open(img_path).convert("RGB")

        # Load Annotations
        boxes = []
        masks = []
        labels = []
        
        for ann in annotations:
            # Bounding box (COCO format: [x_min, y_min, width, height])
            x_min, y_min, w, h = ann["bbox"]
            boxes.append([x_min, y_min, x_min + w, y_min + h])

            # Segmentation mask (Polygon format)
            mask = np.zeros((image_info["height"], image_info["width"]), dtype=np.uint8)
            for seg in ann["segmentation"]:
                poly = np.array(seg, dtype=np.int32).reshape(-1, 2)
                cv2.fillPoly(mask, [poly], 1)
            masks.append(mask)

            # Category label
            labels.append(ann["category_id"])

        # Convert to tensors
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        labels = torch.as_tensor(labels, dtype=torch.int64)

        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks,
            "image_id": torch.tensor([img_id]),
            "area": torch.tensor([ann["area"] for ann in annotations], dtype=torch.float32),
            "iscrowd": torch.tensor([ann["iscrowd"] for ann in annotations], dtype=torch.int64),
        }

        if self.transforms:
            img = self.transforms(img)

        return img, target

In [1]:
from torchvision import transforms as T

def get_transform():
    return T.Compose([
        T.ToTensor()
    ])

In [None]:
# Define dataset and DataLoader
dataset = COCODataset(root="cocoseg/train", annFile="cocoseg/train/_annotations.coco.json", transforms=get_transform())

def collate_fn(batch):
    return tuple(zip(*batch))

dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=collate_fn )


In [None]:
import torch.optim as optim

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005)

# Define optimizer and learning rate
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


In [None]:
num_epochs = 100  # Adjust based on dataset size

for epoch in range(num_epochs):
    total_loss = 0
    count = 0

    for images, targets in dataloader:
        images = [img.to(device) for img in images]
        # Convert bounding boxes to tensors
        targets = [
            {
                k: (torch.tensor(v, dtype=torch.float32).to(device) if k == "boxes" else 
                    v.to(device) if isinstance(v, torch.Tensor) else v)
                for k, v in t.items()
            }
            for t in targets
        ]

        # Skip images with no objects (empty bounding boxes)
        if any(t["boxes"].shape[0] == 0 for t in targets):
            print("Skipping image with no objects...")
            count+=1
            continue  # Move to the next batch

        optimizer.zero_grad()
        loss_dict = model(images, targets)
        loss = sum(loss for loss in loss_dict.values())
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print("count:",count)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}")


# Save the trained model
torch.save(model.state_dict(), "trained_maskrcnn.pth")


In [None]:
from torchvision.ops.boxes import box_iou
from torchvision.datasets import CocoDetection

# Function to compute IoU between predicted and ground-truth boxes
def compute_iou(pred_boxes, true_boxes):
    if len(pred_boxes) == 0 or len(true_boxes) == 0:
        return torch.tensor([])  # Return empty IoU if no boxes
    return box_iou(pred_boxes, true_boxes)

# Function to calculate precision and recall
def calculate_precision_recall(model, dataloader, device, iou_threshold=0.5, conf_threshold=0.5):
    model.eval()
    all_precisions, all_recalls = [], []

    with torch.no_grad():
        for images, targets in dataloader:
            images = [img.to(device) for img in images]
            outputs = model(images)

            for output, target in zip(outputs, targets):
                pred_boxes = output["boxes"][output["scores"] > conf_threshold]
                true_boxes = torch.tensor(target["boxes"], dtype=torch.float32).to(device)

                iou_matrix = compute_iou(pred_boxes, true_boxes)

                if iou_matrix.numel() > 0:
                    true_positives = (iou_matrix.max(dim=1)[0] > iou_threshold).sum().item()
                else:
                    true_positives = 0

                precision = true_positives / max(len(pred_boxes), 1)  # Avoid division by zero
                recall = true_positives / max(len(true_boxes), 1)  # Avoid division by zero

                all_precisions.append(precision)
                all_recalls.append(recall)

    avg_precision = sum(all_precisions) / len(all_precisions)
    avg_recall = sum(all_recalls) / len(all_recalls)

    return avg_precision, avg_recall

# Evaluate the model
precision, recall = calculate_precision_recall(model, dataloader, device="cpu")
print(f"Precision: {precision:.4f}, Recall: {recall:.4f}")
