# Evaluating RAG (Actually Reverser-Reverse RAG)

In [1]:
import os
import json
import dspy
import pandas as pd


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from dspy import ColBERTv2

llama = dspy.GROQ(model="llama3-8b-8192", api_key=os.environ["GROQ_API_KEY"])
turbo = dspy.OpenAI(model="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"])
mini = dspy.OpenAI(model="gpt-4o-mini", api_key=os.environ["OPENAI_API_KEY"])

colber_wiki = ColBERTv2(url="http://20.102.90.50:2017/wiki17_abstracts")

dspy.settings.configure(lm=mini, rm=colber_wiki)

In [4]:
trivia_df = pd.read_csv("./input_data/pollution_trivia_gpt4omini.csv")
trivia_df = trivia_df[['question', 'answer']]
trivia_df.head()

Unnamed: 0,question,answer
0,What is the primary organ affected by lead poi...,Brain
1,What is the primary condition caused by lead p...,Avian plumbism
2,What type of pollution is commonly associated ...,Water contamination
3,What is the main environmental concern with hy...,Water contamination
4,What is the name of the famous oil spill that ...,Exxon Valdez


In [5]:
train_set = trivia_df.iloc[:-35].to_dict(orient="records")
test_set = trivia_df.iloc[-35:].to_dict(orient="records")

print(len(train_set), len(test_set))

train_set = [dspy.Example(question=ex["question"], answer=ex["answer"]).with_inputs('question') for ex in train_set]
test_set = [dspy.Example(question=ex["question"], answer=ex["answer"]).with_inputs('question') for ex in test_set]



81 35


In [6]:
from dsp.utils import deduplicate
from pydantic import BaseModel

class GenerateAnswer(dspy.Signature):
    """ Answer question with short factual answer """
    question = dspy.InputField()
    answer = dspy.OutputField(desc="often between 1 and 5 words")
    
class GenrerateSearchQuestion(dspy.Signature):
    """ Write a simple search question that will help answer a complex query """
    question = dspy.InputField()
    query = dspy.OutputField(desc="Name of a concept, entity, event or topic that can be used to search for more information")
    
class RAG(dspy.Module):
    def __init__(self):
        super().__init__()
        self.generate_query = dspy.Predict(GenrerateSearchQuestion)
        self.retrieve = dspy.Retrieve(k=2)
        self.generate_answer = dspy.Predict(GenerateAnswer)
        
    def forward(self, question):
        query = self.generate_query(question=question).query
        context = self.retrieve(query).passages
        answer = self.generate_answer(question=question, context=context).answer
        return dspy.Prediction(
            query=query, 
            context=context, 
            answer=answer
        )

    


In [7]:
class JudgeQA(dspy.Signature):
    """ Given the question, determine if the two answers below mean the same thing """
    question = dspy.InputField()
    answer1 = dspy.InputField()
    answer2 = dspy.InputField()
    is_same = dspy.OutputField(desc="Yes or No")

import random
def t():
    return dict(temperature=0.7 + 0.0001 * random.uniform(-1, 1))

judge = dspy.ChainOfThought(JudgeQA, **t()) # tweaking temperature to avoid caching and can try different llms

def metric(example, pred, trace = None):
    return judge(question=example.question, answer1=example.answer, answer2=pred.answer).is_same.lower() == 'yes'


## How does llama perform on the trivia quiz

In [48]:
from dspy.evaluate import Evaluate

evaluate = Evaluate(
    metric = metric,
    devset = test_set,
    display_progress = True,
    display_table = 10,
    return_outputs=True,
)
eval = evaluate(RAG())
print(eval)

                                                                                

ERROR:dspy.evaluate.evaluate:2024-08-12T14:33:41.184618Z [error    ] Error for example in dev set: 		 Connection error. [dspy.evaluate.evaluate] filename=evaluate.py lineno=180


                                                                                  

ERROR:dspy.evaluate.evaluate:2024-08-12T14:33:44.009086Z [error    ] Error for example in dev set: 		 Connection error. [dspy.evaluate.evaluate] filename=evaluate.py lineno=180


                                                                                  

ERROR:dspy.evaluate.evaluate:2024-08-12T14:33:46.983327Z [error    ] Error for example in dev set: 		 Connection error. [dspy.evaluate.evaluate] filename=evaluate.py lineno=180


                                                                                  

