In [None]:
import pandas as pd

df=pd.read_pickle('/kaggle/input/building-musique-sots-dataset/musique-sots-dataset.pkl').to_pandas()
df.shape

In [None]:
from doraemon import Doraemon
from relaxed_fda import RelaxedFDA

logger = Doraemon.get_logger(name=__name__, logfile="relaxed_FDA_on_MuSiQue_qa.log")

grouped = (df.groupby('question', sort=False).agg(reasoning_paths=('reason', list), evidence=('evidence', 'first'), ground_truth=('ground_truth', 'first')).reset_index())

In [None]:
import torch
from sketch_of_thought import SoT
from sentence_transformers import SentenceTransformer

device=SoT.get_device()

logger.info(device)

encoder_path='/kaggle/input/encoder-l6-v2/transformers/v0.1.4/1'

encoder=SentenceTransformer(encoder_path)
encoder.to(device)

In [None]:
first = grouped.iloc[3]

logger.info(first['question'])
logger.info(first['evidence'])
logger.info(first['reasoning_paths'])
logger.info(first['ground_truth'])

In [None]:
from typing import List, Dict, Tuple, Optional

# define fewshot here.
DATASET_TYPE=SoT.classify_question(f"Context:{first.evidence}\nQuestion:{first.question}")
logger.info(DATASET_TYPE)

SYSTEM_INSTRUCTION_EN=SoT.get_prompts_en()[DATASET_TYPE]

PROMT=[]

PROMT.append({"role": "system", "content": SYSTEM_INSTRUCTION_EN })

