In [None]:
# Тут нужно скачать датасет с cityscapes

In [None]:
!wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=1

In [None]:
!wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3

In [None]:
!unzip leftImg8bit_trainvaltest.zip -d ./cityscapes

In [None]:
!unzip -o gtFine_trainvaltest.zip -d ./cityscapes

In [None]:
import os
import json
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import random
from tqdm import tqdm

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.v2 as T

from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights
from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights

from torchvision.tv_tensors import Image as TVImage, Mask, BoundingBoxes

CS_LABELS_INFO_FULL = [
    ('unlabeled',            0, 255, False),
    ('ego vehicle',          1, 255, False),
    ('rectification border', 2, 255, False),
    ('out of roi',           3, 255, False),
    ('static',               4, 255, False),
    ('dynamic',              5, 255, False),
    ('ground',               6, 255, False),
    ('road',                 7, 0,   False),
    ('sidewalk',             8, 1,   False),
    ('parking',              9, 255, False),
    ('rail track',          10, 255, False),
    ('building',            11, 2,   False),
    ('wall',                12, 3,   False),
    ('fence',               13, 4,   False),
    ('guard rail',          14, 255, False),
    ('bridge',              15, 255, False),
    ('tunnel',              16, 255, False),
    ('pole',                17, 5,   False),
    ('polegroup',           18, 5,   False),
    ('traffic light',       19, 6,   False),
    ('traffic sign',        20, 7,   False),
    ('vegetation',          21, 8,   False),
    ('terrain',             22, 9,   False),
    ('sky',                 23, 10,  False),
    ('person',              24, 11,  True),
    ('rider',               25, 12,  True),
    ('car',                 26, 13,  True),
    ('truck',               27, 14,  True),
    ('bus',                 28, 15,  True),
    ('caravan',             29, 255, True),
    ('trailer',             30, 255, True),
    ('train',               31, 16,  True),
    ('motorcycle',          32, 17,  True),
    ('bicycle',             33, 18,  True),
    ('license plate',       -1, 255, True)
]

LABELID_TO_TRAINID = {label[1]: label[2] for label in CS_LABELS_INFO_FULL}
CITYSCAPES_CLASSES = [
    (0, 'road', False), (1, 'sidewalk', False), (2, 'building', False),
    (3, 'wall', False), (4, 'fence', False), (5, 'pole', False),
    (6, 'traffic light', False), (7, 'traffic sign', False), (8, 'vegetation', False),
    (9, 'terrain', False), (10, 'sky', False), (11, 'person', True),
    (12, 'rider', True), (13, 'car', True), (14, 'truck', True),
    (15, 'bus', True), (16, 'train', True), (17, 'motorcycle', True),
    (18, 'bicycle', True), (255, 'void', False)
]
NUM_CITYSCAPES_CLASSES = 19 
CITYSCAPES_THING_CLASSES_TRAIN_IDS = [info[0] for info in CITYSCAPES_CLASSES if info[2] and info[0] != 255]
CITYSCAPES_THING_MAP = {train_id: i for i, train_id in enumerate(CITYSCAPES_THING_CLASSES_TRAIN_IDS)}
CITYSCAPES_THING_INV_MAP = {i: train_id for train_id, i in CITYSCAPES_THING_MAP.items()}
NUM_CITYSCAPES_THING_CLASSES = len(CITYSCAPES_THING_CLASSES_TRAIN_IDS)

CITYSCAPES_IGNORE_INDEX = 255
CITYSCAPES_PANOPTIC_OFFSET = 1000


## Функция для вычисления IoU по маске

In [None]:
def calculate_iou(mask1, mask2):
    mask1 = mask1.bool()
    mask2 = mask2.bool()
    intersection = (mask1 & mask2).sum().float()
    union = (mask1 | mask2).sum().float()
    if union == 0:
        return 0.0
    return (intersection / union).item()