ERROR:dspy.evaluate.evaluate:2024-08-12T14:33:49.942664Z [error    ] Error for example in dev set: 		 Connection error. [dspy.evaluate.evaluate] filename=evaluate.py lineno=180


Average Metric: 18.0 / 32  (56.2):  91%|█████████▏| 32/35 [03:22<00:31, 10.59s/it]

APIConnectionError: Connection error.

## How does gpt perform on the trivia quiz

In [None]:
from dspy.evaluate import Evaluate

evaluate = Evaluate(
    metric = metric,
    devset = test_set,
    display_progress = True,
    display_table = 10,
    return_outputs=True,
)
eval = evaluate(RAG())
print(eval)

### Results; Llama: 40.0, Mini: 42.9


## Evaluating on T5 Model (without fine tuning)

In [53]:
import dspy 
from dspy import ColBERTv2

t5 = dspy.HFModel(model="google-t5/t5-base")
colber_wiki = ColBERTv2(url="http://20.102.90.50:2017/wiki17_abstracts")

dspy.settings.configure(lm=t5, rm=colber_wiki)

Average Metric: 18.0 / 32  (56.2):  91%|█████████▏| 32/35 [1:15:01<07:02, 140.68s/it]




In [54]:
from rich import print as rprint

t5_rag = RAG()
query = "What is the herbicide that endocrine disruptor, according to Tyrone Hayes' research?"
actual_answer = "Atrazine"

answer = t5_rag(question=query)

rprint("Question: ", query)
rprint("Actual Answer: ", actual_answer)
print()
rprint("Input Qesry: ", answer.query)
rprint("Context", answer.context[0])
rprint("Predicted Answer: ", answer.answer)




### Pretty Bad huh

## Fine Tuning T5-base with the outputs of GPT 4o Mini

In [8]:
from dspy.teleprompt import  BootstrapFinetune

tp = BootstrapFinetune(metric=None)

unlabeled_train_data = [dspy.Example(question=x["question"]).with_inputs('question') for x in train_set]
gpt_rag = RAG()

t5_base = dspy.HFModel(model="google-t5/t5-base")
t5_base.device = 'mps'

tf_rag = RAG()
for p in tf_rag.predictors():
    p.lm = t5_base

configurations = {
    "target": 't5-base',
    "epochs": 10,
    "bsize": 8,
    "accumsteps": 2,
    "lr": 5e-5,
}

rag_finetuned = tp.compile(
    student = tf_rag,
    teacher = gpt_rag,
    trainset = unlabeled_train_data,
    **configurations,
)

100%|██████████| 81/81 [00:00<00:00, 110.24it/s]


Bootstrapped 81 full traces after 81 examples in round 0.
all 162
local_cache/compiler\all.0d05e675e6e2c3c0.jsonl


Map: 100%|██████████| 162/162 [00:00<00:00, 5225.78 examples/s]


# examples skipped due to parsing error: 0 / 162


Filter: 100%|██████████| 162/162 [00:00<00:00, 7955.38 examples/s]
Map: 100%|██████████| 162/162 [00:00<00:00, 1814.49 examples/s]
Map: 100%|██████████| 162/162 [00:00<00:00, 7962.84 examples/s]
Map: 100%|██████████| 162/162 [00:00<00:00, 3097.69 examples/s]


