<a href="https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/TrOCR/Fine_tune_TrOCR_on_IAM_Handwriting_Database_using_native_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import pandas as pd
from pathlib import Path
from tqdm.notebook import tqdm

# get all the file names
file_paths = [file_path for file_path in list(Path('./tibetan-dataset/LhasaKanjur/transcript/').iterdir()) if file_path.suffix == '.csv']

dfs = []

for file_path in tqdm(file_paths):
    batch_name = file_path.name.removesuffix('.csv')
    df = pd.read_csv(str(file_path), sep=',')
    df['batch_name'] = batch_name
    dfs.append(df)
    
df = pd.concat(dfs, ignore_index=True)

# change the column name line_image_id to file_name
df.rename(columns={'line_image_id': 'file_name'}, inplace=True)

# some file names end with jp instead of jpg, let's fix this
df.head()

FileNotFoundError: [Errno 2] No such file or directory: 'tibetan-dataset/transcript'

In [None]:
import pandas as pd
from pathlib import Path
from tqdm.notebook import tqdm

In [None]:
dfs = []
file_path = Path(f"./trocr/tibetan-dataset/train_uni.csv")
train_df = pd.read_csv(str(file_path), sep=',')
train_df = pd.concat(dfs, ignore_index=True)
train_df.rename(columns={'line_image_id': 'file_name'}, inplace=True)
train_df.head()

In [None]:
dfs = []
file_path = Path(f"./trocr/tibetan-dataset/eval_uni.csv")
eval_df = pd.read_csv(str(file_path), sep=',')
eval_df = pd.concat(dfs, ignore_index=True)
eval_df.rename(columns={'line_image_id': 'file_name'}, inplace=True)
eval_df.head()

Run the below cell and if there are any file names that don't exist in the folder, it will print the file name. It's recommended to remove the rows that has the file name that doesn't exist in the folder. You can do this by uncommenting the last line in the below cell.

Each element of the dataset should return 2 things:
* `pixel_values`, which serve as input to the model.
* `labels`, which are the `input_ids` of the corresponding text in the image.

We use `TrOCRProcessor` to prepare the data for the model. `TrOCRProcessor` is actually just a wrapper around a `ViTFeatureExtractor` (which can be used to resize + normalize images) and a `RobertaTokenizer` (which can be used to encode and decode text into/from `input_ids`). 

In [None]:
import torch
from torch.utils.data import Dataset


class TibetanImageLinePairDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=512):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # get file name + text 
        file_name = self.df['file_name'][idx]
        folder_name = self.df['batch_name'][idx] + '/'
        text = self.df['text'][idx]
        # prepare image (i.e. resize + normalize)
        image = Image.open(self.root_dir + folder_name + file_name).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text, 
                                          padding="max_length", 
                                          max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

In [None]:
encode, decode = "google/vit-base-patch16-224-in21k", "sangjeedondrub/tibetan-roberta-base"

Let's initialize the training and evaluation datasets:

In [None]:
from transformers import TrOCRProcessor, ViTImageProcessor, RobertaTokenizer

feature_extractor=ViTImageProcessor.from_pretrained(encode)
tokenizer = RobertaTokenizer.from_pretrained(decode)
print(tokenizer.vocab_size)
processor = TrOCRProcessor(image_processor=feature_extractor, tokenizer=tokenizer)
train_dataset = TibetanImageLinePairDataset(root_dir='./tibetan-dataset/train/',
                                            df=train_df[:len(train_df)],
                                            processor=processor)
eval_dataset = TibetanImageLinePairDataset(root_dir='./tibetan-dataset/train/',
                                           df=test_df[:len(test_df)],
                                           processor=processor)

In [None]:
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(eval_dataset))

Let's verify an example from the training dataset:

In [None]:
from PIL import Image
encoding = train_dataset[0]
for k,v in encoding.items():
  print(k, v.shape)

We can also check the original image and decode the labels:

