In [None]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [None]:
# Zelle 2: Dataloader für Instance Segmentation

import json
from pathlib import Path
from PIL import Image, ImageDraw
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import functional as F

class CocoInstSegDataset(Dataset):
    def __init__(self, ann_file, img_root, resize_short=640):
        self.ann = json.load(open(ann_file, "r"))
        self.img_root = Path(img_root)
        self.images = self.ann["images"]
        self.anns   = self.ann["annotations"]
        self.cats   = self.ann["categories"]
        self.resize_short = resize_short

        # index: image_id -> list[ann]
        self.by_img = {}
        for a in self.anns:
            self.by_img.setdefault(a["image_id"], []).append(a)

        # category_id -> 1..K (Foreground-Labels)
        cat_ids = [c["id"] for c in self.cats]
        cat_ids_sorted = sorted(cat_ids)
        self.cid_to_lbl = {cid: i + 1 for i, cid in enumerate(cat_ids_sorted)}
        self.num_classes_fg = len(self.cid_to_lbl)

    def __len__(self):
        return len(self.images)

    def _resize(self, img, boxes, segs):
        """Bild + Boxen + Segmente auf resize_short skalieren."""
        w0, h0 = img.size
        if self.resize_short is None:
            return img, boxes, segs, 1.0, 1.0

        scale = self.resize_short / min(w0, h0)
        new_w, new_h = int(round(w0 * scale)), int(round(h0 * scale))
        img = img.resize((new_w, new_h), resample=Image.BILINEAR)

        # Boxen skalieren
        if boxes:
            boxes = [
                [
                    b[0] * scale,
                    b[1] * scale,
                    b[2] * scale,
                    b[3] * scale,
                ]
                for b in boxes
            ]

        # Segmente skalieren
        new_segs = []
        for seg in segs:
            if not seg:
                new_segs.append(seg)
                continue
            scaled = []
            for s in seg:  # s ist eine flache Liste [x1,y1,x2,y2,...]
                scaled_poly = []
                for i in range(0, len(s), 2):
                    x = s[i] * scale
                    y = s[i + 1] * scale
                    scaled_poly.extend([x, y])
                scaled.append(scaled_poly)
            new_segs.append(scaled)

        return img, boxes, new_segs, scale, scale

    def _polygons_to_mask(self, seg, h, w):
        """
        seg: Liste von Polygonen, jedes Polygon = [x1,y1,x2,y2,...]
        Rückgabe: HxW uint8-Maske (0/1)
        """
        mask_img = Image.new("L", (w, h), 0)
        draw = ImageDraw.Draw(mask_img)
        for poly in seg:
            if len(poly) >= 6:
                xy = [(poly[i], poly[i+1]) for i in range(0, len(poly), 2)]
                draw.polygon(xy, outline=1, fill=1)
        return np.array(mask_img, dtype=np.uint8)

    def __getitem__(self, idx):
        meta = self.images[idx]
        img_path = self.img_root / meta["file_name"]
        img = Image.open(img_path).convert("RGB")
        w0, h0 = img.size

        anns = self.by_img.get(meta["id"], [])

        # COCO-BBOX: [x, y, w, h] -> xyxy
        boxes = []
        labels = []
        segs = []
        areas = []
        iscrowd = []

        for a in anns:
            x, y, w, h = a["bbox"]
            xyxy = [x, y, x + w, y + h]
            boxes.append(xyxy)
            labels.append(self.cid_to_lbl[a["category_id"]])
            segs.append(a.get("segmentation", []))
            areas.append(a.get("area", w * h))
            iscrowd.append(a.get("iscrowd", 0))

        # Resizing
        img, boxes, segs, sx, sy = self._resize(img, boxes, segs)
        w, h = img.size

        # Masken bauen
        masks = []
        for seg in segs:
            if not seg:
                masks.append(np.zeros((h, w), dtype=np.uint8))
            else:
                masks.append(self._polygons_to_mask(seg, h, w))

        boxes_t = torch.tensor(boxes, dtype=torch.float32) if boxes else torch.zeros((0, 4), dtype=torch.float32)
        labels_t = torch.tensor(labels, dtype=torch.int64) if labels else torch.zeros((0,), dtype=torch.int64)
        masks_t  = torch.tensor(np.stack(masks), dtype=torch.uint8) if masks else torch.zeros((0, h, w), dtype=torch.uint8)
        areas_t  = torch.tensor(areas, dtype=torch.float32) if areas else torch.zeros((0,), dtype=torch.float32)
        iscrowd_t = torch.tensor(iscrowd, dtype=torch.int64) if iscrowd else torch.zeros((0,), dtype=torch.int64)

        target = {
            "boxes": boxes_t,
            "labels": labels_t,
            "masks": masks_t,
            "image_id": torch.tensor([meta["id"]], dtype=torch.int64),
            "area": areas_t,
            "iscrowd": iscrowd_t,
        }

        img_t = F.to_tensor(img)  # [0..1], CxHxW
        return img_t, target


