In [2]:
from global_utils import CFG, load_pkl, write_to_pkl, load_eval_llm
from datasets_manipulations import load_datasets

from langchain_community.chat_message_histories import SQLChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.llms import HuggingFaceEndpoint, HuggingFacePipeline
from langchain_core.messages import HumanMessage, BaseMessage
from langchain.evaluation.qa import QAEvalChain

from transformers import AutoTokenizer, pipeline
from huggingface_hub import login

from datasets import Dataset
import accelerate
import warnings
import sqlite3
import pickle
import shutil
import torch
import re
import os

login(token=CFG.credentials['llama3.2'])
warnings.filterwarnings('ignore')

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to /root/.cache/huggingface/token
Login successful


# Load Datasets

In [3]:
datasets_to_load = CFG.supported_datasets # CSQA | GSM8K | SQuAD_v1 | SQuAD_v2 | HotpotQA
qa_lists = load_datasets(datasets_to_load)

all_questions = {}
all_gold_answers = {}
all_examples = {}
all_configs = {}

for key in datasets_to_load:
    qa = qa_lists[key]
    questions = [entry['question'] for entry in qa[:CFG.n]]
    gold_answers = [entry['correct_answer'] for entry in qa[:CFG.n]]
    examples = [{"question": q} for q in questions]
    configs = [{"configurable": {"session_id": f"{i+1}"}} for i in range(len(examples))]
    
    all_questions[key] = questions
    all_gold_answers[key] = gold_answers
    all_examples[key] = examples
    all_configs[key] = configs

README.md:   0%|          | 0.00/7.94k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

# Define LLMs

## Llama 3.2

In [4]:
tokenizer=AutoTokenizer.from_pretrained(CFG.model)

# Check if pad_token_id is missing, and set it to eos_token_id if needed
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
    
pl = pipeline(
    "text-generation",
    model=CFG.model,
    tokenizer=tokenizer,
    return_full_text=False,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map="auto",
    no_repeat_ngram_size=3,
    max_new_tokens=150,
    do_sample=False,
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id,
    repetition_penalty=1.1,
    )

llm = HuggingFacePipeline(pipeline=pl)

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/878 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

### Define DB Functions

In [5]:
def get_session_history(session_id):
    return SQLChatMessageHistory(session_id, "sqlite:///memory.db")

def update_db(human_messages, ai_messages):
    # Check if a memory database already exists.
    file_path = f'/kaggle/input/filtered-data-before-doubt/{key}/memory.db'
    if os.path.exists(file_path):
        # Load memory.db to working folder
        shutil.copy(file_path, '/kaggle/working/')
        delete_except_first_two()
    else:
        for session_id, question in enumerate(human_messages):
            db = get_session_history(f'{session_id+1}')
            db.add_messages([BaseMessage(content=question, type='human'),
                             BaseMessage(content=ai_messages[session_id]['text'], type='ai')])
    
def delete_except_first_two():
    # Connect to the SQLite database
    conn = sqlite3.connect('memory.db')
    cursor = conn.cursor()
    
    # Step 1: Identify the message ids to delete (rank > 2 per session)
    cursor.execute("""
        WITH ranked_messages AS (
          SELECT
            id,
            ROW_NUMBER() OVER (PARTITION BY session_id ORDER BY id ASC) AS rn
          FROM message_store
        )
        SELECT id
        FROM ranked_messages
        WHERE rn > 2;
    """)
    
    ids_to_delete = cursor.fetchall()
    
    if ids_to_delete:
        # Step 2: Execute the DELETE statement for all ids except the first two
        cursor.executemany("""
            DELETE FROM message_store
            WHERE id = ?;
        """, [(row[0],) for row in ids_to_delete])
        
        conn.commit()
        print(f"Deleted {cursor.rowcount} messages.")
    else:
        print("No messages to delete.")
    
    # Close the connection
    conn.close()

### Define Chain

In [6]:
def get_llm_chain(key):
    prefix = CFG.prefixes_map[key]
    assert prefix != None, CFG.error_messages['prefix'].format(key=key)
    
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", prefix),
            MessagesPlaceholder(variable_name="history"),
            ("human", "{question}"),
        ]
    )

    runnable = prompt | llm

    chain = RunnableWithMessageHistory(
        runnable,
        get_session_history,
        input_messages_key="question",
        history_messages_key="history",
    )
    
    return chain

# LLM Intergration

## Get LLM Answer

In [7]:
def get_answer(llm, questions, configs):
    predictions = llm.batch(
        questions,
        config=configs,
    )
    return [{'text': pred} for pred in predictions]

## Conversations Before Doubt     

In [8]:
# Conversations Before Doubt     
def get_conversations_before(key, chain, questions, examples, configs):
    file_path = f'/kaggle/input/conversations-before/conversations_before_{key}.pkl'
    if not os.path.exists(file_path):
        conversations_before = get_answer(chain, examples, configs)
    else:
        # Loading conversations_before
        conversations_before = load_pkl(file_path)
        if not os.path.exists('/kaggle/working/memory.db'):
            update_db(questions, conversations_before)
    return conversations_before

## Extract Only Correct Answers  

### Evaluate conversations before  

