[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DiTo97/binarization-segformer/blob/main/fine-tuning.ipynb)

# Fine-tuning Segformer for Document Image Binarization

A notebook by F. Minutoli ([@DiTo97](https://github.com/DiTo97)) that fine-tunes a Segformer model for document image binarization

In [None]:
requirements = " ".join([
    "accelerate==0.18.0",
    "albumentations==1.3.0",
    "bitsandbytes==0.38.1",
    "datasets==2.11.0",
    "evaluate==0.4.0",
    "huggingface-hub==0.13.4",
    "transformers==4.27.4"
])

!python -m pip install --upgrade pip
!python -m pip install $requirements

In [None]:
from huggingface_hub import notebook_login

In [None]:
notebook_login()

## 1. Dataset

In [None]:
!git clone https://github.com/Leedeng/SauvolaNet.git

In [None]:
import pathlib
import sys
import typing
from typing import Any

import datasets
import numpy as np
import numpy.typing as nptyping
from PIL import Image
from tqdm.auto import tqdm

In [None]:
def normalize(bitmap: Image.Image) -> Image.Image:
    bitmap = bitmap.convert("L")
    bitmap = np.array(bitmap).astype(np.uint8)
    condition = bitmap < np.max(bitmap)
    bitmap = np.where(condition, 1, 0).astype(np.bool_)
    bitmap = Image.fromarray(bitmap)

    return bitmap

In [None]:
def prepare_examples(
    batch: typing.Dict[str, typing.List[Any]]
) -> typing.Dict[str, typing.List[Any]]:
    """It prepares a batch of examples for semantic segmentation"""
    sources = batch["source"]
    targets = batch["target"]

    batch = {
        "labelmap": [normalize(Image.open(tgt)) for tgt in targets],
        "pixelmap": [Image.open(src) for src in sources]
    }

    return batch

In [None]:
sauvolanet_src = "SauvolaNet/SauvolaDocBin"
sauvolanet_dataset = "SauvolaNet/Dataset"

sys.path.insert(0, sauvolanet_src)
from dataUtils import collect_binarization_by_dataset

collection = collect_binarization_by_dataset(sauvolanet_dataset)

sys.path.remove(sauvolanet_src)

del sauvolanet_src
del sauvolanet_dataset
del collect_binarization_by_dataset

features = datasets.Features({
    "ensemble": datasets.Value("string"),
    "source": datasets.Value("string"),
    "target": datasets.Value("string"),
})

for name, examples in tqdm(collection.items(), desc="Loading datasets"):
    sources, targets = zip(*examples)

    sources = sorted(sources)
    targets = sorted(targets)

    dataset = {"source": sources, "target": targets, "ensemble": [name] * len(sources)}
    dataset = datasets.Dataset.from_dict(dataset, features)

    collection[name] = dataset

collection = datasets.concatenate_datasets([
    dataset for _, dataset in collection.items()
])

features = datasets.Features({
    "ensemble": datasets.Value("string"),
    "labelmap": datasets.Image(),
    "pixelmap": datasets.Image(),
})

collection = collection.map(
    prepare_examples, 
    batched=True,
    features=features, 
    remove_columns=["source", "target"]
)

collection = collection.class_encode_column("ensemble")

del features

collection = collection.train_test_split(
    seed=10,
    shuffle=True,
    stratify_by_column="ensemble",
    train_size=0.75
)

train_dataset = collection["train"]
test_dataset  = collection[ "test"]

del collection

In [None]:
labels = ["background", "text"]
num_labels = len(labels)

id2label = {key: val for key, val in enumerate(labels)}
label2id = {val: key for key, val in enumerate(labels)}

del labels

## 2. Augmentation

In [None]:
import albumentations
import cv2
import transformers
from transformers import set_seed

In [None]:
set_seed(10)

In [None]:
base_model_name = "nvidia/segformer-b3-finetuned-cityscapes-1024-1024"
base_model_size = {"height": 640, "width": 640}

processor = transformers.SegformerImageProcessor.from_pretrained(base_model_name)
processor.size.update(base_model_size)

In [None]:
FLAGS = {
    # The general kwargs
    "border_mode": cv2.BORDER_CONSTANT,
    "fill_value": 255,
    "mask_fill_value": 0,
    "proba": 0.1,

    # The color kwargs
    "brightness": 0.25, 
    "contrast": 0.25, 
    "saturation": 0.25, 
    "hue": 0.1,
    
    # The crop kwargs
    "min_height": processor.size["height"],
    "min_width" : processor.size[ "width"],
    
    # The geometric kwargs
    "rotate": (-90, 90),
    "translate_percent": 0.1
}

transform1 = albumentations.Compose([
    albumentations.ColorJitter(
        brightness=FLAGS["brightness"], 
        contrast=FLAGS["contrast"], 
        saturation=FLAGS["saturation"], 
        hue=FLAGS["hue"]
    )
])

transform2 = albumentations.Compose([
    albumentations.Flip(p=FLAGS["proba"]),
    albumentations.Affine(
        p=FLAGS["proba"],
        cval=FLAGS["fill_value"],
        cval_mask=FLAGS["mask_fill_value"],
        mode=FLAGS["border_mode"],
        rotate=FLAGS["rotate"], 
        translate_percent=FLAGS["translate_percent"],
    ),
    albumentations.PadIfNeeded(
        border_mode=FLAGS["border_mode"],
        mask_value=FLAGS["mask_fill_value"],
        min_height=FLAGS["min_height"], 
        min_width=FLAGS["min_width"], 
        value=FLAGS["fill_value"],
    ),
    albumentations.RandomCrop(
        p=FLAGS["proba"],
        height=FLAGS["min_height"], 
        width=FLAGS["min_width"],
    )
])

def train_transform(
    batch: typing.Dict[str, typing.List[Any]]
) -> transformers.BatchFeature:
    images = [image.convert("RGB") for image in batch["pixelmap"]]
    images = [np.array(image) for image in images]
    images = [transform1(image=image)["image"] for image in images]

    labels = [np.array(label).astype(np.uint8) for label in batch["labelmap"]]

    examples = [
        transform2(image=image, mask=mask) for image, mask in zip(images, labels)
    ]

    images = [example["image"] for example in examples]
    labels = [example[ "mask"] for example in examples]

    encoding = processor(images, labels)
    return encoding

def  test_transform(
    batch: typing.Dict[str, typing.List[Any]]
) -> transformers.BatchFeature:
    images = [image.convert("RGB") for image in batch["pixelmap"]]
    labels = [label for label in batch["labelmap"]]

    encoding = processor(images, labels)
    return encoding

In [None]:
train_dataset.set_transform(train_transform)
test_dataset.set_transform(test_transform)

## 3. Training

In [None]:
# import evaluate
import bitsandbytes
import torch
from torch import nn
from transformers.trainer_utils import get_last_checkpoint
from transformers.trainer_pt_utils import get_parameter_names

In [None]:
cuda = torch.cuda.is_available()

In [None]:
datasets.logging.set_verbosity_error()
# evaluate.logging.set_verbosity_error()
transformers.logging.set_verbosity_error()

In [None]:
FLAGS = {
    "accumulation_steps": 4,
    "base_model_name": base_model_name,
    "batch_size": 4,
    "fp16": cuda,
    "learning_rate": 5e-5,
    "metric": "dibco",
    "model_name": "binarization-segformer-b3",
    "num_epochs": 50,
    "optimizer": "adamw_torch",
    "scheduler_type": "cosine"
}

In [None]:
processor.push_to_hub(FLAGS["model_name"])

In [None]:
# -*- coding: utf-8 -*-
"""
Created on Tue May 21 21:02:46 2019

@author: VIPlab
"""
import numpy as np
import cv2
import math
from scipy import ndimage as ndi

#predict_img = 'E:\Document-Binarization\DIBCO_metrics\DIBCO_metrics\P03_adotsu.tif'
#predict_img = cv2.imread(predict_img, 0)
#GT_img = 'E:\Document-Binarization\DIBCO_metrics\DIBCO_metrics\P03_GT.tif'
#GT_img = cv2.imread(GT_img, 0)
#predict_img_ = np.copy(predict_img)
#predict_img_ = predict_img_/255
#GT_img_ = np.copy(GT_img)
#GT_img_ = GT_img_/255


G123_LUT = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1,
       0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0,
       1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
       0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1,
       0, 0, 0], dtype=bool)

