***
# Sentence Simplification using BERT-to-GPT2 
***

***
## neuroCraft Project
***
<br>

Christine Sigrist

***

In [24]:
# Imports
import pandas as pd
import numpy as np

from torch.utils.data import Dataset
import pickle
import click
import os
import torch
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from data import WikiDataset

from transformers import BertTokenizer, GPT2Tokenizer
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from transformers import EncoderDecoderModel, BertConfig, EncoderDecoderConfig, GPT2Tokenizer, BertModel, GPT2Model


import time
import tqdm
import logging
import gc
import shutil
import sari

from IPython.display import display, HTML
display(HTML("<style>.container { width:70% !important; }</style>"))

***

In [2]:
# DATA LOADING
"""
DATA INFO

Training:
Wiki dataset comprising of parallel corpus of normal sentences and simple sentences is used to train the model.
The original dataset consists of around 167k English sentence pairs from the Wikipedia articles.
The dataset comprises of mapping of one-to-many, one-to-one and many-to-one sentence pairs.
But the dataset was not suitable for the training without preprocessing.
Upon tokenizing the sentences, sentences having token length of more than 80 were removed keeping the maximum token length of sentences to 80.
The resulting training dataset became 138k from 167k.

Testing:
For the evaluation and testing purpose, TurkCorpus is used.
The dataset consists of 2k manually prepared sentence pairs with 8 reference sentences and 300 sentences for testing purpose which also has 8 reference sentences.
"""

# file opener
def open_file(file_path, ref=False):
    data = []
    if ref:
        ref_data = pickle.load(open(file_path, 'rb'))
        return ref_data
    else:
        with open(file_path, 'r', encoding="utf8") as f:
            sents = f.readlines()
            for s in sents:
                data.append(s.strip())
        return data

# data loader
def load_dataset(src_path, tgt_path=None, ref_path=None, ref=False):
    src = open_file(src_path)
    tgt = None
    ref = None
    if tgt_path is not None:
        tgt = open_file(tgt_path)
    if ref_path is not None:
        ref = open_file(ref_path, ref)
    return src, tgt, ref

# training data
src_train_file = 'dataset/src_train.txt'
tgt_train_file = 'dataset/tgt_train.txt'

# original train data and target train data (no ref data)
src_train, tgt_train, ref_train = load_dataset(src_train_file, tgt_train_file, ref=False)

print('training data')
print(pd.DataFrame(src_train).shape)
print(pd.DataFrame(tgt_train).shape)

sample_index = 1000
src_sample = src_train[sample_index]
tgt_sample = tgt_train[sample_index]

print("Source:", src_sample)
print("Target:", tgt_sample)

# validation data
src_valid_file = 'dataset/src_valid.txt'
tgt_valid_file = 'dataset/tgt_valid.txt'

# original train data and target train data (no ref data)
src_valid, tgt_valid, ref_valid = load_dataset(src_valid_file, tgt_valid_file, ref=False)

print(' ')
print('validation data')
print(pd.DataFrame(src_valid).shape)
print(pd.DataFrame(tgt_valid).shape)

sample_index = 99
src_sample = src_valid[sample_index]
tgt_sample = tgt_valid[sample_index]

print("Source:", src_sample)
print("Target:", tgt_sample)

# testing data
src_test_file = 'dataset/src_test.txt'
tgt_test_file = 'dataset/tgt_test.txt'

# original train data and target train data (no ref data)
src_test, tgt_test, ref_test = load_dataset(src_test_file, tgt_test_file, ref=False)

print(' ')
print('test data')
print(pd.DataFrame(src_test).shape)
print(pd.DataFrame(tgt_test).shape)

sample_index = 10
src_sample = src_test[sample_index]
tgt_sample = tgt_test[sample_index]

print("Source:", src_sample)
print("Target:", tgt_sample)


training data
(138413, 1)
(138413, 1)
Source: Quincy in 1767 was the `` north precinct '' of Braintree , Massachusetts .
Target: He was born in Braintree , Massachusetts , in 1767 .
 
validation data
(2000, 1)
(2000, 1)
Source: Since 1980 the senior pastor has been John Piper .
Target: Since 1980 the main pastor has been John Piper .
 
test data
(359, 1)
(359, 1)
Source: Alessandro ( " Sandro " ) Mazzola ( born 8 November 1942 ) is an Italian former football player .
Target: Alessandro Mazzola is an Italian former football player.


***

In [3]:
# DATA PROCESSING
# function for encoding (tokenizing) the data
def encode_batch(batch, max_len=80):
    '''
    Inputs:
    batch: This function expects a list containing two elements:
    The first element is a string representing the source text.
    The second element is a string representing the target text.

    Tokenization:
    src_tokens: The source text (first element of batch) is tokenized using the BERT tokenizer (bert_tokenizer).
    It is processed to generate tokens (input_ids) and an attention mask for the source text.
    BERT tokenization includes adding special tokens, padding to a maximum length, and truncating if needed.
    tgdt_tokens: The target text (second element of batch) is tokenized using the GPT-2 tokenizer (gpt2_tokenizer). Similar to the source, tokens (input_ids) and an attention mask for the target text are generated.

    Creating Labels:
    labels: This is created from the tgt_tokens.input_ids tensor. It's used for calculating loss during training.
    The tgt_tokens.attention_mask is used to identify where the padding is in the target tokens and sets those positions in the labels tensor to -100.

    Output:
    The function returns five values:
    src_tokens.input_ids: Tensor containing tokenized representation of the source text.
    src_tokens.attention_mask: Tensor containing attention mask for the source text.
    tgt_tokens.input_ids: Tensor containing tokenized representation of the target text.
    tgt_tokens.attention_mask: Tensor containing attention mask for the target text.

    Labels:
    Tensor containing the labels for the target text, modified with -100 in places corresponding to padding.
    '''
    src_tokens = bert_tokenizer(batch[0], max_length=max_len, add_special_tokens=True,
                                return_token_type_ids=False, padding="max_length", truncation=True,
                                return_attention_mask=True, return_tensors="pt")

    tgt_tokens = gpt2_tokenizer(batch[1], max_length=max_len, add_special_tokens=True,
                                return_token_type_ids=False, padding="max_length", truncation=True,
                                return_attention_mask=True, return_tensors="pt")

    labels = tgt_tokens.input_ids.clone()
    labels[tgt_tokens.attention_mask == 0] = -100

    return src_tokens.input_ids, src_tokens.attention_mask, tgt_tokens.input_ids, tgt_tokens.attention_mask, labels

# function for decoding the tokenized sentences to human readable text
def decode_sent_tokens(data):
    '''
    This function takes a list of tokenized sentences (data) and decodes them back into human-readable sentences.

    Inputs:
    data: A list containing tokenized sentences.

    Functionality:
    It iterates through each tokenized sentence in the data list.
    For each tokenized sentence, it uses the GPT-2 tokenizer (gpt2_tokenizer) to decode the tokens into a human-readable sentence (s), skipping special tokens and cleaning up tokenization spaces.
    The decoded sentences (s) are added to a list (sents_list).

    Output:
    The function returns a list (sents_list) containing the decoded sentences
    '''
    sents_list = []
    for sent in data:
        s = gpt2_tokenizer.decode(sent, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        sents_list.append(s)

    return sents_list

# function
def get_sent_tokens(sents):
    '''
    This function tokenizes input sentences and prepares them for downstream processing, perhaps for model input or further manipulation.

    Inputs:
    sents: A list or string containing sentences.

    Functionality:
    It tokenizes the input sentences using the GPT-2 tokenizer (gpt2_tokenizer) with specific settings:
    Adding special tokens.
    Not returning token type IDs.
    Truncating sequences if needed.
    Padding to the longest sequence in the batch.
    Not returning attention masks (only input tensors are returned).

    Output:
    The function returns a list (ref) containing lists of tokenized sentences, where each inner list represents the tokenized form of a sentence from the input.
    '''
    ref = []
    tokens = gpt2_tokenizer(sents, add_special_tokens=True,
                            return_token_type_ids=False, truncation=True, padding="longest",
                            return_attention_mask=False, return_tensors="pt")

    for tok in tokens.input_ids.tolist():
        ref.append([tok])

    return ref




***

In [65]:
# Evaluation

# function to evaluate model
def evaluate_model(model, data_loader, criterion, target_tokenizer):
    '''
    Inputs:
    model: The trained seq2seq model to evaluate.
    data_loader: The data loader providing batches of evaluation data. It should contain pairs of source and target sequences.

    criterion: The loss function used for training the model, typically nn.CrossEntropyLoss() or similar.
    target_tokenizer: The tokenizer used for tokenizing the target sequences.

    Functionality:
    Sets the model to evaluation mode (model.eval()).
    Iterates through the evaluation data in batches.
    Passes the source sequences through the model to generate predictions.
    Computes the loss between predicted sequences and actual target sequences.
    Converts model predictions and target sequences from token IDs to text.
    Stores references (actual target sentences) and hypotheses (predicted sentences) to calculate BLEU score.
    Calculates the average loss and BLEU score over the evaluation dataset.

    Outputs:
    avg_loss: Average loss over the evaluation dataset.
    bleu_score: BLEU score indicating the quality of the model's translations compared to the ground truth.
    '''
    
    model.eval()
    total_loss = 0.0
    total_batches = 0
    references = []
    hypotheses = []

    with torch.no_grad():
        for batch in data_loader:
            batch_source = batch[0]
            batch_target = batch[1]

            outputs = model(input_ids=batch_source, decoder_input_ids=batch_target)

            logits_flat = outputs.logits.view(-1, outputs.logits.size(-1))
            target_flat = batch_target.view(-1)

            loss = criterion(logits_flat, target_flat)
            total_loss += loss.item()
            total_batches += 1

            predicted_ids = outputs.logits.argmax(-1)
            predicted_sentences = [target_tokenizer.decode(ids, skip_special_tokens=True) for ids in predicted_ids]
            target_sentences = [target_tokenizer.decode(ids, skip_special_tokens=True) for ids in batch_target]

            references.extend([sent.split() for sent in target_sentences])
            hypotheses.extend([sent.split() for sent in predicted_sentences])

    avg_loss = total_loss / total_batches

    # Calculate BLEU score with smoothing
    smooth_func = SmoothingFunction().method4
    bleu_score = corpus_bleu(references, hypotheses, smoothing_function=smooth_func)

    return avg_loss, bleu_score

***

In [66]:
# prepare validation dataset

# Set up tokenizers
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_tokenizer.pad_token = gpt2_tokenizer.unk_token

# function to load evaluation data
def load_eval_data(src_path, tgt_path):
    src_data = open_file(src_path)
    tgt_data = open_file(tgt_path)
    return src_data, tgt_data

# evaluation data paths
src_eval_file = 'dataset/src_valid.txt'
tgt_eval_file = 'dataset/tgt_valid.txt'

# load evaluation data
src_eval, tgt_eval = load_eval_data(src_eval_file, tgt_eval_file)

# prepare evaluation data tensors using encode_batch function
source_eval_data_encoded = encode_batch((src_eval))
target_eval_data_encoded = encode_batch((tgt_eval))

# create evaluation data loader
eval_dataset = TensorDataset(*source_eval_data_encoded, *target_eval_data_encoded)
eval_data_loader = DataLoader(eval_dataset, batch_size=3, shuffle=False)

# ref sentences
# Load data from the pickle file
with open('dataset/ref_valid.pkl', 'rb') as file:
    ref_sentences = pickle.load(file)


***

In [68]:
# TRAINING

# constants
TRAIN_BATCH_SIZE = 3
N_EPOCHS = 2000
max_token_len = 80
LOG_EVERY = 10000

# encode source and target data using encode_batch function
source_train_data_encoded = encode_batch((src_train))
target_train_data_encoded = encode_batch((tgt_train))

# Initialize models
bert_model = BertModel.from_pretrained('bert-base-cased')
gpt2_model = GPT2Model.from_pretrained('gpt2')

# Configuration
max_token_len = 80
start_token_id = bert_tokenizer.cls_token_id
end_token_id = gpt2_tokenizer.eos_token_id

# Initialize the Encoder-Decoder model with cross-attention enabled
model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-cased', 'gpt2')

# Update configuration for the model
model.config.decoder_start_token_id = start_token_id
model.config.eos_token_id = end_token_id
model.config.max_length = max_token_len
model.config.no_repeat_ngram_size = 3
model.config.add_cross_attention = True  # Enable cross-attention

# Loss function and optimizer setup
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(N_EPOCHS):
    running_loss = 0.0
    for i in range(0, len(source_train_data_encoded[0]), TRAIN_BATCH_SIZE):
        batch_source = source_train_data_encoded[0][i:i+TRAIN_BATCH_SIZE]
        batch_target = target_train_data_encoded[0][i:i+TRAIN_BATCH_SIZE]

        optimizer.zero_grad()

        outputs = model(input_ids=batch_source, decoder_input_ids=batch_target)

        logits_flat = outputs.logits.view(-1, outputs.logits.size(-1))
        target_flat = batch_target.view(-1)

        loss = criterion(logits_flat, target_flat)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % LOG_EVERY == 0 and i > 0:
            print(f'Epoch [{epoch + 1}/{N_EPOCHS}], Batch [{i + 1}/{len(source_train_data_encoded[0])}], Loss: {running_loss / LOG_EVERY:.4f}')
            running_loss = 0.0

    # Evaluate
    val_loss, val_bleu = evaluate_model(model, eval_data_loader, criterion, gpt2_tokenizer)
    print(f'Epoch [{epoch + 1}/{N_EPOCHS}], Evaluation Loss: {val_loss:.4f}, Evaluation BLEU: {val_bleu:.4f}')
    #val_loss, val_bleu, val_sari = evaluate_model(model, eval_data_loader, criterion, gpt2_tokenizer, ref_sentences[0])
    #print(f'Epoch [{epoch + 1}/{N_EPOCHS}], Evaluation Loss: {val_loss:.4f}, Evaluation BLEU: {val_bleu:.4f}, Evaluation SARI: {val_sari:.4f}')

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.3.crossattention.c_attn.weight', 'h.5.ln_cross_attn.weight', 'h.6.ln_cross_attn.weight', 'h.9.crossattention.q_attn.bias', 'h.9.ln_cross_attn.weight', 'h.10.crossattention.c_proj.bias', 'h.6.crossattention.q_attn.weight', 'h.6.crossattention.c_attn.bias', 'h.5.crossattention.c_attn.weight', 'h.3.ln_cross_attn.bias', 'h.7.crossattention.c_attn.bias', 'h.11.crossattention.c_attn.bias', 'h.1.crossattention.c_attn.weight', 'h.2.crossattention.c_attn.weight', 'h.4.crossattention.c_attn.bias', 'h.3.crossattention.c_attn.bias', 'h.11.crossattention.q_attn.bias', 'h.4.ln_cross_attn.weight', 'h.0.ln_cross_attn.bias', 'h.5.crossattention.q_attn.weight', 'h.7.crossattention.q_attn.bias', 'h.5.crossattention.c_proj.bias', 'h.1.ln_cross_attn.weight', 'h.4.crossattention.q_attn.bias', 'h.1.crossattention.c_proj.weight', 'h.10.crossattention.c_proj.weight', 'h.2.crossattention.q_attn.

Epoch [1/2000], Evaluation Loss: 1.3135, Evaluation BLEU: 0.0000
Epoch [2/2000], Evaluation Loss: 4.9308, Evaluation BLEU: 0.0000
Epoch [3/2000], Evaluation Loss: 8.6167, Evaluation BLEU: 0.0000
Epoch [4/2000], Evaluation Loss: 8.1539, Evaluation BLEU: 0.0000
Epoch [5/2000], Evaluation Loss: 9.7744, Evaluation BLEU: 0.0000
Epoch [6/2000], Evaluation Loss: 5.3611, Evaluation BLEU: 0.0000
Epoch [7/2000], Evaluation Loss: 5.2928, Evaluation BLEU: 0.0000
Epoch [8/2000], Evaluation Loss: 13.3926, Evaluation BLEU: 0.0000
Epoch [9/2000], Evaluation Loss: 6.7120, Evaluation BLEU: 0.0000
Epoch [10/2000], Evaluation Loss: 10.0990, Evaluation BLEU: 0.0000
Epoch [11/2000], Evaluation Loss: 9.5260, Evaluation BLEU: 0.0000
Epoch [12/2000], Evaluation Loss: 8.7591, Evaluation BLEU: 0.0000
Epoch [13/2000], Evaluation Loss: 8.6172, Evaluation BLEU: 0.0000
Epoch [14/2000], Evaluation Loss: 9.0559, Evaluation BLEU: 0.0000
Epoch [15/2000], Evaluation Loss: 9.2633, Evaluation BLEU: 0.0000
Epoch [16/2000], 

Epoch [124/2000], Evaluation Loss: 22.6129, Evaluation BLEU: 0.0000
Epoch [125/2000], Evaluation Loss: 22.6177, Evaluation BLEU: 0.0000
Epoch [126/2000], Evaluation Loss: 22.6225, Evaluation BLEU: 0.0000
Epoch [127/2000], Evaluation Loss: 22.6271, Evaluation BLEU: 0.0000
Epoch [128/2000], Evaluation Loss: 22.6316, Evaluation BLEU: 0.0000
Epoch [129/2000], Evaluation Loss: 22.6359, Evaluation BLEU: 0.0000
Epoch [130/2000], Evaluation Loss: 22.6401, Evaluation BLEU: 0.0000
Epoch [131/2000], Evaluation Loss: 22.6442, Evaluation BLEU: 0.0000
Epoch [132/2000], Evaluation Loss: 22.6482, Evaluation BLEU: 0.0000
Epoch [133/2000], Evaluation Loss: 22.6521, Evaluation BLEU: 0.0000
Epoch [134/2000], Evaluation Loss: 22.6558, Evaluation BLEU: 0.0000
Epoch [135/2000], Evaluation Loss: 22.6595, Evaluation BLEU: 0.0000
Epoch [136/2000], Evaluation Loss: 22.6632, Evaluation BLEU: 0.0000
Epoch [137/2000], Evaluation Loss: 22.6668, Evaluation BLEU: 0.0000
Epoch [138/2000], Evaluation Loss: 22.6703, Eval

Epoch [245/2000], Evaluation Loss: 22.9237, Evaluation BLEU: 0.0000
Epoch [246/2000], Evaluation Loss: 22.9253, Evaluation BLEU: 0.0000
Epoch [247/2000], Evaluation Loss: 22.9268, Evaluation BLEU: 0.0000
Epoch [248/2000], Evaluation Loss: 22.9283, Evaluation BLEU: 0.0000
Epoch [249/2000], Evaluation Loss: 22.9299, Evaluation BLEU: 0.0000
Epoch [250/2000], Evaluation Loss: 22.9314, Evaluation BLEU: 0.0000
Epoch [251/2000], Evaluation Loss: 22.9329, Evaluation BLEU: 0.0000
Epoch [252/2000], Evaluation Loss: 22.9344, Evaluation BLEU: 0.0000
Epoch [253/2000], Evaluation Loss: 22.9359, Evaluation BLEU: 0.0000
Epoch [254/2000], Evaluation Loss: 22.9373, Evaluation BLEU: 0.0000
Epoch [255/2000], Evaluation Loss: 22.9388, Evaluation BLEU: 0.0000
Epoch [256/2000], Evaluation Loss: 22.9403, Evaluation BLEU: 0.0000
Epoch [257/2000], Evaluation Loss: 22.9417, Evaluation BLEU: 0.0000
Epoch [258/2000], Evaluation Loss: 22.9431, Evaluation BLEU: 0.0000
Epoch [259/2000], Evaluation Loss: 22.9446, Eval

Epoch [366/2000], Evaluation Loss: 23.0615, Evaluation BLEU: 0.0000
Epoch [367/2000], Evaluation Loss: 23.0624, Evaluation BLEU: 0.0000
Epoch [368/2000], Evaluation Loss: 23.0632, Evaluation BLEU: 0.0000
Epoch [369/2000], Evaluation Loss: 23.0640, Evaluation BLEU: 0.0000
Epoch [370/2000], Evaluation Loss: 23.0648, Evaluation BLEU: 0.0000
Epoch [371/2000], Evaluation Loss: 23.0657, Evaluation BLEU: 0.0000
Epoch [372/2000], Evaluation Loss: 23.0665, Evaluation BLEU: 0.0000
Epoch [373/2000], Evaluation Loss: 23.0673, Evaluation BLEU: 0.0000
Epoch [374/2000], Evaluation Loss: 23.0681, Evaluation BLEU: 0.0000
Epoch [375/2000], Evaluation Loss: 23.0689, Evaluation BLEU: 0.0000
Epoch [376/2000], Evaluation Loss: 23.0697, Evaluation BLEU: 0.0000
Epoch [377/2000], Evaluation Loss: 23.0705, Evaluation BLEU: 0.0000
Epoch [378/2000], Evaluation Loss: 23.0713, Evaluation BLEU: 0.0000
Epoch [379/2000], Evaluation Loss: 23.0721, Evaluation BLEU: 0.0000
Epoch [380/2000], Evaluation Loss: 23.0729, Eval

Epoch [487/2000], Evaluation Loss: 23.1465, Evaluation BLEU: 0.0000
Epoch [488/2000], Evaluation Loss: 23.1471, Evaluation BLEU: 0.0000
Epoch [489/2000], Evaluation Loss: 23.1477, Evaluation BLEU: 0.0000
Epoch [490/2000], Evaluation Loss: 23.1483, Evaluation BLEU: 0.0000
Epoch [491/2000], Evaluation Loss: 23.1490, Evaluation BLEU: 0.0000
Epoch [492/2000], Evaluation Loss: 23.1496, Evaluation BLEU: 0.0000
Epoch [493/2000], Evaluation Loss: 23.1502, Evaluation BLEU: 0.0000
Epoch [494/2000], Evaluation Loss: 23.1508, Evaluation BLEU: 0.0000
Epoch [495/2000], Evaluation Loss: 23.1514, Evaluation BLEU: 0.0000
Epoch [496/2000], Evaluation Loss: 23.1520, Evaluation BLEU: 0.0000
Epoch [497/2000], Evaluation Loss: 23.1526, Evaluation BLEU: 0.0000
Epoch [498/2000], Evaluation Loss: 23.1532, Evaluation BLEU: 0.0000
Epoch [499/2000], Evaluation Loss: 23.1537, Evaluation BLEU: 0.0000
Epoch [500/2000], Evaluation Loss: 23.1543, Evaluation BLEU: 0.0000
Epoch [501/2000], Evaluation Loss: 23.1549, Eval

Epoch [608/2000], Evaluation Loss: 23.2064, Evaluation BLEU: 0.0000
Epoch [609/2000], Evaluation Loss: 23.2068, Evaluation BLEU: 0.0000
Epoch [610/2000], Evaluation Loss: 23.2073, Evaluation BLEU: 0.0000
Epoch [611/2000], Evaluation Loss: 23.2077, Evaluation BLEU: 0.0000
Epoch [612/2000], Evaluation Loss: 23.2082, Evaluation BLEU: 0.0000
Epoch [613/2000], Evaluation Loss: 23.2086, Evaluation BLEU: 0.0000
Epoch [614/2000], Evaluation Loss: 23.2091, Evaluation BLEU: 0.0000
Epoch [615/2000], Evaluation Loss: 23.2095, Evaluation BLEU: 0.0000
Epoch [616/2000], Evaluation Loss: 23.2100, Evaluation BLEU: 0.0000
Epoch [617/2000], Evaluation Loss: 23.2104, Evaluation BLEU: 0.0000
Epoch [618/2000], Evaluation Loss: 23.2109, Evaluation BLEU: 0.0000
Epoch [619/2000], Evaluation Loss: 23.2113, Evaluation BLEU: 0.0000
Epoch [620/2000], Evaluation Loss: 23.2118, Evaluation BLEU: 0.0000
Epoch [621/2000], Evaluation Loss: 23.2122, Evaluation BLEU: 0.0000
Epoch [622/2000], Evaluation Loss: 23.2126, Eval

Epoch [729/2000], Evaluation Loss: 23.2597, Evaluation BLEU: 0.0000
Epoch [730/2000], Evaluation Loss: 23.2600, Evaluation BLEU: 0.0000
Epoch [731/2000], Evaluation Loss: 23.2603, Evaluation BLEU: 0.0000
Epoch [732/2000], Evaluation Loss: 23.2607, Evaluation BLEU: 0.0000
Epoch [733/2000], Evaluation Loss: 23.2610, Evaluation BLEU: 0.0000
Epoch [734/2000], Evaluation Loss: 23.2613, Evaluation BLEU: 0.0000
Epoch [735/2000], Evaluation Loss: 23.2616, Evaluation BLEU: 0.0000
Epoch [736/2000], Evaluation Loss: 23.2620, Evaluation BLEU: 0.0000
Epoch [737/2000], Evaluation Loss: 23.2623, Evaluation BLEU: 0.0000
Epoch [738/2000], Evaluation Loss: 23.2627, Evaluation BLEU: 0.0000
Epoch [739/2000], Evaluation Loss: 23.2630, Evaluation BLEU: 0.0000
Epoch [740/2000], Evaluation Loss: 23.2633, Evaluation BLEU: 0.0000
Epoch [741/2000], Evaluation Loss: 23.2637, Evaluation BLEU: 0.0000
Epoch [742/2000], Evaluation Loss: 23.2640, Evaluation BLEU: 0.0000
Epoch [743/2000], Evaluation Loss: 23.2644, Eval

Epoch [850/2000], Evaluation Loss: 23.2996, Evaluation BLEU: 0.0000
Epoch [851/2000], Evaluation Loss: 23.2999, Evaluation BLEU: 0.0000
Epoch [852/2000], Evaluation Loss: 23.3001, Evaluation BLEU: 0.0000
Epoch [853/2000], Evaluation Loss: 23.3003, Evaluation BLEU: 0.0000
Epoch [854/2000], Evaluation Loss: 23.3005, Evaluation BLEU: 0.0000
Epoch [855/2000], Evaluation Loss: 23.3006, Evaluation BLEU: 0.0000
Epoch [856/2000], Evaluation Loss: 23.3008, Evaluation BLEU: 0.0000
Epoch [857/2000], Evaluation Loss: 23.3010, Evaluation BLEU: 0.0000
Epoch [858/2000], Evaluation Loss: 23.3012, Evaluation BLEU: 0.0000
Epoch [859/2000], Evaluation Loss: 23.3013, Evaluation BLEU: 0.0000
Epoch [860/2000], Evaluation Loss: 23.3014, Evaluation BLEU: 0.0000
Epoch [861/2000], Evaluation Loss: 23.3016, Evaluation BLEU: 0.0000
Epoch [862/2000], Evaluation Loss: 23.3017, Evaluation BLEU: 0.0000
Epoch [863/2000], Evaluation Loss: 23.3019, Evaluation BLEU: 0.0000
Epoch [864/2000], Evaluation Loss: 23.3020, Eval

Epoch [971/2000], Evaluation Loss: 23.3317, Evaluation BLEU: 0.0000
Epoch [972/2000], Evaluation Loss: 23.3319, Evaluation BLEU: 0.0000
Epoch [973/2000], Evaluation Loss: 23.3320, Evaluation BLEU: 0.0000
Epoch [974/2000], Evaluation Loss: 23.3321, Evaluation BLEU: 0.0000
Epoch [975/2000], Evaluation Loss: 23.3323, Evaluation BLEU: 0.0000
Epoch [976/2000], Evaluation Loss: 23.3324, Evaluation BLEU: 0.0000
Epoch [977/2000], Evaluation Loss: 23.3326, Evaluation BLEU: 0.0000
Epoch [978/2000], Evaluation Loss: 23.3327, Evaluation BLEU: 0.0000
Epoch [979/2000], Evaluation Loss: 23.3328, Evaluation BLEU: 0.0000
Epoch [980/2000], Evaluation Loss: 23.3330, Evaluation BLEU: 0.0000
Epoch [981/2000], Evaluation Loss: 23.3331, Evaluation BLEU: 0.0000
Epoch [982/2000], Evaluation Loss: 23.3332, Evaluation BLEU: 0.0000
Epoch [983/2000], Evaluation Loss: 23.3334, Evaluation BLEU: 0.0000
Epoch [984/2000], Evaluation Loss: 23.3335, Evaluation BLEU: 0.0000
Epoch [985/2000], Evaluation Loss: 23.3337, Eval

Epoch [1091/2000], Evaluation Loss: 23.3704, Evaluation BLEU: 0.0000
Epoch [1092/2000], Evaluation Loss: 23.3708, Evaluation BLEU: 0.0000
Epoch [1093/2000], Evaluation Loss: 23.3711, Evaluation BLEU: 0.0000
Epoch [1094/2000], Evaluation Loss: 23.3714, Evaluation BLEU: 0.0000
Epoch [1095/2000], Evaluation Loss: 23.3717, Evaluation BLEU: 0.0000
Epoch [1096/2000], Evaluation Loss: 23.3721, Evaluation BLEU: 0.0000
Epoch [1097/2000], Evaluation Loss: 23.3725, Evaluation BLEU: 0.0000
Epoch [1098/2000], Evaluation Loss: 23.3728, Evaluation BLEU: 0.0000
Epoch [1099/2000], Evaluation Loss: 23.3733, Evaluation BLEU: 0.0000
Epoch [1100/2000], Evaluation Loss: 23.3736, Evaluation BLEU: 0.0000
Epoch [1101/2000], Evaluation Loss: 23.3740, Evaluation BLEU: 0.0000
Epoch [1102/2000], Evaluation Loss: 23.3743, Evaluation BLEU: 0.0000
Epoch [1103/2000], Evaluation Loss: 23.3747, Evaluation BLEU: 0.0000
Epoch [1104/2000], Evaluation Loss: 23.3750, Evaluation BLEU: 0.0000
Epoch [1105/2000], Evaluation Loss

Epoch [1210/2000], Evaluation Loss: 23.4011, Evaluation BLEU: 0.0000
Epoch [1211/2000], Evaluation Loss: 23.4012, Evaluation BLEU: 0.0000
Epoch [1212/2000], Evaluation Loss: 23.4013, Evaluation BLEU: 0.0000
Epoch [1213/2000], Evaluation Loss: 23.4014, Evaluation BLEU: 0.0000
Epoch [1214/2000], Evaluation Loss: 23.4016, Evaluation BLEU: 0.0000
Epoch [1215/2000], Evaluation Loss: 23.4018, Evaluation BLEU: 0.0000
Epoch [1216/2000], Evaluation Loss: 23.4020, Evaluation BLEU: 0.0000
Epoch [1217/2000], Evaluation Loss: 23.4022, Evaluation BLEU: 0.0000
Epoch [1218/2000], Evaluation Loss: 23.4023, Evaluation BLEU: 0.0000
Epoch [1219/2000], Evaluation Loss: 23.4025, Evaluation BLEU: 0.0000
Epoch [1220/2000], Evaluation Loss: 23.4027, Evaluation BLEU: 0.0000
Epoch [1221/2000], Evaluation Loss: 23.4029, Evaluation BLEU: 0.0000
Epoch [1222/2000], Evaluation Loss: 23.4031, Evaluation BLEU: 0.0000
Epoch [1223/2000], Evaluation Loss: 23.4033, Evaluation BLEU: 0.0000
Epoch [1224/2000], Evaluation Loss

Epoch [1329/2000], Evaluation Loss: 23.4268, Evaluation BLEU: 0.0000
Epoch [1330/2000], Evaluation Loss: 23.4272, Evaluation BLEU: 0.0000
Epoch [1331/2000], Evaluation Loss: 23.4275, Evaluation BLEU: 0.0000
Epoch [1332/2000], Evaluation Loss: 23.4278, Evaluation BLEU: 0.0000
Epoch [1333/2000], Evaluation Loss: 23.4282, Evaluation BLEU: 0.0000
Epoch [1334/2000], Evaluation Loss: 23.4284, Evaluation BLEU: 0.0000
Epoch [1335/2000], Evaluation Loss: 23.4287, Evaluation BLEU: 0.0000
Epoch [1336/2000], Evaluation Loss: 23.4290, Evaluation BLEU: 0.0000
Epoch [1337/2000], Evaluation Loss: 23.4293, Evaluation BLEU: 0.0000
Epoch [1338/2000], Evaluation Loss: 23.4296, Evaluation BLEU: 0.0000
Epoch [1339/2000], Evaluation Loss: 23.4299, Evaluation BLEU: 0.0000
Epoch [1340/2000], Evaluation Loss: 23.4301, Evaluation BLEU: 0.0000
Epoch [1341/2000], Evaluation Loss: 23.4304, Evaluation BLEU: 0.0000
Epoch [1342/2000], Evaluation Loss: 23.4307, Evaluation BLEU: 0.0000
Epoch [1343/2000], Evaluation Loss

Epoch [1448/2000], Evaluation Loss: 23.4377, Evaluation BLEU: 0.0000
Epoch [1449/2000], Evaluation Loss: 23.4379, Evaluation BLEU: 0.0000
Epoch [1450/2000], Evaluation Loss: 23.4380, Evaluation BLEU: 0.0000
Epoch [1451/2000], Evaluation Loss: 23.4382, Evaluation BLEU: 0.0000
Epoch [1452/2000], Evaluation Loss: 23.4384, Evaluation BLEU: 0.0000
Epoch [1453/2000], Evaluation Loss: 23.4385, Evaluation BLEU: 0.0000
Epoch [1454/2000], Evaluation Loss: 23.4387, Evaluation BLEU: 0.0000
Epoch [1455/2000], Evaluation Loss: 23.4389, Evaluation BLEU: 0.0000
Epoch [1456/2000], Evaluation Loss: 23.4391, Evaluation BLEU: 0.0000
Epoch [1457/2000], Evaluation Loss: 23.4393, Evaluation BLEU: 0.0000
Epoch [1458/2000], Evaluation Loss: 23.4395, Evaluation BLEU: 0.0000
Epoch [1459/2000], Evaluation Loss: 23.4397, Evaluation BLEU: 0.0000
Epoch [1460/2000], Evaluation Loss: 23.4398, Evaluation BLEU: 0.0000
Epoch [1461/2000], Evaluation Loss: 23.4400, Evaluation BLEU: 0.0000
Epoch [1462/2000], Evaluation Loss

Epoch [1567/2000], Evaluation Loss: 23.4455, Evaluation BLEU: 0.0000
Epoch [1568/2000], Evaluation Loss: 23.4458, Evaluation BLEU: 0.0000
Epoch [1569/2000], Evaluation Loss: 23.4461, Evaluation BLEU: 0.0000
Epoch [1570/2000], Evaluation Loss: 23.4463, Evaluation BLEU: 0.0000
Epoch [1571/2000], Evaluation Loss: 23.4465, Evaluation BLEU: 0.0000
Epoch [1572/2000], Evaluation Loss: 23.4467, Evaluation BLEU: 0.0000
Epoch [1573/2000], Evaluation Loss: 23.4468, Evaluation BLEU: 0.0000
Epoch [1574/2000], Evaluation Loss: 23.4470, Evaluation BLEU: 0.0000
Epoch [1575/2000], Evaluation Loss: 23.4471, Evaluation BLEU: 0.0000
Epoch [1576/2000], Evaluation Loss: 23.4473, Evaluation BLEU: 0.0000
Epoch [1577/2000], Evaluation Loss: 23.4476, Evaluation BLEU: 0.0000
Epoch [1578/2000], Evaluation Loss: 23.4478, Evaluation BLEU: 0.0000
Epoch [1579/2000], Evaluation Loss: 23.4481, Evaluation BLEU: 0.0000
Epoch [1580/2000], Evaluation Loss: 23.4483, Evaluation BLEU: 0.0000
Epoch [1581/2000], Evaluation Loss

Epoch [1686/2000], Evaluation Loss: 23.4604, Evaluation BLEU: 0.0000
Epoch [1687/2000], Evaluation Loss: 23.4604, Evaluation BLEU: 0.0000
Epoch [1688/2000], Evaluation Loss: 23.4604, Evaluation BLEU: 0.0000
Epoch [1689/2000], Evaluation Loss: 23.4605, Evaluation BLEU: 0.0000
Epoch [1690/2000], Evaluation Loss: 23.4606, Evaluation BLEU: 0.0000
Epoch [1691/2000], Evaluation Loss: 23.4607, Evaluation BLEU: 0.0000
Epoch [1692/2000], Evaluation Loss: 23.4608, Evaluation BLEU: 0.0000
Epoch [1693/2000], Evaluation Loss: 23.4609, Evaluation BLEU: 0.0000
Epoch [1694/2000], Evaluation Loss: 23.4611, Evaluation BLEU: 0.0000
Epoch [1695/2000], Evaluation Loss: 23.4612, Evaluation BLEU: 0.0000
Epoch [1696/2000], Evaluation Loss: 23.4614, Evaluation BLEU: 0.0000
Epoch [1697/2000], Evaluation Loss: 23.4617, Evaluation BLEU: 0.0000
Epoch [1698/2000], Evaluation Loss: 23.4619, Evaluation BLEU: 0.0000
Epoch [1699/2000], Evaluation Loss: 23.4621, Evaluation BLEU: 0.0000
Epoch [1700/2000], Evaluation Loss

Epoch [1805/2000], Evaluation Loss: 23.4862, Evaluation BLEU: 0.0000
Epoch [1806/2000], Evaluation Loss: 23.4863, Evaluation BLEU: 0.0000
Epoch [1807/2000], Evaluation Loss: 23.4864, Evaluation BLEU: 0.0000
Epoch [1808/2000], Evaluation Loss: 23.4864, Evaluation BLEU: 0.0000
Epoch [1809/2000], Evaluation Loss: 23.4864, Evaluation BLEU: 0.0000
Epoch [1810/2000], Evaluation Loss: 23.4863, Evaluation BLEU: 0.0000
Epoch [1811/2000], Evaluation Loss: 23.4862, Evaluation BLEU: 0.0000
Epoch [1812/2000], Evaluation Loss: 23.4861, Evaluation BLEU: 0.0000
Epoch [1813/2000], Evaluation Loss: 23.4860, Evaluation BLEU: 0.0000
Epoch [1814/2000], Evaluation Loss: 23.4859, Evaluation BLEU: 0.0000
Epoch [1815/2000], Evaluation Loss: 23.4857, Evaluation BLEU: 0.0000
Epoch [1816/2000], Evaluation Loss: 23.4855, Evaluation BLEU: 0.0000
Epoch [1817/2000], Evaluation Loss: 23.4854, Evaluation BLEU: 0.0000
Epoch [1818/2000], Evaluation Loss: 23.4852, Evaluation BLEU: 0.0000
Epoch [1819/2000], Evaluation Loss

Epoch [1924/2000], Evaluation Loss: 23.4975, Evaluation BLEU: 0.0000
Epoch [1925/2000], Evaluation Loss: 23.4975, Evaluation BLEU: 0.0000
Epoch [1926/2000], Evaluation Loss: 23.4975, Evaluation BLEU: 0.0000
Epoch [1927/2000], Evaluation Loss: 23.4976, Evaluation BLEU: 0.0000
Epoch [1928/2000], Evaluation Loss: 23.4976, Evaluation BLEU: 0.0000
Epoch [1929/2000], Evaluation Loss: 23.4976, Evaluation BLEU: 0.0000
Epoch [1930/2000], Evaluation Loss: 23.4976, Evaluation BLEU: 0.0000
Epoch [1931/2000], Evaluation Loss: 23.4976, Evaluation BLEU: 0.0000
Epoch [1932/2000], Evaluation Loss: 23.4976, Evaluation BLEU: 0.0000
Epoch [1933/2000], Evaluation Loss: 23.4976, Evaluation BLEU: 0.0000
Epoch [1934/2000], Evaluation Loss: 23.4976, Evaluation BLEU: 0.0000
Epoch [1935/2000], Evaluation Loss: 23.4976, Evaluation BLEU: 0.0000
Epoch [1936/2000], Evaluation Loss: 23.4976, Evaluation BLEU: 0.0000
Epoch [1937/2000], Evaluation Loss: 23.4977, Evaluation BLEU: 0.0000
Epoch [1938/2000], Evaluation Loss

In [None]:
def SARIngram(sgrams, cgrams, rgramslist, numref):
    rgramsall = [rgram for rgrams in rgramslist for rgram in rgrams]
    rgramcounter = Counter(rgramsall)
	
    sgramcounter = Counter(sgrams)
    sgramcounter_rep = Counter()
    for sgram, scount in sgramcounter.items():
        sgramcounter_rep[sgram] = scount * numref
        
    cgramcounter = Counter(cgrams)
    cgramcounter_rep = Counter()
    for cgram, ccount in cgramcounter.items():
        cgramcounter_rep[cgram] = ccount * numref
	
    
    # KEEP
    keepgramcounter_rep = sgramcounter_rep & cgramcounter_rep
    keepgramcountergood_rep = keepgramcounter_rep & rgramcounter
    keepgramcounterall_rep = sgramcounter_rep & rgramcounter

    keeptmpscore1 = 0
    keeptmpscore2 = 0
    for keepgram in keepgramcountergood_rep:
        keeptmpscore1 += keepgramcountergood_rep[keepgram] / keepgramcounter_rep[keepgram]
        keeptmpscore2 += keepgramcountergood_rep[keepgram] / keepgramcounterall_rep[keepgram]
        #print "KEEP", keepgram, keepscore, cgramcounter[keepgram], sgramcounter[keepgram], rgramcounter[keepgram]
    keepscore_precision = 0
    if len(keepgramcounter_rep) > 0:
    	keepscore_precision = keeptmpscore1 / len(keepgramcounter_rep)
    keepscore_recall = 0
    if len(keepgramcounterall_rep) > 0:
    	keepscore_recall = keeptmpscore2 / len(keepgramcounterall_rep)
    keepscore = 0
    if keepscore_precision > 0 or keepscore_recall > 0:
        keepscore = 2 * keepscore_precision * keepscore_recall / (keepscore_precision + keepscore_recall)


    # DELETION
    delgramcounter_rep = sgramcounter_rep - cgramcounter_rep
    delgramcountergood_rep = delgramcounter_rep - rgramcounter
    delgramcounterall_rep = sgramcounter_rep - rgramcounter
    deltmpscore1 = 0
    deltmpscore2 = 0
    for delgram in delgramcountergood_rep:
        deltmpscore1 += delgramcountergood_rep[delgram] / delgramcounter_rep[delgram]
        deltmpscore2 += delgramcountergood_rep[delgram] / delgramcounterall_rep[delgram]
    delscore_precision = 0
    if len(delgramcounter_rep) > 0:
    	delscore_precision = deltmpscore1 / len(delgramcounter_rep)
    delscore_recall = 0
    if len(delgramcounterall_rep) > 0:
    	delscore_recall = deltmpscore1 / len(delgramcounterall_rep)
    delscore = 0
    if delscore_precision > 0 or delscore_recall > 0:
        delscore = 2 * delscore_precision * delscore_recall / (delscore_precision + delscore_recall)


    # ADDITION
    addgramcounter = set(cgramcounter) - set(sgramcounter)
    addgramcountergood = set(addgramcounter) & set(rgramcounter)
    addgramcounterall = set(rgramcounter) - set(sgramcounter)

    addtmpscore = 0
    for addgram in addgramcountergood:
        addtmpscore += 1

    addscore_precision = 0
    addscore_recall = 0
    if len(addgramcounter) > 0:
    	addscore_precision = addtmpscore / len(addgramcounter)
    if len(addgramcounterall) > 0:
    	addscore_recall = addtmpscore / len(addgramcounterall)
    addscore = 0
    if addscore_precision > 0 or addscore_recall > 0:
        addscore = 2 * addscore_precision * addscore_recall / (addscore_precision + addscore_recall)
    
    return (keepscore, delscore_precision, addscore)
    

def SARIsent (ssent, csent, rsents) :
    numref = len(rsents)	

    s1grams = ssent.lower().split(" ")
    c1grams = csent.lower().split(" ")
    s2grams = []
    c2grams = []
    s3grams = []
    c3grams = []
    s4grams = []
    c4grams = []
 
    r1gramslist = []
    r2gramslist = []
    r3gramslist = []
    r4gramslist = []
    for rsent in rsents:
        r1grams = rsent.lower().split(" ")    
        r2grams = []
        r3grams = []
        r4grams = []
        r1gramslist.append(r1grams)
        for i in range(0, len(r1grams)-1) :
            if i < len(r1grams) - 1:
                r2gram = r1grams[i] + " " + r1grams[i+1]
                r2grams.append(r2gram)
            if i < len(r1grams)-2:
                r3gram = r1grams[i] + " " + r1grams[i+1] + " " + r1grams[i+2]
                r3grams.append(r3gram)
            if i < len(r1grams)-3:
                r4gram = r1grams[i] + " " + r1grams[i+1] + " " + r1grams[i+2] + " " + r1grams[i+3]
                r4grams.append(r4gram)        
        r2gramslist.append(r2grams)
        r3gramslist.append(r3grams)
        r4gramslist.append(r4grams)
       
    for i in range(0, len(s1grams)-1) :
        if i < len(s1grams) - 1:
            s2gram = s1grams[i] + " " + s1grams[i+1]
            s2grams.append(s2gram)
        if i < len(s1grams)-2:
            s3gram = s1grams[i] + " " + s1grams[i+1] + " " + s1grams[i+2]
            s3grams.append(s3gram)
        if i < len(s1grams)-3:
            s4gram = s1grams[i] + " " + s1grams[i+1] + " " + s1grams[i+2] + " " + s1grams[i+3]
            s4grams.append(s4gram)
            
    for i in range(0, len(c1grams)-1) :
        if i < len(c1grams) - 1:
            c2gram = c1grams[i] + " " + c1grams[i+1]
            c2grams.append(c2gram)
        if i < len(c1grams)-2:
            c3gram = c1grams[i] + " " + c1grams[i+1] + " " + c1grams[i+2]
            c3grams.append(c3gram)
        if i < len(c1grams)-3:
            c4gram = c1grams[i] + " " + c1grams[i+1] + " " + c1grams[i+2] + " " + c1grams[i+3]
            c4grams.append(c4gram)


    (keep1score, del1score, add1score) = SARIngram(s1grams, c1grams, r1gramslist, numref)
    (keep2score, del2score, add2score) = SARIngram(s2grams, c2grams, r2gramslist, numref)
    (keep3score, del3score, add3score) = SARIngram(s3grams, c3grams, r3gramslist, numref)
    (keep4score, del4score, add4score) = SARIngram(s4grams, c4grams, r4gramslist, numref)
    avgkeepscore = sum([keep1score,keep2score,keep3score,keep4score])/4
    avgdelscore = sum([del1score,del2score,del3score,del4score])/4
    avgaddscore = sum([add1score,add2score,add3score,add4score])/4
    finalscore = (avgkeepscore + avgdelscore + avgaddscore ) / 3

    return finalscore


In [69]:
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from collections import Counter

# BLEU score fore evaluation
def compute_bleu_score(logits, labels):
    refs = get_sent_tokens(labels)
    weights = (1.0/2.0, 1.0/2.0, )
    score = corpus_bleu(refs, logits.tolist(), smoothing_function=SmoothingFunction(epsilon=1e-10).method1, weights=weights)
    return score

def compute_sari(norm, pred_tensor, ref):
    pred = decode_sent_tokens(pred_tensor)
    score = 0
    for step, item in enumerate(ref):
        score += sari.SARIsent(norm[step], pred[step], item)
    return score/TRAIN_BATCH_SIZE

def evaluate_model(model, data_loader, criterion, target_tokenizer, ref_sentences):
    model.eval()
    total_loss = 0.0
    total_batches = 0
    references = []
    hypotheses = []

    with torch.no_grad():
        for batch in data_loader:
            batch_source = batch[0]
            batch_target = batch[1]

            outputs = model(input_ids=batch_source, decoder_input_ids=batch_target)

            logits_flat = outputs.logits.view(-1, outputs.logits.size(-1))
            target_flat = batch_target.view(-1)

            loss = criterion(logits_flat, target_flat)
            total_loss += loss.item()
            total_batches += 1

            predicted_ids = outputs.logits.argmax(-1)
            predicted_sentences = [target_tokenizer.decode(ids, skip_special_tokens=True) for ids in predicted_ids]
            target_sentences = [target_tokenizer.decode(ids, skip_special_tokens=True) for ids in batch_target]

            references.extend([sent.split() for sent in target_sentences])
            hypotheses.extend([sent.split() for sent in predicted_sentences])

    avg_loss = total_loss / total_batches

    # Calculate BLEU score using your function
    #bleu_score = compute_bleu_score(torch.argmax(outputs, dim=-1), batch[1])
    logits = outputs.logits  # Extract the logits from Seq2SeqLMOutput
    predicted_ids = torch.argmax(logits, dim=-1)  # Get the predicted IDs
    bleu_score = compute_bleu_score(predicted_ids, batch[1])  # Compute BLEU score
    
    # Calculate SARI score
    sari_scores = []
    for idx, gen_sent in enumerate(hypotheses):
        sari_score = compute_sari(ref_sentences[idx], gen_sent)
        sari_scores.append(sari_score)
    
    avg_sari_score = sum(sari_scores) / len(sari_scores)

    return avg_loss, bleu_score, avg_sari_score


In [70]:
# TRAINING

# constants
TRAIN_BATCH_SIZE = 3
N_EPOCHS = 20
max_token_len = 80
LOG_EVERY = 10000

# encode source and target data using encode_batch function
source_train_data_encoded = encode_batch((src_train))
target_train_data_encoded = encode_batch((tgt_train))

# Initialize models
bert_model = BertModel.from_pretrained('bert-base-cased')
gpt2_model = GPT2Model.from_pretrained('gpt2')

# Configuration
max_token_len = 80
start_token_id = bert_tokenizer.cls_token_id
end_token_id = gpt2_tokenizer.eos_token_id

# Initialize the Encoder-Decoder model with cross-attention enabled
model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-cased', 'gpt2')

# Update configuration for the model
model.config.decoder_start_token_id = start_token_id
model.config.eos_token_id = end_token_id
model.config.max_length = max_token_len
model.config.no_repeat_ngram_size = 3
model.config.add_cross_attention = True  # Enable cross-attention

# Loss function and optimizer setup
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(N_EPOCHS):
    running_loss = 0.0
    for i in range(0, len(source_train_data_encoded[0]), TRAIN_BATCH_SIZE):
        batch_source = source_train_data_encoded[0][i:i+TRAIN_BATCH_SIZE]
        batch_target = target_train_data_encoded[0][i:i+TRAIN_BATCH_SIZE]

        optimizer.zero_grad()

        outputs = model(input_ids=batch_source, decoder_input_ids=batch_target)

        logits_flat = outputs.logits.view(-1, outputs.logits.size(-1))
        target_flat = batch_target.view(-1)

        loss = criterion(logits_flat, target_flat)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % LOG_EVERY == 0 and i > 0:
            print(f'Epoch [{epoch + 1}/{N_EPOCHS}], Batch [{i + 1}/{len(source_train_data_encoded[0])}], Loss: {running_loss / LOG_EVERY:.4f}')
            running_loss = 0.0

    # Evaluate
    #val_loss, val_bleu = evaluate_model(model, eval_data_loader, criterion, gpt2_tokenizer)
    #print(f'Epoch [{epoch + 1}/{N_EPOCHS}], Evaluation Loss: {val_loss:.4f}, Evaluation BLEU: {val_bleu:.4f}')
    val_loss, val_bleu, val_sari = evaluate_model(model, eval_data_loader, criterion, gpt2_tokenizer, ref_sentences[0])
    print(f'Epoch [{epoch + 1}/{N_EPOCHS}], Evaluation Loss: {val_loss:.4f}, Evaluation BLEU: {val_bleu:.4f}, Evaluation SARI: {val_sari:.4f}')

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.3.crossattention.c_attn.weight', 'h.5.ln_cross_attn.weight', 'h.6.ln_cross_attn.weight', 'h.9.crossattention.q_attn.bias', 'h.9.ln_cross_attn.weight', 'h.10.crossattention.c_proj.bias', 'h.6.crossattention.q_attn.weight', 'h.6.crossattention.c_attn.bias', 'h.5.crossattention.c_attn.weight', 'h.3.ln_cross_attn.bias', 'h.7.crossattention.c_attn.bias', 'h.11.crossattention.c_attn.bias', 'h.1.crossattention.c_attn.weight', 'h.2.crossattention.c_attn.weight', 'h.4.crossattention.c_attn.bias', 'h.3.crossattention.c_attn.bias', 'h.11.crossattention.q_attn.bias', 'h.4.ln_cross_attn.weight', 'h.0.ln_cross_attn.bias', 'h.5.crossattention.q_attn.weight', 'h.7.crossattention.q_attn.bias', 'h.5.crossattention.c_proj.bias', 'h.1.ln_cross_attn.weight', 'h.4.crossattention.q_attn.bias', 'h.1.crossattention.c_proj.weight', 'h.10.crossattention.c_proj.weight', 'h.2.crossattention.q_attn.

ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).

