In [8]:
import tqdm as notebook_tqdm
from datasets import load_dataset

hf_dataset_identifier = "segments/sidewalk-semantic"

ds = load_dataset(hf_dataset_identifier)

ds = ds.shuffle(seed=42)
ds = ds["train"].train_test_split(test_size=0.2)
train_ds = ds["train"]
test_ds = ds["test"]


In [9]:
import json
from huggingface_hub import hf_hub_download

repo_id = f"datasets/{hf_dataset_identifier}"
filename = "id2label.json"
id2label = json.load(open(hf_hub_download(repo_id=hf_dataset_identifier, filename=filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}

num_labels = len(id2label)

In [10]:
from torchvision.transforms import ColorJitter
from transformers import SegformerFeatureExtractor

feature_extractor = SegformerFeatureExtractor()
jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1) 

def train_transforms(example_batch):
    images = [jitter(x) for x in example_batch['pixel_values']]
    labels = [x for x in example_batch['label']]
    inputs = feature_extractor(images, labels)
    return inputs

def val_transforms(example_batch):
    images = [x for x in example_batch['pixel_values']]
    labels = [x for x in example_batch['label']]
    inputs = feature_extractor(images, labels)
    return inputs

# Set transforms
train_ds.set_transform(train_transforms)
test_ds.set_transform(val_transforms)



In [11]:
from transformers import SegformerForSemanticSegmentation

pretrained_model_name = "nvidia/mit-b0" 
model = SegformerForSemanticSegmentation.from_pretrained(
    pretrained_model_name,
    id2label=id2label,
    label2id=label2id
)


Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.running_var', 'decode_head.linear_c.0.proj.bias', 'decode_head.batch_norm.running_mean', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.batch_norm.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.batch_norm.weight', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.classifier.bias', 'decode_head.linear_fuse.weight', 'decode_head.classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
from transformers import TrainingArguments

epochs = 2 # 50 training for 2 epochs for the benchmarking
lr = 0.00006
batch_size = 2

hub_model_id = "hufanyoung/segformer-b0-finetuned-segments-sidewalk-2"

training_args = TrainingArguments(
    "segformer-b0-finetuned-segments-sidewalk-outputs",
    learning_rate=lr,
    num_train_epochs=epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    save_total_limit=3,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    load_best_model_at_end=True,
    push_to_hub=True,
    hub_model_id=hub_model_id,
    hub_strategy="end",
)


In [16]:
import torch
from torch import nn
import evaluate

metric = evaluate.load("mean_iou")

def compute_metrics(eval_pred):
  with torch.no_grad():
    logits, labels = eval_pred
    logits_tensor = torch.from_numpy(logits)
    # scale the logits to the size of the label
    logits_tensor = nn.functional.interpolate(
        logits_tensor,
        size=labels.shape[-2:],
        mode="bilinear",
        align_corners=False,
    ).argmax(dim=1)

    pred_labels = logits_tensor.detach().cpu().numpy()
    # currently using _compute instead of compute
    # see this issue for more info: https://github.com/huggingface/evaluate/pull/328#issuecomment-1286866576
    metrics = metric._compute(
            predictions=pred_labels,
            references=labels,
            num_labels=len(id2label),
            ignore_index=0,
            reduce_labels=feature_extractor.do_reduce_labels,
        )
    
    # add per category metrics as individual key-value pairs
    per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
    per_category_iou = metrics.pop("per_category_iou").tolist()

    metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
    metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})
    
    return metrics


In [17]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
)


