## Testing work

#### Dataset

In [None]:
import os
import json
import numpy as np
from PIL import Image, ImageDraw
import torch
from torch.utils.data import Dataset
import torchvision.transforms.functional as F

class FoodSegJSONDataset(Dataset):
    def __init__(self, root_dir, transforms=None):
        self.root_dir = root_dir
        self.img_dir = os.path.join(root_dir, "img")
        self.ann_dir = os.path.join(root_dir, "ann")
        self.transforms = transforms

        self.img_files = sorted([
            f for f in os.listdir(self.img_dir)
            if f.endswith('.jpg')
        ])

    def load_annotation(self, ann_path, image_size):
        with open(ann_path) as f:
            data = json.load(f)

        masks = []
        boxes = []
        labels = []

        for obj in data['objects']:
            label_id = obj['category_id']
            polygons = obj['segmentation']

            # Draw mask from polygons
            mask = Image.new("L", image_size, 0)
            for poly in polygons:
                ImageDraw.Draw(mask).polygon(poly, outline=1, fill=1)

            mask_np = np.array(mask, dtype=np.uint8)
            pos = np.where(mask_np)
            if pos[0].size == 0 or pos[1].size == 0:
                continue

            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])
            masks.append(mask_np)
            labels.append(label_id)

        if not masks:
            return None

        masks = torch.as_tensor(np.stack(masks), dtype=torch.uint8)
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)

        return {
            "boxes": boxes,
            "labels": labels,
            "masks": masks,
            "area": (masks.sum(dim=(1, 2))).float(),
            "iscrowd": torch.zeros((len(labels),), dtype=torch.int64),
        }

    def __getitem__(self, idx):
        img_name = self.img_files[idx]
        img_path = os.path.join(self.img_dir, img_name)
        ann_path = os.path.join(self.ann_dir, img_name.replace('.jpg', '.json'))

        img = Image.open(img_path).convert("RGB")
        width, height = img.size

        target = self.load_annotation(ann_path, (width, height))
        if target is None:
            return self.__getitem__((idx + 1) % len(self))  # skip bad mask

        target["image_id"] = torch.tensor([idx])

        if self.transforms:
            img, target = self.transforms(img, target)

        img = F.to_tensor(img)
        return img, target

    def __len__(self):
        return len(self.img_files)


#### Engine

In [None]:
from tqdm import tqdm


def train_one_epoch(model, optimiser, data_loader, device, epoch):
    model.train()

    for images, targets in tqdm(data_loader, desc=f"Epoch {epoch}"):
        images = list(img.to(device) for img in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimiser.zero_grad()
        losses.backward()
        optimiser.step()
    
    print(f"Loss: {losses.item():.4f}")

#### Models

In [None]:
import torchvision

from torchvision.models.detection import (
    maskrcnn_resnet50_fpn_v2, 
    MaskRCNN_ResNet50_FPN_V2_Weights,
    faster_rcnn,
    mask_rcnn)


def get_model(num_classes):
    model = maskrcnn_resnet50_fpn_v2(weights=MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1)

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = faster_rcnn.FastRCNNPredictor(in_features, num_classes)
    in_features_mask = model.roi_heads.mask_predictor.con5_mask.in_channels

    hidden_layer = 256
    model.roi_heads.mask_predictor = mask_rcnn.MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

    return model

#### Utils

In [None]:
import torchvision.transforms as T


def get_transform(train=True):
    transforms = []
    # convert image to pytorch tensor
    transforms.append(T.ToTensor())

    if train:
        # apply horizontal flip randomly
        transforms.append(T.RandomHorizontalFlip(0.5))
    
    # combined all transforms into a pipeline
    return T.Compose(transforms)


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import random

def show_image_with_masks(img, pred, categories=None, score_thresh=0.5):
    img = img.permute(1,2,0).numpy()
    
    plt.figure(figsize=(10,10))
    plt.imshow(img)

    ax = plt.gca()

    masks = pred["masks"]
    boxes = pred["boxes"]
    labels = pred["labels"]
    scores = pred["scores"]

    for i in range(len(masks)):
        if scores[i] < score_thresh:
            continue

        mask = masks[i,0].mul(255).byte().cpu().numpy()
        color = np.random.rand(3,)
        
        ax.contour(mask, levels=[0.5], colors=[color])

        x1, y1, x2, y2 = boxes[i].detach().cpu().numpy()
        ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1,
                                   fill=False, color=color, linewidth=2))
        label_id = labels[i].item()
        label_name = categories[label_id] if categories and label_id in categories else str(label_id)
        ax.text(x1, y1, f"{label_id}:{scores[i]:.2f}", color=color, fontsize=12,
                bbox=dict(facecolor='white', edgecolor=color, boxstyle='round,pad=0.2'))
    
    plt.axis("off")
    plt.tight_layout()
    plt.show()




#### Main.py

In [None]:
import torch
import json
import os

from torch.utils.data import DataLoader

from dataset.foodseg_json_dataset import FoodSegJSONDataset
from models.mask_rcnn import get_model
from utils.transforms import get_transform
from engine.train import train_one_epoch
from utils.visualise import show_image_with_masks


def collate_fn(batch):
    return tuple(zip(*batch))


def load_categories(meta_path):
    with open(meta_path) as f:
        meta = json.load(f)
    return {cat['id']: cat['title'] for cat in meta['classes']}


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"device: {device}")

    train_data_path = "data/foodseg103/train"
    test_data_path = "data/foodseg103/test"
    meta_path = "data/foodseg103/meta.json"

    categories = load_categories(meta_path)
    print(categories)

    # Dataset and Dataloader
    dataset = FoodSegJSONDataset(train_data_path, transforms=get_transform(train=True))
    data_loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

    test_dataset = FoodSegJSONDataset(test_data_path, transforms=get_transform(train=False))
    # test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

    # Model
    num_classes = max(categories.keys()) + 1 # background + category count
    print(f"num classes: {num_classes}")
    model = get_model(num_classes).to(device)

    # Optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimiser = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

    # Training Loop
    for epoch in range(3):
        train_one_epoch(model, optimiser, data_loader, device, epoch)

        os.makedirs("outputs/models", exist_ok=True)
        torch.save(model.state_dict(), f"outputs/models/model_epoch_{epoch}.pth")

    
    # Visualize predictions on test set
    model.eval()
    with torch.no_grad():
        img, _ = test_dataset[0]
        pred = model([img.to(device)])[0]
    
    show_image_with_masks(img, pred, categories)