def collate_fn(batch):
    imgs, targets = zip(*batch)
    return list(imgs), list(targets)


In [None]:
# Zelle 3: Dataset & DataLoader

ROOT = "../final_dataset"  # wie in deinem OD-Notebook

train_ds = CocoInstSegDataset(
    f"{ROOT}/annotations/instances_train.json",
    f"{ROOT}/images/train",
    resize_short=640,
)
val_ds = CocoInstSegDataset(
    f"{ROOT}/annotations/instances_val.json",
    f"{ROOT}/images/val",
    resize_short=640,
)

from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_ds,
    batch_size=4,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn,
    pin_memory=True,
)
val_loader = DataLoader(
    val_ds,
    batch_size=4,
    shuffle=False,
    num_workers=4,
    collate_fn=collate_fn,
    pin_memory=True,
)

num_classes_fg = train_ds.num_classes_fg  # Anzahl Vordergrund-Klassen aus COCO
print("Foreground-Klassen:", num_classes_fg)


In [None]:
# Zelle 4: Modell (Mask R-CNN)

import torch
import matplotlib.pyplot as plt
from torchvision.models.detection import maskrcnn_resnet50_fpn_v2
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = maskrcnn_resnet50_fpn_v2(weights="DEFAULT").to(device)

# Klassifikations-Head anpassen
in_feats = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_feats, num_classes_fg + 1)

# Masken-Head anpassen
in_feats_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
model.roi_heads.mask_predictor = MaskRCNNPredictor(
    in_feats_mask, hidden_layer, num_classes_fg + 1
)

model = model.to(device)


In [None]:
# Zelle 6: Training (Instance Segmentation)

from torch.optim import AdamW
from tqdm.auto import tqdm
import time

optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

EPOCHS = 40

loss_per_step = []
loss_per_epoch = []

for epoch in range(EPOCHS):
    model.train()
    total = 0.0
    n_steps = 0
    start = time.time()

    for images, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        images = [img.to(device) for img in images]
        targets = [
            {
                k: (v.to(device) if isinstance(v, torch.Tensor) else v)
                for k, v in t.items()
            }
            for t in targets
        ]

        optimizer.zero_grad()

        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            loss_dict = model(images, targets)
            loss = sum(loss_dict.values())

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total += loss.item()
        n_steps += 1
        loss_per_step.append(loss.item())

    epoch_loss = total / max(1, n_steps)
    loss_per_epoch.append(epoch_loss)
    dur = time.time() - start
    print(f"Epoch {epoch+1}/{EPOCHS} - train loss: {epoch_loss:.4f} ({dur:.1f}s)")

    os.makedirs("models", exist_ok=True)
torch.save(model.state_dict(), f"models/maskrcnn_epoch_{epoch}.pth")


In [None]:
# Zelle 7: Trainingskurven

import matplotlib.pyplot as plt
import os

os.makedirs("metrics", exist_ok=True)

plt.figure()
plt.plot(loss_per_step)
plt.xlabel("Training step")
plt.ylabel("Loss")
plt.title("Training loss per step")
plt.savefig("metrics/loss_per_step.png")
plt.show()

plt.figure()
plt.plot(loss_per_epoch, marker='o')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training loss per epoch")
plt.savefig("metrics/loss_per_epoch.png")
plt.show()
