# Using RAG with Word files



In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### imports

In [2]:
!pip install llama-index
!pip install llama-index-embeddings-huggingface
!pip install peft
!pip install auto-gptq
!pip install optimum
!pip install bitsandbytes

Collecting llama-index
  Downloading llama_index-0.11.23-py3-none-any.whl.metadata (11 kB)
Collecting llama-index-agent-openai<0.4.0,>=0.3.4 (from llama-index)
  Downloading llama_index_agent_openai-0.3.4-py3-none-any.whl.metadata (728 bytes)
Collecting llama-index-cli<0.4.0,>=0.3.1 (from llama-index)
  Downloading llama_index_cli-0.3.1-py3-none-any.whl.metadata (1.5 kB)
Collecting llama-index-core<0.12.0,>=0.11.23 (from llama-index)
  Downloading llama_index_core-0.11.23-py3-none-any.whl.metadata (2.5 kB)
Collecting llama-index-embeddings-openai<0.3.0,>=0.2.4 (from llama-index)
  Downloading llama_index_embeddings_openai-0.2.5-py3-none-any.whl.metadata (686 bytes)
Collecting llama-index-indices-managed-llama-cloud>=0.3.0 (from llama-index)
  Downloading llama_index_indices_managed_llama_cloud-0.4.0-py3-none-any.whl.metadata (3.8 kB)
Collecting llama-index-legacy<0.10.0,>=0.9.48 (from llama-index)
  Downloading llama_index_legacy-0.9.48.post4-py3-none-any.whl.metadata (8.5 kB)
Collecti

In [3]:
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex, StorageContext, load_index_from_storage
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.postprocessor import SimilarityPostprocessor

### Define Settings

In [4]:
# import any embedding model on HF hub (https://huggingface.co/spaces/mteb/leaderboard)
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
# Settings.embed_model = HuggingFaceEmbedding(model_name="thenlper/gte-large") # alternative model

Settings.llm = None             #  in this case, the system will only use an embedding model and not a full-fledged LLM
Settings.chunk_size = 1000      #  input text is being divided into chunks of 256 tokens/characters (depending on the implementation)
Settings.chunk_overlap = 250   # there is an overlap of 25 tokens/characters between consecutive chunks. Overlap helps maintain context continuity between chunks

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/94.8k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/133M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

LLM is explicitly disabled. Using MockLLM.


### Read and Store Docs into Vector DB

In [5]:
import os

# check if storage already exists
PERSIST_DIR = "/content/drive/MyDrive/VQA-Final/miscellanous dataset/RAG_Context/Storage"

if not os.path.exists(PERSIST_DIR):
    # load the documents and create the index
    documents = SimpleDirectoryReader("/content/drive/MyDrive/VQA-Final/miscellanous dataset/RAG_Context/documents").load_data()
    index = VectorStoreIndex.from_documents(documents)
    # store it for later
    index.storage_context.persist(persist_dir=PERSIST_DIR)
else:
    # load the existing index
    storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
    index = load_index_from_storage(storage_context)


### Set Up Search Function

In [6]:
# set number of docs to retreive
top_k = 1

# configure retriever
retriever = VectorIndexRetriever(
    index=index,                                # retrives the similar documents from created index
    similarity_top_k=top_k,
)

In [7]:
# assemble query engine
query_engine = RetrieverQueryEngine(
    retriever=retriever,
    node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.6)],
)

### Retrieve Relevant Docs

In [8]:
# query documents
query = "Where is Taleju Temple?"
response = query_engine.query(query)
response

