<a href="https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/LiLT/Fine_tune_LiltForTokenClassification_on_FUNSD_(nielsr_funsd).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Set-up environment

In [1]:
!pip install -q git+https://github.com/huggingface/transformers.git

rm: cannot remove 'transformers': No such file or directory
Cloning into 'transformers'...
remote: Enumerating objects: 142169, done.[K
remote: Counting objects: 100% (91/91), done.[K
remote: Compressing objects: 100% (35/35), done.[K
remote: Total 142169 (delta 52), reused 77 (delta 49), pack-reused 142078[K
Receiving objects: 100% (142169/142169), 118.37 MiB | 16.63 MiB/s, done.
Resolving deltas: 100% (104961/104961), done.
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone

In [2]:
!pip install -q datasets

[K     |████████████████████████████████| 432 kB 23.8 MB/s 
[K     |████████████████████████████████| 115 kB 74.0 MB/s 
[K     |████████████████████████████████| 212 kB 75.1 MB/s 
[K     |████████████████████████████████| 127 kB 65.3 MB/s 
[?25h

## Load dataset

In [3]:
from datasets import load_dataset

dataset = load_dataset("nielsr/funsd")

Downloading builder script:   0%|          | 0.00/4.54k [00:00<?, ?B/s]

Downloading and preparing dataset funsd/funsd to /root/.cache/huggingface/datasets/nielsr___funsd/funsd/1.0.0/8b0472b536a2dcb975d59a4fb9d6fea4e6a1abe260b7fed6f75301e168cbe595...


Downloading data:   0%|          | 0.00/16.8M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Dataset funsd downloaded and prepared to /root/.cache/huggingface/datasets/nielsr___funsd/funsd/1.0.0/8b0472b536a2dcb975d59a4fb9d6fea4e6a1abe260b7fed6f75301e168cbe595. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'words', 'bboxes', 'ner_tags', 'image_path'],
        num_rows: 149
    })
    test: Dataset({
        features: ['id', 'words', 'bboxes', 'ner_tags', 'image_path'],
        num_rows: 50
    })
})

In [5]:
dataset["train"].features

{'id': Value(dtype='string', id=None),
 'words': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
 'bboxes': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'ner_tags': Sequence(feature=ClassLabel(num_classes=7, names=['O', 'B-HEADER', 'I-HEADER', 'B-QUESTION', 'I-QUESTION', 'B-ANSWER', 'I-ANSWER'], id=None), length=-1, id=None),
 'image_path': Value(dtype='string', id=None)}

In [6]:
labels = dataset["train"].features['ner_tags'].feature.names
id2label = {id:label for id, label in enumerate(labels)}
label2id = {label:id for id, label in enumerate(labels)}
print(id2label)

{0: 'O', 1: 'B-HEADER', 2: 'I-HEADER', 3: 'B-QUESTION', 4: 'I-QUESTION', 5: 'B-ANSWER', 6: 'I-ANSWER'}


In [7]:
example = dataset["train"][0]
print(example["words"])
print(example["bboxes"])
print(example["ner_tags"])

['R&D', ':', 'Suggestion:', 'Date:', 'Licensee', 'Yes', 'No', '597005708', 'R&D', 'QUALITY', 'IMPROVEMENT', 'SUGGESTION/', 'SOLUTION', 'FORM', 'Name', '/', 'Phone', 'Ext.', ':', 'M.', 'Hamann', 'P.', 'Harper,', 'P.', 'Martinez', '9/', '3/', '92', 'R&D', 'Group:', 'J.', 'S.', 'Wigand', 'Supervisor', '/', 'Manager', 'Discontinue', 'coal', 'retention', 'analyses', 'on', 'licensee', 'submitted', 'product', 'samples', '(Note', ':', 'Coal', 'Retention', 'testing', 'is', 'not', 'performed', 'by', 'most', 'licensees.', 'Other', 'B&W', 'physical', 'measurements', 'as', 'ends', 'stability', 'and', 'inspection', 'for', 'soft', 'spots', 'in', 'ciparettes', 'are', 'thought', 'to', 'be', 'sufficient', 'measures', 'to', 'assure', 'cigarette', 'physical', 'integrity.', 'The', 'proposed', 'action', 'will', 'increase', 'laboratory', 'productivity', '.', ')', 'Suggested', 'Solutions', '(s)', ':', 'Delete', 'coal', 'retention', 'from', 'the', 'list', 'of', 'standard', 'analyses', 'performed', 'on', 'licen

## Transform dataset

In [8]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlmv3-base")

Downloading:   0%|          | 0.00/1.14k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

In [9]:
def prepare_examples(batch):
  encoding = tokenizer(batch["words"],
                        boxes=batch["bboxes"],
                        word_labels=batch["ner_tags"],
                        padding="max_length",
                        max_length=128,
                        truncation=True,
                        return_tensors="pt")
  
  return encoding

dataset.set_transform(prepare_examples)

In [10]:
example = dataset["train"][0]
print(example.keys())

dict_keys(['input_ids', 'attention_mask', 'bbox', 'labels'])


In [11]:
tokenizer.decode(example["input_ids"])

'<s> R&D : Suggestion: Date: Licensee Yes No 597005708 R&D QUALITY IMPROVEMENT SUGGESTION/ SOLUTION FORM Name / Phone Ext. : M. Hamann P. Harper, P. Martinez 9/ 3/ 92 R&D Group: J. S. Wigand Supervisor / Manager Discontinue coal retention analyses on licensee submitted product samples (Note : Coal Retention testing is not performed by most licensees. Other B&W physical measurements as ends stability and inspection for soft spots in ciparettes are thought to be sufficient measures to assure cigarette</s>'

In [12]:
for id, box, label in zip(example["input_ids"].tolist(),
                          example["bbox"].tolist(),
                          example["labels"].tolist()):
  if label != -100:
    print(tokenizer.decode([id]), box, id2label[label])
  else:
    print(tokenizer.decode([id]), box, label)

<s> [0, 0, 0, 0] -100
 R [383, 91, 493, 175] O
& [383, 91, 493, 175] -100
D [383, 91, 493, 175] -100
 : [287, 316, 295, 327] B-QUESTION
 Suggest [124, 355, 221, 370] B-QUESTION
ion [124, 355, 221, 370] -100
: [124, 355, 221, 370] -100
 Date [632, 268, 679, 282] B-QUESTION
: [632, 268, 679, 282] -100
 License [670, 309, 748, 323] B-ANSWER
e [670, 309, 748, 323] -100
 Yes [604, 605, 633, 619] B-QUESTION
 No [715, 603, 738, 617] B-QUESTION
 5 [688, 904, 841, 926] O
97 [688, 904, 841, 926] -100
005 [688, 904, 841, 926] -100
708 [688, 904, 841, 926] -100
 R [337, 203, 366, 214] B-HEADER
& [337, 203, 366, 214] -100
D [337, 203, 366, 214] -100
 QU [374, 203, 438, 216] I-HEADER
AL [374, 203, 438, 216] -100
ITY [374, 203, 438, 216] -100
 IM [447, 201, 548, 211] I-HEADER
PROV [447, 201, 548, 211] -100
EMENT [447, 201, 548, 211] -100
 S [335, 215, 425, 229] I-HEADER
UG [335, 215, 425, 229] -100
G [335, 215, 425, 229] -100
EST [335, 215, 425, 229] -100
ION [335, 215, 425, 229] -100
/ [335, 215, 42

## Create PyTorch Dataloaders

In [13]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(dataset["train"], batch_size=2, shuffle=True)
test_dataloader = DataLoader(dataset["test"], batch_size=2, shuffle=True)

In [14]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)

input_ids torch.Size([2, 128])
attention_mask torch.Size([2, 128])
bbox torch.Size([2, 128, 4])
labels torch.Size([2, 128])


## Load model

In [15]:
from transformers import AutoModelForTokenClassification

model = AutoModelForTokenClassification.from_pretrained("nielsr/lilt-roberta-en-base", id2label=id2label, label2id=label2id)

Downloading:   0%|          | 0.00/697 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/523M [00:00<?, ?B/s]

Some weights of LiltForTokenClassification were not initialized from the model checkpoint at nielsr/lilt-roberta-en-base and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Train!

In [16]:
!pip install -q evaluate seqeval

[K     |████████████████████████████████| 69 kB 7.2 MB/s 
[K     |████████████████████████████████| 43 kB 1.8 MB/s 
[?25h  Building wheel for seqeval (setup.py) ... [?25l[?25hdone


In [17]:
import evaluate
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Metric
metric = evaluate.load("seqeval")

def get_labels(predictions, references):
    # Transform predictions and references tensors to numpy arrays
    if device.type == "cpu":
        y_pred = predictions.detach().clone().numpy()
        y_true = references.detach().clone().numpy()
    else:
        y_pred = predictions.detach().cpu().clone().numpy()
        y_true = references.detach().cpu().clone().numpy()

    # Remove ignored index (special tokens)
    true_predictions = [
        [labels[p] for (p, l) in zip(pred, gold_label) if l != -100]
        for pred, gold_label in zip(y_pred, y_true)
    ]
    true_labels = [
        [labels[l] for (p, l) in zip(pred, gold_label) if l != -100]
        for pred, gold_label in zip(y_pred, y_true)
    ]
    return true_predictions, true_labels

Downloading builder script:   0%|          | 0.00/6.34k [00:00<?, ?B/s]

In [18]:
from torch.optim import AdamW
from tqdm.auto import tqdm

optimizer = AdamW(model.parameters(), lr=5e-5)

model.to(device)

for epoch in range(50):
  print("Epoch:", epoch+1)
  for idx, batch in enumerate(tqdm(train_dataloader)):
      # move batch to device
      batch = {k:v.to(device) for k,v in batch.items()}
      outputs = model(**batch)

      predictions = outputs.logits.argmax(-1)
      true_predictions, true_labels = get_labels(predictions, batch["labels"])
      metric.add_batch(references=true_labels, predictions=true_predictions)

      loss = outputs.loss

      if idx % 100 == 0:
        print("Loss:", loss.item())
        results = metric.compute()
        print("Overall f1:", results["overall_f1"])
        print("Overall precision:", results["overall_f1"])
        print("Overall recall:", results["overall_recall"])

      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

Epoch: 1


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 2.00651216506958
Overall f1: 0.032
Overall precision: 0.032
Overall recall: 0.037037037037037035
Epoch: 2


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 1.1887439489364624
Overall f1: 0.48512585812356973
Overall precision: 0.48512585812356973
Overall recall: 0.5553892215568862
Epoch: 3


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.30801185965538025
Overall f1: 0.746275551449359
Overall precision: 0.746275551449359
Overall recall: 0.7915237628613425
Epoch: 4


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.3589589297771454
Overall f1: 0.8290979164157533
Overall precision: 0.8290979164157533
Overall recall: 0.8566450970632156
Epoch: 5


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.06351487338542938
Overall f1: 0.8925579701347577
Overall precision: 0.8925579701347577
Overall recall: 0.9121588089330025
Epoch: 6


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.08621978759765625
Overall f1: 0.9389312977099237
Overall precision: 0.9389312977099237
Overall recall: 0.950161973585846
Epoch: 7


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.020198002457618713
Overall f1: 0.9644694927713795
Overall precision: 0.9644694927713795
Overall recall: 0.9708929452392698
Epoch: 8


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.014185592532157898
Overall f1: 0.968711199404023
Overall precision: 0.968711199404023
Overall recall: 0.9747626186906547
Epoch: 9


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.09443214535713196
Overall f1: 0.977924944812362
Overall precision: 0.977924944812362
Overall recall: 0.9800884955752213
Epoch: 10


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.005804331507533789
Overall f1: 0.9799256505576208
Overall precision: 0.9799256505576208
Overall recall: 0.9830929885629041
Epoch: 11


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.006674524862319231
Overall f1: 0.9818406423718344
Overall precision: 0.9818406423718344
Overall recall: 0.9841505695889053
Epoch: 12


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.17881818115711212
Overall f1: 0.9746819809806102
Overall precision: 0.9746819809806102
Overall recall: 0.9801291604570294
Epoch: 13


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.11992207914590836
Overall f1: 0.9413356290739147
Overall precision: 0.9413356290739147
Overall recall: 0.9522269221199303
Epoch: 14


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.016758734360337257
Overall f1: 0.9520758901071825
Overall precision: 0.9520758901071825
Overall recall: 0.9604772557792692
Epoch: 15


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.01718759350478649
Overall f1: 0.9544065517662876
Overall precision: 0.9544065517662876
Overall recall: 0.9618132544961814
Epoch: 16


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.01741744391620159
Overall f1: 0.9649602385685885
Overall precision: 0.9649602385685885
Overall recall: 0.9692960559161258
Epoch: 17


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.0030022652354091406
Overall f1: 0.9863081287775995
Overall precision: 0.9863081287775995
Overall recall: 0.9886251236399605
Epoch: 18


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.02549370750784874
Overall f1: 0.9913280475718533
Overall precision: 0.9913280475718533
Overall recall: 0.9918195339613287
Epoch: 19


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.010269198566675186
Overall f1: 0.9981391886862672
Overall precision: 0.9981391886862672
Overall recall: 0.9980153807988092
Epoch: 20


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.061909887939691544
Overall f1: 0.9960129578868676
Overall precision: 0.9960129578868676
Overall recall: 0.9970067348465952
Epoch: 21


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.007349176798015833
Overall f1: 0.9954517516902275
Overall precision: 0.9954517516902275
Overall recall: 0.9965542702436623
Epoch: 22


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.00041503208922222257
Overall f1: 0.9981224183251972
Overall precision: 0.9981224183251972
Overall recall: 0.9984973703981969
Epoch: 23


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.0003140615299344063
Overall f1: 0.9987736080451312
Overall precision: 0.9987736080451312
Overall recall: 0.9992638036809816
Epoch: 24


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.0034247231669723988
Overall f1: 0.9934951213410056
Overall precision: 0.9934951213410056
Overall recall: 0.9939924906132666
Epoch: 25


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.03678509593009949
Overall f1: 0.9827943868212324
Overall precision: 0.9827943868212324
Overall recall: 0.985078277886497
Epoch: 26


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.002256282838061452
Overall f1: 0.982372796599575
Overall precision: 0.982372796599575
Overall recall: 0.9852056168505516
Epoch: 27


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.0027514216490089893
Overall f1: 0.9908570035352919
Overall precision: 0.9908570035352919
Overall recall: 0.9919453258481816
Epoch: 28


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.0011266421061009169
Overall f1: 0.984359233097881
Overall precision: 0.984359233097881
Overall recall: 0.9848561332660273
Epoch: 29


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.061145029962062836
Overall f1: 0.9871858058156727
Overall precision: 0.9871858058156727
Overall recall: 0.9886475814412635
Epoch: 30


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.0009299630764871836
Overall f1: 0.9936559273541485
Overall precision: 0.9936559273541485
Overall recall: 0.9945219123505976
Epoch: 31


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.0004975183983333409
Overall f1: 0.9980183304433986
Overall precision: 0.9980183304433986
Overall recall: 0.9982656095143707
Epoch: 32


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.09349943697452545
Overall f1: 0.9858701041150223
Overall precision: 0.9858701041150223
Overall recall: 0.9875838092873106
Epoch: 33


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.027504080906510353
Overall f1: 0.9768643859217329
Overall precision: 0.9768643859217329
Overall recall: 0.9809688581314879
Epoch: 34


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.0768735483288765
Overall f1: 0.9153864910997318
Overall precision: 0.9153864910997318
Overall recall: 0.9296681525507677
Epoch: 35


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.009324297308921814
Overall f1: 0.9304262616364527
Overall precision: 0.9304262616364527
Overall recall: 0.940797621996532
Epoch: 36


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.00737017672508955
Overall f1: 0.9703576678098971
Overall precision: 0.9703576678098971
Overall recall: 0.9760965993100049
Epoch: 37


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.10725349187850952
Overall f1: 0.977343072923115
Overall precision: 0.977343072923115
Overall recall: 0.9815966177567769
Epoch: 38


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.0010490071726962924
Overall f1: 0.9861042183622829
Overall precision: 0.9861042183622829
Overall recall: 0.9892954941498631
Epoch: 39


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.003957876469939947
Overall f1: 0.9959112873249908
Overall precision: 0.9959112873249908
Overall recall: 0.9965286387304736
Epoch: 40


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.0008391257142648101
Overall f1: 0.9966670781385014
Overall precision: 0.9966670781385014
Overall recall: 0.9967901234567901
Epoch: 41


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.0006110325339250267
Overall f1: 0.9994971083731455
Overall precision: 0.9994971083731455
Overall recall: 0.9992458521870287
Epoch: 42


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.0005302277277223766
Overall f1: 1.0
Overall precision: 1.0
Overall recall: 1.0
Epoch: 43


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.0005414157058112323
Overall f1: 1.0
Overall precision: 1.0
Overall recall: 1.0
Epoch: 44


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.00025672270567156374
Overall f1: 1.0
Overall precision: 1.0
Overall recall: 1.0
Epoch: 45


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.0002690097317099571
Overall f1: 1.0
Overall precision: 1.0
Overall recall: 1.0
Epoch: 46


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.0002749921986833215
Overall f1: 1.0
Overall precision: 1.0
Overall recall: 1.0
Epoch: 47


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.00021576719882432371
Overall f1: 1.0
Overall precision: 1.0
Overall recall: 1.0
Epoch: 48


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.0001653546787565574
Overall f1: 1.0
Overall precision: 1.0
Overall recall: 1.0
Epoch: 49


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.00024371480685658753
Overall f1: 1.0
Overall precision: 1.0
Overall recall: 1.0
Epoch: 50


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.00015494016406591982
Overall f1: 1.0
Overall precision: 1.0
Overall recall: 1.0


## Evaluate

In [19]:
from tqdm.auto import tqdm

eval_metric = evaluate.load("seqeval")

for idx, batch in enumerate(tqdm(test_dataloader)):
    # move batch to device
    batch = {k:v.to(device) for k,v in batch.items()}
    with torch.no_grad():
      outputs = model(**batch)

    predictions = outputs.logits.argmax(-1)
    true_predictions, true_labels = get_labels(predictions, batch["labels"])
    eval_metric.add_batch(references=true_labels, predictions=true_predictions)

  0%|          | 0/25 [00:00<?, ?it/s]

In [20]:
results = eval_metric.compute()
results["overall_f1"]

0.7719969395562356

In [21]:
results["overall_accuracy"]

0.7827023699599877