In [None]:
import os
# Getting all the models off my home directory for space concerns
os.environ['TRANSFORMERS_CACHE'] = '/projectnb/sparkgrp/colejh'

# Installs

In [None]:
# !pip install datasets
# !pip install jiwer
import torch
from torch.utils.data import Dataset
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from torch.utils.data import DataLoader
import torch.nn as nn
import pandas as pd
import glob
import os
import random
from tqdm import tqdm
from datasets import load_metric
from transformers import AdamW
import matplotlib.pyplot as plt
import wandb
import torch.optim as optim


# Supressing Warnings

In [None]:
#suppressing all the huggingface warnings
SUPPRESS = True
if SUPPRESS:
    from transformers.utils import logging
    logging.set_verbosity(40)

#others
#ignoring UserWarning and FutureWarning
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# Dataset preperation

In [None]:
# Reading the training file into a DataFrame
IAM_lines = '/projectnb/sparkgrp/ml-herbarium-grp/ml-herbarium-data/tesseract-training/training/IAM/gt/lines/'
IAM_words = '/projectnb/sparkgrp/ml-herbarium-grp/ml-herbarium-data/tesseract-training/training/IAM/gt/words/'
IAM_sentences = '/projectnb/sparkgrp/ml-herbarium-grp/ml-herbarium-data/tesseract-training/training/IAM/gt/sentences/'

model_directory = '/projectnb/sparkgrp/ml-herbarium-grp/ml-herbarium-data/TROCR_Training/'
# All files and directories ending with .txt and that don't begin with a dot:
def get_lists(directory,directory_percentage):
    image_list = glob.glob(directory+"*.png")

    text_list = []
    for image in image_list:
        text_list.extend(open(image.split('.')[0]+'.gt.txt','r').read().splitlines())

    # Take a random percentage of the data
    image_list, text_list = zip(*random.sample(list(zip(image_list, text_list)), round(directory_percentage/100*len(image_list))))
    return image_list,text_list


# Taken from https://github.com/NielsRogge/Transformers-Tutorials/blob/master/TrOCR/Fine_tune_TrOCR_on_IAM_Handwriting_Database_using_native_PyTorch.ipynb
class IAMDataset(Dataset):
    def __init__(self, df, processor, max_target_length=128):
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # get file name + text 
        image = self.df['image'][idx]
        text = self.df['text'][idx]
        # prepare image (i.e. resize + normalize)
        image = Image.open(image).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text, 
                                          padding="max_length", 
                                          max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding


image_list = []
text_list = []
lines_percentage = 1
words_percentage = 5
sentences_percentage = 1
for directory,percentage in zip([IAM_lines, IAM_words, IAM_sentences],[lines_percentage, words_percentage, sentences_percentage]):
    images,text = get_lists(directory,percentage)
    image_list.extend(images)
    text_list.extend(text)


df = pd.DataFrame({'image':image_list,'text':text_list})


# Splitting the data into training and validation sets
from sklearn.model_selection import train_test_split
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
train_df.reset_index(drop=True, inplace=True)
val_df.reset_index(drop=True, inplace=True)



# All Possible Models

In [None]:
all_models = ['microsoft/trocr-base-printed',
              'microsoft/trocr-base-handwritten',
              'microsoft/trocr-large-handwritten',
              'microsoft/trocr-small-handwritten',
              'microsoft/trocr-large-printed',
              'microsoft/trocr-base-stage1',
              'microsoft/trocr-small-printed',
              'microsoft/trocr-small-stage1',
              'microsoft/trocr-base-str',
              'microsoft/trocr-large-str',
              'microsoft/trocr-large-stage1']

# Full Training/Logging for Each Model

In [None]:
for model_name in all_models:
    # New run for each model 
    
    run = wandb.init(reinit = True,name = model_name.split('/')[1], project = "Testing-Tr-OCR")

    #Loading model and processor 
    processor = TrOCRProcessor.from_pretrained(model_name) 
    model = VisionEncoderDecoderModel.from_pretrained(model_name)

    # Checking if gpu is available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    #Putting the model on the GPU
    model= nn.DataParallel(model,list(range(torch.cuda.device_count()))).to(device)


    # Datasets and dataloaders for train validation
    train_dataset = IAMDataset(df=train_df,processor=processor)
    val_dataset = IAMDataset(df=val_df,processor=processor)


    train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=4)

    # set special tokens used for creating the decoder_input_ids from the labels
    model.module.config.decoder_start_token_id = processor.tokenizer.cls_token_id
    model.module.config.pad_token_id = processor.tokenizer.pad_token_id
    # make sure vocab size is set correctly
    model.module.config.vocab_size = model.module.config.decoder.vocab_size

    # set beam search parameters
    model.module.config.eos_token_id = processor.tokenizer.sep_token_id
    model.module.config.max_length = 64
    model.module.config.early_stopping = True
    model.module.config.no_repeat_ngram_size = 3
    model.module.config.length_penalty = 2.0
    model.module.config.num_beams = 4

    # Character error rate calculation
    cer_metric = load_metric("cer")

    def compute_cer(pred_ids, label_ids):
        pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
        label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
        label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

        cer = cer_metric.compute(predictions=pred_str, references=label_str)

        return cer


    optimizer = optim.AdamW(model.parameters(), lr=5e-5)
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=4, eta_min=0.003)
    # Main Training Loop
    avg_train_loss = []
    avg_val_cer = []
    last_val = float('inf')
    min_val = float('inf')
    patience = 3
    EPOCHS = 5
    counter = 0
    for epoch in range(EPOCHS):  # loop over the dataset multiple times
        # train
        model.train()
        train_loss = 0.0
        for batch in tqdm(train_dataloader):
            outputs = model(**batch)
            loss = outputs.loss
            loss.mean().backward()
            optimizer.step()
            optimizer.zero_grad()
            train_loss += torch.sum(loss).detach().cpu().numpy()
        # Save loss locally
        current_loss = train_loss/len(train_dataloader)
        avg_train_loss.append(current_loss)
        print(f"Loss after epoch {epoch+1}:", current_loss)
        scheduler.step()
        
        # evaluate
        model.eval()
        valid_cer = 0.0
        with torch.no_grad():
            for batch in tqdm(val_dataloader):
                # run batch generation
                outputs = model.module.generate(batch['pixel_values'].to(device))
                # compute metrics
                cer = compute_cer(pred_ids=outputs, label_ids=batch["labels"])
                valid_cer += cer 


            current_val = valid_cer/len(val_dataloader)
            # Logging values
            wandb.log({"Test/Loss": current_loss, "Character-Error-Rate": current_val}, step=epoch)

            # If vaildation certainty decreases for patience epochs, stop training
            # Save if new best model
            if current_val > last_val:
                counter += 1
                last_val = current_val
                if counter == patience:
                    print('Validation CER reduced has increased for {} epochs, ending training...'.format(patience))
                    avg_val_cer.append(last_val)
                    run.finish()
                    break
            else:
                counter = 0
                if current_val<min_val:
                    min_val = current_val
                    # saving the best model so far
                    torch.save(model.state_dict(), model_directory + model_name.split('/')[1]+'best_model.pt')
            last_val = current_val

        avg_val_cer.append(last_val)
        print("Validation CER:", last_val)
    run.finish()
