# Text Re-identification Evaluation with Retrieval-Augmented Generation

# Initialization

## Imports

In [1]:
import os, csv, json
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from datetime import datetime
from typing import Optional, List
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd

import torch

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig

from langchain_core.documents import Document as LangchainDocument
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.retrievers import BM25Retriever
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores.utils import DistanceStrategy

from cappr.huggingface.classify import cache_model, predict_proba_examples
from cappr import Example

## Settings

In [2]:
# Input
ID_KEY = "doc_id"
TEXT_KEY = "text"
BK_KEY = "background_knowledge"
USE_SHORT_IDS = False


DATASET_NAME = "wiki553"
CORPUS_FILE_PATH = "data/wiki553/corpora/Wiki553_Corpus.json"
ANONYMIZATIONS_FILE_PATHS = {
    "St.NER3":"data/wiki553/anonymizations/Wiki553_St.NER3.json",
    "St.NER4":"data/wiki553/anonymizations/Wiki553_St.NER4.json",
    "St.NER7":"data/wiki553/anonymizations/Wiki553_St.NER7.json",
    "spaCy":"data/wiki553/anonymizations/Wiki553_spaCy.json",
    "Presidio":"data/wiki553/anonymizations/Wiki553_Presidio.json",        
    "Word2Vec_t=0.5":"data/wiki553/anonymizations/Wiki553_Word2Vec_t=0.5.json",
    "Word2Vec_t=0.25":"data/wiki553/anonymizations/Wiki553_Word2Vec_t=0.25.json",
    "k-anonymity_Random":"data/wiki553/anonymizations/Wiki553_k-anonymity_Random.json",
    "k-anonymity_Greedy":"data/wiki553/anonymizations/Wiki553_k-anonymity_Greedy.json",
    "Manual":"data/wiki553/anonymizations/Wiki553_Manual.json",
    "Student-LLM":"data/wiki553/anonymizations/Wiki553_gpt-4o-2024-09-03_Student.json",
    "Sparks-LLM":"data/wiki553/anonymizations/Wiki553_gpt-4o-2024-09-03_Sparks.json",        
    "MvM-LLM":"data/wiki553/anonymizations/Wiki553_gpt-4o-2024-09-03_MvM.json",
    "Attributes-LLM":"data/wiki553/anonymizations/Wiki553_gpt-4o-2024-09-03_Attributes.json",        
}
BK_FILE_PATH = "data/wiki553/bks/Wiki553_BK=Public.json"

# Retriever
RETRIEVER_USE_DENSE = True # Dense=FAISS | Sparse=BM25
RETRIEVER_NAME = "FAISS" if RETRIEVER_USE_DENSE else "BM25"
RETRIEVER_K = 10
RETRIEVER_CHUNK_SIZE = 128
RETRIEVER_MARKDOWN_SEPARATORS = [
    "\n#{1,6} ",
    "```\n",
    "\n\\*\\*\\*+\n",
    "\n---+\n",
    "\n___+\n",
    "\n\n",
    "\n",
    " ",
    "",
]
RETRIEVER_REMOVE_MASKING_MARKS = True
if RETRIEVER_REMOVE_MASKING_MARKS:
    RETRIEVER_NAME += "_NoMasks"
RETRIEVER_MASKING_MARKS = ["SENSITIVE", "PERSON", "DEM", "LOC",
                 "ORG", "DATETIME", "QUANTITY", "MISC",
                 "NORP", "FAC", "GPE", "PRODUCT", "EVENT",
                 "WORK_OF_ART", "LAW", "LANGUAGE", "DATE",
                 "TIME", "ORDINAL", "CARDINAL", "DATE_TIME", "DATETIME",
                 "NRP", "LOCATION", "ORGANIZATION", "\*\*\*"]
RETRIEVER_EMBEDDING_MODEL_NAME = "thenlper/gte-small" # For dense retriever
RETRIEVER_NAME += f"_{RETRIEVER_EMBEDDING_MODEL_NAME}"