# few-shot examples with correct reasoning with wrong answer
D: List[Dict] = [
    {
        "question": "Bancroft's county borders what county?",
        "evidence": "North Hastings High School (NHHS) is a high school located in Bancroft, Ontario, Canada serving students in the northern portion of Hastings County and part of the Hastings and Prince Edward District School Board. NHHS offers specialized 4-credit courses which allow students to learn principles of resource management and environmental studies, which help them to gain employment in resource-based careers. The York River is a river in Renfrew County, Hastings County, and Haliburton County in Ontario, Canada. The river is in the Saint Lawrence River drainage basin, and flows from the southern extension of Algonquin Provincial Park to the Madawaska River.",
        "correct_rs": "<think>\nLet’s think through this step by step\n#Bancroft → #Hastings_County → borders #Haliburton_County\n</think>\n\\boxed{Haliburton County}",
        "wrong_rs": "<think>\nLet’s think through this step by step\n#Bancroft → #Hastings_County → borders\n</think>\n\\boxed{Hastings}"
    },
    {
        "question": "What place does the administrative territorial entity that Juba is located in share a border with?",
        "evidence": "The Citizen is a newspaper based in Juba, the national capital of South Sudan and the state capital of Central Equatoria. The Badigeru swamp (or Bedigeru, Badingilu) swamp lies in South Sudan, in the Central Equatoria and Eastern Equatoria states between Terekeka and Lafon.",
        "correct_rs": "<think>\nLet’s think through this step by step\n#Juba → #Central_Equatoria → Badigeru swamp\nBadigeru swamp → Eastern Equatoria\n</think>\n\\boxed{Eastern Equatoria}",
        "wrong_rs": "<think>\nLet’s think through this step by step\n#Juba → #Central_Equatoria → #Badigeru_swamp\n</think>\n\\boxed{Badigeru swamp}"
    },
    {
        "question": "Who is the child of the person who followed Tihomir of Serbia?",
        "evidence": "In 1166 Stefan Nemanja overthrew Tihomir in a coup and had him and his brothers, Stracimir and Miroslav, expelled to Byzantium in 1167/1168. Stefan Nemanja defeated Tihomir and his Byzantine army. Tihomir drowned in a river and the other brothers were stripped of their titles, with Nemanja becoming ruler of All Serbia. He pardoned his brothers and Stracimir continued to rule his lands. When Stefan Nemanja besieged and retook control of Duklja in the 1180s, Stracimir and Miroslav attacked the forces of Doclean ruler and kinsman Mihailo. He is widely considered as one of the most important figures of Serbian history. Saint Sava is venerated by the Serbian Orthodox Church as its founder on. Many artistic works from the Middle Ages to modern times have interpreted his career. He is the patron saint of Serbia, Serbs, and Serbian education. The Church of Saint Sava in Belgrade is dedicated to him, built where the Ottomans burnt his remains in 1594 during an uprising in which the Serbs used icons of Sava as their war flags; the church is one of the largest church buildings in the world.",
        "correct_rs": "<think>\nLet’s think through this step by step\n#Tihomir → overthrown by #Stefan_Nemanja\n#Stefan_Nemanja → ruler of All Serbia\n#Stracimir → continued to rule his lands\n#Miroslav → attacked forces of #Mihailo\n#Stefan_Nemanja → widely considered important figure in Serbian history\n#Saint_Sava → venerated by Serbian Orthodox Church\n#Child of Stefan Nemanja → #Saint_Sava\n</think>\n\\boxed{Saint Sava}",
        "wrong_rs": "<think>\nLet’s think through this step by step\n#Stefan_Nemanja → #Stracimir\n</think>\n\\boxed{Stracimir}"
    },
    {
        "question": "What record label did the person who is part of The Bruce Lee Band start?",
        "evidence": "The Bruce Lee Band (or B. Lee Band) is the name given to the releases of Mike Park and his backing band which has so far included Less Than Jake and the Rx Bandits. Asian Man Records is a DIY record label run by Mike Park in Monte Sereno, California. Park started a record label and began releasing music in 1989 under the name Dill Records, with the Asian Man label established May 1996.",
        "correct_rs": "<think>\nLet’s think through this step by step\n#Mike_Park → #The_Bruce_Lee_Band → #Asian_Man_Records\n</think>\n\\boxed{Asian Man Records}",
        "wrong_rs": "<think>\nLet’s think through this step by step\n#The_Bruce_Lee_Band → Mike Park → #record_label\n#record_label → #Dill_Records → #Asian_Man_Records\n</think>\n\\boxed{Dill Records}"
    },
    {
        "question": "What award did the author of The Red Tree receive?",
        "evidence": "The Red Tree (2001), written and illustrated by Shaun Tan, is a picture book that presents a fragmented journey through a dark world. The illustrations are surreal. The text is sparse and matches the dark illustrations. The company's core business is in commercial and animation output, which includes work for Cartoon Network, music videos for Gorillaz, and the Compare the Market.com commercial campaign featuring Aleksandr Orlov (meerkat). Passion Australia produced 'The Lost Thing', directed by Andrew Ruhemann and Shaun Tan, which won an Academy Award for Best Animated Short Film in 2011.",
        "correct_rs": "<think>\nLet’s think through this step by step\n#The_Red_Tree → illustrated by Shaun Tan\n#Shaun_Tan → won Academy Award\n</think>\n\\boxed{Academy Award for Best Animated Short Film}",
        "wrong_rs": "<think>\nLet’s think through this step by step\n#The_Red_Tree → Shaun_Tan → Academy_Award\n</think>\n\\boxed{Best Animated Short Film}"
    }
]

OPTIMIZE_SYSTEM_PROMPT = (
    "You are a concise and helpful assistant for conceptual chaining. "
    "Provide step-by-step solutions using minimal tokens, ensuring accuracy. "
    "Conclude with the final answer in the format: \\boxed{answer}. "
    "When given a reasoning process, refine it to be both correct and succinct."
)

In [7]:
import re