In [None]:
def calculate_pq(gt_panoptic_map_np, pred_panoptic_map_tensor,
                                         iou_threshold=0.5):
    pred_panoptic_map_np = pred_panoptic_map_tensor.cpu().numpy()
    gt_seg_ids = np.unique(gt_panoptic_map_np)
    pred_seg_ids = np.unique(pred_panoptic_map_np)
    parsed_gt_segments = []
    for gt_comb_id in gt_seg_ids:
        if gt_comb_id == 0 and CITYSCAPES_PANOPTIC_OFFSET == 0:
            continue 
        gt_mask = (gt_panoptic_map_np == gt_comb_id)
        if gt_mask.sum() == 0: continue
        semantic_id = gt_comb_id // CITYSCAPES_PANOPTIC_OFFSET
        instance_id = gt_comb_id % CITYSCAPES_PANOPTIC_OFFSET
        is_thing = any(c[0] == semantic_id and c[2] for c in CITYSCAPES_CLASSES)
        if semantic_id == CITYSCAPES_IGNORE_INDEX:
            continue
        parsed_gt_segments.append({
            'combined_id': gt_comb_id,
            'mask': torch.from_numpy(gt_mask),
            'category_id': semantic_id,
            'instance_id': instance_id,
            'isthing': is_thing
        })
    parsed_pred_segments = []
    for pred_comb_id in pred_seg_ids:
        if pred_comb_id == 0 and CITYSCAPES_PANOPTIC_OFFSET == 0:
            continue
        pred_mask = (pred_panoptic_map_np == pred_comb_id)
        if pred_mask.sum() == 0: continue
        semantic_id = pred_comb_id // CITYSCAPES_PANOPTIC_OFFSET
        instance_id = pred_comb_id % CITYSCAPES_PANOPTIC_OFFSET
        is_thing = any(c[0] == semantic_id and c[2] for c in CITYSCAPES_CLASSES)
        if semantic_id == CITYSCAPES_IGNORE_INDEX:
            continue
        parsed_pred_segments.append({
            'combined_id': pred_comb_id,
            'mask': torch.from_numpy(pred_mask),
            'category_id': semantic_id,
            'instance_id': instance_id,
            'isthing': is_thing
        })
    pq_per_class = {}
    all_train_ids = sorted(list(set(s['category_id'] for s in parsed_gt_segments + parsed_pred_segments)))
    for category_id in all_train_ids:
        if category_id == CITYSCAPES_IGNORE_INDEX:
            continue
        gt_class_segments = [s for s in parsed_gt_segments if s['category_id'] == category_id]
        pred_class_segments = [s for s in parsed_pred_segments if s['category_id'] == category_id]
        is_current_class_thing = any(c[0] == category_id and c[2] for c in CITYSCAPES_CLASSES)
        if not gt_class_segments and not pred_class_segments:
            continue 
        tp, fp, fn = 0, 0, 0
        sum_iou_of_tp = 0.0
        if is_current_class_thing:
            matched_gt_indices = set()
            for pred_idx, pred_seg in enumerate(pred_class_segments):
                best_iou = 0.0
                best_gt_idx = -1
                for gt_idx, gt_seg in enumerate(gt_class_segments):
                    if gt_idx in matched_gt_indices:
                        continue
                    iou = calculate_iou(pred_seg['mask'], gt_seg['mask'])
                    if iou > best_iou:
                        best_iou = iou
                        best_gt_idx = gt_idx

                if best_iou > iou_threshold and best_gt_idx != -1:
                    tp += 1
                    sum_iou_of_tp += best_iou
                    matched_gt_indices.add(best_gt_idx)
                else:
                    fp += 1

            fn = len(gt_class_segments) - len(matched_gt_indices)
        else:
            gt_stuff_mask_class = torch.zeros_like(parsed_gt_segments[0]['mask'] if parsed_gt_segments else parsed_pred_segments[0]['mask'], dtype=torch.bool)
            for gt_seg in gt_class_segments: gt_stuff_mask_class |= gt_seg['mask']
            pred_stuff_mask_class = torch.zeros_like(gt_stuff_mask_class, dtype=torch.bool)
            for pred_seg in pred_class_segments: pred_stuff_mask_class |= pred_seg['mask']
            if gt_stuff_mask_class.sum() > 0 or pred_stuff_mask_class.sum() > 0:
                iou_stuff = calculate_iou(pred_stuff_mask_class, gt_stuff_mask_class)
                if iou_stuff > 0:
                    tp = 1
                    sum_iou_of_tp = iou_stuff
                    if gt_stuff_mask_class.sum() == 0 and pred_stuff_mask_class.sum() > 0: fp = 1; tp=0; sum_iou_of_tp=0
                    if pred_stuff_mask_class.sum() == 0 and gt_stuff_mask_class.sum() > 0: fn = 1; tp=0; sum_iou_of_tp=0

                elif pred_stuff_mask_class.sum() > 0:
                    fp = 1
                elif gt_stuff_mask_class.sum() > 0:
                    fn = 1

        sq = sum_iou_of_tp / tp if tp > 0 else 0.0
        rq_denominator = (2 * tp + fp + fn)
        rq = (2 * tp) / rq_denominator if rq_denominator > 0 else 0.0

        pq = sq * rq
        pq_per_class[category_id] = {'pq': pq, 'sq': sq, 'rq': rq, 'tp': tp, 'fp': fp, 'fn': fn}
    final_pq = 0.0
    final_sq = 0.0
    final_rq = 0.0
    num_classes_for_avg = 0

    for cat_id, metrics in pq_per_class.items():
        is_gt_present_for_class = any(s['category_id'] == cat_id for s in parsed_gt_segments)
        if is_gt_present_for_class:
            final_pq += metrics['pq']
            final_sq += metrics['sq']
            final_rq += metrics['rq']
            num_classes_for_avg += 1

    if num_classes_for_avg > 0:
        final_pq /= num_classes_for_avg
        final_sq /= num_classes_for_avg
        final_rq /= num_classes_for_avg

    return final_pq, final_sq, final_rq, pq_per_class