# Reader
READER_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3" #Alternatives: "HuggingFaceH4/zephyr-7b-beta" "google/gemma-3-4b-it" "jet-ai/Jet-Nemotron-2B" "mistralai/Mistral-7B-Instruct-v0.3"
READER_QUANTIZATION_BITS = 8
if READER_QUANTIZATION_BITS == 4:
    READER_QUANTIZATION_CONFIG = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
elif READER_QUANTIZATION_BITS == 8:
    READER_QUANTIZATION_CONFIG = BitsAndBytesConfig(load_in_8bit=True)
else:
    READER_QUANTIZATION_CONFIG = None
READER_BATCH_SIZE = 1
READER_PROMPT_0_USER = [ # Only User role
    {
        "role": "user",
        "content": """Task: Predict the ID of the person corresponding to a document.

Input:
BACKGROUND: A list of people and their details, indexed by ID.
DOCUMENT: A text related to one person from the BACKGROUND.
NAMES: A summary list of the IDs from people in BACKGROUND.

Output:
The ID of the person that best fits the DOCUMENT based on the BACKGROUND. Important: No reasoning, directly return the ID.

----------------

BACKGROUND:{background}


DOCUMENT: {document}


NAMES: {names}

----------------

Based on the input provided, the ID corresponding to the document is:"""
    },
]
READER_PROMPT_1_USER_ASSISTANT = [ # User, Assistant roles
    {
        "role": "user",
        "content": """Task: Predict the ID of the person corresponding to a document.

Input:
BACKGROUND: A list of people and their details, indexed by ID.
DOCUMENT: A text related to one person from the BACKGROUND.
NAMES: A summary list of the IDs from people in BACKGROUND.

Output:
The ID of the person that best fits the DOCUMENT based on the BACKGROUND. Important: No reasoning, directly return the ID.

----------------

BACKGROUND:{background}


DOCUMENT: {document}


NAMES: {names}
"""
    },
    {
        "role": "assistant",
        "content": """Based on the input provided, the ID corresponding to the document is:"""
    },
]
READER_PROMPT_2_SYSTEM_USER_ASSISTANT = [ # System, User, Assistant roles
    {
        "role": "system",
        "content": """Task: Predict the ID of the person corresponding to a document.

Input:
BACKGROUND: A list of people and their details, indexed by ID.
DOCUMENT: A text related to one person from the BACKGROUND.
NAMES: A summary list of the IDs from people in BACKGROUND.

Output:
The ID of the person that best fits the DOCUMENT based on the BACKGROUND. Important: No reasoning, directly return the ID.""",
    },
    {
        "role": "user",
        "content": """BACKGROUND:{background}

----------------

DOCUMENT: {document}

----------------

NAMES: {names}

"""
    },
    {
        "role": "assistant",
        "content": """Based on the input provided, the ID corresponding to the document is:"""
    },
]
READER_PROMPT_3_SYSTEM_USER = [ # System, User roles
    {
        "role": "system",
        "content": """Task: Predict the ID of the person corresponding to a document.

Input:
BACKGROUND: A list of people and their details, indexed by ID.
DOCUMENT: A text related to one person from the BACKGROUND.
NAMES: A summary list of the IDs from people in BACKGROUND.

Output:
The ID of the person that best fits the DOCUMENT based on the BACKGROUND. Important: No reasoning, directly return the ID.""",
    },
    {
        "role": "user",
        "content": """BACKGROUND:{background}

----------------

DOCUMENT: {document}

----------------

NAMES: {names}

----------------

Based on the input provided, the ID corresponding to the document is:"""
    },
]
READER_PROMPT_4_SYSTEM_USER = [
    {
        "role": "system",
        "content": """You are an assistant that identifies which person from a given BACKGROUND corresponds to a DOCUMENT.

Your task:
- Read the BACKGROUND, which lists people and their details by ID.
- Read the DOCUMENT, which relates to exactly one person.
- Read NAMES, which lists all available IDs.
- Decide which ID best matches the DOCUMENT.

Output format:
- Respond with only the matching ID, nothing else (no reasoning, no explanation)."""
    },
    {
        "role": "user",
        "content": """BACKGROUND:
{background}

---

DOCUMENT:
{document}

---

NAMES:
{names}

---

Return only the ID that best matches the DOCUMENT:"""
    },
]
READER_PROMPT_5_USER = [
    {
        "role": "user",
        "content": """Your task is to select the ID that best matches the DOCUMENT based on the BACKGROUND.

Inputs:
- BACKGROUND lists people and their details by ID.
- DOCUMENT is related to exactly one person. It might be pseudo-anonymized.
- IDs lists all available identifiers.

Output:
- Respond with only the matching ID. Do not include explanations or reasoning.

---

BACKGROUND:
{background}

---

DOCUMENT:
{document}

---

IDs:
{names}

---

The ID that best matches the DOCUMENT is:"""
    },
]
READER_PROMPTS_LIST = [READER_PROMPT_0_USER, READER_PROMPT_1_USER_ASSISTANT,
                        READER_PROMPT_2_SYSTEM_USER_ASSISTANT,
                        READER_PROMPT_3_SYSTEM_USER, READER_PROMPT_4_SYSTEM_USER,
                        READER_PROMPT_5_USER]
READER_PROMPT_IDX = 5
READER_SELECTED_PROMPT = READER_PROMPTS_LIST[READER_PROMPT_IDX]
READER_PROMPT_SEPARATOR = "----------------" if READER_PROMPT_IDX < 4 else "---"
READER_REMOVE_ROLE_END_STR = True
READER_ROLE_END_STR = "</s>" if "zephyr" in READER_MODEL_NAME else "[/INST]"
READER_USE_PRIOR_PROBABILITIES = False
READER_USE_CAPPR_CACHE = True
READER_NAME = f"{READER_MODEL_NAME}_Quant={READER_QUANTIZATION_BITS}_Prompt={READER_PROMPT_IDX}{not READER_REMOVE_ROLE_END_STR}"+ \
    f"_CAPPR_Prior={READER_USE_PRIOR_PROBABILITIES}_Cache={READER_USE_CAPPR_CACHE}_ShortIDs={USE_SHORT_IDS}"

# Output
RESULTS_FILEPATH = f"results_{DATASET_NAME}.csv"
RAG_RISK_NAME = f"RAG RISK k={RETRIEVER_K} chunk={RETRIEVER_CHUNK_SIZE} {RETRIEVER_NAME} {READER_NAME}"
print(RAG_RISK_NAME)

RAG RISK k=10 chunk=128 FAISS_NoMasks_thenlper/gte-small mistralai/Mistral-7B-Instruct-v0.3_Quant=8_Prompt=5False_CAPPR_Prior=False_Cache=True_ShortIDs=False


# Dataset loading

In [3]:
def get_masked_text(masked_spans:list, original_text:str) -> str:
    masked_text = ""+original_text
    
    for span in reversed(sorted(masked_spans, key=lambda x:x[0], reverse=False)):
        start_idx = span[0]
        end_idx = span[1]
        if len(span)==3:
            replacement = span[2]
        else: # If there is no replacement, use first masking mark
            replacement = RETRIEVER_MASKING_MARKS[0]
        masked_text = masked_text[:start_idx] + replacement + masked_text[end_idx:]
    
    return masked_text   

In [4]:
data = {}

# Load corpus (ids and texts)
with open(CORPUS_FILE_PATH, "r") as f:
    corpus = json.load(f)
for retrieved in corpus:
    id = retrieved[ID_KEY]
    data[id] = {ID_KEY:id, TEXT_KEY:retrieved[TEXT_KEY]}
del corpus

# Load background knowledge
with open(BK_FILE_PATH, "r") as f:
    bk = json.load(f)
for id, bk_text in bk.items():
    data[id][BK_KEY] = bk_text

# Load anonymizations
for anon_name, anon_file_path in ANONYMIZATIONS_FILE_PATHS.items():
    with open(anon_file_path, "r") as f:
        anon = json.load(f)
    for id, masked_spans in anon.items():
        data[id][anon_name] = get_masked_text(masked_spans, data[id][TEXT_KEY])

