[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DiTo97/binarization-segformer/blob/main/fine-tuning.ipynb)

# Fine-tuning Segformer for Document Image Binarization

A notebook by F. Minutoli ([@DiTo97](https://github.com/DiTo97)) that fine-tunes a Segformer model for document image binarization

In [None]:
requirements = " ".join([
    "accelerate==0.18.0",
    "albumentations==1.3.0",
    "datasets==2.11.0",
    "evaluate==0.4.0",
    "huggingface-hub==0.13.4",
    "transformers==4.27.4"
])

!python -m pip install --upgrade pip
!python -m pip install $requirements

In [None]:
from huggingface_hub import notebook_login

In [None]:
notebook_login()

## 1. Dataset

In [None]:
!git clone https://github.com/Leedeng/SauvolaNet.git

In [None]:
import pathlib
import sys
import typing
from typing import Any

import datasets
import numpy as np
import numpy.typing as nptyping
from PIL import Image
from tqdm.auto import tqdm

In [None]:
def normalize(bitmap: Image.Image) -> Image.Image:
    bitmap = bitmap.convert("L")
    bitmap = np.array(bitmap).astype(np.uint8)
    condition = bitmap < np.max(bitmap)
    bitmap = np.where(condition, 1, 0).astype(np.bool_)
    bitmap = Image.fromarray(bitmap)

    return bitmap

In [None]:
def prepare_examples(
    batch: typing.Dict[str, typing.List[Any]]
) -> typing.Dict[str, typing.List[Any]]:
    """It prepares a batch of examples for semantic segmentation"""
    sources = batch["source"]
    targets = batch["target"]

    batch = {
        "labelmap": [normalize(Image.open(tgt)) for tgt in targets],
        "pixelmap": [Image.open(src) for src in sources]
    }

    return batch

In [None]:
sauvolanet_src = "SauvolaNet/SauvolaDocBin"
sauvolanet_dataset = "SauvolaNet/Dataset"

sys.path.insert(0, sauvolanet_src)
from dataUtils import collect_binarization_by_dataset

collection = collect_binarization_by_dataset(sauvolanet_dataset)

sys.path.remove(sauvolanet_src)

del sauvolanet_src
del sauvolanet_dataset
del collect_binarization_by_dataset

features = datasets.Features({
    "ensemble": datasets.Value("string"),
    "source": datasets.Value("string"),
    "target": datasets.Value("string"),
})

for name, examples in tqdm(collection.items(), desc="Loading datasets"):
    sources, targets = zip(*examples)

    sources = sorted(sources)
    targets = sorted(targets)

    dataset = {"source": sources, "target": targets, "ensemble": [name] * len(sources)}
    dataset = datasets.Dataset.from_dict(dataset, features)

    collection[name] = dataset

collection = datasets.concatenate_datasets([
    dataset for _, dataset in collection.items()
])

features = datasets.Features({
    "ensemble": datasets.Value("string"),
    "labelmap": datasets.Image(),
    "pixelmap": datasets.Image(),
})

collection = collection.map(
    prepare_examples, 
    batched=True,
    features=features, 
    remove_columns=["source", "target"]
)

collection = collection.class_encode_column("ensemble")

del features

collection = collection.train_test_split(
    seed=10,
    shuffle=True,
    stratify_by_column="ensemble",
    train_size=0.75
)

train_dataset = collection["train"]
test_dataset  = collection[ "test"]

del collection

In [None]:
labels = ["background", "text"]
num_labels = len(labels)

id2label = {key: val for key, val in enumerate(labels)}
label2id = {val: key for key, val in enumerate(labels)}

del labels

## 2. Augmentation

In [None]:
import albumentations
import cv2
import transformers
from transformers import set_seed

In [None]:
set_seed(10)

In [None]:
processor = transformers.SegformerImageProcessor()

FLAGS = {
    # The general kwargs
    "border_mode": cv2.BORDER_CONSTANT,
    "fill_value": 255,
    "mask_fill_value": 0,
    "proba": 0.1,

    # The color kwargs
    "brightness": 0.25, 
    "contrast": 0.25, 
    "saturation": 0.25, 
    "hue": 0.1,
    
    # The crop kwargs
    "min_height": processor.size["height"],
    "min_width" : processor.size[ "width"],
    
    # The geometric kwargs
    "rotate": (-90, 90),
    "translate_percent": 0.1
}

transform1 = albumentations.Compose([
    albumentations.ColorJitter(
        brightness=FLAGS["brightness"], 
        contrast=FLAGS["contrast"], 
        saturation=FLAGS["saturation"], 
        hue=FLAGS["hue"]
    )
])

transform2 = albumentations.Compose([
    albumentations.Flip(p=FLAGS["proba"]),
    albumentations.Affine(
        p=FLAGS["proba"],
        cval=FLAGS["fill_value"],
        cval_mask=FLAGS["mask_fill_value"],
        mode=FLAGS["border_mode"],
        rotate=FLAGS["rotate"], 
        translate_percent=FLAGS["translate_percent"],
    ),
    albumentations.PadIfNeeded(
        border_mode=FLAGS["border_mode"],
        mask_value=FLAGS["mask_fill_value"],
        min_height=FLAGS["min_height"], 
        min_width=FLAGS["min_width"], 
        value=FLAGS["fill_value"],
    ),
    albumentations.RandomCrop(
        p=FLAGS["proba"],
        height=FLAGS["min_height"], 
        width=FLAGS["min_width"],
    )
])

def train_transform(
    batch: typing.Dict[str, typing.List[Any]]
) -> transformers.BatchFeature:
    images = [image.convert("RGB") for image in batch["pixelmap"]]
    images = [np.array(image) for image in images]
    images = [transform1(image=image)["image"] for image in images]

    labels = [np.array(label).astype(np.uint8) for label in batch["labelmap"]]

    examples = [
        transform2(image=image, mask=mask) for image, mask in zip(images, labels)
    ]

    images = [example["image"] for example in examples]
    labels = [example[ "mask"] for example in examples]

    encoding = processor(images, labels)
    return encoding

def  test_transform(
    batch: typing.Dict[str, typing.List[Any]]
) -> transformers.BatchFeature:
    images = [image.convert("RGB") for image in batch["pixelmap"]]
    labels = [label for label in batch["labelmap"]]

    encoding = processor(images, labels)
    return encoding

train_dataset.set_transform(train_transform)
test_dataset.set_transform(test_transform)

## 3. Training

In [None]:
import evaluate
import torch
from torch import nn
from transformers.trainer_utils import get_last_checkpoint

In [None]:
cuda = torch.cuda.is_available()

In [None]:
datasets.logging.set_verbosity_error()
evaluate.logging.set_verbosity_error()
transformers.logging.set_verbosity_error()

In [None]:
FLAGS = {
    "accumulation_steps": 4,
    "base_model_name": "nvidia/mit-b0",
    "batch_size": 4,
    "fp16": cuda,
    "learning_rate": 5e-5,
    "metric": "mean_iou",
    "model_name": "segformer-b0-for-binarization",
    "num_epochs": 50,
    "optimizer": "adamw_torch",
    "scheduler_type": "cosine"
}

In [None]:
processor.push_to_hub(FLAGS["model_name"])

In [None]:
logger = transformers.logging.get_logger()
metric = evaluate.load(FLAGS["metric"])

model_kwargs = {
    "id2label": id2label, 
    "label2id": label2id
}

model = transformers.SegformerForSemanticSegmentation.from_pretrained(
    FLAGS["base_model_name"], **model_kwargs
)


def compute_metrics(outputs: transformers.EvalPrediction) -> typing.Dict[str, float]:
    with torch.no_grad():
        logits, labels = outputs
        logits = torch.from_numpy(logits)

        # It upscales the logits to the size of the label
        logits = nn.functional.interpolate(
            logits,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        predictions = logits.detach().cpu().numpy()

        # FIXME: For more information, see
        # https://github.com/huggingface/evaluate/pull/328#issuecomment-1286866576
        metrics = metric._compute(
                predictions=predictions,
                references=labels,
                num_labels=num_labels,
                ignore_index=0,  # The background info is ignored
                reduce_labels=processor.do_reduce_labels,
            )
        
        # It adds per-category metrics as separate key-val pairs
        per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
        per_category_iou = metrics.pop("per_category_iou").tolist()

        metrics.update({f"accuracy_{id2label[key]}": val for key, val in enumerate(per_category_accuracy)})
        metrics.update({f"iou_{id2label[key]}": val for key, val in enumerate(per_category_iou)})
        
        return metrics


training_args = transformers.TrainingArguments(
    auto_find_batch_size=True,
    eval_accumulation_steps=FLAGS["accumulation_steps"],
    eval_steps=10,
    evaluation_strategy="steps",
    fp16=FLAGS["fp16"],
    full_determinism=True,
    gradient_accumulation_steps=FLAGS["accumulation_steps"],
    hub_model_id=FLAGS["model_name"],
    hub_strategy="end",
    learning_rate=FLAGS["learning_rate"],
    load_best_model_at_end=True,
    logging_steps=5,
    lr_scheduler_type=FLAGS["scheduler_type"],
    num_train_epochs=FLAGS["num_epochs"],
    optim=FLAGS["optimizer"],
    output_dir=FLAGS["model_name"],
    per_device_eval_batch_size=FLAGS["batch_size"],
    per_device_train_batch_size=FLAGS["batch_size"],
    push_to_hub=True,
    remove_unused_columns=False,  # https://discuss.huggingface.co/t/divide-by-zero-error-when-following-ch7-tutorial/18393/6
    report_to="tensorboard",
    save_steps=10,
    save_strategy="steps",
    save_total_limit=3,
    seed=10,
    warmup_steps=10,
)

callbacks = [
    transformers.EarlyStoppingCallback(early_stopping_patience=5)
]

trainer = transformers.Trainer(
    args=training_args,
    callbacks=callbacks,
    compute_metrics=compute_metrics,
    eval_dataset=test_dataset,
    model=model,    
    train_dataset=train_dataset
)

try:
    checkpoint = get_last_checkpoint(FLAGS["model_name"])
except FileNotFoundError:
    logger.debug("No checkpoint")
    checkpoint = None

resume_from_checkpoint = checkpoint is not None

trainer.train(resume_from_checkpoint=resume_from_checkpoint)

kwargs = {
    "finetuned_from": FLAGS["base_model_name"],
    "tags": [
        "computer-vision",
        "document-image-binarization"
        "image-segmentation"
    ]
}

trainer.push_to_hub(**kwargs)

## 4. Inference

For a complete example, see T. Cornille's official Segformer [blog post](https://huggingface.co/blog/fine-tune-segformer)