In [None]:
def evaluate(data_loader, e_loss):
    was_training = model.training
    model.eval()
    eval_loss = e_loss
    bleu_score = 0
    sari_score = 0
    softmax = nn.LogSoftmax(dim = -1)

    with torch.no_grad():
        for step, batch in enumerate(data_loader):
            src_tensors, src_attn_tensors, tgt_tensors, tgt_attn_tensors, labels = tokenizer.encode_batch(batch)
            loss, logits = model(input_ids = src_tensors.to(device), 
                            decoder_input_ids = tgt_tensors.to(device),
                            attention_mask = src_attn_tensors.to(device),
                            decoder_attention_mask = tgt_attn_tensors.to(device),
                            labels = labels.to(device))[:2]
            outputs = softmax(logits)
            score = compute_bleu_score(torch.argmax(outputs, dim=-1), batch[1])
            s_score = compute_sari(batch[0], torch.argmax(outputs, dim=-1), batch[2])
            if step == 0:
                eval_loss = loss.item()
                bleu_score = score
                sari_score = s_score
            else:
                eval_loss = (1/2.0)*(eval_loss + loss.item())
                bleu_score = (1/2.0)* (bleu_score+score)
                sari_score = (1/2.0)* (sari_score+s_score)
        
    if was_training:
        model.train()

    return eval_loss, bleu_score, sari_score

