<a href="https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/SegFormer/Fine_tune_SegFormer_on_custom_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Code to tune the Segformer model for soilquality estimation from multispectral satellite images
### Written by Ayush Talukder

## Code was modified from the Segformer fine tuning notebook provided from the source URL below
## https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/SegFormer/Fine_tune_SegFormer_on_custom_dataset.ipynb

## Fine-tune SegFormer on the custom multispectral satelllite images dataset

Segformer model URL [SegFormerForSemanticSegmentation](https://huggingface.co/docs/transformers/main/model_doc/segformer#transformers.SegformerForSemanticSegmentation) on our multispectral satellite **semantic segmentation** dataset. In semantic segmentation, the goal for the model is to label each pixel of an image with one of a list of predefined classes.

Load model weights pre-trained on ImageNet-1k at different , and fine-tune it together with the decoder head, which starts with randomly initialized weights.

In [19]:
## INSTALL ALL THE REQUIRED LIBRARIES
!pip install  datasets evaluate  #-q transformers

!pip install --upgrade transformers 
!pip install  --upgrade --force-reinstall git+https://github.com/huggingface/transformers

!pip install  --upgrade --force-reinstall  torch+cu117 torchaudio+cu117 torchvision

Collecting charset-normalizer<3.0,>=2.0
  Downloading charset_normalizer-2.1.1-py3-none-any.whl (39 kB)
Installing collected packages: charset-normalizer
  Attempting uninstall: charset-normalizer
    Found existing installation: charset-normalizer 3.3.2
    Uninstalling charset-normalizer-3.3.2:
      Successfully uninstalled charset-normalizer-3.3.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gradient 2.0.6 requires PyYAML==5.*, but you have pyyaml 6.0.1 which is incompatible.[0m[31m
[0mSuccessfully installed charset-normalizer-2.1.1
[0mCollecting git+https://github.com/huggingface/transformers
  Cloning https://github.com/huggingface/transformers to /tmp/pip-req-build-md4ivyjj
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers /tmp/pip-req-build-md4ivyjj
  Resolved https://github.com/huggingface/tr

In [20]:
import requests, zipfile, io

from datasets import load_dataset



##  PyTorch dataset and dataloaders

(https://pytorch.org/tutorials/beginner/data_loading_tutorial.html). Each item of the dataset consists of an image and a corresponding segmentation map.

In [22]:
from torch.utils.data import Dataset
import os
from PIL import Image

class SemanticSegmentationDataset(Dataset):
    """Image (semantic) segmentation dataset."""

    def __init__(self, root_dir, image_processor, train=True):
        """
        Args:
            root_dir (string): Root directory of the dataset containing the images + annotations.
            image_processor (SegFormerImageProcessor): image processor to prepare images + segmentation maps.
            train (bool): Whether to load "training" or "validation" images + annotations.
        """
        self.root_dir = root_dir
        self.image_processor = image_processor
        self.train = train

        sub_path = "training" if self.train else "validation"
        self.img_dir = os.path.join(self.root_dir, "images", sub_path)
        self.ann_dir = os.path.join(self.root_dir, "annotations", sub_path)

        # read images
        image_file_names = []
        for root, dirs, files in os.walk(self.img_dir):
          image_file_names.extend(files)
        self.images = sorted(image_file_names)

        # read annotations
        annotation_file_names = []
        for root, dirs, files in os.walk(self.ann_dir):
          annotation_file_names.extend(files)
        self.annotations = sorted(annotation_file_names)

        assert len(self.images) == len(self.annotations), "There must be as many images as there are segmentation maps"

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):

        image = Image.open(os.path.join(self.img_dir, self.images[idx]))
        segmentation_map = Image.open(os.path.join(self.ann_dir, self.annotations[idx]))

        # randomly crop + pad both image and segmentation map to same size
        encoded_inputs = self.image_processor(image, segmentation_map, return_tensors="pt")

        for k,v in encoded_inputs.items():
          encoded_inputs[k].squeeze_() # remove batch dimension

        return encoded_inputs

Let's initialize the training + validation datasets. Important: we initialize the image processor with `reduce_labels=True`, as the classes in ADE20k go from 0 to 150, with 0 meaning "background". However, we want the labels to go from 0 to 149, and only train the model to recognize the 150 classes (which don't include "background"). Hence, we'll reduce all labels by 1 and replace 0 by 255, which is the `ignore_index` of SegFormer's loss function.

In [23]:
from transformers import AutoImageProcessor, TFSegformerForSemanticSegmentation
from PIL import Image
import requests


[1, 150, 128, 128]

In [24]:
from transformers import SegformerImageProcessor

#https://huggingface.co/docs/transformers/model_doc/segformer

from transformers import SegformerModel, SegformerConfig

# Load the preprocessed multispectral soil data, all in one large image
root_dir = 'Soil3sat3chanimgscombinedbinsubset2' #'/notebooks/Soil3sat3chanimgscombined' #'ADE20k_toy_dataset'
image_processor = SegformerImageProcessor(reduce_labels=True)



train_dataset = SemanticSegmentationDataset(root_dir=root_dir, image_processor=image_processor)

## Just for this run, validate on the training dataset
valid_dataset = SemanticSegmentationDataset(root_dir=root_dir, image_processor=image_processor, train=False)

In [25]:
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(valid_dataset))

Number of training examples: 395
Number of validation examples: 0


Let's verify a random example:

In [26]:
encoded_inputs = train_dataset[0]

In [27]:
encoded_inputs["pixel_values"].shape

torch.Size([3, 512, 512])

In [28]:
encoded_inputs["labels"].shape

torch.Size([512, 512])

In [29]:
encoded_inputs["labels"]

tensor([[119, 119, 119,  ..., 119, 119, 119],
        [119, 119, 119,  ..., 119, 119, 119],
        [119, 119, 119,  ..., 119, 119, 119],
        ...,
        [119, 119, 119,  ..., 119, 119, 119],
        [119, 119, 119,  ..., 119, 119, 119],
        [119, 119, 119,  ..., 119, 119, 119]])

In [30]:
encoded_inputs["labels"].squeeze().unique()

tensor([ 24, 119])

Next, we define corresponding dataloaders.

In [31]:
from torch.utils.data import DataLoader

BATCHSIZE = 4 #16 #2 
train_dataloader = DataLoader(train_dataset, batch_size=BATCHSIZE, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCHSIZE)

In [32]:
batch = next(iter(train_dataloader))

In [33]:
for k,v in batch.items():
  print(k, v.shape)

pixel_values torch.Size([4, 3, 512, 512])
labels torch.Size([4, 512, 512])


In [34]:
batch["labels"].shape

torch.Size([4, 512, 512])

In [35]:
mask = (batch["labels"] != 255)


tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

In [36]:
batch["labels"][mask]

tensor([119, 119, 119,  ..., 119, 119, 119])

## Select and define the model before fine tuning

Here we load the model, and equip the encoder with weights pre-trained on ImageNet-1k. We evaluate  3 variants, `nvidia/mit-b3', `nvidia/mit-b5', `nvidia/mit-b0` here from the [hub](https://huggingface.co/models?other=segformer)). We also set the `id2label` and `label2id` mappings, which will be useful when performing inference.

In [37]:
from transformers import SegformerForSemanticSegmentation
import json
from huggingface_hub import hf_hub_download

# load id2label mapping from a JSON on the hub
repo_id = "huggingface/label-files"
filename = "ade20k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}

# define model
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b3", #"nvidia/mit-b0", #"nvidia/mit-b3", # "nvidia/mit-b5", #"nvidia/mit-b3", #"nvidia/mit-b0",
                                                         num_labels=150,
                                                         id2label=id2label,
                                                         label2id=label2id,
)

ade20k-id2label.json:   0%|          | 0.00/2.81k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/70.0k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/179M [00:00<?, ?B/s]

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b3 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Fine-tune the model

Here we fine-tune the model in native PyTorch, using the AdamW optimizer. We use the same learning rate as the one reported in the [paper](https://arxiv.org/abs/2105.15203).

It's also very useful to track metrics during training. For semantic segmentation, typical metrics include the mean intersection-over-union (mIoU) and pixel-wise accuracy. These are available in the Datasets library. We can load it as follows:

In [38]:
import evaluate

metric = evaluate.load("mean_iou")

Downloading builder script:   0%|          | 0.00/12.9k [00:00<?, ?B/s]

In [39]:
image_processor.do_reduce_labels

True

In [None]:
import torch
from torch import nn
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
losslist = []
IoUlist = []
Acclist = []
# define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)
# move model to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

model.train()
for epoch in range(75): #range(100): #range(40):  #range(40):   #range(25):  #range(40):  #range(8):  #range(200):  # loop over the dataset multiple times
   print("Epoch:", epoch)
   for idx, batch in enumerate(tqdm(train_dataloader)):
        # get the inputs;
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss, logits = outputs.loss, outputs.logits

        loss.backward()
        optimizer.step()

        # evaluate
        with torch.no_grad():
          upsampled_logits = nn.functional.interpolate(logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
          predicted = upsampled_logits.argmax(dim=1)

          # note that the metric expects predictions + labels as numpy arrays
          metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy())

        # let's print loss and metrics every 100 batches
        if idx % 100 == 0:
          # currently using _compute instead of compute
          # see this issue for more info: https://github.com/huggingface/evaluate/pull/328#issuecomment-1286866576
          metrics = metric._compute(
                  predictions=predicted.cpu(),
                  references=labels.cpu(),
                  num_labels=len(id2label),
                  ignore_index=255,
                  reduce_labels=False, # we've already reduced the labels ourselves
              )

          print("Loss:", loss.item(),"Mean_iou:", metrics["mean_iou"],"Mean accuracy:", metrics["mean_accuracy"])
          #print("Mean_iou:", metrics["mean_iou"])
          #print("Mean accuracy:", metrics["mean_accuracy"])
          losslist.append(loss.item())
          IoUlist.append(metrics["mean_iou"])
          Acclist.append(metrics["mean_accuracy"])
        
        

Epoch: 0


  0%|          | 0/99 [00:00<?, ?it/s]

  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label


Loss: 5.1850128173828125 Mean_iou: 0.00013376001036470767 Mean accuracy: 0.010284337792267548
Epoch: 1


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 1.6385059356689453 Mean_iou: 0.36310013069465596 Mean accuracy: 0.5458416452471003
Epoch: 2


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.7878708839416504 Mean_iou: 0.41272356614630706 Mean accuracy: 0.6490026088882825
Epoch: 3


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.5493640303611755 Mean_iou: 0.5839314931819564 Mean accuracy: 0.7428850452862733
Epoch: 4


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.48193854093551636 Mean_iou: 0.5835414669389427 Mean accuracy: 0.755703847205941
Epoch: 5


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.8494604825973511 Mean_iou: 0.3109601096552831 Mean accuracy: 0.5825065409558076
Epoch: 6


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.572008490562439 Mean_iou: 0.5346720529479377 Mean accuracy: 0.6913431041684456
Epoch: 7


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.600440502166748 Mean_iou: 0.5420757812646227 Mean accuracy: 0.7023387751037853
Epoch: 8


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.5206790566444397 Mean_iou: 0.6109182999898238 Mean accuracy: 0.7574637005457272
Epoch: 9


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.49237722158432007 Mean_iou: 0.42811516999803967 Mean accuracy: 0.6078103826051617
Epoch: 10


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.6627342104911804 Mean_iou: 0.4024880483252158 Mean accuracy: 0.6063443518704434
Epoch: 11


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.4749429523944855 Mean_iou: 0.5435412584589141 Mean accuracy: 0.6930484494476892
Epoch: 12


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.40914788842201233 Mean_iou: 0.5965445759380681 Mean accuracy: 0.7434167991604589
Epoch: 13


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.4303412437438965 Mean_iou: 0.6300061895644302 Mean accuracy: 0.8050924473516086
Epoch: 14


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.41725048422813416 Mean_iou: 0.649887648935886 Mean accuracy: 0.7807726079584711
Epoch: 15


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.5796909332275391 Mean_iou: 0.5763932146063334 Mean accuracy: 0.7320773833815355
Epoch: 16


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.5765277147293091 Mean_iou: 0.5418630991043436 Mean accuracy: 0.7067622374835694
Epoch: 17


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.22436192631721497 Mean_iou: 0.7708295881599767 Mean accuracy: 0.877206691519611
Epoch: 18


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.3113657832145691 Mean_iou: 0.7711920512540329 Mean accuracy: 0.8726008779796751
Epoch: 19


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.3466954529285431 Mean_iou: 0.5684338695071558 Mean accuracy: 0.667308084000477
Epoch: 20


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.39616140723228455 Mean_iou: 0.6422415989321133 Mean accuracy: 0.758891806531107
Epoch: 21


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.5270740389823914 Mean_iou: 0.5753536118055398 Mean accuracy: 0.7277966673364309
Epoch: 22


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.3220931589603424 Mean_iou: 0.7271143397687662 Mean accuracy: 0.8767545479433846
Epoch: 23


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.42756617069244385 Mean_iou: 0.5806766836996258 Mean accuracy: 0.7086879713160876
Epoch: 24


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.5580666065216064 Mean_iou: 0.5145672180163454 Mean accuracy: 0.6538616029568541
Epoch: 25


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.2400899976491928 Mean_iou: 0.7929197338896865 Mean accuracy: 0.9054526321084485
Epoch: 26


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.42239272594451904 Mean_iou: 0.6195745293932169 Mean accuracy: 0.748728975710544
Epoch: 27


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.31949087977409363 Mean_iou: 0.7404783553042873 Mean accuracy: 0.8569531004182558
Epoch: 28


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.35113969445228577 Mean_iou: 0.715332593568641 Mean accuracy: 0.8369763344012254
Epoch: 29


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.208771213889122 Mean_iou: 0.822878375131728 Mean accuracy: 0.91903929449139
Epoch: 30


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.48001569509506226 Mean_iou: 0.6120812194508443 Mean accuracy: 0.7526444162521082
Epoch: 31


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.3041592240333557 Mean_iou: 0.7434021343199045 Mean accuracy: 0.8398322101766693
Epoch: 32


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.32369667291641235 Mean_iou: 0.7417807884135469 Mean accuracy: 0.8541959723135264
Epoch: 33


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.3158841133117676 Mean_iou: 0.7357438557059977 Mean accuracy: 0.8526961259935417
Epoch: 34


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.3838197886943817 Mean_iou: 0.6954366977636882 Mean accuracy: 0.8202070418357568
Epoch: 35


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.47195371985435486 Mean_iou: 0.6249534430722052 Mean accuracy: 0.7764612439009492
Epoch: 36


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.35369792580604553 Mean_iou: 0.7105977631948419 Mean accuracy: 0.834408584270941
Epoch: 37


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.39849215745925903 Mean_iou: 0.6495358855588503 Mean accuracy: 0.7749611564352277
Epoch: 38


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.4165491461753845 Mean_iou: 0.5815552792618889 Mean accuracy: 0.7059226770277793
Epoch: 39


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.33976811170578003 Mean_iou: 0.7202795161045854 Mean accuracy: 0.8447560483706524
Epoch: 40


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.31605011224746704 Mean_iou: 0.74352966926708 Mean accuracy: 0.852987793125074
Epoch: 41


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.47904613614082336 Mean_iou: 0.6397908196826458 Mean accuracy: 0.7809793734615493
Epoch: 42


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.41146019101142883 Mean_iou: 0.6695008523290219 Mean accuracy: 0.8066666367474011
Epoch: 43


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.36432239413261414 Mean_iou: 0.6964168401686521 Mean accuracy: 0.8331198328717333
Epoch: 44


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.46296584606170654 Mean_iou: 0.6043694756754249 Mean accuracy: 0.7546966609365702
Epoch: 45


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.3265584409236908 Mean_iou: 0.6624281932975703 Mean accuracy: 0.776978993080444
Epoch: 46


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.2166537195444107 Mean_iou: 0.8160937354990822 Mean accuracy: 0.917729228673023
Epoch: 47


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.22918489575386047 Mean_iou: 0.7473351251702491 Mean accuracy: 0.8270768018170296
Epoch: 48


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.4281662702560425 Mean_iou: 0.6443683230689996 Mean accuracy: 0.7828778521692539
Epoch: 49


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.40397995710372925 Mean_iou: 0.6119830146228957 Mean accuracy: 0.731424318998349
Epoch: 50


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.4031391143798828 Mean_iou: 0.6687161184088993 Mean accuracy: 0.8079544121585756
Epoch: 51


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.2803560495376587 Mean_iou: 0.7663661189813253 Mean accuracy: 0.881274232647068
Epoch: 52


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.3686921298503876 Mean_iou: 0.7042852892102522 Mean accuracy: 0.8228874522187195
Epoch: 53


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.3618180453777313 Mean_iou: 0.6781638606319222 Mean accuracy: 0.7888632922040328
Epoch: 54


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.5154937505722046 Mean_iou: 0.5897283513674529 Mean accuracy: 0.7463716542208774
Epoch: 55


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.48292461037635803 Mean_iou: 0.6206822855041731 Mean accuracy: 0.7656229991170143
Epoch: 56


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.38455623388290405 Mean_iou: 0.6795794194530111 Mean accuracy: 0.7987200965672321
Epoch: 57


  0%|          | 0/99 [00:00<?, ?it/s]

Loss: 0.16405479609966278 Mean_iou: 0.8175828831299298 Mean accuracy: 0.9325362776004129


In [None]:
print(len(IoUlist))

print(IoUlist)

In [None]:
import matplotlib.pyplot as plt

IoU = IoUlist #history.history['loss']
val_loss = IoU # history.history['val_loss']
epochs = range(1, len(IoU) + 1)
plt.plot(epochs, IoU, 'y', label='Mean IoU')
#plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Mean IoU')
plt.xlabel('Epochs')
plt.ylabel('Mean IoU')
plt.legend()
plt.show()

In [None]:
Acc = Acclist #history.history['loss']
#Acc = IoU # history.history['val_loss']
epochs = range(1, len(Acc) + 1)
plt.plot(epochs, Acc, 'r-', label='Mean Accuracy')
#plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Mean Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Mean Accuracy')
plt.legend()
plt.show()

In [None]:
loss = losslist #history.history['loss']
val_loss = loss # history.history['val_loss']
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, 'b-', label='Training loss')
#plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
import pandas as pd
dict = {'Mean IoU': IoUlist, 'Training Loss': losslist, 'Accuracy': Acclist}
df = pd.DataFrame(dict) 
# saving the dataframe
#"nvidia/mit-b0",
#df.to_csv('Result80Epochs-batch4_CSV_segsoil.csv')

#df.to_csv('Result25Epochs-batch4_CSV_segsoil_mitb3.csv')

#df.to_csv('Result36Epochs-run2_batch4_CSV_segsoil_mitb5.csv')

#RUN ON APR 22
#df.to_csv('Result40Epochs-run4_batch4_CSV_segsoil_mitb5.csv')

#df.to_csv('Result50Epochs-run4_batch4_CSV_segsoil_mitb3.csv')


#df.to_csv('Result50Epochs-run4_batch4_CSV_segsoil_mitb0.csv')


#df.to_csv('Result100Epochs-run4_batch4_CSV_segsoil_mitb0.csv')


df.to_csv('Result70Epochs-run4_batch4_CSV_segsoil_mitb0.csv')

## Inference

Finally, let's check whether the model has really learned something.

Let's test the trained model on an image (refer to my [inference notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SegFormer/Segformer_inference_notebook.ipynb) for details):

In [None]:

# Test the fine tuned model on a specific 5152x512 patch of the test dataset


imgnum = 287 #(good)



image = Image.open(f'Soil3sat3chanimgscombinedbinsubset2/images/test/patch_{imgnum}.jpg')

labelorig = Image.open(f'Soil3sat3chanimgscombinedbinsubset2/annotations/test/patch_{imgnum}.png')



In [None]:
# prepare the image for the model
pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device)
#print(pixel_values.shape)
import torch

# forward pass
with torch.no_grad():
  outputs = model(pixel_values=pixel_values)

# logits are of shape (batch_size, num_labels, height/4, width/4)
logits = outputs.logits.cpu()
#print(logits.shape)

predicted_segmentation_map = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
predicted_segmentation_map = predicted_segmentation_map.cpu().numpy()

predlabelmult = predicted_segmentation_map *1
predlabelmult

imgpredlabel = predlabelmult #np.array(image) * 0.5 + color_seg * 0.5
imgpredlabel = imgpredlabel.astype(np.uint8)

#plt.figure(figsize=(4, 4))
#plt.imshow(imgpredlabel,cmap='gray')
#plt.show()

fig, axs = plt.subplots(1, 2)
axs[0].imshow(imgpredlabel, cmap='gray')
axs[1].imshow(labelorig, cmap='gray')

# Add titles to the subplots
axs[0].set_title('Predicted Soil Quality')
axs[1].set_title('Ground-Truth Soil Quality')

# Show the plot
plt.show()


In [None]:
#labelorig

# Display the images side-by-side
fig, axs = plt.subplots(1, 3)


axs[0].imshow(image)
axs[1].imshow(labelorig, cmap='gray')
axs[2].imshow(imgpredlabel, cmap='gray')

# Add titles to the subplots
axs[0].set_title('Input Image (Pseudocolor)',fontsize=9)
axs[1].set_title('Ground-Truth Soil Quality',fontsize=9)
axs[2].set_title('Predicted Soil Quality',fontsize=9)

axs[0].xaxis.set_visible(False)
axs[0].yaxis.set_visible(False)
axs[1].xaxis.set_visible(False)
axs[1].yaxis.set_visible(False)
axs[2].xaxis.set_visible(False)
axs[2].yaxis.set_visible(False)

# Show the plot
plt.show()

#plt.figure(figsize=(4, 4))
#plt.imshow(labelorig,cmap='gray')
#plt.show()

In [None]:
#labelorig

# Display the images side-by-side
fig, axs = plt.subplots(1, 3)



norm = plt.Normalize(vmin=0, vmax=1)
cmap = plt.cm.jet
axs[0].imshow(image,cmap=cmap, norm=norm)
# Display the image
#plt.imshow(image, cmap=cmap, norm=norm)
axs[1].imshow(labelorig, cmap='gray')
axs[2].imshow(imgpredlabel, cmap='gray')

# Add titles to the subplots
axs[0].set_title('Input Image (Pseudocolor)',fontsize=9)
axs[1].set_title('Ground-Truth Soil Quality',fontsize=9)
axs[2].set_title('Predicted Soil Quality',fontsize=9)

axs[0].xaxis.set_visible(False)
axs[0].yaxis.set_visible(False)
axs[1].xaxis.set_visible(False)
axs[1].yaxis.set_visible(False)
axs[2].xaxis.set_visible(False)
axs[2].yaxis.set_visible(False)

# Show the plot
plt.show()



In [None]:
#img = image


#norm = plt.Normalize(vmin=0, vmax=1)
#cmap = plt.cm.jet
#plt.imshow(image,cmap='RdBu_r', norm=norm)
# Display the image
#plt.imshow(image, cmap=cmap, norm=norm)

# Create a colormap
#cmap = plt.cm.jet

img = np.array(image)

# Normalize each channel
img_norm = np.zeros_like(image)
for i in range(3):
    img_norm[:, :, i] = (img[:, :, i] - img[:, :, i].mean()) / (1.3*img[:, :, i].std())
    #img_norm[:, :, i] = 255*(img[:, :, i] - np.min(img[:, :, i])) / (np.max(img[:, :, i])-  np.min(img[:, :, i]))


#plt.imshow(img_norm,cmap='plasma')  #cmap='YlOrBr') #cmap='gist_earth')
# Show the plot
#plt.show()

fig, axs = plt.subplots(1, 3)



norm = plt.Normalize(vmin=0, vmax=1)
cmap = plt.cm.jet
axs[0].imshow(img_norm,cmap=cmap, norm=norm)
# Display the image
#plt.imshow(image, cmap=cmap, norm=norm)
axs[1].imshow(labelorig, cmap='gray')
axs[2].imshow(imgpredlabel, cmap='gray')

# Add titles to the subplots
axs[0].set_title('Input Image (Pseudocolor)',fontsize=9)
axs[1].set_title('Ground-Truth Soil Quality',fontsize=9)
axs[2].set_title('Predicted Soil Quality',fontsize=9)

axs[0].xaxis.set_visible(False)
axs[0].yaxis.set_visible(False)
axs[1].xaxis.set_visible(False)
axs[1].yaxis.set_visible(False)
axs[2].xaxis.set_visible(False)
axs[2].yaxis.set_visible(False)

# Show the plot
plt.show()

Compare this to the ground truth segmentation map: