## Import the necessary packages, libraries, classes and methods

In [None]:
# importing all the necessary packages

import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import DataLoader

from disfdata import DisflQA
from encdecmod import LSTM_ED, WordEmbedding, EarlyStopping
from extrastuff import train, test, save, load

import json
import sentencepiece as spm
import glob

import warnings
warnings.filterwarnings('ignore')

from nltk.translate.bleu_score import corpus_bleu

In [None]:
# Loading the dataset
json_file = open('./Datasets/Disfl-QA/train.json')
data = json.load(json_file)

In [None]:
# Creating the text files to write the original and disfluent sentences to
original_txt_file = open('./Datasets/Disfl-QA/original.txt','w',encoding='utf-8')
disfluent_txt_file = open('./Datasets/Disfl-QA/disfluent.txt','w',encoding='utf-8')

In [None]:
# Extract tokens & frequency from sentences
for k,v in data.items():

    original_txt_file.write(v['original'].lower() + '\n')
    disfluent_txt_file.write(v['disfluent'].lower() + '\n')

original_txt_file.close()
disfluent_txt_file.close()

In [None]:
# Build vocabulary with sentencepiece
punc_list = ['`','~','!','@','#','$','%','^','&','*','-','_','+','=',
             '\\','|',':',';','"','\'',',','.','?','/',
             '(',')','{','}','[',']','<','>'] # punctuation

In [None]:
spm.SentencePieceTrainer.Train(
    input='./Datasets/Disfl-QA/disfluent.txt', 
    model_prefix='./Datasets/Disfl-QA/spm_disfluent', 
    vocab_size=1000, 
    model_type='unigram',
    unk_id=0, bos_id=1, eos_id=2, pad_id=3,
    user_defined_symbols=punc_list)

In [None]:
spm.SentencePieceTrainer.Train(
    input='./Datasets/Disfl-QA/original.txt', 
    model_prefix='./Datasets/Disfl-QA/spm_original', 
    vocab_size=1000, 
    model_type='unigram',
    unk_id=0, bos_id=1, eos_id=2, pad_id=3,
    user_defined_symbols=punc_list)

In [None]:
spm.SentencePieceTrainer.Train(
    input=glob.glob('./Datasets/Disfl-QA/*.txt'), 
    model_prefix='./Datasets/Disfl-QA/spm', 
    vocab_size=1000, 
    model_type='unigram',
    unk_id=0, bos_id=1, eos_id=2, pad_id=3,
    user_defined_symbols=punc_list)

In [None]:
# Test the model
sp_dis = spm.SentencePieceProcessor(model_file='./Datasets/DisFl-QA/spm_disfluent.model')
sp_ori = spm.SentencePieceProcessor(model_file='./Datasets/DisFl-QA/spm_original.model')
sp_all = spm.SentencePieceProcessor(model_file='./Datasets/DisFl-QA/spm.model')
enc = sp_all.Encode('how long did julia butterfly hill live near a nuclear-missile installation?')
sp_all.Decode(enc)

## Set model parameters

In [None]:
# --- Hyperparameters ---
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["OMP_NUM_THREADS"] = "1"
torch.set_num_threads(1)
torch.manual_seed(seed=0)

MODEL_NAME = 'LSTM_BI_ED_FINETUNE'

EPOCHS = 100

## Dataset preparations

In [None]:
train_dataset = DisflQA(file_name='Datasets/Disfl-QA/train.json', max_len=100, return_len=False)
train_loader = data.DataLoader(train_dataset, batch_size=8, num_workers=2, shuffle=True)

val_dataset = DisflQA(file_name='Datasets/Disfl-QA/dev.json', max_len=100, return_len=False)
val_loader = data.DataLoader(val_dataset, batch_size=8, num_workers=2)

src_vocab_emb = WordEmbedding(len(train_dataset.src_vocab), 256, 0.2)
tgt_vocab_emb = WordEmbedding(len(train_dataset.tgt_vocab), 256, 0.2)

# if an attribute error occurs, run the import cell again and continue

In [None]:
src_vocab_emb.dropout

In [None]:
tgt_vocab_emb.dropout

## Model initialization

In [None]:
model = LSTM_ED(src_vocab_emb, tgt_vocab_emb, emb_dim=256, hid_dim=256, n_layers=4, dropout=0.1).cuda()

## optimizers, loss function and normalizer for the model

In [None]:
class CrossLoss(nn.Module):
    def __init__(self, ignore_index=-1):
        super().__init__()
        self.CrossLoss = nn.CrossEntropyLoss(ignore_index=ignore_index)

    def forward(self, output, target):
        output = torch.log(output)  
        output = output.reshape(-1, output.shape[-1])  
        target = target.reshape(-1).long() 
        return self.CrossLoss(output, target)

