In [1]:
import sys

sys.path.append("..")

from methods.simple_rag import SimpleRAG
from dataloader.load_datasets import *
from helpers import llm
from helpers import pubmed
from tqdm import tqdm
import time
import os

In [2]:
df = pd.read_csv("zero_shot_results.csv")
df.head()

Unnamed: 0,model,method,claim,documents,answer
0,deepseek-r1:32b,zero_shot,40mg/day dosage of folic acid and 2mg/day dosa...,NAN,"<think>\nOkay, so I need to figure out whether..."
1,deepseek-r1:32b,zero_shot,A breast cancer patient's capacity to metaboli...,NAN,"<think>\nOkay, so I'm trying to figure out whe..."
2,deepseek-r1:32b,zero_shot,A deficiency of folate increases blood levels ...,NAN,"<think>\nOkay, so I need to determine if the c..."
3,deepseek-r1:32b,zero_shot,A diminished ovarian reserve does not solely i...,NAN,"<think>\nOkay, so I need to determine if the c..."
4,deepseek-r1:32b,zero_shot,A high microerythrocyte count protects against...,NAN,"<think>\nOkay, so I need to assess whether the..."


## Extract Classification from Model Responses

This function parses the model's response to determine if it classified the claim as SUPPORTED or CONTRADICT(ED).

In [3]:
import re


def extract_classification(response):
    """
    Extract whether the model classified the claim as SUPPORTED or CONTRADICT(ED).
    
    Args:
        response (str): The model's response text
        
    Returns:
        str: 'SUPPORTED', 'CONTRADICT', or 'UNKNOWN' if classification cannot be determined
    """
    if pd.isna(response) or response == "NAN":
        return "UNKNOWN"
    
    response = str(response).upper()
    
    # Strategy 1: Look for explicit "Final Answer:" pattern (common in CoT responses)
    final_answer_match = re.search(r'FINAL\s+ANSWER\s*:?\s*(\w+)', response)
    if final_answer_match:
        answer = final_answer_match.group(1)
        if 'SUPPORT' in answer:
            return 'SUPPORTED'
        elif 'CONTRADICT' in answer:
            return 'CONTRADICT'
    
    # Strategy 2: Look for "Answer:" pattern (common in zero-shot responses)
    answer_match = re.search(r'ANSWER\s*:?\s*(\w+)', response)
    if answer_match:
        answer = answer_match.group(1)
        if 'SUPPORT' in answer:
            return 'SUPPORTED'
        elif 'CONTRADICT' in answer:
            return 'CONTRADICT'
    
    # Strategy 3: Check after </think> tag (for models like deepseek-r1)
    if '</THINK>' in response:
        after_think = response.split('</THINK>')[-1].strip()
        # Look for SUPPORTED or CONTRADICT in the conclusion
        if 'SUPPORTED' in after_think[:500]:  # Check first 500 chars after </think>
            return 'SUPPORTED'
        elif 'CONTRADICT' in after_think[:500]:
            return 'CONTRADICT'
    
    # Strategy 4: Look for these keywords anywhere in the response
    # Count occurrences to handle cases where both appear
    supported_count = len(re.findall(r'\bSUPPORT(?:ED)?\b', response))
    contradict_count = len(re.findall(r'\bCONTRADICT(?:ED|S)?\b', response))
    
    # If one clearly dominates, use that
    if supported_count > contradict_count:
        return 'SUPPORTED'
    elif contradict_count > supported_count:
        return 'CONTRADICT'
    elif supported_count == contradict_count and supported_count > 0:
        # If equal, look at the last occurrence
        last_supported = response.rfind('SUPPORT')
        last_contradict = response.rfind('CONTRADICT')
        if last_supported > last_contradict:
            return 'SUPPORTED'
        else:
            return 'CONTRADICT'
    
    return "UNKNOWN"


# Test the function on the loaded data
df['classification'] = df['answer'].apply(extract_classification)

# Show distribution of classifications
print("\nClassification Distribution:")
print(df['classification'].value_counts())
print(f"\nTotal rows: {len(df)}")
print(f"Successfully classified: {len(df[df['classification'] != 'UNKNOWN'])}")
print(f"Unknown classifications: {len(df[df['classification'] == 'UNKNOWN'])}")



Classification Distribution:
classification
SUPPORTED     685
CONTRADICT    315
Name: count, dtype: int64

Total rows: 1000
Successfully classified: 1000
Unknown classifications: 0


In [8]:
claims = pd.read_csv("../dataloader/scifact_medical_causal_claims.csv")
claims.head()

Unnamed: 0.1,Unnamed: 0,id,claim,evidence_doc_id,evidence_label,evidence_sentences,cited_doc_ids,causal_result_raw,is_medical_causal
0,7,12,40mg/day dosage of folic acid and 2mg/day dosa...,33409100,SUPPORT,[8],[33409100],Yes,True
1,24,30,A breast cancer patient's capacity to metaboli...,24341590,SUPPORT,[10],[24341590],Yes,True
2,29,34,A deficiency of folate increases blood levels ...,11705328,SUPPORT,[4],[11705328],Yes,True
3,32,39,A diminished ovarian reserve does not solely i...,13497630,SUPPORT,[7],[13497630],Yes,True
4,42,41,A high microerythrocyte count protects against...,18174210,SUPPORT,[1 9],[18174210],Yes,True


In [6]:
df["classification"]

0       SUPPORTED
1       SUPPORTED
2       SUPPORTED
3       SUPPORTED
4       SUPPORTED
          ...    
