## Arguments

### Model Arguments

In [None]:
model_name_or_path = 't5-small'

### Data Arguments

In [None]:
data_args = {
    'train_file': '/path/to/training_data_file',
    'validation_file': '/path/to/validation_data_file',
    'max_target_length': 128,
    'max_source_length': 512,
    'ignore_pad_token_for_loss': True,
    }

### Training Arguments

In [None]:
training_args = {
    'model_name_or_path': model_name_or_path,
    'output_dir': './output',
    'predict_with_generate': False,
    'do_train': True,
    'do_eval': True,
    'per_device_train_batch_size': 8,
    'per_device_eval_batch_size': 8,
    'gradient_accumulation_steps': 2,
    'learning_rate': 5e-4,
    'evaluation_strategy': 'steps',
    'num_train_epochs': 10,
    'save_total_limit': 4,
    'save_strategy': 'epoch',
    'seed': 42
    }

## Setting up the seed

In [None]:
from transformers import set_seed

set_seed(training_args['seed'])

## Load Dataset

In [None]:
from datasets import load_dataset

data_files = { 'train': data_args['train_file'], 'validation': data_args['validation_file'] }
datasets = load_dataset('json', data_files=data_files)

## Load pretrained model and tokenizer

In [None]:
from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM

config = AutoConfig.from_pretrained(model_name_or_path)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast = True)

model = AutoModelForSeq2SeqLM.from_pretrained(
    model_name_or_path,
    config = config,
    from_tf=bool('.ckpt' in model_name_or_path)
)

## Tokenize the inputs and targets

In [None]:
def preprocess_function(examples):
    inputs = [ex for ex in examples['inputs']]
    targets = [ex for ex in examples['targets']]
    model_inputs = tokenizer(inputs, max_length= data_args['max_source_length'], padding = False, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length = data_args['max_target_length'], padding = False, truncation=True)

    model_inputs['labels'] = labels['input_ids']
    return model_inputs

In [None]:
train_dataset, eval_dataset = datasets['train'], datasets['validation']
column_names = train_dataset.column_names

train_dataset = train_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=column_names,
    load_from_cache_file=True,
)

eval_dataset = eval_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=column_names,
    load_from_cache_file=True,
)

### Data Collator

In [None]:
from transformers import DataCollatorForSeq2Seq

label_pad_token_id = -100 if data_args['ignore_pad_token_for_loss'] else tokenizer.pad_token_id

data_collator = DataCollatorForSeq2Seq(tokenizer, label_pad_token_id=label_pad_token_id)

### Initialise Trainer

In [None]:
from transfomers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=None,
)

## Training

In [None]:
import os

train_result = trainer.train(model_path=model_name_or_path if os.path.isdir(model_name_or_path) else None)
trainer.save_model()

output_train_file = os.path.join(training_args['output_dir'], 'train_results.txt')
if trainer.is_world_process_zero():
    with open(output_train_file, 'w') as writer:
        for key, value in sorted(train_result.metrics.items()):
            writer.write(f'{key} = {value}\n')

    # Need to save the state, since Trainer.save_model saves only the tokenizer with the model
    trainer.state.save_to_json(os.path.join(training_args['output_dir'], 'trainer_state.json'))

## Evaluation

In [None]:
import json

def evaluate_predictions(pred_filename, gold_filename):
    with open(pred_filename, 'r') as pred_f, open(gold_filename) as gold_f:
        pred_lines = pred_f.readlines()
        gold_lines = gold_f.readlines()
    
        total = 0.0
        full_correct = 0.0
        first_correct = 0.0
        
        for i in range(len(pred_lines)):
            pred_line = pred_lines[i].strip()
            if gold_filename.endswith('.json'):
                gold_json = json.loads(gold_lines[i])
                gold_line = gold_json['translation']['tgt']
            else:  
                gold_line = gold_lines[i].strip().split('\t')[1]
            
            # remove space before period/question mark
            gold_line = gold_line.replace(' ?', '?').replace(' .', '.').replace(' ,', ',') 

            total +=1

            if pred_line == gold_line:
                full_correct += 1
                first_correct += 1
            else:
                pred_words = pred_line.split()
                gold_words = gold_line.split()
                if len(pred_words) > 0 and pred_words[0] == gold_words[0]:
                    first_correct += 1

  
    return  (first_correct / total), (full_correct / total)

In [None]:
basename = os.path.basename(data_args['validation_file']).replace('.json', '')

predictions = trainer.predict(test_dataset=eval_dataset, max_length=100)
output_pred_file = os.path.join(training_args['output_dir'], basename + '.eval_preds_seq2seq.txt')
if trainer.is_world_process_zero():
    with open(output_pred_file, 'w') as writer:
        for pred in tokenizer.batch_decode(predictions.predictions, skip_special_tokens=True):
            writer.write(pred + '\n')

output_eval_file = os.path.join(training_args['output_dir'], basename + '.eval_results_seq2seq.txt')
first_acc, full_acc = evaluate_predictions(output_pred_file, data_args['validation_file'])
if trainer.is_world_process_zero():
    with open(output_eval_file, 'w') as writer:
        writer.write(f'Exact match accuracy: {full_acc}\n')
        writer.write(f'First word accuracy: {first_acc}\n')