In [2]:
from datasets import load_dataset

seed = 42
train_sample_size = 5
validation_sample_size = 2
test_sample_size = 2

ds = load_dataset("Chris1/cityscapes", streaming=True)
train_ds = ds["train"].shuffle(seed=seed).take(train_sample_size)
validation_ds = ds["validation"].shuffle(seed=seed).take(validation_sample_size)
test_ds = ds["test"].shuffle(seed=seed).take(test_sample_size)

print("Train dataset:")
print(list(train_ds))
print("Validation dataset:")
print(list(validation_ds))
print("Test dataset:")
print(list(test_ds))

Train dataset:
[{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=2048x1024 at 0x7F25477E6A50>, 'semantic_segmentation': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=2048x1024 at 0x7F2547756590>}, {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=2048x1024 at 0x7F254750F710>, 'semantic_segmentation': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=2048x1024 at 0x7F254750F8D0>}, {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=2048x1024 at 0x7F254750FB10>, 'semantic_segmentation': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=2048x1024 at 0x7F254750FC50>}, {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=2048x1024 at 0x7F254750FE10>, 'semantic_segmentation': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=2048x1024 at 0x7F2547528050>}, {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=2048x1024 at 0x7F2547528290>, 'semantic_segmentation': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=2048x

In [3]:
from labels import labels, Label

# create lebel2id and id2label dictionaries
label2id = {label.name: label.id for label in labels}
id2label = {label.id: label.name for label in labels}
print(label2id)
print(id2label)

{'unlabeled': 0, 'ego vehicle': 1, 'rectification border': 2, 'out of roi': 3, 'static': 4, 'dynamic': 5, 'ground': 6, 'road': 7, 'sidewalk': 8, 'parking': 9, 'rail track': 10, 'building': 11, 'wall': 12, 'fence': 13, 'guard rail': 14, 'bridge': 15, 'tunnel': 16, 'pole': 17, 'polegroup': 18, 'traffic light': 19, 'traffic sign': 20, 'vegetation': 21, 'terrain': 22, 'sky': 23, 'person': 24, 'rider': 25, 'car': 26, 'truck': 27, 'bus': 28, 'caravan': 29, 'trailer': 30, 'train': 31, 'motorcycle': 32, 'bicycle': 33, 'license plate': -1}
{0: 'unlabeled', 1: 'ego vehicle', 2: 'rectification border', 3: 'out of roi', 4: 'static', 5: 'dynamic', 6: 'ground', 7: 'road', 8: 'sidewalk', 9: 'parking', 10: 'rail track', 11: 'building', 12: 'wall', 13: 'fence', 14: 'guard rail', 15: 'bridge', 16: 'tunnel', 17: 'pole', 18: 'polegroup', 19: 'traffic light', 20: 'traffic sign', 21: 'vegetation', 22: 'terrain', 23: 'sky', 24: 'person', 25: 'rider', 26: 'car', 27: 'truck', 28: 'bus', 29: 'caravan', 30: 'tra

In [4]:
from transformers import AutoImageProcessor

checkpoint = "nvidia/mit-b0"
image_processor = AutoImageProcessor.from_pretrained(checkpoint, reduce_labels=True)

Downloading (…)rocessor_config.json: 100%|██████████| 272/272 [00:00<00:00, 1.43MB/s]
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [6]:
from torchvision.transforms import ColorJitter

jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)

def train_transforms(example):
    image = jitter(example["image"])
    label = example["semantic_segmentation"]
    return image_processor(image, label)


def val_transforms(example):
    image = example["image"]
    label = example["semantic_segmentation"]
    return image_processor(image, label)

train_ds = train_ds.map(train_transforms)
validation_ds = validation_ds.map(val_transforms)


In [7]:
train_example = next(iter(train_ds))
image = train_example["image"]
target = train_example["semantic_segmentation"]

print(image)
print(target)

<PIL.PngImagePlugin.PngImageFile image mode=RGB size=2048x1024 at 0x7F251E8A4790>
<PIL.PngImagePlugin.PngImageFile image mode=RGB size=2048x1024 at 0x7F251E8A4950>


In [None]:
import matplotlib.pyplot as plt
import numpy as np

plt.imshow(np.swapaxes(np.swapaxes(pixels, 0, 2), 0, 1))

In [None]:
import evaluate
import numpy as np
import torch
from torch import nn

metric = evaluate.load("mean_iou")

def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        logits_tensor = torch.from_numpy(logits)
        logits_tensor = nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        pred_labels = logits_tensor.detach().cpu().numpy()
        metrics = metric.compute(
            predictions=pred_labels,
            references=labels,
            num_labels=num_labels,
            ignore_index=255,
            reduce_labels=False,
        )
        for key, value in metrics.items():
            if type(value) is np.ndarray:
                metrics[key] = value.tolist()
        return metrics

In [None]:
from transformers import AutoModelForSemanticSegmentation, TrainingArguments, Trainer

model = AutoModelForSemanticSegmentation.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)

In [None]:
training_args = TrainingArguments(
    output_dir="segformer-b0-scene-parse-150",
    learning_rate=6e-5,
    num_train_epochs=50,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    save_total_limit=3,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
)

trainer.train()