# Using RAG with Word files



In [None]:
from google.colab import drive
drive.mount('/content/drive')

### imports

In [None]:
!pip install llama-index
!pip install llama-index-embeddings-huggingface
!pip install peft
!pip install auto-gptq
!pip install optimum
!pip install bitsandbytes

In [None]:
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex, StorageContext, load_index_from_storage
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.postprocessor import SimilarityPostprocessor

### Define Settings

In [None]:
# import any embedding model on HF hub (https://huggingface.co/spaces/mteb/leaderboard)
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
# Settings.embed_model = HuggingFaceEmbedding(model_name="thenlper/gte-large") # alternative model

Settings.llm = None             #  in this case, the system will only use an embedding model and not a full-fledged LLM
Settings.chunk_size = 100       #  input text is being divided into chunks of 256 tokens/characters (depending on the implementation)
Settings.chunk_overlap = 25     # there is an overlap of 25 tokens/characters between consecutive chunks. Overlap helps maintain context continuity between chunks

### Read and Store Docs into Vector DB

In [None]:
import os

# check if storage already exists
PERSIST_DIR = "/content/drive/MyDrive/VQA-Final/miscellanous dataset/RAG_Context/Storage"

if not os.path.exists(PERSIST_DIR):
    # load the documents and create the index
    documents = SimpleDirectoryReader("/content/drive/MyDrive/VQA-Final/miscellanous dataset/RAG_Context/documents").load_data()
    index = VectorStoreIndex.from_documents(documents)
    # store it for later
    index.storage_context.persist(persist_dir=PERSIST_DIR)
else:
    # load the existing index
    storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
    index = load_index_from_storage(storage_context)


### Set Up Search Function

In [None]:
# set number of docs to retreive
top_k = 1

# configure retriever
retriever = VectorIndexRetriever(
    index=index,                                # retrives the similar documents from created index
    similarity_top_k=top_k,
)

In [None]:
# assemble query engine
query_engine = RetrieverQueryEngine(
    retriever=retriever,
    node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.5)],
)

### Retrieve Relevant Docs

In [None]:
# query documents
query = "Where is Taleju Temple?"
response = query_engine.query(query)
response

In [None]:
# Reformat response to form a paragraph
context = "Context:\n"
for i in range(top_k):
    # Append each text from the source nodes, stripping excess newlines and spaces
    context += response.source_nodes[i].text.replace("\n", " ").strip() + " "

# Remove extra spaces at the start/end and print the final formatted context
context = context.strip()

# Print the reformatted response as a paragraph
print(context)


# Retrive the context


In [1]:
!pip install wandb
!pip install rouge-score

Collecting rouge-score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: rouge-score
  Building wheel for rouge-score (setup.py) ... [?25ldone
[?25h  Created wheel for rouge-score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=d22ab900980aac8cb6df3683332dc9cd5a052f64622028ffe2b1e2a435b2b5cd
  Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532ff8599a28e9beef17985c9e9c31e541b4
Successfully built rouge-score
Installing collected packages: rouge-score
Successfully installed rouge-score-0.1.2


In [2]:
import torch
from transformers import BartTokenizer, BartForConditionalGeneration, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math
import wandb
from rouge_score import rouge_scorer
# from dotenv import load_dotenv
import os
from tabulate import tabulate
from nltk.translate.bleu_score import corpus_bleu
import sympy as sp
# load_dotenv()

In [4]:
df = pd.read_csv('/kaggle/input/test-rag/updated_dataset.csv')

# If 'Unnamed: 0' column still exists, you can drop it
df = df.loc[:, ~df.columns.str.contains('^Unnamed')]

In [None]:
values = df['object'].value_counts()
values

In [None]:
df = df[~(df['object'] == "what are the customary actions performed at the conclusion of a prayer wheel practice session?\"")]

In [None]:
values = df['object'].value_counts()
values

In [None]:
# 80% -> Training Data, 20% -> Testing Data
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)

# 90% -> Training Data, 10% -> Validation Data
train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42)