if USE_SHORT_IDS:
    # Tranform IDs (full names) to short codes
    new_data = {}
    new_ids = {chr(65+i//26)+chr(65+(i%26)) for i in range(len(data))}
    for new_id, value in zip(new_ids, data.values()):
        value[ID_KEY] = new_id
        new_data[new_id] = value
    data = new_data

#print(data)

In [5]:
# Load from dataframe
df = pd.DataFrame.from_dict(data.values())

ds = []
for idx, row in df.iterrows():
  if row[BK_KEY] is not None and row[BK_KEY].strip() != "":
    retrieved = {"text": row[BK_KEY], "source": row[ID_KEY]}
    ds.append(retrieved)

#print(ds[0])

In [6]:
# Convert into LangchainDocuments
raw_knowledge_base = [
    LangchainDocument(page_content=doc["text"], metadata={"source": doc["source"]})
    for doc in tqdm(ds)
]

#print(len(raw_knowledge_base))
#print(raw_knowledge_base[0])

  0%|          | 0/548 [00:00<?, ?it/s]

# Retriever

## Splitting

In [10]:
def split_documents(
    chunk_size: int,
    knowledge_base: List[LangchainDocument],
    tokenizer_name: Optional[str] = RETRIEVER_EMBEDDING_MODEL_NAME,
) -> List[LangchainDocument]:
    """
    Split documents into chunks of maximum size `chunk_size` tokens and return a list of documents.
    """
    text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
        AutoTokenizer.from_pretrained(tokenizer_name),
        chunk_size=chunk_size,
        chunk_overlap=int(chunk_size / 10),
        add_start_index=True,
        strip_whitespace=True,
        separators=RETRIEVER_MARKDOWN_SEPARATORS,
    )

    docs_processed = []
    for doc in tqdm(knowledge_base):
        docs_processed += text_splitter.split_documents([doc])

    # Remove duplicates
    unique_texts = {}
    docs_processed_unique = []
    for doc in docs_processed:
        if doc.page_content not in unique_texts:
            unique_texts[doc.page_content] = True
            docs_processed_unique.append(doc)

    return docs_processed_unique

docs_processed = split_documents(
    RETRIEVER_CHUNK_SIZE,
    raw_knowledge_base,
    tokenizer_name=RETRIEVER_EMBEDDING_MODEL_NAME,
)

  0%|          | 0/548 [00:00<?, ?it/s]

## Instanciate

In [11]:
if RETRIEVER_USE_DENSE:
    # Embedding model in GPU
    embedding_model = HuggingFaceEmbeddings(
        model_name=RETRIEVER_EMBEDDING_MODEL_NAME,
        model_kwargs={
            "device": "cuda"
        },
        encode_kwargs={
            "normalize_embeddings": True, # Set `True` for cosine similarity
            "batch_size": 32,  # Process embeddings in batches
        },
        multi_process=False
    )

    # Create FAISS index
    retriever = FAISS.from_documents(
        docs_processed,
        embedding_model,
        distance_strategy=DistanceStrategy.COSINE
    )

else:
    # Create BM25 retriever
    retriever = BM25Retriever.from_documents(docs_processed)

## Functions

In [12]:
def retrieval_df(df, retriever, k)->dict:
  retrievals = {}

  for col_name in df.columns:
      if col_name in [BK_KEY, ID_KEY]: # Exclude unnecesary columns
          continue
      retrievals[col_name] = retrieval_col(df, col_name, retriever, k)

  return retrievals

def retrieval_col(df, col_name, retriever, k)->list:
  retrievals = []
  with tqdm(total=len(df), desc=f"Retrievals for {col_name=}") as pbar:
      for idx, row in df.iterrows():
          text = row[col_name]
          if text is None or text.strip() == "":
            retrievals.append(None)
            continue

          retrieved_docs = retrieval_doc(text, retriever, k)
          retrievals.append(retrieved_docs)
          pbar.update(1)

  return retrievals

def retrieval_doc(text, retriever, k, use_masking_marks_removal:bool=RETRIEVER_REMOVE_MASKING_MARKS):
    if use_masking_marks_removal:
        text = remove_masking_marks(text)
    if type(retriever)==BM25Retriever:
        retriever.k = k # Force the proper k
        result = retriever.invoke(text)
    else: # FAISS
        result = retriever.similarity_search(query=text, k=k)
    return result

def remove_masking_marks(original_text:str, masking_marks:list=RETRIEVER_MASKING_MARKS):
    new_text = original_text
    for mark in masking_marks:
        new_text = new_text.replace(mark, "").strip()
        new_text = ' '.join(new_text.split())
    return new_text

# Reader

In [7]:
# Model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    READER_MODEL_NAME,
    config=AutoConfig.from_pretrained(READER_MODEL_NAME),
    quantization_config=READER_QUANTIZATION_CONFIG,
    device_map="auto"
)
original_forward = model.forward
def cached_forward(*args, **kwargs):
    kwargs["use_cache"] = True
    return original_forward(*args, **kwargs)
