In [None]:
import torch
import torchvision
from torch import optim
from torchvision.transforms import functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from matplotlib import patches
import matplotlib.pyplot as plt
import numpy as np
from torchvision.datasets import OxfordIIITPet

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (device)

In [None]:
class PetDataset(torch.utils.data.Dataset):
    def __init__(self, root, split="trainval"):
        self.dataset = OxfordIIITPet(
            root=root,
            split=split,
            target_types="segmentation",
            download=True
        )

    def __getitem__(self, idx):
        img, mask = self.dataset[idx]
        img = F.to_tensor(img)

        mask = np.array(mask)
        binary_mask = mask == 1

        if binary_mask.sum() == 0:
            return self[(idx+1) % len(self)]

        binary_mask = torch.as_tensor(binary_mask, dtype=torch.float32)

        pos = torch.where(binary_mask)
        xmin = torch.min(pos[1])
        xmax = torch.max(pos[1])
        ymin = torch.min(pos[0])
        ymax = torch.max(pos[0])

        boxes = torch.tensor([[xmin, ymin, xmax, ymax]], dtype=torch.float32)
        labels = torch.ones((1,), dtype=torch.int64)
        masks = binary_mask.unsqueeze(0).float()  # float tensor

        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks
        }

        return img, target


    def __len__(self):
        return len(self.dataset)

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))


train_dataset = PetDataset("./data", split="trainval")
val_dataset = PetDataset("./data", split="test")

from torch.utils.data import Subset
train_dataset = Subset(train_dataset, range(150))
val_dataset = Subset(val_dataset, range(50))


train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn, num_workers=0, pin_memory=True, persistent_workers=False)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)

In [None]:
def train_one_epoch(model, loader, optimizer, device, epoch, epochs):
    model.train()
    running_loss = 0

    pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=False)

    for images, targets in pbar:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        loss = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        pbar.set_postfix(loss=loss.item())

    return running_loss / len(loader)

In [None]:
def train_model(model, train_loader, val_loader, optimizer, device, epochs):
    history = {
        "train_loss": [],
        "val_loss": [],
        "map": []
        }

    for epoch in range(epochs):
        train_loss = train_one_epoch(
            model, train_loader, optimizer, device, epoch, epochs
        )

        model.train()
        val_loss = 0.0

        with torch.no_grad():
            for images, targets in val_loader:
                images = [img.to(device) for img in images]
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

                loss_dict = model(images, targets)
                loss = sum(loss for loss in loss_dict.values())
                val_loss += loss

        val_loss = val_loss / len(val_loader)

        map_results = evaluate_map(model, val_loader, device)
        map = map_results['map'].item()

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss.item())
        history["map"].append(map)

        print(
            f"Epoch {epoch+1}/{epochs} | "
            f"Train Loss: {train_loss:.4f} | "
            f"Val Loss: {val_loss:.4f} | "
            f"Map: {map:.4f}"
        )

    return history

In [None]:
def plot_curves(results):
    plt.figure(figsize=(15,4))

    plt.subplot(1, 3, 1)
    for name, history in results.items():
        plt.plot(history["train_loss"], label=f"{name} Train Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training loss")
    plt.legend()

    plt.subplot(1, 3, 2)
    for name, history in results.items():
        plt.plot(history["val_loss"], label=f"{name} Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Validation loss")
    plt.legend()

    plt.subplot(1, 3, 3)
    for name, history in results.items():
        plt.plot(history["map"], label=f"{name} Map")
    plt.xlabel("Epoch")
    plt.ylabel("Map")
    plt.title("Map")
    plt.legend()

    plt.tight_layout()
    plt.show()

In [None]:
def show_detection(model, dataset, device, n=6, score_thr=0.5):

    import random
    model.eval()

    idxs = random.sample(range(len(dataset)), n)

    plt.figure(figsize=(4*n, 4))

    for i, idx in enumerate(idxs):
        img, target = dataset[idx]
        with torch.no_grad():
            pred = model([img.to(device)])[0]

        img_np = img.permute(1,2,0).cpu().numpy()

        plt.subplot(1, n, i+1)
        plt.imshow(img_np)
        ax = plt.gca()

        for box in target["boxes"]:
            x1, y1, x2, y2 = box
            rect = patches.Rectangle((x1, y1), x2-x1, y2-y1,
                                     linewidth=2, edgecolor='g', facecolor='none')
            ax.add_patch(rect)

        for box, score in zip(pred["boxes"], pred["scores"]):
            if score < score_thr:
                continue
            x1, y1, x2, y2 = box.cpu()
            rect = patches.Rectangle((x1, y1), x2-x1, y2-y1,
                                     linewidth=2, edgecolor='r', facecolor='none')
            ax.add_patch(rect)

        ax.set_title(f"GT=Green\nPred=Red", fontsize=8)
        ax.axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

def create_model(num_classes, device):
    model = maskrcnn_resnet50_fpn(weights="DEFAULT")

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
        in_features, num_classes
    )

    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(
        in_features_mask, hidden_layer, num_classes
    )

    return model.to(device)

In [None]:
from torchmetrics.detection.mean_ap import MeanAveragePrecision

def evaluate_map(model, loader, device):
    model.eval()
    metric = MeanAveragePrecision(iou_type="bbox")

    with torch.no_grad():
        for images, targets in loader:
            images = [img.to(device) for img in images]
            preds = model(images)

            preds = [{k: v.cpu() for k, v in p.items()} for p in preds]
            targets = [{k: v.cpu() for k, v in t.items()} for t in targets]

            metric.update(preds, targets)

    return metric.compute()

In [None]:
results = {}
lr = 1e-3

for epochs in [5, 10, 15]:

    print(f"Training with epochs={epochs}")

    model = create_model(num_classes=2, device=device)

    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    history = train_model(model, train_loader, val_loader, optimizer, device, epochs)
    map_results = evaluate_map(model, val_loader, device)

    results[f"epochs={epochs}"] = history

plot_curves(results)
show_detection(model, val_dataset, device, n=6, score_thr=0.5)

In [None]:
results = {}

lrs = [1e-4, 5e-4, 1e-3]

for lr in lrs:
    print(f"Training with lr={lr}")
    model = create_model(num_classes=2, device=device)

    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    history = train_model(model, train_loader, val_loader, optimizer, device, epochs=10)

    results[f"lr={lr}"] = history

plot_curves(results)
show_detection(model, val_dataset, device, n=6, score_thr=0.5)