This notebook shows how to set up and train the model 

In [1]:
import pickle as pk
import argparse
import pickle as pk
import time
from re import S
import torch
import copy
from tqdm import tqdm
import torch.nn as nn
from pytorch_lightning import seed_everything
from torch.nn import functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tqdm import tqdm

from src.data_utils import *
#from src.model_utils import setupTokenizer
from src.datasethandler import NarrationDataSet,train_data_permutated_path,test_data_path,DataSetLoader,ClassificationReportPreprocessor
from src.inferenceUtils import PerformanceNarrator
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback,T5Config
from src.modeling_bart import BartNarrationModel
from src.modeling_t5 import T5NarrationModel
from src.trainer_utils import getTrainingArguments,CustomTrainerFusion,get_model
os.environ["WANDB_DISABLED"] = "true"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
# Define the parameters used to set up the models
modeltype = 'earlyfusion'  # either baseline or 'earlyfusion'

# either t5-small,t5-base, t5-large, facebook/bart-base, or facebook/bart-large
modelbase = 'facebook/bart-base'

# we will use the above variables to set up the folder to save our model
pre_trained_model_name = modelbase.split(
    '/')[1] if 'bart' in modelbase else modelbase

# where the trained model will be saved
output_path = 'TrainModels/' + modeltype + '/'+pre_trained_model_name+'/'


In [4]:
# Using the dataset used in the paper
dataset_raw= DataSetLoader() 

# Process the data and set up the tokenizer
narrationdataset = NarrationDataSet(modelbase,
                                    max_preamble_len=160,
                                    max_len_trg=185, max_rate_toks=8,
                                    lower_narrations=True,
                                    process_target=True)

narrationdataset.fit(dataset_raw.train_data, dataset_raw.test_data)

dataset = narrationdataset.train_dataset
test_dataset = narrationdataset.test_dataset
tokenizer = tokenizer_ = narrationdataset.tokenizer_


train_dataset, val_dataset = dataset, test_dataset


train_size = int(len(dataset))
val_size = int(len(test_dataset))
print('{:>5,} training samples'.format(train_size))
print('{:>5,} validation samples'.format(val_size))

4,529 training samples
  100 validation samples


In [5]:
# Define the arguments/parameters to train the model

arguments = train_arguments = {'output_dir': output_path,
                               'warmup_ratio': 0.2,
                               'per_device_train_batch_size': 8,
                               'num_train_epochs': 10,
                               'lr_scheduler_type': 'cosine',
                               'learning_rate': 5e-5,
                               'evaluation_strategy': 'steps',
                               'logging_steps': 500,
                               'seed': 456}


# Model Definition and Training

In [None]:
seed_everything(456)
device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device('cpu')

# Build actual trainingArgument object
training_arguments = getTrainingArguments(train_arguments)

getModel = get_model(narrationdataset,model_type=modeltype)


trainer = CustomTrainerFusion(model_init=getModel,
                        args=training_arguments,
                        train_dataset=narrationdataset.train_dataset,
                        eval_dataset=narrationdataset.test_dataset,
                        callbacks=[EarlyStoppingCallback(early_stopping_patience=4)])

# Train the narrator

trainer.train()


# Save the model with the lowest evaluation loss
trainer.save_model()
trainer.save_state()

# get the best checkpoint
best_check_point = trainer.state.best_model_checkpoint


params_dict = train_arguments

params_dict['best_check_point'] = best_check_point
params_dict['output_path'] = output_path
json.dump(params_dict, open(f'{output_path}/parameters.json', 'w'))



In [10]:
print(trainer.state.best_model_checkpoint)

'TrainModels/earlyfusion/bart-base/checkpoint-1500'

In [9]:
resut = trainer.evaluate()
resut

{'eval_loss': 1.2032346725463867,
 'eval_runtime': 0.572,
 'eval_samples_per_second': 174.815,
 'eval_steps_per_second': 22.726,
 'epoch': 2.0}

# Inference
Initialising the narrator

In [13]:
narrator = PerformanceNarrator(trainer.model,narrationdataset,device,sampling=False,verbose=False)



In [15]:
example = dataset_raw.test_data[99]
seed = 456

In [17]:
outt=narrator.generateNarration(example, seed,  max_length=190,
                          length_penalty=8.6, beam_size=10,
                          repetition_penalty= 3.5,
                           return_top_beams=4)

Global seed set to 456


In [18]:
outt

["On this multi-class classification problem, the model was trained to assign test samples one of the following classes #CA, #CB, #CC, and #CC. The classifier's performance assessment scores are as follows: (a) Accuracy equal to 76.44%. (b) F1score equal to about76.03%. These scores across the different metrics suggest that this model will be moderately effective at correctly predicting the true label for several test examples with only a small margin of error. Furthermore, from the precision and recall scores, we can draw the conclusion that it will likely have a lower misclassification error rate."]