# Training BART with Google Colab

In [None]:
### The code was trained in Colab
!pip install simpletransformers -q
!pip install nltk==3.4.5

import nltk
print(nltk.__version__)

try:
  from nltk.translate.meteor_score import meteor_score
  print('Meteor score will not work without the right ntlk version')
except ImportError:
  print('Still import issue')

In [None]:
import tqdm
import spacy
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from simpletransformers.seq2seq import Seq2SeqModel
from sklearn.model_selection import train_test_split
import os
import os.path
import tensorflow as tf
import os
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.nist_score import sentence_nist

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('good to go')
print(f'Device: {DEVICE}')

csv_link = 'https://raw.githubusercontent.com/chophilip21/covid_dialogue/main/dialogue.csv' #original
augmented_link = 'https://raw.githubusercontent.com/chophilip21/covid_dialogue/main/augmented.csv' #augmented version
txt_link = 'https://raw.githubusercontent.com/chophilip21/covid_dialogue/main/covid_additional.txt' # raw text version

dataset = pd.read_csv(augmented_link, names = ['input_text', 'target_text'], header=0)
train_df, test_df = train_test_split(dataset, test_size=0.2)
valid_df, test_df = train_test_split(test_df, test_size=0.5)

print('The length of train_df is: ', len(train_df))
print('The length of valid is: ', len(valid_df))
print('The length of test_df is: ', len(test_df))


def txt_to_dict(txt_path, save_path):

    patient = []
    doctor = []

    with open(txt_path, 'r') as f:
        lines = f.readlines()

        for i, line in enumerate(lines):
            if line.startswith('Patient:'): 
                patient.append(' '.join(lines[i+1:i+2]))
            
            elif line.startswith('Doctor:'):
                doctor.append(' '.join(lines[i+1: i+2]))

    data = {'src': patient, 'trg': doctor}

    return data


In [None]:
model_args = {
    "reprocess_input_data": True,
    "overwrite_output_dir": True,
    "max_seq_length": 50,
    "train_batch_size": 4, # check if we can have bigger weights. 
    "eval_batch_size": 1,
    "output_dir": 'weights',
    "num_train_epochs": 5,
    "save_eval_checkpoints": False,
    "save_model_every_epoch": False,
    "evaluate_during_training": True,
    "evaluate_generated_text": True,
    "evaluate_during_training_verbose": True,
    "use_multiprocessing": True,
    "gradient_accumulation_steps": 1,
    "max_length": 50,
    "manual_seed": 4,
}


model = Seq2SeqModel(
    encoder_decoder_type="bart",
    encoder_decoder_name="facebook/bart-large",
    args=model_args,
)


model.train_model(train_df, eval_data=valid_df)


In [None]:
def save_prediction(labels, preds):

  for prediction in preds:
    with open('prediction.txt', 'a') as f:
      f.write(prediction + '\n')

  print('complete')



print(model.eval_model(test_df, save = save_prediction))

"""
Predictions on a random string
"""

test = "Hi doctor, What are the symptoms of Covid-19..?"
inference = model.predict([test])
print(inference)

In [None]:
nltk.download('wordnet')


"""
The code crashes when I evaluate within the evalulate function. 
-This is an alternative
"""

def nist_2(labels, preds):

    label = ' '.join([str(elem) for elem in labels])
    prediction = ' '.join([str(elem) for elem in preds])
  
    if len(prediction) < 2 or len(label) < 2:
        return 0
    return sentence_nist([label], prediction, 2)

def nist_4(labels, preds):

    label = ' '.join([str(elem) for elem in labels])
    prediction = ' '.join([str(elem) for elem in preds])

    if len(prediction) < 4 or len(label) < 4:
        return 0

    return sentence_nist([label], prediction, 4)

def calculate_m_score(target, predictions, length):

  score = 0

  for t, p in zip(target, predictions):
    score += meteor_score(t, p)

  
  return score / length


predictions = []

with open('prediction.txt') as fp:
  line = fp.readline()
  line = line.strip()
  
  while line:
    if len(line) > 1:
      predictions.append(line)
    line = fp.readline()

# label = ' '.join([str(elem) for elem in labels])
pred = ' '.join([str(elem) for elem in predictions])
target = test_df.target_text

bleu2 = sentence_bleu(target, pred, weights=tuple(1 / 2 for i in range(2)))
bleu4 = sentence_bleu(target, pred, weights=tuple(1 / 4 for i in range(2)))
nist2 = nist_2(target, pred)
nist4 = nist_4(target, pred)


print('bleu2 is {}'.format(bleu2))
print('bleu4 is {}'.format(bleu4))
print('nist2 is {}'.format(nist2))
print('nist4 is {}'.format(nist4))

meteor = calculate_m_score(target, predictions, 71)

print('meteor is {}'.format(meteor))