G123P_LUT = np.array([0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0,
       1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0,
       0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0,
       1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1,
       0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0], dtype=bool)

def bwmorph_thin(image, n_iter=None):
    # check parameters
    if n_iter is None:
        n = -1
    elif n_iter <= 0:
        raise ValueError('n_iter must be > 0')
    else:
        n = n_iter
    
    # check that we have a 2d binary image, and convert it
    # to uint8
    skel = np.array(image).astype(np.uint8)
    
    if skel.ndim != 2:
        raise ValueError('2D array required')
    if not np.all(np.in1d(image.flat,(0,1))):
        raise ValueError('Image contains values other than 0 and 1')

    # neighborhood mask
    mask = np.array([[ 8,  4,  2],
                     [16,  0,  1],
                     [32, 64,128]],dtype=np.uint8)

    # iterate either 1) indefinitely or 2) up to iteration limit
    while n != 0:
        before = np.sum(skel) # count points before thinning
        
        # for each subiteration
        for lut in [G123_LUT, G123P_LUT]:
            # correlate image with neighborhood mask
            N = ndi.correlate(skel, mask, mode='constant')
            # take deletion decision from this subiteration's LUT
            D = np.take(lut, N)
            # perform deletion
            skel[D] = 0
            
        after = np.sum(skel) # coint points after thinning
        
        if before == after:
            # iteration had no effect: finish
            break
            
        # count down to iteration limit (or endlessly negative)
        n -= 1
    
    return skel.astype(bool)


