# Acknowledgment

This implementation is fully based on the following code:

- code: https://github.com/NielsRogge/Transformers-Tutorials/blob/master/DINOv2/Train_a_linear_classifier_on_top_of_DINOv2_for_semantic_segmentation.ipynb
- Author: NielsRogge

# Description
- Task: semantic segmentation
- Dataset: ADE20K
- Model: DINOv2-base
- Evaluation: ms

# Installations and Imports

In [1]:
Install = False
if Install:
    !pip3 install evaluate
    !pip3 install transformers
    !pip3 install datasets
    !pip3 install albumentations

In [2]:
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

In [3]:
import torch
from torch.optim import AdamW
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import evaluate
import numpy as np
from PIL import Image
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from transformers import Dinov2Model, Dinov2PreTrainedModel
from transformers.modeling_outputs import SemanticSegmenterOutput
from datasets import load_dataset
import albumentations as A

# Functions and Classes

In [4]:
def get_dataset(dataset_dir):
    dataset = load_dataset(dataset_dir)
    print(dataset)
    return dataset

def generate_id2label(num_labels):
    id2label = {0: "label0"}

    for i in range(1, num_labels + 1):
        id2label[i] = f"label{i}"
    
    return id2label

def get_transform_ms():
    
    ADE_MEAN = np.array([123.675, 116.280, 103.530]) / 255
    ADE_STD = np.array([58.395, 57.120, 57.375]) / 255

    train_transform = A.Compose([
        A.Resize(width=560, height=560),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=ADE_MEAN.tolist(), std=ADE_STD.tolist()),
    ])

    val_transform = A.Compose([
        A.Resize(width=560, height=560),
        A.Normalize(mean=ADE_MEAN.tolist(), std=ADE_STD.tolist()),

    ])
    
    return train_transform, val_transform

def collate_fn(inputs):
    batch = dict()
    batch["pixel_values"] = torch.stack([i[0] for i in inputs], dim=0)
    batch["labels"] = torch.stack([i[1] for i in inputs], dim=0)
    batch["original_images"] = [i[2] for i in inputs]
    batch["original_segmentation_maps"] = [i[3] for i in inputs]

    return batch

def train_model(model, train_dataloader, metric, id2label, learning_rate=1e-5, epochs=3, device=None):

    for name, param in model.named_parameters():
        if name.startswith("dinov2"):
            param.requires_grad = False
    
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    optimizer = AdamW(model.parameters(), lr=learning_rate)

    model.to(device)

    model.train()
    
    metrics_history = []

    for epoch in range(epochs):
        print("Epoch:", epoch)
        for idx, batch in enumerate(tqdm(train_dataloader)):
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(pixel_values, labels=labels)
            loss = outputs.loss

            loss.backward()
            optimizer.step()

            optimizer.zero_grad()

            with torch.no_grad():
                predicted = outputs.logits.argmax(dim=1)
                
                metric.add_batch(predictions=predicted.detach().cpu().numpy(), 
                                 references=labels.detach().cpu().numpy())

            if idx % 100 == 0:
                metrics = metric.compute(num_labels=len(id2label),
                                        ignore_index=0,
                                        reduce_labels=False)
                
                metrics_history.append({
                    'epoch': epoch,
                    'batch_idx': idx,
                    'mean_iou': metrics['mean_iou'],
                    'mean_accuracy': metrics['mean_accuracy']
                })                
                
                print("Loss:", loss.item())
                print("Mean IOU:", metrics["mean_iou"])
                print("Mean Accuracy:", metrics["mean_accuracy"])
                print("----------------------------------")
    
    return metrics_history

def val_eval_ms(model, val_dataloader, metric, id2label, device=None):

    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    model.to(device)

    model.eval()

    with torch.no_grad():
        for batch in tqdm(val_dataloader):
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(pixel_values, labels=labels)

            predicted = outputs.logits.argmax(dim=1)

            metric.add_batch(predictions=predicted.detach().cpu().numpy(), 
                             references=labels.detach().cpu().numpy())

    final_metrics = metric.compute(num_labels=len(id2label),
                                   ignore_index=0,
                                   reduce_labels=False)

    print(f"Final Mean IOU: {final_metrics['mean_iou']}")
    print(f"Final Mean Accuracy: {final_metrics['mean_accuracy']}")

In [5]:
class SegmentationDataset(Dataset):
  def __init__(self, dataset, transform, feature_img, feature_seg):
    self.dataset = dataset
    self.transform = transform
    self.feature_img = feature_img
    self.feature_seg = feature_seg

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    item = self.dataset[idx]
    original_image = np.array(item[self.feature_img])
    original_segmentation_map = np.array(item[self.feature_seg])

    transformed = self.transform(image=original_image, mask=original_segmentation_map)
    image, target = torch.tensor(transformed['image']), torch.LongTensor(transformed['mask'])
    
    if image.dim() == 2:
        image = image.unsqueeze(2)
        image = image.expand(-1, -1, 3)

    image = image.permute(2, 0, 1)

    return image, target, original_image, original_segmentation_map

class LinearClassifier(torch.nn.Module):
    def __init__(self, in_channels, tokenW=40, tokenH=40, num_labels=1):
        super(LinearClassifier, self).__init__()

        self.in_channels = in_channels
        self.width = tokenW
        self.height = tokenH
        self.classifier = torch.nn.Conv2d(in_channels, num_labels, (1,1))

    def forward(self, embeddings):
        embeddings = embeddings.reshape(-1, self.height, self.width, self.in_channels)
        embeddings = embeddings.permute(0,3,1,2)

        return self.classifier(embeddings)