In [None]:
class CityscapesPanopticDataset(Dataset):
    def __init__(self, root_dir, split='train', transforms=None, max_samples=None):
        self.root_dir = root_dir
        self.split = split
        self.transforms = transforms
        self.img_dir = os.path.join(root_dir, 'leftImg8bit', split)
        self.ann_dir = os.path.join(root_dir, 'gtFine', split)
        self.images = []
        self.labelid_map_files = []
        self.instanceid_map_files = []

        for city in os.listdir(self.img_dir):
            city_img_dir = os.path.join(self.img_dir, city)
            city_ann_dir = os.path.join(self.ann_dir, city)
            for file_name in os.listdir(city_img_dir):
                if file_name.endswith('_leftImg8bit.png'):
                    img_path = os.path.join(city_img_dir, file_name)
                    base_ann_name = file_name.replace('_leftImg8bit.png', '_gtFine')
                    labelid_png_name = f"{base_ann_name}_labelIds.png"
                    instanceid_png_name = f"{base_ann_name}_instanceIds.png"
                    labelid_path = os.path.join(city_ann_dir, labelid_png_name)
                    instanceid_path = os.path.join(city_ann_dir, instanceid_png_name)
                    if os.path.exists(labelid_path) and os.path.exists(instanceid_path):
                        self.images.append(img_path)
                        self.labelid_map_files.append(labelid_path)
                        self.instanceid_map_files.append(instanceid_path)
        if max_samples:
            indices = list(range(len(self.images)))
            random.shuffle(indices)
            selected_indices = indices[:max_samples]
            self.images = [self.images[i] for i in selected_indices]
            self.labelid_map_files = [self.labelid_map_files[i] for i in selected_indices]
            self.instanceid_map_files = [self.instanceid_map_files[i] for i in selected_indices]
    def __len__(self):
        return len(self.images)
    def _generate_panoptic_targets(self, labelid_map_path, instanceid_map_path):
        labelid_img_pil = Image.open(labelid_map_path)
        instance_img_pil = Image.open(instanceid_map_path)
        labelid_map_np = np.array(labelid_img_pil)
        instanceid_map_np = np.array(instance_img_pil)

        semantic_map_trainid_np = np.full(labelid_map_np.shape, CITYSCAPES_IGNORE_INDEX, dtype=np.uint8)
        for cityscapes_id, train_id_val in LABELID_TO_TRAINID.items():
            semantic_map_trainid_np[labelid_map_np == cityscapes_id] = train_id_val
        boxes_np, labels_np, masks_np = [], [], []
        gt_panoptic_eval_map = np.zeros_like(semantic_map_trainid_np, dtype=np.int32)
        gt_segments_info_for_pq = []
        current_gt_instance_id_counter = 1 
        unique_instance_values = np.unique(instanceid_map_np)
        for inst_val in unique_instance_values:
            original_label_id_gt = inst_val // 1000 
            
            train_id_gt = LABELID_TO_TRAINID.get(original_label_id_gt, CITYSCAPES_IGNORE_INDEX)
            if train_id_gt == CITYSCAPES_IGNORE_INDEX:
                continue
            is_thing_gt = any(c[0] == train_id_gt and c[2] for c in CITYSCAPES_CLASSES)
            mask_region_gt = (instanceid_map_np == inst_val)
            combined_id_gt = 0
            if is_thing_gt:
                combined_id_gt = train_id_gt * CITYSCAPES_PANOPTIC_OFFSET + current_gt_instance_id_counter
                gt_panoptic_eval_map[mask_region_gt] = combined_id_gt
                current_gt_instance_id_counter +=1
                if train_id_gt in CITYSCAPES_THING_CLASSES_TRAIN_IDS:
                    binary_mask_for_instance = mask_region_gt.astype(np.uint8)
                    pos = np.where(binary_mask_for_instance)
                    if len(pos[0]) > 0 and len(pos[1]) > 0:
                        xmin, xmax = np.min(pos[1]), np.max(pos[1])
                        ymin, ymax = np.min(pos[0]), np.max(pos[0])
                        if xmax > xmin and ymax > ymin:
                            boxes_np.append([xmin, ymin, xmax, ymax])
                            labels_np.append(CITYSCAPES_THING_MAP[train_id_gt])
                            masks_np.append(binary_mask_for_instance)
            else:
                combined_id_gt = train_id_gt * CITYSCAPES_PANOPTIC_OFFSET
                gt_panoptic_eval_map[mask_region_gt] = combined_id_gt
            if combined_id_gt != 0:
                 gt_segments_info_for_pq.append({
                     'id': combined_id_gt,
                     'category_id': train_id_gt,
                     'isthing': is_thing_gt
                 })

        return semantic_map_trainid_np, boxes_np, labels_np, masks_np, gt_panoptic_eval_map, gt_segments_info_for_pq
    def __getitem__(self, idx):
        img_path = self.images[idx]
        labelid_map_path = self.labelid_map_files[idx]
        instanceid_map_path = self.instanceid_map_files[idx]
        image_pil = Image.open(img_path).convert("RGB")

        semantic_map_np, boxes_np, labels_np, masks_np, gt_panoptic_eval_map, gt_segments_info_for_pq = \
            self._generate_panoptic_targets(labelid_map_path, instanceid_map_path)

        image_tv = TVImage(image_pil)
        semantic_target_tv = Mask(semantic_map_np.astype(np.int64))

        h, w = image_pil.size[1], image_pil.size[0]
        if len(boxes_np) > 0:
            boxes_tv = BoundingBoxes(np.array(boxes_np, dtype=np.float32), format="XYXY", canvas_size=(h,w))
            labels_tv = torch.tensor(labels_np, dtype=torch.int64)
            masks_tv = Mask(np.array(masks_np, dtype=np.uint8))
        else:
            boxes_tv = BoundingBoxes(torch.empty((0, 4), dtype=torch.float32), format="XYXY", canvas_size=(h,w))
            labels_tv = torch.empty((0,), dtype=torch.int64)
            masks_tv = Mask(torch.empty((0, h, w), dtype=np.uint8))

        instance_target = {"boxes": boxes_tv, "labels": labels_tv, "masks": masks_tv}
        if self.transforms:
            image_tv_transformed, semantic_target_tv_transformed, instance_target_transformed = \
                self.transforms(image_tv, semantic_target_tv, instance_target)
            return image_tv_transformed, semantic_target_tv_transformed, instance_target_transformed, \
                   image_tv, gt_panoptic_eval_map, gt_segments_info_for_pq
        else:
            return image_tv, semantic_target_tv, instance_target, \
                   image_tv, gt_panoptic_eval_map, gt_segments_info_for_pq