In [None]:
train_df

In [None]:
values = train_df['object'].value_counts()
values

In [None]:
# Placeholder function for retrieving context using RAG
def retrieve_context(question):
    # Here, you would implement your RAG querying logic
    # For example, using the query engine you've set up
    context = ""
    response = query_engine.query(question)  # Replace with actual querying method
    # Check if source_nodes exist in the response and get the number of available nodes
    num_nodes = len(response.source_nodes) if response.source_nodes else 0

    # Iterate only up to the available number of nodes or top_k, whichever is smaller
    for i in range(min(top_k, num_nodes)):
        # Append each text from the source nodes, stripping excess newlines and spaces
        context += response.source_nodes[i].text.replace("\n", " ").strip() + " "

    # Remove extra spaces at the start/end and print the final formatted context
    context = context.strip() # Adjust based on how your response object is structured
    return context

In [None]:
# Add a new column for context
df['Context'] = ''

# Iterate through each question and retrieve the context
for index, row in df.iterrows():
    question = row['question']
    # Retrieve context for the current question
    context = retrieve_context(question)
    # Update the DataFrame with the retrieved context
    df.at[index, 'Context'] = context

# Save the updated DataFrame to a new CSV file
output_file_path = 'updated_dataset'  # Replace with your desired output file path
# Define an escape character (e.g., backslash)
df.to_csv("/content/drive/MyDrive/VQA-Final/miscellanous dataset/RAG_Context/updated_dataset.csv", index=False, escapechar='\\')

print("Context added and saved to:", output_file_path)

# BART QA with context Training

In [5]:
df = pd.read_csv('/kaggle/input/test-rag/updated_dataset.csv')

# # If 'Unnamed: 0' column still exists, you can drop it
# df = df.loc[:, ~df.columns.str.contains('^Unnamed')]

# 80% -> Training Data, 20% -> Testing Data
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)

# 90% -> Training Data, 10% -> Validation Data
train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42)

In [6]:
train_df

Unnamed: 0,question,object,answer,Context
3791,How many arms can Bhairava be depicted with?,kala bhairav,"Bhairava may be depicted with four, eight, or ...","Often shown with a wrathful expression, bared ..."
6210,How did the brothers celebrate completing the ...,boudhanath,"They stood in front of it and made prayers, ge...",Why? Because it is so powerful that the wishes...
1515,What specific physical characteristics must a ...,taleju bell,"The girl must be in excellent health, never ha...",The king and other religious leaders that migh...
1469,During which event does the Nepalese King typi...,taleju bell,The Nepalese King seeks the blessing of the Ro...,The Kumari is also revered and worshipped by s...
2763,What does the prayer wheel represent in Buddhi...,prayer wheel,The prayer wheel embodies all the actions of t...,"A prayer wheel, or mani wheel, is a cylindrica..."
...,...,...,...,...
6165,What relevance does the Bouddha Stupa legend h...,boudhanath,It is considered important and incredibly insp...,A painting of Samvari is on the rear of the Pu...
310,What does the Nara Yali emblem signify in trad...,yali,The Nara Yali symbolizes the protection of dha...,"Kanjivaram sari, it is intricately woven with ..."
995,Which individual is currently recognized as th...,Taleju Temple,The current Royal Kumari is Trishna Shakya age...,The Kumari is also revered and worshipped by s...
6310,How has the adoption of Mahayana Buddhism evol...,boudhanath,It has spread and been preserved for many year...,So they stayed around him on the mountain to p...


In [7]:
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")

def calculate_max_length(column_name):
    df[column_name] = df[column_name].astype(str)
    return df[column_name].apply(lambda x: len(tokenizer.tokenize(x))).max()

max_length_question = calculate_max_length('question')
max_length_object = calculate_max_length('object')
max_length_context = calculate_max_length('Context')
max_length_answer = calculate_max_length('answer')