In [9]:
def evaluate(key, questions, gold_answers):
    # Initialize QAEvalChain
    qa_eval_chain = load_eval_llm()

    # Prepare examples (questions with gold answers)
    if key == 'GSM8K':
        examples_test = [ {"question": q, "answer": r.split('#### ')[-1]} for q, r in zip(questions, gold_answers)]
    else:
        examples_test = [ {"question": q, "answer": r} for q, r in zip(questions, gold_answers)]

    # Convert to Datasets objects to improve efficiency
    examples_test = Dataset.from_list(examples_test)
    conversations_before_test = Dataset.from_list(conversations_before)

    # Evaluate the model-generated answers by passing 'predictions' separately
    eval_results = qa_eval_chain.evaluate(examples=examples_test,
                                          predictions=conversations_before_test,
                                          question_key="question",
                                          prediction_key="text")

    return eval_results

### Filter Incorrect Responses

In [10]:
def filter_data(key, conversations_before, eval_results, questions, gold_answers, configs):
    
    file_path = f'/kaggle/input/filtered-data-before-doubt/{key}'
    filtered_configs = []
    if not os.path.exists(file_path):
        filtered_conversations_before = []
        filtered_questions = []
        filtered_gold_answers = []
        for conv, res, q, a, conf in zip(conversations_before, eval_results, questions, gold_answers, configs):
            temp = res['results'].lower()
            if 'correct' in temp and 'incorrect' not in temp:
                filtered_conversations_before.append(conv)
                filtered_questions.append(q)
                filtered_gold_answers.append(a)
                filtered_configs.append(conf)

    else:
        # Loading filtered conversations_before
        filtered_conversations_before = load_pkl(f'/kaggle/input/filtered-data-before-doubt/{key}/filtered_conversations_before_{key}.pkl')
        # Loading filtered questions
        filtered_questions = load_pkl(f'/kaggle/input/filtered-data-before-doubt/{key}/filtered_questions_{key}.pkl')
        # Loading filtered gold answers
        filtered_gold_answers = load_pkl(f'/kaggle/input/filtered-data-before-doubt/{key}/filtered_gold_answers_{key}.pkl')

        filtered_questions_set = set(filtered_questions)
        for session_id in range(1, CFG.n+1):
            q = (get_session_history(f'{session_id}').get_messages()[0]).content
            if q in filtered_questions_set:
                filtered_configs.append({"configurable": {"session_id": f"{session_id}"}})

    # build filtered examples
    filtered_examples = [{"question": q} for q in filtered_questions]
    return filtered_conversations_before, filtered_questions, filtered_gold_answers, filtered_configs, filtered_examples

## Conversations After Doubt

In [27]:
def get_conversation_after_doubt(llm, configs, experiment, questions, conversations_before=None):
    
    def update_history(_llm, _questions, _configs, history):
        qs = Dataset.from_list(_questions)
        preds = get_answer(_llm, qs, _configs)
        for i, pred in enumerate(preds):
            history[i].append(pred['text'])
            
    if conversations_before is None:
        history = [[] for _ in range(len(questions))] # idx i: history of question i
        update_history(llm, questions, configs, history)
    else:
        history = [[ans['text']] for ans in conversations_before]
    
    for idx, induced_doubt in enumerate(experiment):
        print(f"Generateing answers for induced doubt question {idx+1}/{len(experiment)}")
        induced_doubts = []
        for hist in history:
            hist.append(induced_doubt)
            induced_doubts.append({"question": induced_doubt})
        update_history(llm, induced_doubts, configs, history)
    return history

## Main

In [None]:
all_conversations = {}

for key in datasets_to_load:
    if key == 'SQuAD_v1':
        continue
    print(key)
    questions = all_questions[key]
    gold_answers = all_gold_answers[key]
    examples = all_examples[key]
    configs = all_configs[key]
    
    # Define chain
    chain = get_llm_chain(key)
    
    # Get conversations before and memory database
    conversations_before = get_conversations_before(key, chain, questions, examples, configs)
    
    # Evaluate results and receive filtering initial incorrect responses
    eval_results = None
    if not os.path.exists(f'/kaggle/input/filtered-data-before-doubt/{key}'):
        eval_results = evaluate(key, questions, gold_answers)
    
    # Filter incorrect responses 
    data = filter_data(key, conversations_before, eval_results, questions, gold_answers, configs)
    filtered_conversations_before, filtered_questions, filtered_gold_answers, filtered_configs, filtered_examples = data
    
    # Sanity check 
    print('Sanity Check - should print the same lengths:')
    print(len(filtered_questions), len(filtered_gold_answers), len(filtered_configs))
    
    
    # Get conversations after doubt
    file_path = f'/kaggle/input/conversations-after/conversations_after_{key}.pkl'
    if os.path.exists(file_path):
        conversations_after = []
        for idx, exp in enumerate(CFG.experiments):
            print(f'Experiment {idx+1}/{len(CFG.experiments)}')
            if key == 'GSM8K':
                filtered_conversations_before = filtered_conversations_before[:220]
            conversations_after.append(get_conversation_after_doubt(chain, filtered_configs, exp, filtered_questions, filtered_conversations_before))
            # delete experiment history from all sessions expect for the main question and first response
            delete_except_first_two() 
    else:
        # Loading the data from a pickle file
        conversations_after = load_pkl(file_path)
    
    # Delete memory database
    !rm -rf memory.db
    
    all_conversations[key] = [filtered_conversations_before, conversations_after, filtered_questions, filtered_gold_answers]

### Save Data For Evaluation

In [30]:
for key in datasets_to_load:
    if key == 'SQuAD_v1':
        continue
    write_to_pkl(f'filtered_conversations_before_{key}', all_conversations[key][0])
    write_to_pkl(f'conversations_after_{key}', all_conversations[key][1])
    write_to_pkl(f'filtered_questions_{key}', all_conversations[key][2])
    write_to_pkl(f'filtered_gold_answers_{key}', all_conversations[key][3])
print('Done')

Done