In [None]:
def get_cityscapes_color_map():
    colors = np.array([
        [128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153],
        [153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35], [152, 251, 152],
        [70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
        [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32], [0,0,0]
    ], dtype=np.uint8)
    full_colors = np.zeros((256, 3), dtype=np.uint8)
    full_colors[:len(colors)] = colors
    return full_colors

In [None]:
CITYSCAPES_ROOT = "./cityscapes"
MAX_SAMPLES_FOR_DATASET_DEMO = 5 
IMG_SIZE_FOR_DATASET_DEMO = (256, 512)
demo_transforms = T.Compose([
    T.Resize(IMG_SIZE_FOR_DATASET_DEMO, antialias=True),
    T.ToDtype(torch.float, scale=True),
    # T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset_for_viz = CityscapesPanopticDataset(
    CITYSCAPES_ROOT,
    'train',
    transforms=demo_transforms,
    max_samples=MAX_SAMPLES_FOR_DATASET_DEMO
)

num_samples_to_viz = min(3, len(dataset_for_viz))
random_indices = random.sample(range(len(dataset_for_viz)), num_samples_to_viz)


In [None]:
def get_transform(train, new_size=(256, 512)):
    transforms = []
    transforms.append(T.Resize(new_size, antialias=True))
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    transforms.append(T.ToDtype(torch.float, scale=True))
    transforms.append(T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
    return T.Compose(transforms)

def collate_fn(batch):
    images_transformed = torch.stack([item[0] for item in batch], 0)
    semantic_targets_transformed = torch.stack([item[1] for item in batch], 0)
    instance_targets_transformed = [item[2] for item in batch]
    original_images_for_eval = [item[3] for item in batch]
    gt_panoptic_maps_for_eval = [item[4] for item in batch]
    gt_segments_info_list_for_eval = [item[5] for item in batch]
    return images_transformed, semantic_targets_transformed, instance_targets_transformed, \
           original_images_for_eval, gt_panoptic_maps_for_eval, gt_segments_info_list_for_eval

def get_semantic_model(num_classes=NUM_CITYSCAPES_CLASSES):
    model = deeplabv3_resnet50(weights=DeepLabV3_ResNet50_Weights.DEFAULT)
    model.classifier[-1] = nn.Conv2d(256, num_classes, kernel_size=(1, 1), stride=(1, 1))
    if hasattr(model, 'aux_classifier') and model.aux_classifier is not None:
         model.aux_classifier[-1] = nn.Conv2d(256, num_classes, kernel_size=(1,1), stride=(1,1))
    return model

def get_instance_model(num_thing_classes=NUM_CITYSCAPES_THING_CLASSES):
    model = maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT)
    in_features_box = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features_box, num_thing_classes + 1)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(in_features_mask, hidden_layer, num_thing_classes + 1)
    return model


def train_semantic_one_epoch(model, optimizer, data_loader, device, epoch):
    model.train()
    criterion = nn.CrossEntropyLoss(ignore_index=CITYSCAPES_IGNORE_INDEX)
    total_loss = 0
    progress_bar = tqdm(data_loader, desc=f"Epoch {epoch} [Semantic Training]", unit="batch")

    for i, (images, semantic_targets, _) in enumerate(progress_bar):
        images = images.to(device)
        semantic_targets = semantic_targets.squeeze(1).to(device)
        optimizer.zero_grad()
        outputs = model(images)['out']
        loss = criterion(outputs, semantic_targets)
        loss.backward()
        optimizer.step()
        current_loss = loss.item()
        total_loss += current_loss
        progress_bar.set_postfix(loss=f"{current_loss:.4f}", avg_loss=f"{total_loss/(i+1):.4f}")

    avg_epoch_loss = total_loss / len(data_loader)
    print(f"Epoch {epoch} Semantic Avg Loss: {avg_epoch_loss:.4f}")
    return avg_epoch_loss

def train_instance_one_epoch(model, optimizer, data_loader, device, epoch):
    model.train()
    total_loss = 0
    iter_count = 0
    progress_bar = tqdm(data_loader, desc=f"Epoch {epoch} [Instance Training]", unit="batch")
    for i, (images, _, instance_targets) in enumerate(progress_bar):
        images_on_device = [img.to(device) for img in images]
        targets_on_device = []
        for t_dict in instance_targets:
            targets_on_device.append({k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t_dict.items()})

        optimizer.zero_grad()
        loss_dict = model(images_on_device, targets_on_device)
        losses = sum(loss for loss in loss_dict.values())
        losses.backward()
        optimizer.step()
        current_total_loss = losses.item()
        total_loss += current_total_loss
        iter_count += 1
        loss_details_str = ", ".join([f"{k}: {v.item():.3f}" for k, v in loss_dict.items()])
        progress_bar.set_postfix(total_loss=f"{current_total_loss:.4f}", avg_loss=f"{total_loss/iter_count:.4f}", details=loss_details_str)

    if iter_count > 0:
        avg_epoch_loss = total_loss / iter_count
        print(f"Epoch {epoch} Instance Avg Loss: {avg_epoch_loss:.4f}")
        return avg_epoch_loss
    else:
        print(f"Epoch {epoch} Instance: No valid batches processed.")
        return float('nan')


def panoptic_fusion(semantic_pred, instance_preds):
    panoptic_seg = torch.zeros_like(semantic_pred, dtype=torch.int32)
    H, W = semantic_pred.shape
    for train_id in range(NUM_CITYSCAPES_CLASSES):
        is_thing = any(c[0] == train_id and c[2] for c in CITYSCAPES_CLASSES)
        if not is_thing:
            panoptic_seg[semantic_pred == train_id] = train_id * CITYSCAPES_PANOPTIC_OFFSET
    sorted_instances = sorted(
        [
            {
                'mask': inst['masks'].squeeze(0),
                'label': CITYSCAPES_THING_INV_MAP[inst['labels'].item()],
                'score': inst['scores'].item()
            } for inst in instance_preds
        ],
        key=lambda x: x['score'],
        reverse=True
    )

    used_instance_pixels = torch.zeros_like(semantic_pred, dtype=torch.bool)
    current_instance_id = 1

    for inst in sorted_instances:
        mask = inst['mask'] > 0.5
        label_train_id = inst['label']
        current_mask_region = mask & ~used_instance_pixels
        if current_mask_region.sum() == 0:
            continue
        panoptic_seg[current_mask_region] = label_train_id * CITYSCAPES_PANOPTIC_OFFSET + current_instance_id
        used_instance_pixels |= current_mask_region
        current_instance_id += 1
        if current_instance_id >= CITYSCAPES_PANOPTIC_OFFSET:
            break
    remaining_pixels = ~used_instance_pixels
    for train_id in CITYSCAPES_THING_CLASSES_TRAIN_IDS:
        semantic_thing_pixels = (semantic_pred == train_id) & remaining_pixels
        panoptic_seg[semantic_thing_pixels] = train_id * CITYSCAPES_PANOPTIC_OFFSET
    return panoptic_seg



def predict_panoptic(semantic_model, instance_model, image_tensor, device, instance_score_thresh=0.5):
    semantic_model.eval()
    instance_model.eval()
    with torch.no_grad():
        sem_output = semantic_model(image_tensor.unsqueeze(0).to(device))['out']
        semantic_pred = torch.argmax(sem_output.squeeze(), dim=0).cpu()
        inst_outputs = instance_model(image_tensor.unsqueeze(0).to(device))
        pred_dict = inst_outputs[0]
        masks = pred_dict['masks'].cpu()
        labels = pred_dict['labels'].cpu()
        scores = pred_dict['scores'].cpu()
        keep_indices = scores > instance_score_thresh
        filtered_instances = []
        if keep_indices.any():
            masks = masks[keep_indices]
            labels = labels[keep_indices]
            scores = scores[keep_indices]

            for i in range(len(scores)):
                filtered_instances.append({
                    'masks': masks[i],
                    'labels': labels[i],
                    'scores': scores[i]
                })
        panoptic_result = panoptic_fusion(semantic_pred, filtered_instances)

    return semantic_pred, filtered_instances, panoptic_result

def get_cityscapes_color_map():
    colors = np.array([
        [128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153],
        [153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35], [152, 251, 152],
        [70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
        [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32], [0,0,0]
    ], dtype=np.uint8)
    full_colors = np.zeros((256, 3), dtype=np.uint8)
    full_colors[:len(colors)] = colors
    return full_colors

In [None]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Device: {DEVICE}")

CITYSCAPES_ROOT = "./cityscapes"
NUM_EPOCHS_SEMANTIC = 15
NUM_EPOCHS_INSTANCE = 15
BATCH_SIZE = 4
LEARNING_RATE_SEM = 1e-4
LEARNING_RATE_INST = 1e-4
MAX_SAMPLES_TRAIN = 300 
MAX_SAMPLES_VAL = 100
IMG_SIZE = (256, 512)
SAVE_EPOCH = 5
CHECKPOINT_DIR = "./checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
semantic_losses = []
instance_losses = []






dataset_train = CityscapesPanopticDataset(CITYSCAPES_ROOT, 'train', get_transform(train=False, new_size=IMG_SIZE), max_samples=MAX_SAMPLES_TRAIN)
dataset_val = CityscapesPanopticDataset(CITYSCAPES_ROOT, 'val', get_transform(train=False, new_size=IMG_SIZE), max_samples=MAX_SAMPLES_VAL)
def train_collate_fn(batch):
    images_transformed = torch.stack([item[0] for item in batch], 0)
    semantic_targets_transformed = torch.stack([item[1] for item in batch], 0)
    instance_targets_transformed = [item[2] for item in batch]
    return images_transformed, semantic_targets_transformed, instance_targets_transformed

dataloader_train_sem = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, collate_fn=train_collate_fn, drop_last=(BATCH_SIZE > 1 and len(dataset_train) % BATCH_SIZE !=0) )
dataloader_train_inst = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, collate_fn=train_collate_fn, drop_last=(BATCH_SIZE > 1 and len(dataset_train) % BATCH_SIZE !=0) )
dataloader_val_pq = DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=2, collate_fn=collate_fn)