class CrossLost(nn.Module):
    def __init__(self, ignore_index=-1):
        super().__init__()
        self.CrossLoss = CrossLoss(ignore_index=ignore_index)

    def forward(self, output, target):
        output = output[:,:-1,:]
        target = target[:,1:] 
        return self.CrossLoss(output, target)

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
criterion = CrossLost(ignore_index=3)
scaler = torch.cuda.amp.GradScaler()
scheduler = None

## Initial state

In [None]:
start_epoch = -1
best_loss = 1e9
history = {
    'train_loss': [], 
    'val_loss': []
}

print('Total Parameters: {}'.format(sum(p.numel() for p in model.parameters())))

## Training

In [None]:
early_stopping = EarlyStopping(tolerance=3, min_delta=2)

value = os.path.exists('Model/{}.pt'.format(MODEL_NAME))
if value:
    os.remove('Model/{}.pt'.format(MODEL_NAME))
    
for i in range(start_epoch+1,EPOCHS):
    print('Epoch {}:'.format(i))
    train_loss = train(train_loader, model, optimizer, criterion, scheduler, device='cuda', scaler=scaler, kw_src=['input','output'])
    val_loss = test(val_loader, model, criterion, device='cuda', return_results=False, kw_src=['input','output'])

    # Log of loss values
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    
    # early stopping
    early_stopping(train_loss, val_loss)
    if early_stopping.early_stop:
        print("We are at epoch:", i)
        break

    if val_loss < best_loss:
        best_loss = val_loss
        save('Model/{}.pt'.format(MODEL_NAME), model, optimizer, epoch=i, stats={'val_loss': best_loss, 'history': history})
save('Model/{}_epoch_{}.pt'.format(MODEL_NAME,i), model, optimizer, epoch=i, stats={'val_loss': best_loss, 'history': history})

## Test

In [None]:
value = os.path.exists('Model/{}.pt'.format(MODEL_NAME))
if value:
    pass
else:
    print("Model not available in directory")

start_epoch, stats = load('Model/{}_epoch_19.pt'.format(MODEL_NAME), model, optimizer)
best_loss = stats['val_loss']
history = stats['history']

# updating the graph
plt.ylabel('Loss Value')
plt.xlabel('Number of Epoch') 
plt.plot(np.arange(len(history['train_loss'])), history['train_loss'], linestyle='--', color='g', label='Train Loss')
plt.plot(np.arange(len(history['val_loss'])), history['val_loss'], linestyle='--', color='r', label='Validation Loss')
plt.legend() 
plt.savefig('Results/Loss_{}.png'.format(MODEL_NAME))
plt.show()

In [None]:
class Bleu(nn.Module):
    def __init__(self, ignore_index=-1):
        super().__init__()
        self.bleu = 'a'
        
    def forward(self, output, target):
        bleu_1 = corpus_bleu(target, output, weights=(1.0,0,0,0))
        bleu_2 = corpus_bleu(target, output, weights=(0.5,0.5,0,0))
        bleu_3 = corpus_bleu(target, output, weights=(0.3,0.3,0.3,0))
        bleu_4 = corpus_bleu(target, output, weights=(0.25,0.25,0.25,0.25))
        return bleu_1, bleu_2, bleu_3, bleu_4

In [None]:
bleu = Bleu()

In [None]:
output_bleu = []
target_bleu = []

In [None]:
test_dataset = DisflQA(file_name='Datasets/Disfl-QA/test.json', max_len=100, return_len=False, infer=True)
test_loader = data.DataLoader(test_dataset, batch_size=128, num_workers=2)

_, outputs, targets = test(test_loader,model,device='cuda',return_results=True)
outputs = outputs.numpy()
targets = targets.numpy()

write_input = open('Output/{}_log_inputs.txt'.format(MODEL_NAME), 'w', encoding='utf-8')
write_output = open('Output/{}_log_outputs.txt'.format(MODEL_NAME), 'w', encoding='utf-8')
write_target = open('Output/{}_log_targets.txt'.format(MODEL_NAME), 'w', encoding='utf-8')

for i in range(len(test_dataset)):
    str_input = test_dataset.src_vocab.decode(test_dataset[i][0].tolist())
    str_target = test_dataset.tgt_vocab.decode(test_dataset[i][1].tolist())

    post_process_output = []
    for j in range(len(outputs[i])):
        post_process_output.append(outputs[i][j])
        if outputs[i][j] == 2:
            break
    post_process_output = np.array(post_process_output)        
    str_output = test_dataset.tgt_vocab.decode(post_process_output.tolist())

    write_input.write(str_input + '\n')
    write_output.write(str_output + '\n')
    write_target.write(str_target + '\n')
    
    output_bleu.append(str_output.split())
    target_bleu.append(str_target.split())
    
bleu_1, bleu_2, bleu_3, bleu_4 = bleu(target_bleu, output_bleu)
print('BLEU-1 Loss : ', bleu_1)
print('BLEU-2 Loss : ', bleu_2)
print('BLEU-3 Loss : ', bleu_3)
print('BLEU-4 Loss : ', bleu_4)
    

write_input.close()
write_output.close()
write_target.close()