def Fmeasure(predict_img_,GT_img_):
    temp_tp = (1-predict_img_) * (1-GT_img_)
    temp_fp = (1-predict_img_) * GT_img_
    temp_fn = predict_img_ * (1-GT_img_)
    temp_tn = predict_img_ * GT_img_
    count_tp = sum(sum(temp_tp))
    count_fp = sum(sum(temp_fp))
    count_fn = sum(sum(temp_fn))
    count_tn = sum(sum(temp_tn))
    temp_p = count_tp / (count_fp + count_tp + 1e-4)
    temp_r = count_tp / (count_fn + count_tp + 1e-4)
    temp_f = 2 * (temp_p * temp_r) / (temp_p + temp_r + 1e-4)
    return temp_f

def Psnr(predict_img_,GT_img_):
    temp_fp = (1-predict_img_) * GT_img_
    temp_fn = predict_img_ * (1-GT_img_)
    xm = GT_img_.shape[0]
    ym = GT_img_.shape[1]
    fp_fn = temp_fp + temp_fn
    fp_fn[fp_fn>0] = 1
    fp_fn[fp_fn==0] = 0
    err=sum(sum(fp_fn)) / (xm * ym) 
    temp_PSNR = 10 * math.log( 1 / err,10)
    return temp_PSNR

def Pfmeasure(predict_img_,GT_img_):
    N_GT_img_ = 1 - GT_img_
    skel_GT = bwmorph_thin(N_GT_img_)
    skel_GT = (skel_GT).astype('uint8')
    skel_GT = 1 - skel_GT
    temp_tp = (1-predict_img_) * (1-GT_img_)
    temp_fp = (1-predict_img_) * GT_img_
    temp_fn = predict_img_ * (1-GT_img_)
    temp_tn = predict_img_ * GT_img_
    count_tp = sum(sum(temp_tp))
    count_fp = sum(sum(temp_fp))
    count_fn = sum(sum(temp_fn))
    count_tn = sum(sum(temp_tn))
    temp_p = count_tp / (count_fp + count_tp + 1e-4) 
    temp_skl_tp = (1-predict_img_) * (1-skel_GT)
    temp_skl_fp = (1-predict_img_) * skel_GT
    temp_skl_fn = predict_img_ * (1-skel_GT)
    temp_skl_tn = predict_img_ * skel_GT
    count_skl_tp = sum(sum(temp_skl_tp))
    count_skl_fp = sum(sum(temp_skl_fp))
    count_skl_fn = sum(sum(temp_skl_fn))
    count_skl_tn = sum(sum(temp_skl_tn))
    temp_pseudo_p = count_skl_tp / (count_skl_fp + count_skl_tp + 1e-4) 
    temp_pseudo_r = count_skl_tp / (count_skl_fn + count_skl_tp + 1e-4) 
    temp_pseudo_f = 2 * (temp_p * temp_pseudo_r) / (temp_p + temp_pseudo_r + 1e-4)
    return temp_pseudo_f


