This runs our model with mBART-50 specific corrections: https://huggingface.co/facebook/mbart-large-50

In [None]:
# mount drive
from google.colab import drive
drive.mount('/content/gdrive/', force_remount=True)

In [None]:
main_dir = 'gdrive/MyDrive/TPDL 2023 Colab Notebooks/'

# where to output models
output_dir = main_dir + 'mBART_models/ocrOnly_large/' # math/cite/refs -- just left in as raw

# where is data stored?
aligned_dataset_dir = main_dir + 'data/alignments/'

# which model do we want to start from pre-trained?
#model_pretrained = 'google/byt5-small' # orig
#model_pretrained = 'yelpfeast/byt5-base-english-ocr-correction' # for OCR correction specifically
###model_pretrained = 'facebook/mbart-large-50' # mBART-50

In [None]:
import pandas as pd
train_df = pd.read_csv(aligned_dataset_dir+'train_masked_n500000_20230503.csv')
eval_df = pd.read_csv(aligned_dataset_dir+'val_masked_n10000_20230503.csv')
test_df = pd.read_csv(aligned_dataset_dir+'test_masked_n10000_20230503.csv')

only_words = True

In [None]:
!pip install transformers[sentencepiece]==4.28.0

In [None]:
#!pip install transformers

Order here is important!

In [None]:
# !pip install pybind11 
# !pip install fastwer

In [None]:
from transformers import HfArgumentParser, TensorFlowBenchmark, TensorFlowBenchmarkArguments
#import pandas as pd
from transformers import T5ForConditionalGeneration, AutoTokenizer
from transformers import TrainingArguments
from torch.utils.data import Dataset, DataLoader
from transformers import Trainer
from transformers import EarlyStoppingCallback

In [None]:
##import fastwer
from glob import glob
import matplotlib.pyplot as plt
import numpy as np

In [None]:
from sys import path
path.append(main_dir + 'libraries/')
from utils_ocr_mini import get_fill_in_types

In [None]:
from torch import cuda

device = 'cuda' if cuda.is_available() else 'cpu'
cuda.empty_cache()
print(device)

In [None]:
def add_formatted_columns(datain):
    source = []
    target = []
    source_aligned = []
    target_aligned = []
    for i in range(len(datain)):
        d = datain.iloc[i]
        s = np.array(list(d['aligned sentences source'])) # aligned source, with ^ symbols
        t = np.array(list(d['aligned sentences target'])) # aligned target, with @ symbols
        a = np.array(list(get_fill_in_types(d['aligned sentences target types'])))
        if len(s) == len(t):
            ss = "".join(s[np.where( (a == ' ') | (a == 'W') | (a == 'w'))[0]].tolist())
            tt = "".join(t[np.where( (a == ' ') | (a == 'W') | (a == 'w'))[0]].tolist())
        else:
            print('have issue, testing')
            if t[0] == ' ' and s[0] != ' ':
                t = np.array(list(d['aligned sentences target']))[1:] # aligned target, with @ symbols
                a = np.array(list(get_fill_in_types(d['aligned sentences target types'])))[1:]
                if len(s) == len(t):
                    ss = "".join(s[np.where( (a == ' ') | (a == 'W') | (a == 'w'))[0]].tolist())
                    tt = "".join(t[np.where( (a == ' ') | (a == 'W') | (a == 'w'))[0]].tolist())
                else:
                    print('not aligned, best guess')
                    import sys; sys.exit()

        source_aligned.append(ss.replace('^','@')) # align with original 
        target_aligned.append(tt)
        source.append(ss.replace('^',''))
        target.append(tt.replace('@',''))

    datain['words source aligned'] = source_aligned
    datain['words target aligned'] = target_aligned
    datain['words source'] = source
    datain['words target'] = target
    return datain

In [None]:
train_df.head()

In [None]:
if only_words:
    train_df = add_formatted_columns(train_df)
    eval_df = add_formatted_columns(eval_df)
    test_df = add_formatted_columns(test_df)
    # rename sentences we want
    train_df = train_df.rename(columns={"words source": "input_text", 
                        "words target": "target_text"})
    eval_df = eval_df.rename(columns={"words source": "input_text", 
                        "words target": "target_text"})
    test_df = test_df.rename(columns={"words source": "input_text", 
                        "words target": "target_text"})