In [None]:
train_df['file_name'][0]

In [None]:
image = Image.open(train_dataset.root_dir + train_df['batch_name'][0] + '/' + train_df['file_name'][0]).convert("RGB")
image

In [None]:
labels = encoding['labels']
labels[labels == -100] = processor.tokenizer.pad_token_id
label_str = processor.decode(labels, skip_special_tokens=True)
print(label_str)

Let's create corresponding dataloaders:

## Prepare data

Let's make a [regular PyTorch dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). We first create a Pandas dataframe with 2 columns. Each row consists of the file name of an image, and the corresponding text.

In [None]:

from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=4)

## Train a model

Here, we initialize the TrOCR model from its pretrained weights. Note that the weights of the language modeling head are already initialized from pre-training, as the model was already trained to generate text during its pre-training stage. Refer to the paper for details.

In [None]:
use_existing_model = False
date = None

In [None]:
from transformers import VisionEncoderDecoderModel
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
if use_existing_model:
    if date is not None:
        model = VisionEncoderDecoderModel.from_pretrained(f"best_model_{date}")
    else:
        model = VisionEncoderDecoderModel.from_pretrained("test")
else:
    model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(encode, decode)

model.to(device)

assert model.config.decoder.is_decoder is True
assert model.config.decoder.add_cross_attention is True

Importantly, we need to set a couple of attributes, namely:
* the attributes required for creating the `decoder_input_ids` from the `labels` (the model will automatically create the `decoder_input_ids` by shifting the `labels` one position to the right and prepending the `decoder_start_token_id`, as well as replacing ids which are -100 by the pad_token_id)
* the vocabulary size of the model (for the language modeling head on top of the decoder)
* beam-search related parameters which are used when generating text.

In [None]:
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 512
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.5
model.config.num_beams = 4


We will evaluate the model on the Character Error Rate (CER), which is available in HuggingFace Datasets (see [here](https://huggingface.co/metrics/cer)).

In [None]:
from datasets import load_metric

cer_metric = load_metric("cer")

In [None]:
def compute_cer(pred_ids, label_ids):
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return cer

In [None]:
from transformers import AdamW
from tqdm.notebook import tqdm
import datetime

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
if use_existing_model and date is not None:
    optimizer.load_state_dict(torch.load(f"best_model_optimizer_{date}.pt", map_location=device))
best_cer = 10.0
date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

for epoch in range(10):  # loop over the dataset multiple times
   # train
   
   model.train()
   train_loss = 0.0
   for batch in tqdm(train_dataloader):
      # get the inputs
      batch = {k: v.to(device) for k, v in batch.items()}

      # forward + backward + optimize
      outputs = model(**batch)
      loss = outputs.loss
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

      train_loss += loss.item()

   print(f"Loss after epoch {epoch}:", train_loss/len(train_dataloader))
    
   # evaluate
   model.eval()
   valid_cer = 0.0
   with torch.no_grad():
     for batch in tqdm(eval_dataloader):
       batch = {k: v.to(device) for k, v in batch.items()}
       # run batch generation
       outputs = model.generate(batch["pixel_values"])
       # compute metrics
       cer = compute_cer(pred_ids=outputs, label_ids=batch["labels"])
       valid_cer += cer 

   current_cer = valid_cer / len(eval_dataloader)
   print("Validation CER:", current_cer)
   if current_cer < best_cer:
       print('Updating the best model')
       best_cer = current_cer
       model.save_pretrained(f"best_model_{date}")
       torch.save(optimizer.state_dict(), f"best_model_optimizer_{date}.pt")
   

model.save_pretrained(f"test_{date}")

## Inference

Note that after training, you can easily load the model using the .`from_pretrained(output_dir)` method.

For inference on new images, I refer to my inference notebook, that can also be found in my [Transformers Tutorials repository](https://github.com/NielsRogge/Transformers-Tutorials) on Github.