<a href="https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/TrOCR/Evaluating_TrOCR_base_handwritten_on_the_IAM_test_set.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Set-up environment

In [1]:
!pip install -q git+https://github.com/huggingface/transformers.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
[K     |████████████████████████████████| 596 kB 7.2 MB/s 
[K     |████████████████████████████████| 56 kB 6.4 MB/s 
[K     |████████████████████████████████| 895 kB 59.6 MB/s 
[K     |████████████████████████████████| 3.3 MB 48.1 MB/s 
[?25h  Building wheel for transformers (PEP 517) ... [?25l[?25hdone


In [2]:
!pip install -q datasets jiwer

[K     |████████████████████████████████| 290 kB 8.1 MB/s 
[K     |████████████████████████████████| 243 kB 67.2 MB/s 
[K     |████████████████████████████████| 125 kB 73.8 MB/s 
[K     |████████████████████████████████| 1.3 MB 54.3 MB/s 
[K     |████████████████████████████████| 50 kB 8.3 MB/s 
[K     |████████████████████████████████| 160 kB 57.2 MB/s 
[K     |████████████████████████████████| 271 kB 73.0 MB/s 
[?25h  Building wheel for python-Levenshtein (setup.py) ... [?25l[?25hdone


## Load IAM test set

In [3]:
import pandas as pd

df = pd.read_fwf('/content/drive/MyDrive/TrOCR/Tutorial notebooks/IAM/gt_test.txt', header=None)
df.rename(columns={0: "file_name", 1: "text"}, inplace=True)
del df[2]
df.head()

Unnamed: 0,file_name,text
0,c04-110-00.jpg,Become a success with a disc and hey presto ! ...
1,c04-110-01.jpg,"assuredness "" Bella Bella Marie "" ( Parlophone..."
2,c04-110-02.jpg,I don't think he will storm the charts with th...
3,c04-110-03.jpg,"CHRIS CHARLES , 39 , who lives in Stockton-on-..."
4,c04-116-00.jpg,He is also a director of a couple of garages ....


In [4]:
import torch
from torch.utils.data import Dataset
from PIL import Image

class IAMDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # get file name + text 
        file_name = self.df['file_name'][idx]
        text = self.df['text'][idx]
        # some file names end with jp instead of jpg, the two lines below fix this
        if file_name.endswith('jp'):
          file_name = file_name + 'g'
        # prepare image (i.e. resize + normalize)
        image = Image.open(self.root_dir + file_name).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text, 
                                          padding="max_length", 
                                          max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

In [5]:
from transformers import TrOCRProcessor

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
test_dataset = IAMDataset(root_dir='/content/drive/MyDrive/TrOCR/Tutorial notebooks/IAM/image/',
                           df=df,
                           processor=processor)

Downloading:   0%|          | 0.00/4.03k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/228 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/772 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

In [6]:
from torch.utils.data import DataLoader

test_dataloader = DataLoader(test_dataset, batch_size=8)

In [7]:
batch = next(iter(test_dataloader))

In [8]:
for k,v in batch.items():
  print(k, v.shape)

pixel_values torch.Size([8, 3, 384, 384])
labels torch.Size([8, 128])


In [9]:
from transformers import TrOCRProcessor

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")

In [10]:
labels = batch["labels"]
labels[labels == -100] = processor.tokenizer.pad_token_id
label_str = processor.batch_decode(labels, skip_special_tokens=True)
label_str

["Become a success with a disc and hey presto! You're a star.... Rolly sings with",
 'assuredness " Bella Bella Marie " ( Parlophone ), a lively song that changes tempo mid-way',
 "I don't think he will storm the charts with this one, but it's a good start.",
 'CHRIS CHARLES, 39, who lives in Stockton-on-Tees, is an accountant.',
 'He is also a director of a couple of garages. And he finds time as well to be a lyric',
 'writer. He writes with Tolchard Evans, composer of " Lady of Spain " and other big hits.',
 'Tolch, as he is known in Tin Pan Alley, likes songs with a month in the title. He wrote',
 '" My September Love, " the big David Whitfield hit of 1956.']

## Run evaluation

In [11]:
from transformers import VisionEncoderDecoderModel
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
model.to(device)

Downloading:   0%|          | 0.00/1.24G [00:00<?, ?B/s]

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


VisionEncoderDecoderModel(
  (encoder): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): PatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0): ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=False)
              (key): Linear(in_features=768, out_features=768, bias=False)
              (value): Linear(in_features=768, out_features=768, bias=False)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, 

In [12]:
from datasets import load_metric

cer = load_metric("cer")

Downloading:   0%|          | 0.00/1.91k [00:00<?, ?B/s]

In [13]:
from tqdm.notebook import tqdm

print("Running evaluation...")

for batch in tqdm(test_dataloader):
    # predict using generate
    pixel_values = batch["pixel_values"].to(device)
    outputs = model.generate(pixel_values)

    # decode
    pred_str = processor.batch_decode(outputs, skip_special_tokens=True)
    labels = batch["labels"]
    labels[labels == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels, skip_special_tokens=True)

    # add batch to metric
    cer.add_batch(predictions=pred_str, references=label_str)

final_score = cer.compute()

Running evaluation...


  0%|          | 0/365 [00:00<?, ?it/s]

In [14]:
print("Character error rate on test set:", final_score)

0.038336078808735505