Dataset statistics: {'max_source_length': 33, 'max_target_length': 40}
Keys of tokenized dataset: ['prompt', 'completion', 'input_ids', 'attention_mask', 'labels']
Finetuning dataset: DatasetDict({
    train: Dataset({
        features: ['prompt', 'completion', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 145
    })
    test: Dataset({
        features: ['prompt', 'completion', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 17
    })
})


 10%|█         | 9/90 [06:54<54:41, 40.51s/it]  
  0%|          | 0/3 [00:00<?, ?it/s]
 67%|██████▋   | 2/3 [00:07<00:03,  3.98s/it]
                                              
                                             

{'eval_loss': 2.9240455627441406, 'eval_rouge1': 21.5216, 'eval_rouge2': 5.6863, 'eval_rougeL': 20.3792, 'eval_rougeLsum': 20.2869, 'eval_gen_len': 8.882352941176471, 'eval_runtime': 18.0397, 'eval_samples_per_second': 0.942, 'eval_steps_per_second': 0.166, 'epoch': 0.95}


 10%|█         | 9/90 [07:15<54:41, 40.51s/it]
100%|██████████| 3/3 [00:10<00:00,  2.79s/it]
 21%|██        | 19/90 [13:37<32:27, 27.44s/it]  
  0%|          | 0/3 [00:00<?, ?it/s]
 67%|██████▋   | 2/3 [00:05<00:02,  2.89s/it]
                                               
                                             

{'eval_loss': 2.6114232540130615, 'eval_rouge1': 21.9608, 'eval_rouge2': 7.451, 'eval_rougeL': 21.4146, 'eval_rougeLsum': 21.3809, 'eval_gen_len': 6.294117647058823, 'eval_runtime': 14.8666, 'eval_samples_per_second': 1.144, 'eval_steps_per_second': 0.202, 'epoch': 2.0}


 21%|██        | 19/90 [13:52<32:27, 27.44s/it]
100%|██████████| 3/3 [00:07<00:00,  2.08s/it]
 31%|███       | 28/90 [18:35<33:02, 31.98s/it]
  0%|          | 0/3 [00:00<?, ?it/s]
 67%|██████▋   | 2/3 [00:06<00:03,  3.35s/it]
                                               
                                             

{'eval_loss': 2.2309906482696533, 'eval_rouge1': 22.9456, 'eval_rouge2': 7.451, 'eval_rougeL': 22.0677, 'eval_rougeLsum': 22.0477, 'eval_gen_len': 7.411764705882353, 'eval_runtime': 16.1149, 'eval_samples_per_second': 1.055, 'eval_steps_per_second': 0.186, 'epoch': 2.95}


 31%|███       | 28/90 [18:54<33:02, 31.98s/it]
100%|██████████| 3/3 [00:08<00:00,  2.38s/it]
 42%|████▏     | 38/90 [22:48<18:57, 21.88s/it]
  0%|          | 0/3 [00:00<?, ?it/s]
 67%|██████▋   | 2/3 [00:11<00:05,  5.56s/it]
                                               
                                             

{'eval_loss': 1.9732593297958374, 'eval_rouge1': 24.566, 'eval_rouge2': 5.6863, 'eval_rougeL': 23.7346, 'eval_rougeLsum': 23.6898, 'eval_gen_len': 6.470588235294118, 'eval_runtime': 21.335, 'eval_samples_per_second': 0.797, 'eval_steps_per_second': 0.141, 'epoch': 4.0}


 42%|████▏     | 38/90 [23:10<18:57, 21.88s/it]
100%|██████████| 3/3 [00:13<00:00,  3.80s/it]
 52%|█████▏    | 47/90 [29:55<27:25, 38.27s/it]
  0%|          | 0/3 [00:00<?, ?it/s]
 67%|██████▋   | 2/3 [00:06<00:03,  3.39s/it]
                                               
                                             

{'eval_loss': 1.826461911201477, 'eval_rouge1': 20.9239, 'eval_rouge2': 5.8824, 'eval_rougeL': 20.0804, 'eval_rougeLsum': 20.162, 'eval_gen_len': 8.058823529411764, 'eval_runtime': 20.556, 'eval_samples_per_second': 0.827, 'eval_steps_per_second': 0.146, 'epoch': 4.95}


 52%|█████▏    | 47/90 [30:20<27:25, 38.27s/it]
100%|██████████| 3/3 [00:08<00:00,  2.41s/it]
 63%|██████▎   | 57/90 [38:11<24:27, 44.47s/it]
  0%|          | 0/3 [00:00<?, ?it/s]
 67%|██████▋   | 2/3 [00:05<00:02,  2.78s/it]
                                               
                                             

{'eval_loss': 1.8260215520858765, 'eval_rouge1': 20.5084, 'eval_rouge2': 5.8824, 'eval_rougeL': 19.6882, 'eval_rougeLsum': 19.7537, 'eval_gen_len': 8.235294117647058, 'eval_runtime': 13.7829, 'eval_samples_per_second': 1.233, 'eval_steps_per_second': 0.218, 'epoch': 6.0}


 63%|██████▎   | 57/90 [38:25<24:27, 44.47s/it]
100%|██████████| 3/3 [00:07<00:00,  2.04s/it]
 73%|███████▎  | 66/90 [41:40<07:21, 18.41s/it]
  0%|          | 0/3 [00:00<?, ?it/s]
 67%|██████▋   | 2/3 [00:05<00:02,  2.56s/it]
                                               
                                             

{'eval_loss': 1.8166471719741821, 'eval_rouge1': 20.5084, 'eval_rouge2': 5.8824, 'eval_rougeL': 19.6882, 'eval_rougeLsum': 19.7537, 'eval_gen_len': 8.235294117647058, 'eval_runtime': 12.1044, 'eval_samples_per_second': 1.404, 'eval_steps_per_second': 0.248, 'epoch': 6.95}


 73%|███████▎  | 66/90 [41:54<07:21, 18.41s/it]
100%|██████████| 3/3 [00:06<00:00,  1.82s/it]
 84%|████████▍ | 76/90 [45:26<03:50, 16.45s/it]
  0%|          | 0/3 [00:00<?, ?it/s]
 67%|██████▋   | 2/3 [00:05<00:02,  2.63s/it]
                                               
                                             

{'eval_loss': 1.7828607559204102, 'eval_rouge1': 23.8903, 'eval_rouge2': 11.3971, 'eval_rougeL': 23.1472, 'eval_rougeLsum': 22.8887, 'eval_gen_len': 8.176470588235293, 'eval_runtime': 12.3948, 'eval_samples_per_second': 1.372, 'eval_steps_per_second': 0.242, 'epoch': 8.0}


 84%|████████▍ | 76/90 [45:39<03:50, 16.45s/it]
100%|██████████| 3/3 [00:06<00:00,  1.83s/it]
 94%|█████████▍| 85/90 [51:03<02:49, 34.00s/it]
  0%|          | 0/3 [00:00<?, ?it/s]
 67%|██████▋   | 2/3 [00:08<00:04,  4.05s/it]
                                               
                                             

{'eval_loss': 1.7652031183242798, 'eval_rouge1': 25.8964, 'eval_rouge2': 11.3971, 'eval_rougeL': 25.2991, 'eval_rougeLsum': 24.7342, 'eval_gen_len': 8.117647058823529, 'eval_runtime': 19.197, 'eval_samples_per_second': 0.886, 'eval_steps_per_second': 0.156, 'epoch': 8.95}


 94%|█████████▍| 85/90 [51:25<02:49, 34.00s/it]
100%|██████████| 3/3 [00:12<00:00,  2.82s/it]
100%|██████████| 90/90 [55:07<00:00, 36.27s/it]
  0%|          | 0/3 [00:00<?, ?it/s]
 67%|██████▋   | 2/3 [00:05<00:02,  2.84s/it]
                                               
                                             

{'eval_loss': 1.7605310678482056, 'eval_rouge1': 25.8964, 'eval_rouge2': 11.3971, 'eval_rougeL': 25.2991, 'eval_rougeLsum': 24.7342, 'eval_gen_len': 8.117647058823529, 'eval_runtime': 12.9989, 'eval_samples_per_second': 1.308, 'eval_steps_per_second': 0.231, 'epoch': 9.47}


100%|██████████| 90/90 [55:31<00:00, 36.27s/it]
100%|██████████| 3/3 [00:06<00:00,  1.99s/it]
                                               

{'train_runtime': 3341.4886, 'train_samples_per_second': 0.434, 'train_steps_per_second': 0.027, 'train_loss': 1.481752692328559, 'epoch': 9.47}


100%|██████████| 90/90 [55:41<00:00, 37.13s/it]


Best checkpoint of model: ../finetuning_ckpts\V8E8WA6VTUQTT.all\checkpoint-90
#> Best checkpoint path: ../finetuning_ckpts\V8E8WA6VTUQTT.all\checkpoint-90 for all
Assigning the LM of predictor all.
Assigning the LM of predictor all.


In [9]:
from rich import print as rprint

query = "What is the herbicide that endocrine disruptor, according to Tyrone Hayes' research?"
actual_answer = "Atrazine"

answer = rag_finetuned(question=query)

rprint("Question: ", query)
rprint("Actual Answer: ", actual_answer)
print()
rprint("Input Qesry: ", answer.query)
rprint("Context", answer.context[0])
rprint("Predicted Answer: ", answer.answer)




### Not that Dumb or is it???