def DRD(predict_img_,GT_img_):
    xm = GT_img_.shape[0]
    ym = GT_img_.shape[1]
    blkSize=8 
    MaskSize=5 
    u0_GT1 = np.zeros((xm+2,ym+2)) 
    u0_GT1[1 : xm + 1, 1 : ym + 1] = GT_img_
    intim = np.cumsum(np.cumsum(u0_GT1, 0), 1)
    NUBN = 0
    blkSizeSQR = blkSize * blkSize
    counter = 0
    for i in range(1,(xm - blkSize + 1),blkSize): 
        for j in range(1,(ym - blkSize + 1),blkSize): 
            
            blkSum=intim[i + blkSize - 1, j + blkSize - 1] - intim[i - 1, j + blkSize - 1] - intim[i + blkSize - 1, j - 1] + intim[i - 1, j -1] 
            if blkSum == 0:
                pass
            elif blkSum == blkSizeSQR: 
                counter += 1;
                pass
            else: 
                NUBN = NUBN + 1
    wm = np.zeros((MaskSize, MaskSize))
    ic = int((MaskSize + 1) / 2 ) 
    jc = ic 
    for i in range(0,MaskSize): 
        for j in range(0,MaskSize): 
            num = math.sqrt((i+1 - ic) * (i+1 - ic) + (j+1 - jc) * (j+1 - jc))
            if num == 0: 
                wm[i, j]=0
            else: 
                wm[i, j] = 1 / num
    wnm = wm / sum(sum(wm)) 
    u0_GT_Resized = np.zeros((xm + ic + 1, ym + jc + 1)) 
    u0_GT_Resized[ic-1 : xm + ic - 1, jc-1 : ym + jc - 1]= GT_img_
    u_Resized = np.zeros((xm + ic + 1, ym + jc + 1)) 
    u_Resized[ic-1 : xm + ic - 1, jc-1 : ym + jc - 1] = predict_img_
    temp_fp_Resized = (1-u_Resized) * u0_GT_Resized 
    temp_fn_Resized = u_Resized * (1-u0_GT_Resized) 
    Diff = temp_fp_Resized+temp_fn_Resized 
    Diff[Diff==0] = 0 
    Diff[Diff>0] = 1 
    xm2 = Diff.shape[0] 
    ym2 = Diff.shape[1] 
    SumDRDk = 0
    def my_xor_infile(u_infile, u0_GT_infile): 
        temp_fp_infile = (1-u_infile) * u0_GT_infile 
        temp_fn_infile = u_infile * (1-u0_GT_infile) 
        temp_xor_infile = temp_fp_infile + temp_fn_infile 
        temp_xor_infile[temp_xor_infile==0] = 0 
        temp_xor_infile[temp_xor_infile>0] = 1 
        return temp_xor_infile
    for i in range(ic-1,xm2 - ic + 1): 
        for j in range(jc-1,ym2 - jc + 1): 
            if Diff[i,j] == 1: 
                Local_Diff = my_xor_infile(u0_GT_Resized[i - ic +1 : i + ic  , j - ic+1 : j + ic ], u_Resized[i, j]) 
                DRDk = sum(sum(Local_Diff * wnm)) 
                SumDRDk = SumDRDk + DRDk       
    temp_DRD = SumDRDk / (NUBN + 1e-4)
    return temp_DRD

In [None]:
logger = transformers.logging.get_logger()
# metric = evaluate.load(FLAGS["metric"])

model_kwargs = {
    "id2label": id2label, 
    "label2id": label2id,
    "ignore_mismatched_sizes": True
}

model = transformers.SegformerForSemanticSegmentation.from_pretrained(
    FLAGS["base_model_name"], **model_kwargs
)