else:
    # rename sentences we want
    train_df = train_df.rename(columns={"sentences source": "input_text", 
                        "sentences target": "target_text"})
    eval_df = eval_df.rename(columns={"sentences source": "input_text", 
                        "sentences target": "target_text"})
    test_df = test_df.rename(columns={"sentences source": "input_text", 
                        "sentences target": "target_text"})

In [None]:
args_dict = {
    #"model_name_or_path": 'google/byt5-small',
    #"max_len": 4096,
    #"max_length": 4096,
    "output_dir": output_dir,
    "overwrite_output_dir": True,
    "per_device_train_batch_size": 4,
    "per_device_eval_batch_size": 4,
    "gradient_accumulation_steps": 4,
    "learning_rate": 5e-4,
    "warmup_steps": 250,
    "logging_steps": 100,
    "evaluation_strategy": "steps",
    "eval_steps": 1000,
    "num_train_epochs": 4,
    "do_train": True,
    "do_eval": True,
    "fp16": False,
    #"use_cache": False,
    "max_steps": 100000,
    'save_steps':1000,
    'save_strategy':'steps',
    'load_best_model_at_end': True#,
    # 'metric_for_best_model':'eval_loss',
    # 'greater_is_better':False
}

In [None]:
#!pip install --upgrade accelerate

In [None]:
parser = HfArgumentParser(
        (TrainingArguments))
training_args = parser.parse_dict(args_dict)
# set_seed(training_args.seed)
args = training_args[0]

In [None]:
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

In [None]:
#!pip install sentencepiece

In [None]:
#!pip install transformers[sentencepiece]

In [None]:
# Load pretrained model and tokenizer
# tokenizer = AutoTokenizer.from_pretrained(
#     model_pretrained,
#     cache_dir=output_dir, 
#     max_length=4096
# )
# mbart specific
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="en_XX")

In [None]:
# model = T5ForConditionalGeneration.from_pretrained(
#     model_pretrained,
#     cache_dir=output_dir,
# )
# specific for mBART
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")

In [None]:
# overwriting the default max_length of 20 
tokenizer.model_max_length=4096
model.config.max_length=4096

In [None]:
class GPReviewDataset(Dataset):
    def __init__(self, Text, Label):
        self.Text = Text
        self.Label = Label
        # self.tokenizer = tokenizer
        # self.max_len = max_len
    def __len__(self):
        return len(self.Text)
    def __getitem__(self, item):
        Text = str(self.Text[item])
        Label = self.Label[item]
        inputs = tokenizer(Text, padding="max_length", truncation=True, max_length=512)
        outputs = tokenizer(Label, padding="max_length", truncation=True, max_length=512)
        return {
          "input_ids":inputs.input_ids,
          "attention_mask" : inputs.attention_mask,
          "labels" : outputs.input_ids,
          "decoder_attention_mask" : outputs.attention_mask,
          # "labels" : lbz
        }

In [None]:
ds_train = GPReviewDataset(
  Text=train_df.input_text.to_numpy(),
  Label=train_df.target_text.to_numpy()
  # tokenizer=tokenizer,
  # max_len=max_len
)

In [None]:
ds_test = GPReviewDataset(
  Text=eval_df.input_text.to_numpy(),
  Label=eval_df.target_text.to_numpy()
  # tokenizer=tokenizer,
  # max_len=max_len
)

In [None]:
train_dataset = ds_train
valid_dataset = ds_test

In [None]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    # callbacks = [EarlyStoppingCallback(early_stopping_patience=10)]
    # compute_metrics=compute_metrics

)

In [None]:
trainer.args.save_total_limit = 10
trainer.args.logging_steps = 100 # down from 100
trainer.args.save_steps=500 # down from 10000
#trainer.train() # put in checkpoint if need be here to load 
trainer.train(output_dir + 'checkpoint-11000') # put in checkpoint if need be here to load 