model.forward = cached_forward

tokenizer = AutoTokenizer.from_pretrained(READER_MODEL_NAME, trust_remote_code=True)

model_and_tokenizer = (model, tokenizer)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [8]:
# Prompt
rag_prompt_template = tokenizer.apply_chat_template(
    READER_SELECTED_PROMPT, tokenize=False, add_generation_prompt=False
)
print(rag_prompt_template)

<s>[INST] Your task is to select the ID that best matches the DOCUMENT based on the BACKGROUND.

Inputs:
- BACKGROUND lists people and their details by ID.
- DOCUMENT is related to exactly one person. It might be pseudo-anonymized.
- IDs lists all available identifiers.

Output:
- Respond with only the matching ID. Do not include explanations or reasoning.

---

BACKGROUND:
{background}

---

DOCUMENT:
{document}

---

IDs:
{names}

---

The ID that best matches the DOCUMENT is:[/INST]


In [9]:
# Testing it does not return NaN
examples = [
    Example(
        prompt="Jodie Foster played",
        completions=["Clarice Starling", "Trinity in The Matrix"],
    ),
    Example(
        prompt="Batman, from Batman: The Animated Series, was played by",
        completions=("Pete Holmes", "Kevin Conroy", "Spongebob!"),
    ),
    Example(
        prompt="Scott Andrew Caan (born August 23, 1976) is an American actor. He currently stars as Detective Danny \"Danno\" Williams in the CBS television series Hawaii Five-0 (2010–present), for which he was nominated for a Golden Globe Award. Caan is also known for his recurring role as manager Scott Lavin in the HBO television series Entourage (2009–2011). He was also a part of 1990s rap group The Whooliganz with The Alchemist. The duo went by the names Mad Skillz and Mudfoot, respectively.",
        completions=['Scott Caan', 'Dwayne Johnson', 'Tom Hanks', 'Patrick Stewart'],
    ),
]

# Run CAPPR
pred_probs = predict_proba_examples(
    examples, model_and_tokenizer=(model, tokenizer), batch_size=READER_BATCH_SIZE
)
for pred in pred_probs:
    print(pred)

[0.87264718 0.12735282]
[8.05619281e-06 9.83392846e-01 1.65990978e-02]
[8.74034594e-01 1.04346800e-01 2.14904272e-02 1.28178625e-04]


# Re-identification risk assessment

## Functions

In [17]:
def rag_linkage_df(df:pd.DataFrame, retrievals:dict, id_to_label:dict, prompt_template) -> dict:
  predictions = {}
  model_and_tokenizer = (model, tokenizer)

  # For each column
  for col_name, col_retrievals in retrievals.items():
    col_documents = df[col_name]
    col_probs = rag_linkage_col(col_documents, col_retrievals, model_and_tokenizer, prompt_template, id_to_label)
    predictions[col_name] = col_probs

  return predictions

