<a href="https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/TrOCR/Fine_tune_TrOCR_on_IAM_Handwriting_Database_using_native_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [20]:
import evaluate

cer_metric = evaluate.load("cer")

  cer_metric = load_metric("cer")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [21]:
def compute_cer(pred_ids, label_ids):
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return cer

In [None]:
import os
from accelerate.utils import write_basic_config
write_basic_config() # Write a config file
os._exit(00) # Restart the notebook

In [None]:
from accelerate import Accelerator, ProjectConfiguration
import pandas as pd
import os
from pathlib import Path
from sklearn.model_selection import train_test_split    
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from accelerate.utils import set_seed

file_paths = [file_path for file_path in list(Path('./tibetan-dataset/transcript/').iterdir())
                  if file_path.suffix == '.csv']
file_paths = file_paths[:2]

dfs = []

for file_path in tqdm(file_paths):
    batch_name = file_path.name.removesuffix('.csv')
    df = pd.read_csv(str(file_path), sep=',')
    df['batch_name'] = batch_name
    dfs.append(df)

df = pd.concat(dfs, ignore_index=True)

# change the column name line_image_id to file_name
df.rename(columns={'line_image_id': 'file_name'}, inplace=True)

# remove the rows that their image file does not exist
original_image_name = ''
thumbs_up = True
potential_missing_files = []
print('started')
for row in tqdm(df.iterrows()):
    file_name = row[1]['file_name']
    this_original_image_name = file_name.split('_')[0]
    if this_original_image_name != original_image_name:
        original_image_name = this_original_image_name
        batch_name = row[1]['batch_name']
        if not os.path.isfile('./tibetan-dataset/train/' + batch_name + '/' + file_name):
            thumbs_up = False
            potential_missing_files.append(file_name)
        else:
            thumbs_up = True
    else:
        if not thumbs_up:
            potential_missing_files.append(file_name)

df = df[~df['file_name'].isin(potential_missing_files)]

In [None]:
from trocr.datautils import TibetanImageLinePairDataset
from sklearn.model_selection import train_test_split
from transformers import ViTImageProcessor, RobertaTokenizer, TrOCRProcessor
def get_dataloaders(batch_size):
    
    # split the data into training + testing
    train_df, test_df = train_test_split(df, test_size=0.2)
    # we reset the indices to start from zero
    train_df.reset_index(drop=True, inplace=True)
    test_df.reset_index(drop=True, inplace=True)
    
    # Create Dataset and DataLoader
    
    train_dataset = TibetanImageLinePairDataset(root_dir='./tibetan-dataset/train/',
                                                df=train_df[:len(train_df)],
                                                processor=processor)
    eval_dataset = TibetanImageLinePairDataset(root_dir='./tibetan-dataset/train/',
                                               df=test_df[:len(test_df)],
                                               processor=processor)

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size)
    
    
    return train_dataloader, eval_dataloader, processor
    
    

In [None]:
batch_size = 8
learning_rate = 5e-5
seed = 42
mixed_precision = "fp16"
encode, decode = "google/vit-base-patch16-224-in21k", "sangjeedondrub/tibetan-roberta-base"
    
feature_extractor=ViTImageProcessor.from_pretrained(encode)
tokenizer = RobertaTokenizer.from_pretrained(decode)
print(tokenizer.vocab_size)
processor = TrOCRProcessor(image_processor=feature_extractor, tokenizer=tokenizer)

In [None]:
import datetime

def training_function():
    set_seed(seed)
    config = ProjectConfiguration(project_dir="./project", logging_dir="./logging")
    accelerator = Accelerator(log_with="tensorboard", project_config=config, mixed_precision=mixed_precision)
    
    train_dataloader, eval_dataloader, processor = get_dataloaders(batch_size)
    
    
    from transformers import VisionEncoderDecoderModel
    import torch
    
    device = accelerator.device
    model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(encode, decode)
    model.to(device)
    
    # set special tokens used for creating the decoder_input_ids from the labels
    model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
    model.config.pad_token_id = processor.tokenizer.pad_token_id
    # make sure vocab size is set correctly
    model.config.vocab_size = model.config.decoder.vocab_size
    
    # set beam search parameters
    model.config.eos_token_id = processor.tokenizer.sep_token_id
    model.config.max_length = 512
    model.config.early_stopping = True
    model.config.no_repeat_ngram_size = 3
    model.config.length_penalty = 2.0
    model.config.num_beams = 4
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    best_cer = 10.0
    date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    
    model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader
    )
    
    for epoch in range(10):  # loop over the dataset multiple times
       # train
       model.train()
       for batch in tqdm(train_dataloader, disable=not accelerator.is_local_main_process):
          # get the inputs
    
          # forward + backward + optimize
          outputs = model(**batch)
          loss = outputs.loss
          accelerator.backward(loss)
          optimizer.step()
          optimizer.zero_grad()
    
       # accelerator.wait_for_everyone()

       # evaluate
       model.eval()
       valid_cer = 0.0
       with torch.no_grad():
         for batch in tqdm(eval_dataloader, disable=not accelerator.is_local_main_process):
           # run batch generation
           outputs = model.generate(batch["pixel_values"])
           # compute metrics
           cer = compute_cer(pred_ids=accelerator.gather(outputs), label_ids=accelerator.gather(batch["labels"]))
           valid_cer += cer 
        
       # accelerator.wait_for_everyone()

       current_cer = valid_cer / len(eval_dataloader)
       accelerator.print(f"Validation CER: {current_cer}")
       if accelerator.is_main_process and current_cer < best_cer:
           accelerator.save_state()
       
    
    model.save_pretrained(f"test_{date}")
    

In [None]:
from accelerate import notebook_launcher
args = ()
notebook_launcher(training_function, args, num_processes=2)