def load_checkpt(checkpt_path, optimizer=None):
    checkpoint = torch.load(checkpt_path)
    if device == "cpu":
        model.load_state_dict(checkpoint["model_state_dict"], map_location=torch.device("cpu"))
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"], map_location=torch.device("cpu"))
    else:
        model.load_state_dict(checkpoint["model_state_dict"])
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

    eval_loss = checkpoint["eval_loss"]
    epoch = checkpoint["epoch"]

    return optimizer, eval_loss, epoch

def save_model_checkpt(state, is_best, check_pt_path, best_model_path):
    f_path = check_pt_path
    torch.save(state, f_path)

    if is_best:
        best_fpath = best_model_path
        shutil.copyfile(f_path, best_fpath)

In [None]:
def train_model(start_epoch, eval_loss, loaders, optimizer, check_pt_path, best_model_path):
    best_eval_loss = eval_loss
    print("Model training started...")
    for epoch in range(start_epoch, N_EPOCH):
        print(f"Epoch {epoch} running...")
        epoch_start_time = time.time()
        epoch_train_loss = 0
        epoch_eval_loss = 0
        model.train()
        for step, batch in enumerate(loaders[0]):
            src_tensors, src_attn_tensors, tgt_tensors, tgt_attn_tensors, labels = tokenizer.encode_batch(batch)
            optimizer.zero_grad()
            model.zero_grad()
            loss = model(input_ids = src_tensors.to(device), 
                            decoder_input_ids = tgt_tensors.to(device),
                            attention_mask = src_attn_tensors.to(device),
                            decoder_attention_mask = tgt_attn_tensors.to(device),
                            labels = labels.to(device))[0]
            if step == 0:
                epoch_train_loss = loss.item()
            else:
                epoch_train_loss = (1/2.0)*(epoch_train_loss + loss.item())
            
            loss.backward()
            optimizer.step()

            if (step+1) % LOG_EVERY == 0:
                print(f'Epoch: {epoch} | iter: {step+1} | avg. train loss: {epoch_train_loss} | time elapsed: {time.time() - epoch_start_time}')
                logging.info(f'Epoch: {epoch} | iter: {step+1} | avg. train loss: {epoch_train_loss} | time elapsed: {time.time() - epoch_start_time}')
        
        eval_start_time = time.time()
        epoch_eval_loss, bleu_score, sari_score = evaluate(loaders[1], epoch_eval_loss)
        epoch_eval_loss = epoch_eval_loss/TRAIN_BATCH_SIZE
        print(f'Completed Epoch: {epoch} | avg. eval loss: {epoch_eval_loss:.5f} | blue score: {bleu_score} | Sari score: {sari_score} | time elapsed: {time.time() - eval_start_time}')
        logging.info(f'Completed Epoch: {epoch} | avg. eval loss: {epoch_eval_loss:.5f} | blue score: {bleu_score}| Sari score: {sari_score} | time elapsed: {time.time() - eval_start_time}')

        check_pt = {
            'epoch': epoch+1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'eval_loss': epoch_eval_loss,
            'sari_score': sari_score,
            'bleu_score': bleu_score
        }
        check_pt_time = time.time()
        print("Saving Checkpoint.......")
        if epoch_eval_loss < best_eval_loss:
            print("New best model found")
            logging.info(f"New best model found")
            best_eval_loss = epoch_eval_loss
            save_model_checkpt(check_pt, True, check_pt_path, best_model_path)
        else:
            save_model_checkpt(check_pt, False, check_pt_path, best_model_path)  
        print(f"Checkpoint saved successfully with time: {time.time() - check_pt_time}")
        logging.info(f"Checkpoint saved successfully with time: {time.time() - check_pt_time}")

        gc.collect()
        torch.cuda.empty_cache()  