class Dinov2ForSemanticSegmentationMS(Dinov2PreTrainedModel):
  def __init__(self, config):
    super().__init__(config)

    self.dinov2 = Dinov2Model(config)
    self.classifier = LinearClassifier(config.hidden_size * 4, 40, 40, config.num_labels)

  def forward(self, pixel_values, output_hidden_states=True, output_attentions=False, labels=None):

    outputs = self.dinov2(pixel_values,
                            output_hidden_states=output_hidden_states,
                            output_attentions=output_attentions)
    
    hidden_states = torch.cat(outputs.hidden_states[-4:], dim=-1)
    patch_embeddings = hidden_states[:,1:,:]

    logits = self.classifier(patch_embeddings)
    logits = torch.nn.functional.interpolate(logits, size=pixel_values.shape[2:], mode="bilinear", align_corners=False)

    loss = None
    if labels is not None:

      loss_fct = torch.nn.CrossEntropyLoss(ignore_index=0)
      loss = loss_fct(logits.squeeze(), labels.squeeze())

    return SemanticSegmenterOutput(
        loss=loss,
        logits=logits,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

# Hyperparameters

In [6]:
num_labels = 150
id2label = generate_id2label(num_labels)

batch_size = 16
learning_rate = 1e-5
epochs = 10

# Dataset Processing

In [7]:
dataset_dir = "sezer12138/ADE20k_Segementation"
dataset = get_dataset(dataset_dir)

DatasetDict({
    train: Dataset({
        features: ['image', 'annotated', 'Scene_category'],
        num_rows: 20210
    })
    val: Dataset({
        features: ['image', 'annotated', 'Scene_category'],
        num_rows: 2000
    })
})


In [8]:
train_set = "train"
val_set = "val"
feature_img = "image"
feature_seg = "annotated"

In [9]:
train_transform, val_transform = get_transform_ms()
train_dataset = SegmentationDataset(dataset[train_set], transform=train_transform, feature_img=feature_img, feature_seg=feature_seg)
val_dataset = SegmentationDataset(dataset[val_set], transform=val_transform, feature_img=feature_img, feature_seg=feature_seg)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

# Train and Evaluation

### Base Model

In [10]:
model = Dinov2ForSemanticSegmentationMS.from_pretrained("facebook/dinov2-base", id2label=id2label, num_labels=len(id2label))

Some weights of Dinov2ForSemanticSegmentationMS were not initialized from the model checkpoint at facebook/dinov2-base and are newly initialized: ['classifier.classifier.bias', 'classifier.classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
metric = evaluate.load("mean_iou")
metric_val = evaluate.load("mean_iou")

In [12]:
results = train_model(model, train_dataloader, metric, id2label, learning_rate=learning_rate, epochs=epochs, device=None)

Epoch: 0


  0%|          | 0/1264 [00:00<?, ?it/s]

  acc = total_area_intersect / total_area_label


Loss: 5.831729412078857
Mean IOU: 0.0008315554997648702
Mean Accuracy: 0.0061604621828927515
----------------------------------
Loss: 5.100131988525391
Mean IOU: 0.0026693451567006998
Mean Accuracy: 0.009756585076266274
----------------------------------
Loss: 4.058493137359619
Mean IOU: 0.006204779582129662
Mean Accuracy: 0.014799875468185977
----------------------------------
Loss: 3.7275755405426025
Mean IOU: 0.013465360998585607
Mean Accuracy: 0.025005634967895568
----------------------------------
Loss: 2.980827808380127
Mean IOU: 0.022096361008635955
Mean Accuracy: 0.037901427100769665
----------------------------------
Loss: 3.0593085289001465
Mean IOU: 0.03080645428479513
Mean Accuracy: 0.04967027211910477
----------------------------------
Loss: 2.2927896976470947
Mean IOU: 0.042041000022931145
Mean Accuracy: 0.062466115351662264
----------------------------------
Loss: 2.497021198272705
Mean IOU: 0.05221973480061944
Mean Accuracy: 0.07414532977767335
-------------------------

  iou = total_area_intersect / total_area_union


Loss: 2.122680425643921
Mean IOU: 0.06304852050078932
Mean Accuracy: 0.08687217815800934
----------------------------------
Loss: 1.8659729957580566
Mean IOU: 0.07237735878845941
Mean Accuracy: 0.09781667765504165
----------------------------------
Loss: 1.9589581489562988
Mean IOU: 0.08154372881825184
Mean Accuracy: 0.10932430793101985
----------------------------------
Loss: 1.7093273401260376
Mean IOU: 0.09206894898392287
Mean Accuracy: 0.12207304604227337
----------------------------------
Loss: 1.590934157371521
Mean IOU: 0.10011680271966333
Mean Accuracy: 0.13045772427733976
----------------------------------
Epoch: 1


  0%|          | 0/1264 [00:00<?, ?it/s]

Loss: 2.002190589904785
Mean IOU: 0.10654349777044954
Mean Accuracy: 0.13938728826240684
----------------------------------
Loss: 2.3491148948669434
Mean IOU: 0.11565203719217522
Mean Accuracy: 0.14925586276748687
----------------------------------
Loss: 1.4814821481704712
Mean IOU: 0.1263542612977016
Mean Accuracy: 0.16108003618534558
----------------------------------
Loss: 1.5033583641052246
Mean IOU: 0.1295991962242095
Mean Accuracy: 0.16626829790557265
----------------------------------
Loss: 1.5855528116226196
Mean IOU: 0.14298338105761213
Mean Accuracy: 0.1799435828714534
----------------------------------
Loss: 1.4443718194961548
Mean IOU: 0.15431830336366792
Mean Accuracy: 0.19430993944200495
----------------------------------
Loss: 1.3551511764526367
Mean IOU: 0.16020401035629145
Mean Accuracy: 0.20201119641264595
----------------------------------
Loss: 1.1157556772232056
Mean IOU: 0.16627879249701596
Mean Accuracy: 0.20989290270900493
----------------------------------
Loss

  0%|          | 0/1264 [00:00<?, ?it/s]

Loss: 1.1402921676635742
Mean IOU: 0.2157005031329137
Mean Accuracy: 0.2696402925332665
----------------------------------
Loss: 1.0918604135513306
Mean IOU: 0.22484996399187204
Mean Accuracy: 0.2811610869350024
----------------------------------
Loss: 1.2253432273864746
Mean IOU: 0.22615526482641518
Mean Accuracy: 0.2832045168832052
----------------------------------
Loss: 1.4216843843460083
Mean IOU: 0.2408445643782442
Mean Accuracy: 0.30284786975943545
----------------------------------
Loss: 1.4370543956756592
Mean IOU: 0.23516192367641703
Mean Accuracy: 0.29486368894467335
----------------------------------
Loss: 1.3368408679962158
Mean IOU: 0.245293548330552
Mean Accuracy: 0.30637786702312847
----------------------------------
Loss: 0.9766488075256348
Mean IOU: 0.2556535235783035
Mean Accuracy: 0.317895493600781
----------------------------------
Loss: 1.2688446044921875
Mean IOU: 0.2549860792241386
Mean Accuracy: 0.3209022605915446
----------------------------------
Loss: 1.0734

  0%|          | 0/1264 [00:00<?, ?it/s]

Loss: 1.0273890495300293
Mean IOU: 0.28661539153348414
Mean Accuracy: 0.35808546007110825
----------------------------------
Loss: 1.147711992263794
Mean IOU: 0.29185688502331414
Mean Accuracy: 0.3656431578298572
----------------------------------
Loss: 1.2281999588012695
Mean IOU: 0.29944243779869806
Mean Accuracy: 0.37690855795533823
----------------------------------
Loss: 0.9882869720458984
Mean IOU: 0.29840340948793226
Mean Accuracy: 0.37270233414874554
----------------------------------
Loss: 0.9076258540153503
Mean IOU: 0.299458253649164
Mean Accuracy: 0.3736571756898562
----------------------------------
Loss: 1.0382825136184692
Mean IOU: 0.31171577434220743
Mean Accuracy: 0.38826496820387
----------------------------------
Loss: 0.9849675893783569
Mean IOU: 0.3076762752559487
Mean Accuracy: 0.3850586287847981
----------------------------------
Loss: 0.920588493347168
Mean IOU: 0.3113426603532226
Mean Accuracy: 0.39255582087715923
----------------------------------
Loss: 1.1960

  0%|          | 0/1264 [00:00<?, ?it/s]

Loss: 0.9241369962692261
Mean IOU: 0.32717565549690136
Mean Accuracy: 0.4187631496346474
----------------------------------
Loss: 0.6052500009536743
Mean IOU: 0.327621646975631
Mean Accuracy: 0.40961111024788244
----------------------------------
Loss: 0.9026890993118286
Mean IOU: 0.33001663620715116
Mean Accuracy: 0.41542614855091725
----------------------------------
Loss: 0.7441354393959045
Mean IOU: 0.3320889825740222
Mean Accuracy: 0.4184500392848245
----------------------------------
Loss: 0.8226580619812012
Mean IOU: 0.3381598460600915
Mean Accuracy: 0.42385906777785437
----------------------------------
Loss: 0.9209801554679871
Mean IOU: 0.3332556027475867
Mean Accuracy: 0.4207053714914702
----------------------------------
Loss: 0.7239039540290833
Mean IOU: 0.3428808142218207
Mean Accuracy: 0.4314723122921686
----------------------------------
Loss: 0.8450959920883179
Mean IOU: 0.35064698622205676
Mean Accuracy: 0.4398410049658005
----------------------------------
Loss: 0.977

  0%|          | 0/1264 [00:00<?, ?it/s]

Loss: 1.0382267236709595
Mean IOU: 0.36221553965310865
Mean Accuracy: 0.4548146770426632
----------------------------------
Loss: 0.8132150173187256
Mean IOU: 0.3591089620382225
Mean Accuracy: 0.4501198022590941
----------------------------------
Loss: 0.8251363635063171
Mean IOU: 0.3629439275421772
Mean Accuracy: 0.4578543294890386
----------------------------------
Loss: 1.193735957145691
Mean IOU: 0.3567679637693436
Mean Accuracy: 0.4488129828352466
----------------------------------
Loss: 1.1221084594726562
Mean IOU: 0.36259045469804
Mean Accuracy: 0.45525957945131024
----------------------------------
Loss: 0.9492073655128479
Mean IOU: 0.3667214573066583
Mean Accuracy: 0.4666785807993854
----------------------------------
Loss: 0.8693735003471375
Mean IOU: 0.3725062286649378
Mean Accuracy: 0.46628961111244993
----------------------------------
Loss: 1.0988085269927979
Mean IOU: 0.3641809284709186
Mean Accuracy: 0.46186671064045093
----------------------------------
Loss: 0.9274651

  0%|          | 0/1264 [00:00<?, ?it/s]

Loss: 0.8200410008430481
Mean IOU: 0.36641292438375783
Mean Accuracy: 0.4728202017940772
----------------------------------
Loss: 0.8858672976493835
Mean IOU: 0.37592290398500977
Mean Accuracy: 0.47609379507435867
----------------------------------
Loss: 0.6337774395942688
Mean IOU: 0.38212064503829746
Mean Accuracy: 0.4772971548186333
----------------------------------
Loss: 0.7036575078964233
Mean IOU: 0.3719925007550242
Mean Accuracy: 0.4694301947544859
----------------------------------
Loss: 1.0246458053588867
Mean IOU: 0.3713603196440643
Mean Accuracy: 0.4704197814355505
----------------------------------
Loss: 0.7555981874465942
Mean IOU: 0.3955420846761106
Mean Accuracy: 0.4957486922374278
----------------------------------
Loss: 1.0133531093597412
Mean IOU: 0.38266271778108174
Mean Accuracy: 0.48053773597494287
----------------------------------
Loss: 0.7731493711471558
Mean IOU: 0.39363249841806164
Mean Accuracy: 0.4905179431368426
----------------------------------
Loss: 1.1

  0%|          | 0/1264 [00:00<?, ?it/s]

Loss: 0.6446806192398071
Mean IOU: 0.3768991019566543
Mean Accuracy: 0.4832787055916506
----------------------------------
Loss: 0.7878640294075012
Mean IOU: 0.38998255049388536
Mean Accuracy: 0.4923372362079866
----------------------------------
Loss: 0.9009037017822266
Mean IOU: 0.3868370968274154
Mean Accuracy: 0.4856155962169552
----------------------------------
Loss: 0.9052491784095764
Mean IOU: 0.3858334075518636
Mean Accuracy: 0.48933893526943667
----------------------------------
Loss: 0.6707981824874878
Mean IOU: 0.3894228947263257
Mean Accuracy: 0.48758571403371226
----------------------------------
Loss: 0.7744745016098022
Mean IOU: 0.39192043795709375
Mean Accuracy: 0.4973286926811971
----------------------------------
Loss: 0.7146725058555603
Mean IOU: 0.39925951842388363
Mean Accuracy: 0.5035986245375632
----------------------------------
Loss: 0.8916160464286804
Mean IOU: 0.40548666029814256
Mean Accuracy: 0.5080911564171025
----------------------------------
Loss: 0.71

  0%|          | 0/1264 [00:00<?, ?it/s]

Loss: 0.7894209027290344
Mean IOU: 0.3995296027079726
Mean Accuracy: 0.5077579535010336
----------------------------------
Loss: 0.7875381112098694
Mean IOU: 0.4101894205446828
Mean Accuracy: 0.5142072176968431
----------------------------------
Loss: 0.617716372013092
Mean IOU: 0.3983185905708091
Mean Accuracy: 0.5028568257841407
----------------------------------
Loss: 0.7780706882476807
Mean IOU: 0.39885536245241543
Mean Accuracy: 0.5050216786171395
----------------------------------
Loss: 0.8695551753044128
Mean IOU: 0.4045222835709344
Mean Accuracy: 0.506004642500338
----------------------------------
Loss: 0.8154429793357849
Mean IOU: 0.4125347384959342
Mean Accuracy: 0.5172993233085864
----------------------------------
Loss: 1.0863122940063477
Mean IOU: 0.4027635058235533
Mean Accuracy: 0.5021646565168699
----------------------------------
Loss: 0.8667376041412354
Mean IOU: 0.4047414093730824
Mean Accuracy: 0.5135460007484243
----------------------------------
Loss: 0.644693195

  0%|          | 0/1264 [00:00<?, ?it/s]

Loss: 0.8058062791824341
Mean IOU: 0.39488048747604665
Mean Accuracy: 0.5030590294770495
----------------------------------
Loss: 0.9744255542755127
Mean IOU: 0.4134884394402623
Mean Accuracy: 0.5174823595449999
----------------------------------
Loss: 0.6840226054191589
Mean IOU: 0.40485910931917907
Mean Accuracy: 0.508888533401606
----------------------------------
Loss: 0.48600995540618896
Mean IOU: 0.41037242238532723
Mean Accuracy: 0.5210318706286644
----------------------------------
Loss: 0.664333701133728
Mean IOU: 0.4166706961527839
Mean Accuracy: 0.525741565069832
----------------------------------
Loss: 0.678834080696106
Mean IOU: 0.3927997976237352
Mean Accuracy: 0.4994417515498418
----------------------------------
Loss: 0.8244014382362366
Mean IOU: 0.40702757298598274
Mean Accuracy: 0.5174764025936285
----------------------------------
Loss: 0.6373482346534729
Mean IOU: 0.4113615890111823
Mean Accuracy: 0.5203104760335844
----------------------------------
Loss: 0.9373562

In [13]:
val_eval_ms(model, val_dataloader, metric_val, id2label, device=None)

  0%|          | 0/125 [00:00<?, ?it/s]

Final Mean IOU: 0.40189161446495286
Final Mean Accuracy: 0.5044378567386025


In [14]:
with open("./autodl-tmp/seg_ade20k_base_ms.txt", "w") as file:
    for item in results:
        file.write(f"{item}\n")

### Small Model

In [15]:
# model = Dinov2ForSemanticSegmentationMS.from_pretrained("facebook/dinov2-small", id2label=id2label, num_labels=len(id2label))

In [16]:
#metric = evaluate.load("mean_iou")
#metric_val = evaluate.load("mean_iou")

In [17]:
#results = train_model(model, train_dataloader, metric, id2label, learning_rate=learning_rate, epochs=epochs, device=None)

In [18]:
#val_eval_ms(model, val_dataloader, metric_val, id2label, device=None)

In [19]:
#with open("./autodl-tmp/seg_ade20k_small_ms.txt", "w") as file:
#    for item in results:
#        file.write(f"{item}\n")