semantic_model = get_semantic_model(num_classes=NUM_CITYSCAPES_CLASSES).to(DEVICE)
instance_model = get_instance_model(num_thing_classes=NUM_CITYSCAPES_THING_CLASSES).to(DEVICE)
optimizer_sem = optim.AdamW(semantic_model.parameters(), lr=LEARNING_RATE_SEM)
optimizer_inst = optim.AdamW(instance_model.parameters(), lr=LEARNING_RATE_INST)




for epoch in range(NUM_EPOCHS_SEMANTIC):
    avg_loss = train_semantic_one_epoch(semantic_model, optimizer_sem, dataloader_train_sem, DEVICE, epoch)
    semantic_losses.append(avg_loss)
    if (epoch + 1) % SAVE_EPOCH == 0 or (epoch + 1) == NUM_EPOCHS_SEMANTIC:
        sem_checkpoint_path = os.path.join(CHECKPOINT_DIR, f"semantic_model_epoch_{epoch+1}.pth")
        torch.save({'epoch': epoch + 1, 'model_state_dict': semantic_model.state_dict(),
                    'optimizer_state_dict': optimizer_sem.state_dict()}, sem_checkpoint_path)
        print(f"Saved semantic model {sem_checkpoint_path}")




for epoch in range(NUM_EPOCHS_INSTANCE):
    avg_loss = train_instance_one_epoch(instance_model, optimizer_inst, dataloader_train_inst, DEVICE, epoch)
    instance_losses.append(avg_loss)
    if (epoch + 1) % SAVE_EPOCH == 0 or (epoch + 1) == NUM_EPOCHS_INSTANCE:
        inst_checkpoint_path = os.path.join(CHECKPOINT_DIR, f"instance_model_epoch_{epoch+1}.pth")
        torch.save({'epoch': epoch + 1, 'model_state_dict': instance_model.state_dict(),
                    'optimizer_state_dict': optimizer_inst.state_dict()}, inst_checkpoint_path)
        print(f"Saved instance model {inst_checkpoint_path}")



In [None]:

if semantic_losses or instance_losses:
    plt.figure(figsize=(12, 5))
    if semantic_losses:
        plt.subplot(1, 2, 1 if instance_losses else 1)
        plt.plot(range(1, len(semantic_losses) + 1), semantic_losses, marker='o', label='Semantic Training Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Average Loss')
        plt.title('Semantic Model Training Loss')
        plt.legend()
        plt.grid(True)

    if instance_losses:
        plt.subplot(1, 2, 2 if semantic_losses else 1)
        plt.plot(range(1, len(instance_losses) + 1), instance_losses, marker='x', color='r', label='Instance Training Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Average Loss')
        plt.title('Instance Model Training Loss')
        plt.legend()
        plt.grid(True)

    plt.tight_layout()
    plt.savefig(os.path.join(CHECKPOINT_DIR, "training_losses.png"))
    print(f"Saved training loss plot to {os.path.join(CHECKPOINT_DIR, 'training_losses.png')}")