Cloning https://huggingface.co/hufanyoung/segformer-b0-finetuned-segments-sidewalk-2 into local empty directory.
Download file pytorch_model.bin:   0%|          | 6.25k/14.3M [00:00<?, ?B/s]
[A

[A[A


[A[A[A



[A[A[A[A




[A[A[A[A[A





[A[A[A[A[A[A






[A[A[A[A[A[A[A







[A[A[A[A[A[A[A[A








Download file pytorch_model.bin: 100%|██████████| 14.3M/14.3M [00:01<00:00, 14.7MB/s]
Download file runs/Apr15_21-49-20_7ea360766af0/events.out.tfevents.1650059381.7ea360766af0.71.2: 100%|██████████| 9.82k/9.82k [00:01<?, ?B/s]

[A
Download file runs/Apr15_21-49-20_7ea360766af0/1650059381.5713542/events.out.tfevents.1650059381.7ea360766af0.71.3: 100%|██████████| 4.86k/4.86k [00:01<?, ?B/s]


[A[A

Clean file runs/Apr15_21-49-20_7ea360766af0/events.out.tfevents.1650059381.7ea360766af0.71.2: 100%|██████████| 9.82k/9.82k [00:01<00:00, 8.87kB/s]



[A[A[A


Download file training_args.bin: 100%|██████████| 3.11k/3.11k [00:01<?, ?B/s]




[A[A[A

In [18]:
trainer.train()




Step,Training Loss,Validation Loss,Mean Iou,Mean Accuracy,Overall Accuracy,Accuracy Unlabeled,Accuracy Flat-road,Accuracy Flat-sidewalk,Accuracy Flat-crosswalk,Accuracy Flat-cyclinglane,Accuracy Flat-parkingdriveway,Accuracy Flat-railtrack,Accuracy Flat-curb,Accuracy Human-person,Accuracy Human-rider,Accuracy Vehicle-car,Accuracy Vehicle-truck,Accuracy Vehicle-bus,Accuracy Vehicle-tramtrain,Accuracy Vehicle-motorcycle,Accuracy Vehicle-bicycle,Accuracy Vehicle-caravan,Accuracy Vehicle-cartrailer,Accuracy Construction-building,Accuracy Construction-door,Accuracy Construction-wall,Accuracy Construction-fenceguardrail,Accuracy Construction-bridge,Accuracy Construction-tunnel,Accuracy Construction-stairs,Accuracy Object-pole,Accuracy Object-trafficsign,Accuracy Object-trafficlight,Accuracy Nature-vegetation,Accuracy Nature-terrain,Accuracy Sky,Accuracy Void-ground,Accuracy Void-dynamic,Accuracy Void-static,Accuracy Void-unclear,Iou Unlabeled,Iou Flat-road,Iou Flat-sidewalk,Iou Flat-crosswalk,Iou Flat-cyclinglane,Iou Flat-parkingdriveway,Iou Flat-railtrack,Iou Flat-curb,Iou Human-person,Iou Human-rider,Iou Vehicle-car,Iou Vehicle-truck,Iou Vehicle-bus,Iou Vehicle-tramtrain,Iou Vehicle-motorcycle,Iou Vehicle-bicycle,Iou Vehicle-caravan,Iou Vehicle-cartrailer,Iou Construction-building,Iou Construction-door,Iou Construction-wall,Iou Construction-fenceguardrail,Iou Construction-bridge,Iou Construction-tunnel,Iou Construction-stairs,Iou Object-pole,Iou Object-trafficsign,Iou Object-trafficlight,Iou Nature-vegetation,Iou Nature-terrain,Iou Sky,Iou Void-ground,Iou Void-dynamic,Iou Void-static,Iou Void-unclear
20,2.7683,3.171063,0.085942,0.136347,0.594143,,0.425367,0.918091,0.001617,0.001761,0.003946,,0.00047,0.025641,0.0,0.880821,0.0,0.0,0.0,0.0,0.0,0.0,7.8e-05,0.234857,0.0,0.04031,0.002604,0.0,,0.0,0.083042,0.0,0.0,0.960732,0.084151,0.67911,0.0,0.0,0.020493,0.0,0.0,0.349773,0.637718,0.001589,0.001743,0.00387,,0.000464,0.01771,0.0,0.39521,0.0,0.0,0.0,0.0,0.0,0.0,6.1e-05,0.214766,0.0,0.023958,0.002158,0.0,0.0,0.0,0.029465,0.0,0.0,0.578248,0.072505,0.579119,0.0,0.0,0.013666,0.0
40,2.5592,2.401905,0.091625,0.140696,0.629858,,0.569521,0.937952,0.00059,0.00126,0.008058,,1e-05,0.0,0.0,0.844238,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.349027,0.0,0.019593,0.000114,0.0,,0.0,0.016707,0.0,0.0,0.985233,0.042347,0.727295,0.0,0.0,0.000343,0.0,0.0,0.416701,0.666785,0.000586,0.001258,0.007925,,9e-06,0.0,0.0,0.493006,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.305228,0.0,0.018039,0.000102,0.0,0.0,0.0,0.015439,0.0,0.0,0.522057,0.037089,0.630707,0.0,0.0,0.000333,0.0
60,2.0324,1.987704,0.105863,0.158909,0.672657,,0.621367,0.942299,0.000445,0.014513,0.006183,,0.0,0.0,0.0,0.886214,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.66121,0.0,0.00328,2e-05,0.0,,0.0,0.001091,0.0,0.0,0.960764,0.13295,0.854655,0.0,0.0,9.1e-05,0.0,0.0,0.442175,0.687083,0.000445,0.014491,0.006105,,0.0,0.0,0.0,0.500478,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.494995,0.0,0.003209,2e-05,0.0,0.0,0.0,0.001083,0.0,0.0,0.642263,0.11255,0.694363,0.0,0.0,9.1e-05,0.0
80,1.8813,1.798102,0.121482,0.168806,0.689422,,0.704669,0.924858,0.0,0.001656,0.004294,,0.0,0.0,0.0,0.879187,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.772755,0.0,0.002765,0.0,0.0,,0.0,0.0,0.0,0.0,0.939929,0.296693,0.874989,0.0,0.0,0.0,0.0,,0.451567,0.705295,0.0,0.001655,0.004259,,0.0,0.0,0.0,0.529135,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.533065,0.0,0.002746,0.0,0.0,,0.0,0.0,0.0,0.0,0.692694,0.245449,0.721549,0.0,0.0,0.0,0.0
100,1.9248,1.705227,0.123429,0.173524,0.693114,,0.803303,0.888598,0.000258,0.007243,0.004888,,0.0,0.0,0.0,0.878016,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.774209,0.0,0.000171,0.0,0.0,,0.0,0.0,0.0,0.0,0.951154,0.375808,0.869109,0.0,0.0,0.0,0.0,,0.444882,0.735195,0.000258,0.007237,0.004845,,0.0,0.0,0.0,0.528566,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.542223,0.0,0.00017,0.0,0.0,,0.0,0.0,0.0,0.0,0.693268,0.304217,0.688867,0.0,0.0,0.0,0.0
120,1.8037,1.616712,0.130274,0.177558,0.705256,,0.773595,0.920598,0.0,0.006601,0.00433,,0.0,0.0,0.0,0.881188,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.764255,0.0,0.000304,0.0,0.0,,0.0,0.0,0.0,0.0,0.956899,0.515889,0.858188,0.0,0.0,0.0,0.0,,0.469071,0.733433,0.0,0.006597,0.004305,,0.0,0.0,0.0,0.552361,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.560057,0.0,0.000302,0.0,0.0,,0.0,0.0,0.0,0.0,0.695135,0.393204,0.754296,0.0,0.0,0.0,0.0
140,2.1306,1.569754,0.13562,0.184112,0.713675,,0.693539,0.939396,7.2e-05,0.07097,0.012644,,0.0,0.0,0.0,0.915231,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.842609,0.0,0.000144,0.0,0.0,,0.0,0.0,0.0,0.0,0.923252,0.587083,0.906643,0.0,0.0,0.0,0.0,,0.487303,0.715979,7.2e-05,0.070696,0.012479,,0.0,0.0,0.0,0.522704,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.559094,0.0,0.000144,0.0,0.0,,0.0,0.0,0.0,0.0,0.750228,0.489031,0.732111,0.0,0.0,0.0,0.0
160,1.4232,1.464143,0.141076,0.186501,0.723536,,0.790765,0.935636,0.0,0.08125,0.004991,,0.0,0.0,0.0,0.896222,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.824478,0.0,0.000781,0.0,0.0,,0.0,0.0,0.0,0.0,0.934803,0.643812,0.855292,0.0,0.0,0.0,0.0,,0.49777,0.735782,0.0,0.080878,0.00497,,0.0,0.0,0.0,0.562398,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.576034,0.0,0.000779,0.0,0.0,,0.0,0.0,0.0,0.0,0.747289,0.531261,0.777261,0.0,0.0,0.0,0.0
180,1.6526,1.464034,0.146169,0.195837,0.724751,,0.864107,0.883223,0.0,0.293588,0.005661,,0.0,0.0,0.0,0.928574,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.811248,0.0,0.000212,0.0,0.0,,0.0,0.0,0.0,0.0,0.925141,0.636817,0.91822,0.0,0.0,0.0,0.0,,0.47229,0.756175,0.0,0.285185,0.005631,,0.0,0.0,0.0,0.527402,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.568967,0.0,0.000212,0.0,0.0,,0.0,0.0,0.0,0.0,0.755407,0.544859,0.761276,0.0,0.0,0.0,0.0
200,2.0379,1.35479,0.151914,0.193911,0.735647,,0.74421,0.950613,0.0,0.290935,0.007878,,0.0,0.0,0.0,0.812206,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.931743,0.0,0.0,0.0,0.0,,0.0,0.0,0.0,0.0,0.887568,0.701789,0.878212,0.0,0.0,0.0,0.0,,0.546592,0.733262,0.0,0.280545,0.007849,,0.0,0.0,0.0,0.605184,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.533929,0.0,0.0,0.0,0.0,,0.0,0.0,0.0,0.0,0.767236,0.576146,0.810507,0.0,0.0,0.0,0.0


  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_are

TrainOutput(global_step=400, training_loss=1.783346957564354, metrics={'train_runtime': 843.0153, 'train_samples_per_second': 1.898, 'train_steps_per_second': 0.474, 'total_flos': 2.81087582404608e+16, 'train_loss': 1.783346957564354, 'epoch': 2.0})