Response(response='Context information is below.\n---------------------\npage_label: 2\nfile_path: /content/drive/MyDrive/VQA-Final/miscellanous dataset/RAG_Context/documents/Taleju Temple-1.pdf\n\nThrough\nthese\ntales,\nthe\ntemple\ncontinues\nto\ninspire\nreverence\nand\nawe,\npreserving\nits\nplace\nas\na\npowerful\ncenter\nof\ndevotion\nin\nthe\nKathmandu\nValley.The\nTaleju\nTemple\nlocated\nin\nBasantapur,\nwithin\nthe\nKathmandu\nDurbar\nSquare,\nis\none\nof\nthe\nmost\nrevered\nand\nhistorically\nsignificant\ntemples\nin\nNepal.\nBuilt\nin\n1564\nby\nKing\nMahendra\nMalla,\nthis\ntemple\nis\ndedicated\nto\nTaleju\nBhawani,\nthe\nroyal\ndeity\nof\nthe\nMalla\ndynasty.\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: Where is Taleju Temple?\nAnswer: ', source_nodes=[NodeWithScore(node=TextNode(id_='c072e1ed-2866-4acc-b2af-461b0fbf5952', embedding=None, metadata={'page_label': '2', 'file_name': 'Taleju Temple-1.pdf', 'file_p

In [9]:
# Reformat response to form a paragraph
context = "Context:\n"
for i in range(top_k):
    # Append each text from the source nodes, stripping excess newlines and spaces
    context += response.source_nodes[i].text.replace("\n", " ").strip() + " "

# Remove extra spaces at the start/end and print the final formatted context
context = context.strip()

# Print the reformatted response as a paragraph
print(context)


Context:
Through these tales, the temple continues to inspire reverence and awe, preserving its place as a powerful center of devotion in the Kathmandu Valley.The Taleju Temple located in Basantapur, within the Kathmandu Durbar Square, is one of the most revered and historically significant temples in Nepal. Built in 1564 by King Mahendra Malla, this temple is dedicated to Taleju Bhawani, the royal deity of the Malla dynasty.


# Retrive the context


In [10]:
import pandas as pd
import numpy as np

In [13]:
df = pd.read_csv('/content/drive/MyDrive/VQA-Final/part B/cleaned_second_data.csv')
df

# If 'Unnamed: 0' column still exists, you can drop it
df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
df

Unnamed: 0,Question,Object,Answer
0,what is a yali ?,yali,a yali is a mythical creature found predominan...
1,what animals is a yali typically composed of ?,yali,a yali is typically depicted as a composite of...
2,what attributes does a yali symbolize ?,yali,"a yali symbolizes attributes like strength, pr..."
3,why are yalis considered unique ?,yali,yalis are unique because they do not adhere to...
4,which cultures have similar mythical creatures...,yali,cultures that have similar mythical creatures ...
...,...,...,...
11973,embody yalis ever depicted as vehicles in anci...,yali,"in some artistic depictions, yalis are express..."
11974,constitute yalis allowed inside homes or templ...,yali,"yalis are primarily outside guardians, symboli..."
11975,what symbolic bid or 'food' were given to yalis ?,yali,yalis were honored with offerings such as flow...
11976,did people weigh yalis as protectors within th...,yali,"yes, yalis are viewed as protectors, often pla..."


In [14]:
df.shape

(11978, 3)

In [15]:
values = df['Object'].unique()
values

array(['yali', 'taleju temple', 'taleju bell', 'swet bhairava',
       'prayer wheel', 'nyatopola temple', 'kala bhairav', 'hanuman idol',
       'hanging pala', 'garuda', 'boudhanath', 'ankhi jhyal'],
      dtype=object)

In [16]:
# check for missing values

missing_values = df.isnull().sum()
missing_values

Unnamed: 0,0
Question,0
Object,0
Answer,0


In [None]:
# drop the missing values

df = df.dropna()

In [17]:
values = df['Object'].value_counts()
values

Unnamed: 0_level_0,count
Object,Unnamed: 1_level_1
boudhanath,1428
hanuman idol,1083
taleju bell,1075
nyatopola temple,1039
kala bhairav,1038
swet bhairava,1012
prayer wheel,1002
yali,997
garuda,969
hanging pala,915


In [18]:
# Placeholder function for retrieving context using RAG
def retrieve_context(question):
    # Here, you would implement your RAG querying logic
    # For example, using the query engine you've set up
    context = ""
    response = query_engine.query(question)  # Replace with actual querying method
    # Check if source_nodes exist in the response and get the number of available nodes
    num_nodes = len(response.source_nodes) if response.source_nodes else 0

    # Iterate only up to the available number of nodes or top_k, whichever is smaller
    for i in range(min(top_k, num_nodes)):
        # Append each text from the source nodes, stripping excess newlines and spaces
        context += response.source_nodes[i].text.replace("\n", " ").strip() + " "

    # Remove extra spaces at the start/end and print the final formatted context
    context = context.strip() # Adjust based on how your response object is structured
    return context

In [19]:
# Add a new column for context
df['Context'] = ''

# Iterate through each question and retrieve the context
for index, row in df.iterrows():
    question = row['Question']
    # Retrieve context for the current question
    context = retrieve_context(question)
    # Update the DataFrame with the retrieved context
    df.at[index, 'Context'] = context

# Save the updated DataFrame to a new CSV file
output_file_path = 'rag_second_dataset'  # Replace with your desired output file path
# Define an escape character (e.g., backslash)
df.to_csv("/content/drive/MyDrive/VQA-Final/miscellanous dataset/RAG_Context/new_Rag.csv", index=False, escapechar='\\')

print("Context added and saved to:", output_file_path)

Context added and saved to: rag_second_dataset


# BART QA with context Training

In [None]:
!pip install wandb
!pip install rouge-score

Collecting rouge-score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge-score
  Building wheel for rouge-score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge-score: filename=rouge_score-0.1.2-py3-none-any.whl size=24935 sha256=b68c116efbef78f5a64a2c5eeeb831d04d5d2c3d659a587bd78cafc6ba931d32
  Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532ff8599a28e9beef17985c9e9c31e541b4
Successfully built rouge-score
Installing collected packages: rouge-score
Successfully installed rouge-score-0.1.2


In [None]:
import torch
from transformers import BartTokenizer, BartForConditionalGeneration, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import math
import wandb
from rouge_score import rouge_scorer
# from dotenv import load_dotenv
import os
from tabulate import tabulate
from nltk.translate.bleu_score import corpus_bleu
import sympy as sp
# load_dotenv()

In [None]:
df = pd.read_csv('/content/drive/MyDrive/VQA-Final/miscellanous dataset/RAG_Context/new_Rag.csv')

# # If 'Unnamed: 0' column still exists, you can drop it
# df = df.loc[:, ~df.columns.str.contains('^Unnamed')]

# 80% -> Training Data, 20% -> Testing Data
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)

# 90% -> Training Data, 10% -> Validation Data
train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42)

In [None]:
train_df

Unnamed: 0,Question,Object,Answer,Context
5503,What were the effects of Vayu's action of taki...,hanuman idol,The depletion of air resulted in significant a...,"Hanuman's father, Vayu, became upset and withd..."
6675,In what way did the Garudas first succeed in c...,garuda,"At one point, the Garudas managed to capture ...",The Garudas at one time caught the nāgas by se...
2161,What are the historical roots of the Kumari-Pu...,taleju temple,"Dating back to the 17th century, the traditio...",Whilst the veneration of a living Kumari in Ne...
2070,How did Taleju Bhawani influence the Malla rul...,taleju temple,She offered them spiritual mentorship and divi...,One popular legend recounts how Taleju Bhawani...
9113,In what way are large prayer wheels different ...,prayer wheel,Large prayer wheels are distinguished by their...,strength and many repetitions. We offer differ...
...,...,...,...,...
3896,How should one maintain their concentration wh...,prayer wheel,"It is recommended to focus body, speech, and ...","On rare occasions, advanced Tantric practition..."
10107,Can you recount the tale of Kala Bhairava’s en...,kala bhairava,"According to legend, Kala Bhairava arose durin...",As the two of them argued over their supremacy...
7331,How did Jadzima's journey progress after she f...,boudhanath,She passed on after successfully completing th...,"They were an extremely poor family, I think. S..."
1346,What are a few well-known types of Yalis ?,yali,Yalis are commonly seen in multiple variation...,It shares similarities with other mythical cre...


In [None]:
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")

def calculate_max_length(column_name):
    df[column_name] = df[column_name].astype(str)
    return df[column_name].apply(lambda x: len(tokenizer.tokenize(x))).max()

max_length_question = calculate_max_length('Question')
max_length_object = calculate_max_length('Object')
max_length_context = calculate_max_length('Context')
max_length_answer = calculate_max_length('Answer')

print(f"Maximum token length in the 'Input' : {max_length_question + max_length_object + max_length_context}")
print(f"Maximum token length in the 'Output' : {max_length_answer}")

Maximum token length in the 'Input' : 180
Maximum token length in the 'Output' : 63


### Vocabulary Check

In [None]:
all_text = ' '.join(df['Question']) + ' ' + ' '.join(df['Answer']) + ' ' + ' '.join(df['Context'])
words = list(all_text.split())

len(words)

1574657

In [None]:
from collections import Counter

word_counts = Counter(all_text.split())
word_counts

In [None]:
unique_words = list(word_counts.keys())
unique_words

In [None]:
print(len(unique_words))

14929


In [None]:
def compare_words_with_bart_vocab(word_list):
    # Tokenize the words in the list
    tokenized_words = [token for word in word_list for token in tokenizer.tokenize(word)]

    # Get BART's vocabulary
    bart_vocab = set(tokenizer.get_vocab().keys())

    # Check which tokens are in or not in BART's vocabulary
    words_in_vocab = {word for word in tokenized_words if word in bart_vocab}
    words_not_in_vocab = {word for word in tokenized_words if word not in bart_vocab}

    return words_in_vocab, words_not_in_vocab

In [None]:
words_in_vocab, words_not_in_vocab = compare_words_with_bart_vocab(unique_words)

print("Words in BART vocabulary:", words_in_vocab)
print("Words not in BART vocabulary:", words_not_in_vocab)

Words not in BART vocabulary: set()


In [None]:
print(f"Words to be added: {len(words_not_in_vocab)}")

Words to be added: 0


In [None]:
def add_new_tokens_to_vocab(new_tokens):
    num_added_tokens = tokenizer.add_tokens(new_tokens)
    model.resize_token_embeddings(len(tokenizer))
    vocab = tokenizer.get_vocab()
    print("Vocabulary size:", len(vocab))

new_tokens = list(words_not_in_vocab)
num_added_tokens = add_new_tokens_to_vocab(new_tokens)

Vocabulary size: 50265


### Training

In [None]:
class QADataset(Dataset):
    '''For Loading the Dataset for Question Answering'''
    def __init__(self, data, tokenizer, question_max_length=200, context_max_length=200, answer_max_length=80):
        self.data = data
        self.tokenizer = tokenizer
        self.question_max_length = question_max_length
        self.context_max_length = context_max_length
        self.answer_max_length = answer_max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Extract the question, object (context), and answer from the dataset
        question = self.data.iloc[idx]['Question']
        object_name = self.data.iloc[idx]['Object']
        answer = self.data.iloc[idx]['Answer']
        context = self.data.iloc[idx]['Context']

        # Combine the object name (context) and question
        combined = 'Object :' + str(object_name) + ' Question: ' + str(question) + ' Context: ' + str(context)            # if object_name is the float

        # Tokenize the combined input (question + context)
        inputs = self.tokenizer(
            combined,
            max_length=self.context_max_length,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )

        targets = self.tokenizer(
            answer,
            max_length=self.answer_max_length,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )

        input_ids = inputs.input_ids.squeeze()
        attention_mask = inputs.attention_mask.squeeze()
        target_ids = targets.input_ids.squeeze()

        return {

            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': target_ids
        }

In [None]:
model

BartForConditionalGeneration(
  (model): BartModel(
    (shared): BartScaledWordEmbedding(50264, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): BartScaledWordEmbedding(50264, 1024, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x BartEncoderLayer(
          (self_attn): BartSdpaAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
    

In [None]:
train_dataset = QADataset(train_df, tokenizer=tokenizer)
val_dataset = QADataset(val_df, tokenizer=tokenizer)
test_dataset = QADataset(test_df, tokenizer=tokenizer)

### Code changed with freezing the encoder layers, gradient accumulation and low learning rate

In [None]:
class VQA_Trainer:
    '''Class for Trainer Setup to Train the BART Model for VQA'''

    def __init__(self, model, train_dataloader, eval_dataloader, device, config, gradient_accumulation_steps=4, learning_rate=1e-5):
        ''' Constructor '''
        self.model = model
        self.train_dataloader = train_dataloader
        self.eval_dataloader = eval_dataloader
        self.device = device
        self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
        self.scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

        # Freeze initial encoder layers to prevent overfitting
        for param in self.model.model.encoder.parameters():
            param.requires_grad = False

        # Optimizer with a lower learning rate
        self.optimizer = AdamW(filter(lambda p: p.requires_grad, self.model.parameters()), lr=learning_rate)

        # Learning rate scheduler (adjust if necessary)
        total_steps = config["epochs"] * (len(train_dataloader) // gradient_accumulation_steps)
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=0, num_training_steps=total_steps)

        # WandB setup
        wandb.init(project=config['project_name'], config=config)
        wandb.watch(self.model, log="all")

        self.gradient_accumulation_steps = gradient_accumulation_steps

    def evaluate(self):
        ''' For Evaluation at the End of Each Epoch '''
        self.model.eval()
        total_loss = 0
        predictions, references, token_level_accuracies = [], [], []

        progress_bar = tqdm(self.eval_dataloader, desc="Evaluating")
        for batch in progress_bar:
            with torch.no_grad():
                inputs = {key: val.to(self.device) for key, val in batch.items()}
                outputs = self.model(**inputs)
                total_loss += outputs.loss.item()

                # Generate predictions and decode
                summary_ids = self.model.generate(inputs['input_ids'], min_length=10, max_length=100, num_beams=4, early_stopping=True)
                decoded_preds = self.tokenizer.batch_decode(summary_ids, skip_special_tokens=True)

                labels = batch['labels']
                labels = torch.where(labels != -100, labels, self.tokenizer.pad_token_id)
                decoded_refs = self.tokenizer.batch_decode(labels, skip_special_tokens=True)

                predictions.extend([pred.split() for pred in decoded_preds])
                references.extend([[ref.split()] for ref in decoded_refs])

                # Token-level accuracy
                for pred, ref in zip(decoded_preds, decoded_refs):
                    pred_tokens, ref_tokens = pred.split(), ref.split()
                    token_accuracy = sum(1 for p, r in zip(pred_tokens, ref_tokens) if p == r) / max(len(ref_tokens), 1)
                    token_level_accuracies.append(token_accuracy)

        avg_loss = total_loss / len(self.eval_dataloader)
        perplexity = math.exp(avg_loss)
        bleu_score = corpus_bleu(references, predictions)

        # ROUGE-L scores
        rouge_l_f1_scores, rouge_l_precision_scores, rouge_l_recall_scores = [], [], []
        for pred, ref in zip(decoded_preds, decoded_refs):
            rouge_l = self.scorer.score(ref, pred)['rougeL']
            rouge_l_f1_scores.append(rouge_l.fmeasure)
            rouge_l_precision_scores.append(rouge_l.precision)
            rouge_l_recall_scores.append(rouge_l.recall)

        avg_token_level_accuracy = np.mean(token_level_accuracies)
        avg_rouge_l_f1 = np.mean(rouge_l_f1_scores)
        avg_rouge_l_precision = np.mean(rouge_l_precision_scores)
        avg_rouge_l_recall = np.mean(rouge_l_recall_scores)

        # Log metrics
        metrics = {
            "Validation Loss": avg_loss,
            "Perplexity": perplexity,
            "BLEU": bleu_score,
            "ROUGE-L F1": avg_rouge_l_f1,
            "ROUGE-L Precision": avg_rouge_l_precision,
            "ROUGE-L Recall": avg_rouge_l_recall,
            "Token-Level Accuracy": avg_token_level_accuracy,
        }
        wandb.log(metrics)
        print(tabulate(pd.DataFrame([metrics]), headers="keys", tablefmt="psql"))

        return metrics

    def train_epoch(self):
        ''' To Train for Single Epoch with Gradient Accumulation '''
        self.model.train()
        total_loss = 0
        for i, batch in enumerate(tqdm(self.train_dataloader, desc="Training")):
            inputs = {key: val.to(self.device) for key, val in batch.items()}
            outputs = self.model(**inputs)
            loss = outputs.loss / self.gradient_accumulation_steps
            loss.backward()
            total_loss += loss.item()

            # Step only after gradient accumulation steps
            if (i + 1) % self.gradient_accumulation_steps == 0:
                self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()

            wandb.log({"train_loss": loss.item()})

        print(f"Total training loss for epoch: {total_loss / len(self.train_dataloader)}")

    def train(self, epochs):
        ''' To Train for N Number of Epochs Passed from User '''
        for epoch in range(epochs):
            print(f"Epoch {epoch + 1}/{epochs}")
            self.train_epoch()
            metrics = self.evaluate()
            print(f"Metrics: {metrics}")

        torch.save(self.model.state_dict(), f"model_final.pth")


In [None]:
api_key = os.getenv('API_KEY')
!wandb login --relogin $api_key

Traceback (most recent call last):
  File "/usr/local/bin/wandb", line 8, in <module>
    sys.exit(cli())
  File "/usr/local/lib/python3.10/dist-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/click/core.py", line 1078, in main
    rv = self.invoke(ctx)
  File "/usr/local/lib/python3.10/dist-packages/click/core.py", line 1688, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/usr/local/lib/python3.10/dist-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/usr/local/lib/python3.10/dist-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/wandb/cli/cli.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/wandb/cli/cli.py", line 253, in login
    wandb.login(
  File "/usr/local/lib/python3.1

In [None]:
config = {
     "batch_size":8,
     "epochs" : 10,
     "model_name": "facebook/bart-large-cnn",
     "project_name": "RAG_Conditional_BART_latest_dataset",
}

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BartForConditionalGeneration.from_pretrained(config['model_name']).to(device)

In [None]:
trainer = VQA_Trainer(model, train_dataloader, val_dataloader,device,config)



VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

In [None]:
trainer.train(config['epochs'])

Epoch 1/10


Training:  19%|█▉        | 191/1016 [01:29<06:28,  2.12it/s]


KeyboardInterrupt: 

In [None]:
wandb.finish()

In [None]:
# free up unused GPU
import gc
import torch

gc.collect()
torch.cuda.empty_cache()