def construct_interven_prompt(selected_demos: List[Dict], r_k:str, test_question: str, test_evidence: str)->str:
    prompt = ("")

    for i, demo in enumerate(selected_demos, start=1):
        prompt += (
            f"Demo {i}:\n"
            f"Q: The question is: {demo['question']}\n"
            f"E: The context is: {demo['evidence']}\n"
            f"The provided reasoning path is: {demo['wrong_rs']}\n"
            f"A: The improved reasoning path is: {demo['correct_rs']}\n"
            f"Therefore, the correct answer is: {RelaxedFDA.get_answer(demo['correct_rs'])}\n\n"
        )
    
    # Append the test question.
    prompt += (
        "Test Example:\n"
        f"Q: The question is: {test_question}\n"
        f"E: The context is : {test_evidence}\n"
        "<think>/nLet us think step by step.\n"
        f"The provided reasoning path is: {r_k}\n"
        "A: The improved reasoning path is: [improved_rs]\n"
        "Therefore, the correct answer is: [answer]\n"
    )

    return prompt

def normalize(s):
    # Lowercase, strip whitespace, remove punctuation (including quotes)
    return re.sub(r'[^\w\s]', '', str(s).strip().lower())

def is_substring_match(a:str, b:str)->bool:
    na=normalize(a)
    nb=normalize(b)
    return na in nb or nb in na

def extract_match(a, b, threshold=0.5):
    """
    Returns True if the overlap of tokens is >= threshold (as a proportion of the shorter string's tokens).
    """
    tokens_a = set(normalize(a).split())
    tokens_b = set(normalize(b).split())
    overlap = tokens_a & tokens_b
    min_len = min(len(tokens_a), len(tokens_b))
    if min_len == 0:
        return False
    return len(overlap) / min_len >= threshold

def enhanced_exact_match(predicted: str, ground_truth: str, token_threshold: float = 0.5) -> bool:
    """
    Returns True if the predicted answer exactly matches, is a substring of, or has sufficient token overlap with the ground truth answer.
    """
    if normalize(predicted) == normalize(ground_truth):
        return True
    if is_substring_match(predicted, ground_truth):
        return True
    if extract_match(predicted, ground_truth, token_threshold):
        return True
    return False


# Precision: The ratio of overlapping tokens to the total number of tokens in the predicted answer.
# Recall: The ratio of overlapping tokens to the total number of tokens in the ground truth answer.
def f1_score_custom(predicted:str, ground_truth:str):
    """Compute F1 score based on token overlap."""
    pred_tokens = normalize(predicted).split()
    gt_tokens = normalize(ground_truth).split()
    common = Counter(pred_tokens) & Counter(gt_tokens)
    num_common = sum(common.values())
    if num_common == 0:
        return 0.0
    precision = num_common / len(pred_tokens)
    recall = num_common / len(gt_tokens)
    return 2 * (precision * recall) / (precision + recall)

In [None]:
from collections import Counter

correct_em = 0
total = 0
f1_total = 0.0

for _, row in grouped.iterrows():
    q = row['question']
    r_p = row['reasoning_paths']
    evidence=row['evidence']

    final_answer, o_p = RelaxedFDA.causal_infer_with_fdr(
        q,
        evidence,
        logger,
        construct_interven_prompt,
        OPTIMIZE_SYSTEM_PROMPT,
        encoder,
        r_p, 
        D)
    
    logger.info(f"The final answer: {final_answer}")

    ground_truth = row['ground_truth']

    if final_answer is not None and ground_truth is not None:
        final_answer_clean = RelaxedFDA.get_answer(final_answer)
        logger.info(f"Question is {q}")
        logger.info(f"Final answer is {str(final_answer_clean)}, ground truth is {ground_truth}")

        norm_pred=normalize(final_answer_clean)
        norm_truth=normalize(ground_truth)
        
        if enhanced_exact_match(final_answer_clean, ground_truth):
            correct_em += 1
        else:
            logger.warning(f"Mismatch: predicted={final_answer_clean}, expected={ground_truth}")

        f1=f1_score_custom(final_answer_clean, ground_truth)
        f1_total += f1
        total += 1
    else:
        logger.warning(f"Missing data: predicted={final_answer}, expected={ground_truth}")

In [None]:
# After processing all rows
em_score = correct_em / total if total > 0 else 0.0
avg_f1 = f1_total / total if total > 0 else 0.0
logger.info(f"Exact Match (EM): {em_score:.2%}")
logger.info(f"Average F1 Score: {avg_f1:.2%}")