print(f"Maximum token length in the 'Input' : {max_length_question + max_length_object + max_length_context}")
print(f"Maximum token length in the 'Output' : {max_length_answer}")

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.58k [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

Maximum token length in the 'Input' : 182
Maximum token length in the 'Output' : 57


In [None]:
all_text = ' '.join(df['Question']) + ' ' + ' '.join(df['Answer']) + ' ' + ' '.join(df['Context'])
words = list(all_text.split())

len(words)

In [None]:
from collections import Counter

word_counts = Counter(all_text.split())
word_counts

In [None]:
unique_words = list(word_counts.keys())
unique_words

In [None]:
print(len(unique_words))

In [None]:
def compare_words_with_bart_vocab(word_list):
    # Tokenize the words in the list
    tokenized_words = [token for word in word_list for token in tokenizer.tokenize(word)]

    # Get BART's vocabulary
    bart_vocab = set(tokenizer.get_vocab().keys())

    # Check which tokens are in or not in BART's vocabulary
    words_in_vocab = {word for word in tokenized_words if word in bart_vocab}
    words_not_in_vocab = {word for word in tokenized_words if word not in bart_vocab}

    return words_in_vocab, words_not_in_vocab

In [None]:
words_in_vocab, words_not_in_vocab = compare_words_with_bart_vocab(unique_words)

print("Words in BART vocabulary:", words_in_vocab)
print("Words not in BART vocabulary:", words_not_in_vocab)

In [None]:
print(f"Words to be added: {len(words_not_in_vocab)}")

In [None]:
def add_new_tokens_to_vocab(new_tokens):
    num_added_tokens = tokenizer.add_tokens(new_tokens)
    model.resize_token_embeddings(len(tokenizer))
    vocab = tokenizer.get_vocab()
    print("Vocabulary size:", len(vocab))

new_tokens = list(words_not_in_vocab)
num_added_tokens = add_new_tokens_to_vocab(new_tokens)

In [9]:
class QADataset(Dataset):
    '''For Loading the Dataset for Question Answering'''
    def __init__(self, data, tokenizer, question_max_length=200, context_max_length=200, answer_max_length=80):
        self.data = data
        self.tokenizer = tokenizer
        self.question_max_length = question_max_length
        self.context_max_length = context_max_length
        self.answer_max_length = answer_max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Extract the question, object (context), and answer from the dataset
        question = self.data.iloc[idx]['question']
        object_name = self.data.iloc[idx]['object']
        answer = self.data.iloc[idx]['answer']
        context = self.data.iloc[idx]['Context']

        # Combine the object name (context) and question
        combined = 'Object :' + str(object_name) + ' Question: ' + str(question) + ' Context: ' + str(context)            # if object_name is the float

        # Tokenize the combined input (question + context)
        inputs = self.tokenizer(
            combined,
            max_length=self.context_max_length,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )

        targets = self.tokenizer(
            answer,
            max_length=self.answer_max_length,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )

        input_ids = inputs.input_ids.squeeze()
        attention_mask = inputs.attention_mask.squeeze()
        target_ids = targets.input_ids.squeeze()

        return {

            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': target_ids
        }

In [10]:
model

BartForConditionalGeneration(
  (model): BartModel(
    (shared): BartScaledWordEmbedding(50264, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): BartScaledWordEmbedding(50264, 1024, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x BartEncoderLayer(
          (self_attn): BartSdpaAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
    

In [11]:
train_dataset = QADataset(train_df, tokenizer=tokenizer)
val_dataset = QADataset(val_df, tokenizer=tokenizer)
test_dataset = QADataset(test_df, tokenizer=tokenizer)

### Code changed with freezing the encoder layers, gradient accumulation and low learning rate

In [13]:
class VQA_Trainer:
    '''Class for Trainer Setup to Train the BART Model for VQA'''

    def __init__(self, model, train_dataloader, eval_dataloader, device, config, gradient_accumulation_steps=4, learning_rate=1e-5):
        ''' Constructor '''
        self.model = model
        self.train_dataloader = train_dataloader
        self.eval_dataloader = eval_dataloader
        self.device = device
        self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
        self.scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

        # Freeze initial encoder layers to prevent overfitting
        for param in self.model.model.encoder.parameters():
            param.requires_grad = False

        # Optimizer with a lower learning rate
        self.optimizer = AdamW(filter(lambda p: p.requires_grad, self.model.parameters()), lr=learning_rate)

        # Learning rate scheduler (adjust if necessary)
        total_steps = config["epochs"] * (len(train_dataloader) // gradient_accumulation_steps)
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=0, num_training_steps=total_steps)

        # WandB setup
        wandb.init(project=config['project_name'], config=config)
        wandb.watch(self.model, log="all")

        self.gradient_accumulation_steps = gradient_accumulation_steps

    def evaluate(self):
        ''' For Evaluation at the End of Each Epoch '''
        self.model.eval()
        total_loss = 0
        predictions, references, token_level_accuracies = [], [], []

        progress_bar = tqdm(self.eval_dataloader, desc="Evaluating")
        for batch in progress_bar:
            with torch.no_grad():
                inputs = {key: val.to(self.device) for key, val in batch.items()}
                outputs = self.model(**inputs)
                total_loss += outputs.loss.item()

                # Generate predictions and decode
                summary_ids = self.model.generate(inputs['input_ids'], min_length=10, max_length=100, num_beams=4, early_stopping=True)
                decoded_preds = self.tokenizer.batch_decode(summary_ids, skip_special_tokens=True)

                labels = batch['labels']
                labels = torch.where(labels != -100, labels, self.tokenizer.pad_token_id)
                decoded_refs = self.tokenizer.batch_decode(labels, skip_special_tokens=True)

                predictions.extend([pred.split() for pred in decoded_preds])
                references.extend([[ref.split()] for ref in decoded_refs])

                # Token-level accuracy
                for pred, ref in zip(decoded_preds, decoded_refs):
                    pred_tokens, ref_tokens = pred.split(), ref.split()
                    token_accuracy = sum(1 for p, r in zip(pred_tokens, ref_tokens) if p == r) / max(len(ref_tokens), 1)
                    token_level_accuracies.append(token_accuracy)

        avg_loss = total_loss / len(self.eval_dataloader)
        perplexity = math.exp(avg_loss)
        bleu_score = corpus_bleu(references, predictions)

        # ROUGE-L scores
        rouge_l_f1_scores, rouge_l_precision_scores, rouge_l_recall_scores = [], [], []
        for pred, ref in zip(decoded_preds, decoded_refs):
            rouge_l = self.scorer.score(ref, pred)['rougeL']
            rouge_l_f1_scores.append(rouge_l.fmeasure)
            rouge_l_precision_scores.append(rouge_l.precision)
            rouge_l_recall_scores.append(rouge_l.recall)

        avg_token_level_accuracy = np.mean(token_level_accuracies)
        avg_rouge_l_f1 = np.mean(rouge_l_f1_scores)
        avg_rouge_l_precision = np.mean(rouge_l_precision_scores)
        avg_rouge_l_recall = np.mean(rouge_l_recall_scores)

        # Log metrics
        metrics = {
            "Validation Loss": avg_loss,
            "Perplexity": perplexity,
            "BLEU": bleu_score,
            "ROUGE-L F1": avg_rouge_l_f1,
            "ROUGE-L Precision": avg_rouge_l_precision,
            "ROUGE-L Recall": avg_rouge_l_recall,
            "Token-Level Accuracy": avg_token_level_accuracy,
        }
        wandb.log(metrics)
        print(tabulate(pd.DataFrame([metrics]), headers="keys", tablefmt="psql"))

        return metrics

    def train_epoch(self):
        ''' To Train for Single Epoch with Gradient Accumulation '''
        self.model.train()
        total_loss = 0
        for i, batch in enumerate(tqdm(self.train_dataloader, desc="Training")):
            inputs = {key: val.to(self.device) for key, val in batch.items()}
            outputs = self.model(**inputs)
            loss = outputs.loss / self.gradient_accumulation_steps
            loss.backward()
            total_loss += loss.item()

            # Step only after gradient accumulation steps
            if (i + 1) % self.gradient_accumulation_steps == 0:
                self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()

            wandb.log({"train_loss": loss.item()})

        print(f"Total training loss for epoch: {total_loss / len(self.train_dataloader)}")

    def train(self, epochs):
        ''' To Train for N Number of Epochs Passed from User '''
        for epoch in range(epochs):
            print(f"Epoch {epoch + 1}/{epochs}")
            self.train_epoch()
            metrics = self.evaluate()
            print(f"Metrics: {metrics}")

        torch.save(self.model.state_dict(), f"model_final.pth")


In [14]:
api_key = os.getenv('API_KEY')
!wandb login --relogin $api_key

Traceback (most recent call last):
  File "/opt/conda/bin/wandb", line 8, in <module>
    sys.exit(cli())
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1078, in main
    rv = self.invoke(ctx)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/wandb/cli/cli.py", line 108, in wrapper
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/wandb/cli/cli.py", line 258, in login
    wandb.login(
  File "/opt/conda/lib/python3.1

In [15]:
config = {
     "batch_size":8,
     "epochs" : 10,
     "model_name": "facebook/bart-large-cnn",
     "project_name": "RAG_Conditional_BART",
}

In [16]:
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)

In [17]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BartForConditionalGeneration.from_pretrained(config['model_name']).to(device)

In [18]:
trainer = VQA_Trainer(model, train_dataloader, val_dataloader,device,config)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112775211111309, max=1.0…

In [19]:
trainer.train(config['epochs'])

Epoch 1/10


Training: 100%|██████████| 594/594 [02:51<00:00,  3.47it/s]


Total training loss for epoch: 0.5514911087825644


Evaluating: 100%|██████████| 66/66 [01:43<00:00,  1.57s/it]


+----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------+
|    |   Validation Loss |   Perplexity |     BLEU |   ROUGE-L F1 |   ROUGE-L Precision |   ROUGE-L Recall |   Token-Level Accuracy |
|----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------|
|  0 |           1.03936 |      2.82742 | 0.150997 |     0.294921 |             0.38086 |         0.245384 |               0.115591 |
+----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------+
Metrics: {'Validation Loss': 1.039363952297153, 'Perplexity': 2.8274180695367317, 'BLEU': 0.15099734749378416, 'ROUGE-L F1': 0.29492116815589925, 'ROUGE-L Precision': 0.38086007130124777, 'ROUGE-L Recall': 0.24538371168305378, 'Token-Level Accuracy': 0.11559075947283412}
Epoch 2/10


Training: 100%|██████████| 594/594 [02:51<00:00,  3.46it/s]


Total training loss for epoch: 0.2628271894824224


Evaluating: 100%|██████████| 66/66 [01:42<00:00,  1.55s/it]


+----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------+
|    |   Validation Loss |   Perplexity |     BLEU |   ROUGE-L F1 |   ROUGE-L Precision |   ROUGE-L Recall |   Token-Level Accuracy |
|----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------|
|  0 |          0.955087 |       2.5989 | 0.186126 |     0.240002 |            0.239882 |         0.289683 |               0.138437 |
+----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------+
Metrics: {'Validation Loss': 0.9550869916424607, 'Perplexity': 2.5988966553748085, 'BLEU': 0.18612625994313448, 'ROUGE-L F1': 0.24000175104826268, 'ROUGE-L Precision': 0.2398824124791617, 'ROUGE-L Recall': 0.289683105872622, 'Token-Level Accuracy': 0.13843740976005522}
Epoch 3/10


Training: 100%|██████████| 594/594 [02:51<00:00,  3.47it/s]


Total training loss for epoch: 0.2377465318900969


Evaluating: 100%|██████████| 66/66 [01:44<00:00,  1.58s/it]


+----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------+
|    |   Validation Loss |   Perplexity |     BLEU |   ROUGE-L F1 |   ROUGE-L Precision |   ROUGE-L Recall |   Token-Level Accuracy |
|----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------|
|  0 |          0.899177 |      2.45758 | 0.204559 |     0.407488 |             0.35554 |         0.508628 |               0.161969 |
+----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------+
Metrics: {'Validation Loss': 0.8991767863432566, 'Perplexity': 2.457579165469841, 'BLEU': 0.2045587644916054, 'ROUGE-L F1': 0.4074876043598919, 'ROUGE-L Precision': 0.355540293040293, 'ROUGE-L Recall': 0.5086278396311291, 'Token-Level Accuracy': 0.1619690212907936}
Epoch 4/10


Training: 100%|██████████| 594/594 [02:52<00:00,  3.45it/s]


Total training loss for epoch: 0.2198980438839707


Evaluating: 100%|██████████| 66/66 [01:44<00:00,  1.58s/it]


+----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------+
|    |   Validation Loss |   Perplexity |     BLEU |   ROUGE-L F1 |   ROUGE-L Precision |   ROUGE-L Recall |   Token-Level Accuracy |
|----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------|
|  0 |          0.851344 |      2.34279 | 0.220163 |     0.427761 |            0.400789 |         0.482045 |               0.177755 |
+----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------+
Metrics: {'Validation Loss': 0.8513437446319696, 'Perplexity': 2.342792853061659, 'BLEU': 0.2201625223379602, 'ROUGE-L F1': 0.4277607735162914, 'ROUGE-L Precision': 0.40078879453879457, 'ROUGE-L Recall': 0.4820446832579186, 'Token-Level Accuracy': 0.1777549619804236}
Epoch 5/10


Training: 100%|██████████| 594/594 [02:58<00:00,  3.32it/s]


Total training loss for epoch: 0.20586173048224113


Evaluating: 100%|██████████| 66/66 [01:38<00:00,  1.50s/it]


+----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------+
|    |   Validation Loss |   Perplexity |     BLEU |   ROUGE-L F1 |   ROUGE-L Precision |   ROUGE-L Recall |   Token-Level Accuracy |
|----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------|
|  0 |          0.817772 |      2.26545 | 0.242235 |     0.405655 |            0.350111 |         0.537078 |               0.201423 |
+----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------+
Metrics: {'Validation Loss': 0.8177719883846514, 'Perplexity': 2.2654467687599964, 'BLEU': 0.24223507281538234, 'ROUGE-L F1': 0.40565546279021775, 'ROUGE-L Precision': 0.3501107241142071, 'ROUGE-L Recall': 0.5370775058275058, 'Token-Level Accuracy': 0.20142337953304915}
Epoch 6/10


Training: 100%|██████████| 594/594 [02:59<00:00,  3.31it/s]


Total training loss for epoch: 0.19425647582871344


Evaluating: 100%|██████████| 66/66 [01:47<00:00,  1.63s/it]


+----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------+
|    |   Validation Loss |   Perplexity |     BLEU |   ROUGE-L F1 |   ROUGE-L Precision |   ROUGE-L Recall |   Token-Level Accuracy |
|----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------|
|  0 |          0.794841 |      2.21409 | 0.249753 |     0.404239 |            0.371011 |         0.513709 |               0.219904 |
+----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------+
Metrics: {'Validation Loss': 0.7948410348458723, 'Perplexity': 2.214089005827602, 'BLEU': 0.24975296294149343, 'ROUGE-L F1': 0.4042385343391861, 'ROUGE-L Precision': 0.3710106779185726, 'ROUGE-L Recall': 0.5137089932126697, 'Token-Level Accuracy': 0.21990424094330854}
Epoch 7/10


Training: 100%|██████████| 594/594 [02:51<00:00,  3.45it/s]


Total training loss for epoch: 0.18587819386511942


Evaluating: 100%|██████████| 66/66 [01:45<00:00,  1.60s/it]


+----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------+
|    |   Validation Loss |   Perplexity |     BLEU |   ROUGE-L F1 |   ROUGE-L Precision |   ROUGE-L Recall |   Token-Level Accuracy |
|----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------|
|  0 |          0.759708 |      2.13765 | 0.259478 |     0.363818 |            0.383284 |         0.354671 |               0.229864 |
+----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------+
Metrics: {'Validation Loss': 0.7597078464247964, 'Perplexity': 2.1376516067002473, 'BLEU': 0.2594777338604868, 'ROUGE-L F1': 0.3638181436654484, 'ROUGE-L Precision': 0.3832840646139879, 'ROUGE-L Recall': 0.3546705304928989, 'Token-Level Accuracy': 0.2298636812131819}
Epoch 8/10


Training: 100%|██████████| 594/594 [02:51<00:00,  3.47it/s]


Total training loss for epoch: 0.17958330031898287


Evaluating: 100%|██████████| 66/66 [01:47<00:00,  1.63s/it]


+----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------+
|    |   Validation Loss |   Perplexity |     BLEU |   ROUGE-L F1 |   ROUGE-L Precision |   ROUGE-L Recall |   Token-Level Accuracy |
|----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------|
|  0 |          0.743172 |       2.1026 | 0.273745 |     0.299257 |            0.310565 |          0.30786 |               0.242748 |
+----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------+
Metrics: {'Validation Loss': 0.7431724514022018, 'Perplexity': 2.102595326435502, 'BLEU': 0.27374548705959273, 'ROUGE-L F1': 0.2992565435491805, 'ROUGE-L Precision': 0.31056547619047614, 'ROUGE-L Recall': 0.3078602406048058, 'Token-Level Accuracy': 0.2427480355186944}
Epoch 9/10


Training: 100%|██████████| 594/594 [02:59<00:00,  3.31it/s]


Total training loss for epoch: 0.17552879851574849


Evaluating: 100%|██████████| 66/66 [01:39<00:00,  1.50s/it]


+----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------+
|    |   Validation Loss |   Perplexity |     BLEU |   ROUGE-L F1 |   ROUGE-L Precision |   ROUGE-L Recall |   Token-Level Accuracy |
|----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------|
|  0 |          0.734266 |      2.08395 | 0.275009 |     0.616154 |             0.58632 |         0.683643 |                0.25073 |
+----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------+
Metrics: {'Validation Loss': 0.7342657103683009, 'Perplexity': 2.083951206789657, 'BLEU': 0.2750088220682335, 'ROUGE-L F1': 0.6161542862651733, 'ROUGE-L Precision': 0.5863204197763021, 'ROUGE-L Recall': 0.6836425957942986, 'Token-Level Accuracy': 0.2507301745016528}
Epoch 10/10


Training: 100%|██████████| 594/594 [02:58<00:00,  3.32it/s]


Total training loss for epoch: 0.1735944439867129


Evaluating: 100%|██████████| 66/66 [01:46<00:00,  1.61s/it]


+----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------+
|    |   Validation Loss |   Perplexity |     BLEU |   ROUGE-L F1 |   ROUGE-L Precision |   ROUGE-L Recall |   Token-Level Accuracy |
|----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------|
|  0 |          0.732179 |      2.07961 | 0.277023 |     0.394495 |            0.365138 |           0.4391 |               0.256147 |
+----+-------------------+--------------+----------+--------------+---------------------+------------------+------------------------+
Metrics: {'Validation Loss': 0.7321786979834238, 'Perplexity': 2.079606510106488, 'BLEU': 0.2770234858450497, 'ROUGE-L F1': 0.39449547742230673, 'ROUGE-L Precision': 0.3651382942000791, 'ROUGE-L Recall': 0.439100135975136, 'Token-Level Accuracy': 0.25614702294148145}


In [None]:
wandb.finish()

In [None]:
# free up unused GPU
import gc
import torch

gc.collect()
torch.cuda.empty_cache()