def rag_linkage_col(col_documents, col_retrievals, model_and_tokenizer, prompt_template, id_to_label) -> np.ndarray:
    col_probs = np.zeros((len(col_documents), len(id_to_label)))
    model_needs_caching = READER_USE_CAPPR_CACHE # Only for first caching
    
    # Generate all the prompts for pipeline batching
    examples = []
    example_idx_to_doc_idx = {}
    example_idx = 0
    for doc_idx, (document, doc_retrievals) in enumerate(zip(col_documents, col_retrievals)):
        if not doc_retrievals is None: # If none, argmax will do the equivalent to random guess
            prompt, retrieved_ids, retrieved_ids_counts = rag_linkage_construct_prompt(doc_retrievals, document, prompt_template)

            # If only one individual retrieved, that is the response
            if len(retrieved_ids) == 1:
                label = id_to_label[retrieved_ids[0]]
                col_probs[doc_idx][label] = 1
            # Otherwise, reader prediction required
            else:
                retrieved_labels = [id_to_label[id] for id in retrieved_ids]

                # Computing prior probabilities
                if READER_USE_PRIOR_PROBABILITIES:                    
                    for id, count in retrieved_ids_counts.items():
                        col_probs[doc_idx][id_to_label[id]] += count
                    col_probs[doc_idx][retrieved_labels] = np.exp(col_probs[doc_idx][retrieved_labels]) / np.sum(np.exp(col_probs[doc_idx][retrieved_labels]))
                    prior = col_probs[doc_idx][retrieved_labels]
                else:
                    prior = None
                
                if READER_USE_CAPPR_CACHE:
                    # Cache model for first time
                    if model_needs_caching:
                        prompt_prefix = prompt.split(READER_PROMPT_SEPARATOR)[0]+READER_PROMPT_SEPARATOR
                        model_and_tokenizer = cache_model(
                            model_and_tokenizer, prompt_prefix
                        )
                        model_needs_caching = False
                    # The rest of the text is the prompt                    
                    prompt = prompt = prompt[len(prompt_prefix):]

                # Create exmaple for CAPPR
                examples.append(Example(prompt=prompt,
                        completions=retrieved_ids,
                        prior=prior))
                example_idx_to_doc_idx[example_idx] = doc_idx
                example_idx += 1
    
    # Performing CAPPR predictions
    rag_linkage_cappr_predict(col_probs, model_and_tokenizer, examples, example_idx_to_doc_idx, id_to_label)

    return col_probs

def rag_linkage_construct_prompt(doc_retrievals, document, prompt_template):
    # Obtain retrievals grouped by id
    doc_retrievals_dict = {}
    retrieved_ids_counts = {}
    for retrieved in doc_retrievals:
       id = retrieved.metadata["source"]
       content = retrieved.page_content
       doc_retrievals_dict[id] = doc_retrievals_dict.get(id, "") + f"\t{content}\n"
       retrieved_ids_counts[id] = retrieved_ids_counts.get(id, 0) + 1
    
    # Generate list of retrieved names
    retrieved_names = list(doc_retrievals_dict.keys())

    # Generate background
    background = "".join(
        [f"\nID={id}\n{text}" for id, text in doc_retrievals_dict.items()]
    )

    # Generate the prompt
    prompt = prompt_template.format(
        document=document, background=background, names=retrieved_names
    )
    
    if READER_REMOVE_ROLE_END_STR:
        prompt = prompt[:-(len(READER_ROLE_END_STR)+1)] # Remove the end role/sequence mark

    return prompt, retrieved_names, retrieved_ids_counts

def rag_linkage_cappr_predict(col_probs, model_and_tokenizer, examples, example_idx_to_doc_idx, id_to_label):
    pred_probs = predict_proba_examples(examples, model_and_tokenizer=model_and_tokenizer, batch_size=READER_BATCH_SIZE)
    for example_idx, (probs, example) in enumerate(zip(pred_probs, examples)):
        doc_idx = example_idx_to_doc_idx[example_idx]
        retrieved_names = example.completions
        retrieved_labels = [id_to_label[name] for name in retrieved_names]
        col_probs[doc_idx][retrieved_labels] = probs

In [18]:
def eval_rag_linkage_df(df:pd.DataFrame, k:int, id_to_label:dict, retrievals:dict, model_and_tokenizer, prompt_template, verbose:bool=True):
    predictions = {}
    top_1_accuracies = {}
    top_k_accuracies = {}

    for col_name, col_retrievals in retrievals.items():
        if verbose:
            print(f"Evaluation of {col_name} documents")
        
        predictions[col_name] = rag_linkage_col(df[col_name], col_retrievals, model_and_tokenizer, prompt_template, id_to_label)
        top_1_accuracies[col_name], top_k_accuracies[col_name] = eval_linkage_col(df, predictions[col_name], k, id_to_label)

        if verbose:
            print(f"Top-1 accuracies for {col_name}: {top_1_accuracies[col_name]}")
            print(f"Top-{k} accuracies for {col_name}: {top_k_accuracies[col_name]}")
    
    return top_1_accuracies, top_k_accuracies, predictions