995    CONTRADICT
996    CONTRADICT
997    CONTRADICT
998    CONTRADICT
999    CONTRADICT
Name: classification, Length: 1000, dtype: object

## Compare Classifications with Ground Truth

Now let's merge the predictions with the ground truth labels and evaluate performance

In [9]:
# Prepare ground truth labels
# The evidence_label column contains SUPPORT or CONTRADICT
claims['ground_truth'] = claims['evidence_label'].str.upper()

# Merge predictions with ground truth based on claim text
df_eval = df.merge(
    claims[['claim', 'ground_truth']], 
    on='claim', 
    how='left'
)

print("Merged data shape:", df_eval.shape)
print("\nGround truth distribution:")
print(df_eval['ground_truth'].value_counts())
print("\nSample merged data:")
df_eval[['claim', 'model', 'method', 'classification', 'ground_truth']].head(10)

Merged data shape: (1000, 7)

Ground truth distribution:
ground_truth
SUPPORT       500
CONTRADICT    500
Name: count, dtype: int64

Sample merged data:


Unnamed: 0,claim,model,method,classification,ground_truth
0,40mg/day dosage of folic acid and 2mg/day dosa...,deepseek-r1:32b,zero_shot,SUPPORTED,SUPPORT
1,A breast cancer patient's capacity to metaboli...,deepseek-r1:32b,zero_shot,SUPPORTED,SUPPORT
2,A deficiency of folate increases blood levels ...,deepseek-r1:32b,zero_shot,SUPPORTED,SUPPORT
3,A diminished ovarian reserve does not solely i...,deepseek-r1:32b,zero_shot,SUPPORTED,SUPPORT
4,A high microerythrocyte count protects against...,deepseek-r1:32b,zero_shot,SUPPORTED,SUPPORT
5,A mutation in HNF4A leads to an increased risk...,deepseek-r1:32b,zero_shot,SUPPORTED,SUPPORT
6,A single nucleotide variant the gene DGKK is s...,deepseek-r1:32b,zero_shot,CONTRADICT,SUPPORT
7,A strong bias in the phage genome locations wh...,deepseek-r1:32b,zero_shot,SUPPORTED,SUPPORT
8,ALDH1 expression is associated with poorer pro...,deepseek-r1:32b,zero_shot,SUPPORTED,SUPPORT
9,AMP-activated protein kinase (AMPK) activation...,deepseek-r1:32b,zero_shot,SUPPORTED,SUPPORT


In [10]:
# Calculate accuracy for each model
# Filter out UNKNOWN classifications for fair comparison
df_eval_valid = df_eval[df_eval['classification'] != 'UNKNOWN'].copy()

# Check if prediction matches ground truth
df_eval_valid['correct'] = df_eval_valid['classification'] == df_eval_valid['ground_truth']

# Overall accuracy
overall_accuracy = df_eval_valid['correct'].mean()
print(f"Overall Accuracy (excluding UNKNOWN): {overall_accuracy:.2%}")
print(f"Total valid predictions: {len(df_eval_valid)}")
print(f"Correct predictions: {df_eval_valid['correct'].sum()}")
print(f"Incorrect predictions: {(~df_eval_valid['correct']).sum()}")


Overall Accuracy (excluding UNKNOWN): 24.90%
Total valid predictions: 1000
Correct predictions: 249
Incorrect predictions: 751


In [11]:
# Accuracy by model
print("\n" + "="*80)
print("ACCURACY BY MODEL")
print("="*80)

model_accuracy = df_eval_valid.groupby('model').agg({
    'correct': ['sum', 'count', 'mean']
}).round(4)

model_accuracy.columns = ['Correct', 'Total', 'Accuracy']
model_accuracy = model_accuracy.sort_values('Accuracy', ascending=False)
print(model_accuracy)



ACCURACY BY MODEL
                 Correct  Total  Accuracy
model                                    
qwen3:30b             79    200     0.395
deepseek-r1:32b       61    200     0.305
llama3.1:70b          45    200     0.225
mistral:7b            41    200     0.205
llama3.1:8b           23    200     0.115


In [12]:
# Accuracy by method (if multiple methods in the dataset)
if 'method' in df_eval_valid.columns and df_eval_valid['method'].nunique() > 1:
    print("\n" + "="*80)
    print("ACCURACY BY METHOD")
    print("="*80)
    
    method_accuracy = df_eval_valid.groupby('method').agg({
        'correct': ['sum', 'count', 'mean']
    }).round(4)
    
    method_accuracy.columns = ['Correct', 'Total', 'Accuracy']
    method_accuracy = method_accuracy.sort_values('Accuracy', ascending=False)
    print(method_accuracy)


In [13]:
# Accuracy by model AND method
print("\n" + "="*80)
print("ACCURACY BY MODEL AND METHOD")
print("="*80)

model_method_accuracy = df_eval_valid.groupby(['model', 'method']).agg({
    'correct': ['sum', 'count', 'mean']
}).round(4)

model_method_accuracy.columns = ['Correct', 'Total', 'Accuracy']
model_method_accuracy = model_method_accuracy.sort_values('Accuracy', ascending=False)
print(model_method_accuracy)



ACCURACY BY MODEL AND METHOD
                           Correct  Total  Accuracy
model           method                             
qwen3:30b       zero_shot       79    200     0.395
deepseek-r1:32b zero_shot       61    200     0.305
llama3.1:70b    zero_shot       45    200     0.225
mistral:7b      zero_shot       41    200     0.205
llama3.1:8b     zero_shot       23    200     0.115