def compute_metrics(outputs: transformers.EvalPrediction) -> typing.Dict[str, float]:
    with torch.no_grad():
        logits, labels = outputs
        logits = torch.from_numpy(logits)

        # It upscales the logits to the size of the label
        logits = nn.functional.interpolate(
            logits,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        predictions = logits.detach().cpu().numpy()

        batch_size, height, width = logits.shape

        npixel = height*width

        fmeasures = []
        pfmeasures = []
        psnrs = []
        drds = []

        for idx in range(batch_size):
            im = predictions[idx]
            im_gt = labels[idx]

            fmeasure = Fmeasure(im, im_gt)
            psnr = Psnr(im, im_gt)
            pfmeasure = Pfmeasure(im, im_gt)
            drd = DRD(im, im_gt)

            fmeasures.append(fmeasure)
            pfmeasures.append(pfmeasure)
            psnrs.append(psnr)
            drds.append(drd)

        batch_fmeasure = np.mean(fmeasures)
        batch_pfmeasure = np.mean(pfmeasures)
        batch_psnr = np.mean(psnrs)
        batch_drd = np.mean(drds)

        metrics = {
            "fmeasure": batch_fmeasure,
            "pfmeasure": batch_pfmeasure,
            "psnr": batch_psnr,
            "drd": batch_drd
        }

        return metrics

        # # FIXME: For more information, see
        # # https://github.com/huggingface/evaluate/pull/328#issuecomment-1286866576
        # metrics = metric._compute(
        #         predictions=predictions,
        #         references=labels,
        #         num_labels=num_labels,
        #         ignore_index=0,  # The background info is ignored
        #         reduce_labels=processor.do_reduce_labels,
        #     )
        
        # # It adds per-category metrics as separate key-val pairs
        # per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
        # per_category_iou = metrics.pop("per_category_iou").tolist()

        # metrics.update({f"accuracy_{id2label[key]}": val for key, val in enumerate(per_category_accuracy)})
        # metrics.update({f"iou_{id2label[key]}": val for key, val in enumerate(per_category_iou)})
        
        # return metrics


training_args = transformers.TrainingArguments(
    auto_find_batch_size=True,
    eval_accumulation_steps=FLAGS["accumulation_steps"],
    eval_steps=10,
    evaluation_strategy="steps",
    fp16=False,
    full_determinism=False,
    gradient_accumulation_steps=FLAGS["accumulation_steps"],
    hub_model_id=FLAGS["model_name"],
    hub_strategy="end",
    learning_rate=FLAGS["learning_rate"],
    load_best_model_at_end=True,
    logging_steps=5,
    lr_scheduler_type=FLAGS["scheduler_type"],
    num_train_epochs=FLAGS["num_epochs"],
    optim=FLAGS["optimizer"],
    output_dir=FLAGS["model_name"],
    per_device_eval_batch_size=FLAGS["batch_size"],
    per_device_train_batch_size=FLAGS["batch_size"],
    push_to_hub=True,
    remove_unused_columns=False,  # https://discuss.huggingface.co/t/divide-by-zero-error-when-following-ch7-tutorial/18393/6
    report_to="tensorboard",
    save_steps=10,
    save_strategy="steps",
    save_total_limit=3,
    seed=10,
    warmup_steps=50,
)

# decay_parameters = get_parameter_names(model, [nn.LayerNorm])
# decay_parameters = [name for name in decay_parameters if "bias" not in name]

# c = [
#     {
#         "params": [
#              param for name, param in model.named_parameters() 
#              if name in decay_parameters
#         ],
#         "weight_decay": training_args.weight_decay,
#     },
#     {
#         "params": [
#              param for name, param in model.named_parameters() 
#              if name in decay_parameters
#         ],
#         "weight_decay": 0.0,
#     },
# ]

# optim_kwargs = {
#     "betas": (training_args.adam_beta1, training_args.adam_beta2),
#     "eps": training_args.adam_epsilon,
# }
# optim_kwargs["lr"] = training_args.learning_rate

# adam_8bit_optim = bitsandbytes.optim.Adam8bit(
#     optim_kwargs,
#     betas=(training_args.adam_beta1, training_args.adam_beta2),
#     eps=training_args.adam_epsilon,
#     lr=training_args.learning_rate,
# )

callbacks = [
    transformers.EarlyStoppingCallback(early_stopping_patience=5)
]

trainer = transformers.Trainer(
    args=training_args,
    callbacks=callbacks,
    compute_metrics=compute_metrics,
    eval_dataset=test_dataset,
    model=model,    
    train_dataset=train_dataset,
    # optimizers=(adam_8bit_optim, None)
)

try:
    checkpoint = get_last_checkpoint(FLAGS["model_name"])
except FileNotFoundError:
    logger.debug("No checkpoint")
    checkpoint = None

resume_from_checkpoint = checkpoint is not None

trainer.train(resume_from_checkpoint=resume_from_checkpoint)

kwargs = {
    "finetuned_from": FLAGS["base_model_name"],
    "tags": [
        "document-image-binarization"
        "image-segmentation"
    ]
}

trainer.push_to_hub(**kwargs)

## 4. Inference

For a complete example, see T. Cornille's official Segformer [blog post](https://huggingface.co/blog/fine-tune-segformer)