In [2]:
import pandas as pd
import ollama
import numpy as np
from rank_bm25 import BM25Okapi

In [3]:
# initialize embedding and language models
embedding_model = 'snowflake-arctic-embed'
language_model = 'llama3.1:8b'

In [4]:
# load dataset
df = pd.read_excel('../combined_slang_data_multi.xlsx')

# create a string per unique word by combining all rows for that word
combined_rows_as_strings_dict = {}
for i, row in df.iterrows():
    # append or create string entry for each word
    word = row['slang']
    row_string = "\n".join(f"{k}: {v}" for k, v in row.items())
    combined_rows_as_strings_dict[word] = combined_rows_as_strings_dict.get(word, "") + "\n" + row_string
combined_rows_as_strings = list(combined_rows_as_strings_dict.values())

VECTOR_DB = [] # a list of (text, embedding) tuples

def add_text_to_db(text):
    embedding = ollama.embed(model=embedding_model, input=text)['embeddings'][0]
    VECTOR_DB.append((text, embedding))

for row_text in combined_rows_as_strings:
    add_text_to_db(row_text)

In [5]:
df_eval = pd.read_csv('eval_reverse.csv')

In [6]:
def tokenize(text: str):
    # simple tokenizer; you can make this smarter if you want
    return text.lower().split()

# texts in the same order as VECTOR_DB
doc_texts = [text for text, _ in VECTOR_DB]
tokenized_docs = [tokenize(t) for t in doc_texts]

bm25 = BM25Okapi(tokenized_docs)

In [7]:
EMB_MATRIX = np.vstack([np.asarray(emb, dtype=np.float32) for _, emb in VECTOR_DB])
EMB_NORMS = np.linalg.norm(EMB_MATRIX, axis=1) + 1e-9
DOC_TEXTS = [text for text, _ in VECTOR_DB]  # keep alignment

def hybrid_search(query_text, top_k=10, alpha=0.4):
    q_embedding = np.asarray(
        ollama.embed(model=embedding_model, input=query_text)['embeddings'][0],
        dtype=np.float32
    )
    q_norm = np.linalg.norm(q_embedding) + 1e-9

    # vectorized dense scores
    dense_scores = (EMB_MATRIX @ q_embedding) / (EMB_NORMS * q_norm)

    # BM25
    query_tokens = tokenize(query_text)
    bm25_scores = np.array(bm25.get_scores(query_tokens), dtype=np.float32)

    # min-max normalize
    def minmax_normalize(x):
        x_min, x_max = x.min(), x.max()
        if x_max == x_min:
            return np.zeros_like(x)
        return (x - x_min) / (x_max - x_min)

    dense_norm = minmax_normalize(dense_scores)
    bm25_norm = minmax_normalize(bm25_scores)

    hybrid_scores = alpha * dense_norm + (1 - alpha) * bm25_norm

    # fast top-k with argpartition
    top_idx = np.argpartition(-hybrid_scores, top_k - 1)[:top_k]
    # sort those top-k
    top_idx = top_idx[np.argsort(-hybrid_scores[top_idx])]

    results = []
    for rank, idx in enumerate(top_idx, start=1):
        text = DOC_TEXTS[idx]
        results.append({
            "rank": rank,
            "index": int(idx),
            "hybrid_score": float(hybrid_scores[idx]),
            "dense_score": float(dense_scores[idx]),
            "bm25_score": float(bm25_scores[idx]),
            "text": text,
        })
    return results

In [8]:
def row_to_prompt(row: pd.Series) -> str:
    """
    Build a prompt in the SAME style as your finetune prompts,
    but for the reverse task: sentence with slang -> choose meaning.
    """
    return f"""You will be given a sentence that contains a modern slang term and four possible interpretations.
Choose the option that best explains the meaning of the sentence in standard English.

Sentence: "{row['sentence']}"

Options:
A) {row['option_A']}
B) {row['option_B']}
C) {row['option_C']}
D) {row['option_D']}

Answer with just the letter."""

def build_rag_prompt(pre_rag_prompt, context):
    prompt = f"\nContext:\n{context}\n{pre_rag_prompt}"
    return prompt

def build_rag_context(prompt, top_k=5):
    results = hybrid_search(prompt, top_k=top_k)
    context = "\n".join([r['text'] for r in results])
    return context

def evaluate_rag_model(df_eval, top_k=5):
    prompts = []
    responses = []
    num_correct = 0
    num_questions = len(df_eval)
    for i in range(len(df_eval)):
        row = df_eval.iloc[i]
        pre_rag_prompt = row_to_prompt(row)
        context = build_rag_context(pre_rag_prompt, top_k=top_k)

        rag_prompt = build_rag_prompt(pre_rag_prompt, context)
        response = ollama.chat(
            model=language_model,
            messages=[
                {"role": "system", "content": "You are a helpful assistant that answers multiple choice questions with just the correct letter."},
                {"role": "user", "content": rag_prompt}, # <-- change this line back to `rag_prompt` to use the context
            ]
        )
        answer = response.message.content[0]
        correct_answer = df_eval.iloc[i]['correct']
        if answer.lower() == correct_answer.lower():
            num_correct += 1

        # logging
        if i % 10 == 0:
            print(f"Processed {i + 1} questions, accuracy: {num_correct / (i + 1):.2%}")

        prompts.append(rag_prompt)
        responses.append(answer)

    accuracy = num_correct / num_questions
    return accuracy, prompts, responses

In [9]:
eval_accuracy, prompts, responses = evaluate_rag_model(df_eval, top_k=5)
print(f"Final accuracy: {eval_accuracy:.2%}")

Processed 1 questions, accuracy: 100.00%
Processed 11 questions, accuracy: 100.00%
Processed 21 questions, accuracy: 100.00%
Processed 31 questions, accuracy: 100.00%
Processed 41 questions, accuracy: 100.00%
Processed 51 questions, accuracy: 100.00%
Processed 61 questions, accuracy: 100.00%
Processed 71 questions, accuracy: 100.00%
Processed 81 questions, accuracy: 100.00%
Processed 91 questions, accuracy: 100.00%
Processed 101 questions, accuracy: 100.00%
Processed 111 questions, accuracy: 100.00%
Processed 121 questions, accuracy: 100.00%
Processed 131 questions, accuracy: 100.00%
Processed 141 questions, accuracy: 100.00%
Processed 151 questions, accuracy: 100.00%
Processed 161 questions, accuracy: 100.00%
Processed 171 questions, accuracy: 99.42%
Processed 181 questions, accuracy: 98.90%
Processed 191 questions, accuracy: 97.91%
Processed 201 questions, accuracy: 98.01%
Processed 211 questions, accuracy: 98.10%
Processed 221 questions, accuracy: 97.74%
Processed 231 questions, acc