def eval_linkage_df(df:pd.DataFrame, predictions:dict, k:int, id_to_label:dict)->dict:
  top_1_accuracies = {}
  top_k_accuracies = {}

  for col_name, preds in predictions.items():
    top_1_accuracies[col_name], top_k_accuracies[col_name] = eval_linkage_col(df, preds, k, id_to_label)

  return top_1_accuracies, top_k_accuracies   

def eval_linkage_col(df, preds:np.ndarray, k:int, id_to_label:dict)->tuple:
  top_1_count = 0
  top_k_count = 0

  for idx, row in df.iterrows():
    # If there is a prediction
    if preds[idx].sum() != 0:
      id = row[ID_KEY]
      label = id_to_label[id]
      top_1_count += 1 if np.argmax(preds[idx]) == label else 0
      top_k_count += 1 if label in np.argsort(preds[idx])[-k:] else 0

  top_1_accuracy = 100 * top_1_count / len(df)
  top_k_accuracy = 100 * top_k_count / len(df)

  return top_1_accuracy, top_k_accuracy

def accuracies_to_csv(accuracy_data: dict, method_name: str, filename: str):
    try:
        # Get the names of the datasets from the dictionary keys.
        # These will serve as the main column headers for the accuracy values.
        dataset_names = list(accuracy_data.keys())

        # Construct the header row for the CSV.
        # The first element is a label for the method name column, followed by dataset names.
        header_row = ['Method'] + dataset_names

        # Construct the data row for the method's accuracies.
        # It starts with the method's name, then appends each accuracy score
        # corresponding to the order of `dataset_names`.
        datetime_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 
        data_row = [datetime_str, method_name]
        for dataset in dataset_names:
            # Ensure the accuracy value exists for the dataset
            if dataset in accuracy_data:
                data_row.append(f"{accuracy_data[dataset]:.2f}")
            else:
                # Append an empty string or a placeholder if a dataset is missing
                # (though `dataset_names` is derived from `accuracy_data` keys,
                # this provides robustness if data structures change).
                data_row.append('')

        # Determine if the file exists to decide whether to write headers and append or overwrite.
        file_exists = os.path.exists(filename)

        # Open the CSV file. Use 'w' mode if it's a new file (to write headers),
        # otherwise use 'a' mode (append) if it already exists.
        # `newline=''` is crucial for CSV files to prevent extra blank rows.
        with open(filename, 'a+', newline='', encoding='utf-8') as csvfile:
            # Create a CSV writer object.
            csv_writer = csv.writer(csvfile)

            # Write the header row ONLY if the file did not exist previously.
            if not file_exists:
                csv_writer.writerow(header_row)

            # Write the data row to the CSV
            csv_writer.writerow(data_row)

        print(f"✅ Accuracy data for method '{method_name}' successfully written to '{filename}'.")

    except IOError as e:
        print(f"❌ Error writing to file '{filename}': {e}")
    except Exception as e:
        print(f"❌ An unexpected error occurred: {e}")

## Execution

In [None]:
# Perform all retrievals
id_to_label = {id:idx for idx, id in enumerate(df[ID_KEY])}
retrievals = retrieval_df(df, retriever, RETRIEVER_K)

Retrievals for col_name='text':   0%|          | 0/553 [00:00<?, ?it/s]

Retrievals for col_name='St.NER3':   0%|          | 0/553 [00:00<?, ?it/s]

In [None]:
rag_top_1_accuracies, rag_top_k_accuracies, rag_predictions = eval_rag_linkage_df(df, RETRIEVER_K, id_to_label, retrievals, model_and_tokenizer, rag_prompt_template)
accuracies_to_csv(rag_top_1_accuracies, RAG_RISK_NAME, RESULTS_FILEPATH)