# AthNLP 2024 Retrieval Augmentation

# Part 1 - Prompting
Get access token, load pre-trained quantized LlaMA 3.1, use it e.g. in PubMedQA to answer questions (small portion of the training set). Try again using [chain-of-thought prompting](https://arxiv.org/abs/2201.11903), using [self-explanations](https://aclanthology.org/2024.findings-acl.19/), [faithful or not](https://arxiv.org/pdf/2407.14487), each student has to think his/her own alternatives. Evaluate on a benchmark (we will give accuracy / students need to implement [keystroke reduction](https://arxiv.org/pdf/2006.12040)). Try again using [self-consistency](https://arxiv.org/pdf/2203.11171) instead of greedy decoding e.g. based on ensembles of the different attempts (prompts). Give the same prompts to ChatGPT for the first 10 examples and compare the results using both evaluation metrics.

To continue with this lab you need to create a [HF access token](https://huggingface.co/docs/hub/en/security-tokens).

In [None]:
import warnings
warnings.filterwarnings("ignore")

import os
import torch
import random
import argparse
import numpy as np
from huggingface_hub import login
from tqdm import tqdm
import evaluate

# Generate a Hugging Face access token and install LlaMA 3.1
from huggingface_hub import login
os.environ["HF_ACCESS_TOKEN"] = "WRITE_YOUR_ACCESS_TOKEN_HERE]"
login(token=os.environ["HF_ACCESS_TOKEN"], add_to_git_credential=True)

llm = LLM(llm_name="meta-llama/Meta-Llama-3.1-8B-Instruct", quantized=True)
toy = llm.answer("When was superconductivity discovered?", num_beams=1)

In [None]:
# Load the dataset (PubMedQA)
pubmedqa_train_data = DataLoaderPQAInferenceWithAugmentations(boundaries=[0, 990])
pubmedqa_test_data = DataLoaderPQAInferenceWithAugmentations(boundaries=[990, 1000])

# Create appropriate (torch) Data Loaders
train_loader = torch.utils.data.DataLoader(pubmedqa_train_data, batch_size=1, shuffle=False)
test_loader = torch.utils.data.DataLoader(pubmedqa_test_data, batch_size=1, shuffle=False)

save_dir = "./results"

In [None]:
# Prompt it to answer the questions (without extra information in the prompt)
predicted_llm = {}
for i, (question, options, concepts, answers, qid) in tqdm(enumerate(test_loader, 0)):
  if i == 0:
    answer = llm.answer(question=question[0], num_beams=1)
    predicted_llm[qid[0]] = answer
write_dict(predicted_llm, save_dir, "answers_dict_llm_demo.txt")

predicted_llm[list(predicted_llm.keys())[0]]

In [None]:
# Prompt it to answer the questions (asking for a structured output)
# Structured output is useful. Otherwise we won't easily detect the position of the exact answer and of the explanation.
medrag_cot = MedRAG(llm_name="meta-llama/Meta-Llama-3.1-8B-Instruct", rag=False, retriever_name="MedCPT",
                    corpus_name="Textbooks", quantized=True)

predicted_cot1 = {}
for i, (question, options, concepts, answers, qid) in tqdm(enumerate(test_loader, 0)):
    answer, snippets, scores = medrag_cot.answer(question=question[0], options=options, k=0, num_beams=1)
    predicted_cot1[qid[0]] = answer
write_dict(predicted_cot1, save_dir, "answers_dict_cot1_demo.txt")

predicted_cot1[list(predicted_cot1.keys())[0]]

In [None]:
# Prompt it with CoT
predicted_cot2 = {}

# ...

In [None]:
# Define some helpers for model evaluation
"""
Removes all special characters.
"""
def clean(answer):
    return re.sub(" +", "", re.sub(r'[^A-Za-z.]+', ' ', answer))

"""
Loads the models' answers and explanations.
"""
def eval_prediction(prediction):
    if isinstance(prediction[list(prediction.keys())[0]], str):
        explanations = {
            qid: prediction[qid].split("\"step_by_step_thinking\":")[-1].split("\"answer_choice\":")[0].strip()
            for qid in prediction.keys()}
        answers = {qid: clean(prediction[qid].split("\"answer_choice\":")[-1].strip())[0]
                             for qid in prediction.keys()}
    elif isinstance(prediction[list(prediction.keys())[0]], list):
        explanations = {
            qid: prediction[qid][0].split("\"step_by_step_thinking\":")[-1].split("\"answer_choice\":")[0].strip()
            for qid in prediction.keys()}
        answers = {qid: clean(prediction[qid][0].split("\"answer_choice\":")[-1].strip())[0]
                             for qid in prediction.keys()}
    else:
        raise NotImplementedError

    return explanations, answers

In [None]:
# Prompt it to produce self-explanations
predicted_exp = {}

# ...

In [None]:
# Evaluate your results on a benchmark (accuracy / keystroke reduction)
explanations, answers = eval_prediction(predicted_cot1)

mapping = {'yes': 'A', 'no': 'B', 'maybe': 'C'}
ground_truth = {pubmedqa_test_data.qids[idx]: mapping[pubmedqa_test_data.choices[idx]]
                for idx in range(len(pubmedqa_test_data.qids))}
accuracy = [True if answers[qid] == ground_truth[qid] else False
            for qid in answers.keys()].count(True) / len(pubmedqa_test_data.qids)

keystroke_reduction = {}
# ...

In [None]:
# Implement self-consistency and compare the results
predicted_selfcon = {}

# ...

In [None]:
# Prompt ChatGPT and compare the results
predicted_gpt4o = {}

# ...

# Part 2 - Learning (in context)
Prompt again the LLM using training examples as few-shot demonstrators; which are selected in different ways: **(a)** Retrieve from the training data similar questions using cosine similarity. **(b)** Prompt an LLM (e.g. ChatGPT or LlaMA 3.1) to predict concepts from the tag set provided. Repeat the previous step by retrieving demonstrators based on the tags (based on the assumption that the same tags are included in similar examples). For each of the above scenarios, use the answers to add demonstrators in the context. We will give as an example retrieving similar examples using gold (human-authored) tags, the students have to follow the same methodology (i.e. [MedRAG](https://teddy-xionggz.github.io/benchmark-medical-rag/)) or use simple prompting using cosine similarity and tags predicted by an LLM (as tagger).
- Evaluate again using accuracy (provided) and [keystroke reduction](https://arxiv.org/pdf/2006.12040) (implemented above) and compare to vanilla prompting.
- Visualise performance against the number of shots (amount of training examples included in the prompt as demonstrators). Profile the time needed and think of possible efficiency tricks (e.g: how efficient is cosine in terms of matrix multipications in the accelerator and how can we improve it?).

In [None]:
# Retrieve from the training data similar questions using cosine similarity.
neighbors = {}
for i, (question, options, concepts, answers, qid) in tqdm(enumerate(test_loader, 0)):
  neighbors[qid] = {}
  for j, (question_neig, options_neig, concepts_neig, answers_neig, qid_neig) in enumerate(train_loader, 0):
    sim = 0
    # ...

In [None]:
# Prompt an LLM to predict concepts from the tag set.
concepts_unique = []
for instance in pubmedqa_train_data.concepts + pubmedqa_test_data.concepts:
    concepts_unique = concepts_unique + [concept for concept in instance]
concepts_unique = list(set(concepts_unique))

# ...

In [None]:
# Repeat the previous step by retrieving demonstrators based on the tags.
neighbors = {}
for i, (question, options, concepts, answers, qid) in tqdm(enumerate(test_loader, 0)):
  neighbors[qid[0]] = []
  for j, (question_neig, options_neig, concepts_neig, answers_neig, qid_neig) in enumerate(train_loader, 0):
    if bool(set([c[0] for c in concepts]) & set([c[0] for c in concepts])):
        neighbors[qid[0]].append(qid_neig[0])
    neighbors[qid[0]] = list(set(neighbors[qid[0]]))

neighbors[list(neighbors.keys())[0]]

In [None]:
# For each of the above scenarios, use the answers to add demonstrators in the context.
# We demonstrate retrieving similar examples using gold (human-authored) tags,
# you have to follow the same methodology using cosine similarity and tags
# predicted by an LLM (as tagger).
num_shots = 3
predicted_fs = {}
for i, (question, options, concepts, answers, qid) in tqdm(enumerate(test_loader, 0)):
   shots_selected = random.sample(neighbors[qid[0]], num_shots)
   query = ""
   for shot in shots_selected:
       query += (pubmedqa_train_data.data.data[shot]['QUESTION'] + " {\"step_by_step_thinking\": \"" +
                 pubmedqa_train_data.data.data[shot]['LONG_ANSWER'] + "\"answer_choice\": \"" +
                 mapping[pubmedqa_train_data.data.data[shot]['final_decision']] + "\"}\n \n")
   query += question[0]

   answer, snippets, scores = medrag_cot.answer(question=question[0], options=options, k=0, num_beams=1)
   predicted_fs[qid[0]] = answer

predicted_fs[list(predicted_fs.keys())[0]]

# Use a tagger (e.g. LLM-based) to label DB and test;
# retrieve related training records based on the assigned tags;
# compute Keystroke Reduction as an evaluation metric
# ...

In [None]:
# Evaluate and compare to vanilla prompting
explanations, answers = eval_prediction(predicted_fs)

mapping = {'yes': 'A', 'no': 'B', 'maybe': 'C'}
ground_truth = {pubmedqa_test_data.qids[idx]: mapping[pubmedqa_test_data.choices[idx]]
                for idx in range(len(pubmedqa_test_data.qids))}
accuracy_fs = [True if answers[qid] == ground_truth[qid] else False
               for qid in answers.keys()].count(True) / len(pubmedqa_test_data.qids)

# ...

titles = ["Vanilla prompting", "Few shot (3 random shots from gold tags)"]
metrics = [accuracy, accuracy_fs] # ...

plt.figure(figsize = (10, 5))
plt.bar(titles, metrics, color ='maroon', width = 0.4)
plt.xlabel("Method")
plt.ylabel("Acccuracy")
plt.title("Factuality comparison (in terms of accuracy)")
plt.show()

# ...

In [None]:
# Visualise performance against the number of shots
import matplotlib.pyplot as plt

# ...

In [None]:
# Profile the time needed (efficiency tricks: cosine/dot)
import time

# ...

# Part 3 - Retrieval Augmentation (frozen and with guided decoding)
Setup a FAISS index with dense representations of Medical Textbooks (no labels) and use it to update the prompt with external information. We will provide an implementation example based on [MedRAG toolkit](https://teddy-xionggz.github.io/benchmark-medical-rag/). We also give  answers from the LLM using different amounts of retrieved documents and document collections (PubMed articles) since indexing and retrieval take time. Evaluate again the Retrieval Augmented LLM's responses compared to vanilla and few-shot and visualise performance compared to # retrieved records. Instead of using frozen retrieval augmentation, try again with guided-decoding based on the tags' representations, where the beam decoder is influenced by the tags' vector representations when producing an output sequence as it is described in [DMMCS framework](https://aclanthology.org/2024.findings-acl.444/). Retrieve related training records based on the assigned tags (as in Part 2) using either the gold tags or those generated using the LLM tagger previously developed and apply DMMCS by retrieving records with similar tags. Measure the correlation between answers.

In [None]:
# Setup a DB with dense representations (no labels)
medrag = MedRAG(llm_name="meta-llama/Meta-Llama-3.1-8B-Instruct", rag=True, retriever_name="MedCPT",
                corpus_name="Textbooks", quantized=True)`

In [None]:
# Use to update the prompt with external information
num_shots = 3
predicted_rag = {}
for i, (question, options, concepts, answers, qid) in tqdm(enumerate(test_loader, 0)):
    answer, snippets, scores = medrag.answer(question=question[0], options=options, k=num_shots, num_beams=1)
    predicted_rag[qid[0]] = answer
write_dict(predicted_rag, save_dir, "answers_dict_rag_demo.txt")

predicted_rag[list(predicted_rag.keys())[0]]

In [None]:
# Assess compared to vanilla and few-shot
explanations, answers = eval_prediction(predicted_rag)

mapping = {'yes': 'A', 'no': 'B', 'maybe': 'C'}
ground_truth = {pubmedqa_test_data.qids[idx]: mapping[pubmedqa_test_data.choices[idx]]
                for idx in range(len(pubmedqa_test_data.qids))}
accuracy_rag = [True if answers[qid] == ground_truth[qid] else False
                for qid in answers.keys()].count(True) / len(pubmedqa_test_data.qids)

# ...

titles = ["Vanilla prompting", "Few shot (3 random shots from gold tags)", "RAG (3 random relevant snippets)"]
metrics = [accuracy, accuracy_fs, accuracy_rag] # ...

plt.figure(figsize = (10, 5))
plt.bar(titles, metrics, color ='maroon', width = 0.4)
plt.xlabel("Method")
plt.ylabel("Acccuracy")
plt.title("Factuality comparison (in terms of accuracy)")
plt.show()

# ...

In [None]:
# Visualise performance compared to # retrieved records
import matplotlib.pyplot as plt

# ...

In [None]:
# Apply DMM by retrieving records with similar tags.
!python datagen_dmmcs.py
!python stats_extraction.py --config ../config/stats_extractor_config_pubmedqa.json

dmm_config = {
    "dataset_name": "pubmedqa_20ex",
    "dataset_concepts_mapper": "./pubmedqa_dmmcs_data/mapping.csv",
    "hist_file_path": "./snapshots/artifacts/hist_train.pkl",
    "mmc_sim_file_path": "./snapshots/artifacts/median_max_cos_c.pkl",
    "word_index_path": "./snapshots/artifacts/word_index.pkl",
    "embedding_matrix_path": "./snapshots/artifacts/embedding_matrix.npy",
    "n_gpu": 1,
    "cuda_nr": 0,
    "seed": 42,
    "num_workers_test": 4,
    "dmmcs_params": {
                      "do_dmmcs": True,
                      "alpha": 0.15
                    },
    "generation_params":  {
                          "do_sample": False,
                          "num_beams": 1,
                          "max_length": 8192,
                          "min_length": 5
                          },
    "logging":  {
                "print_on_screen": True
                }
}

# Instantiate InstructBLIP with the provided config file
llama_dmmcs_config = DMMLM(dmm_config, llm_name="meta-llama/Meta-Llama-3.1-8B-Instruct",
                           token=os.environ["HF_ACCESS_TOKEN"])

# Run!
predicted_dmmcs, actuals = llama_dmmcs_config.generate(test_loader, llama_dmmcs_config.config["hist_file_path"],
                                                   llama_dmmcs_config.config["mmc_sim_file_path"])

write_dict(predicted_dmmcs, save_dir, "answers_dict_dmmcs_demo.txt")

predicted_dmmcs[list(predicted_dmmcs.keys())[0]]

In [None]:
# Measure the correlation between answers.
class Correlation():
    def __init__(self, device='cuda:0'):
        super().__init__()
        self.metrics = evaluate.load("bertscore")
        self.device = device

    """
    Evaluate predictions using scores from BERToids.
    """
    def evaluate(self, refs, hyps):
        self.metrics.add_batch(predictions=refs, references=hyps)
        results = self.metrics.compute(model_type='albert-base-v2', num_layers=5, all_layers=False, idf=False,
                                       lang='en', rescale_with_baseline=True, baseline_path=None)
        return np.mean(results["f1"])

# ...

# 4. Concept-based image captioning
Download [IU-XRAY dataset](https://www.kaggle.com/datasets/raddar/chest-xrays-indiana-university) for medical image captioning ([diagnostic captioning](https://arxiv.org/pdf/2101.07299)) and use DMMCS for guided decoding of a Vision Language Model (e.g. IDEFICS-2, IDEFICS-3, OpenFlamingo) based on [our implementation](https://github.com/nlpaueb) of the framework on InstructBLIP. Measure BERTscore and compare LM perplexities.

In [None]:
%cd ..
!git clone https://github.com/nlpaueb/dmmcs.git
% cd dmmcs

# ...