In [None]:
import ast

import pandas as pd
import torch
from torch import tensor
import torch.nn.functional as F

from src.ddx_data_gen.json_extraction import load_schemes_and_labelspace
from src.ddx_data_gen.prompt_args import PromptArgs
from src.eval.pipeline import extract_json_and_pred_from_text
from src.utils import init_notebook, convert_codes_to_short_codes, load_sbert_model
from src.wandb.data_loader import load_cross_validation_patients
from src.wandb.run import init_wandb
from src.exp_args import ExpArgs

init_notebook()

from src.eval.classification_metrics import (calculate_icd_metrics, calculate_disease_metrics,
                                             get_valid_json_pct, summarize_cv_results,
                                             calculate_in_domain_score)


%load_ext autoreload
%autoreload 2

In [None]:
wandb_run = init_wandb('Evaluation', eval_mode=True)
# all_dfs = {}

prompt_args = PromptArgs()
ood_prompt_args = PromptArgs()
exp_args = ExpArgs()
test_patients = pd.read_parquet('data/reasoning/abdominal_pain/test_dataset.pq')
ood_patients = pd.read_parquet('data/reasoning/patients_ood_700.pq')
load_schemes_and_labelspace(test_patients, prompt_args, exp_args)
load_schemes_and_labelspace(ood_patients, ood_prompt_args, exp_args)
mapping_model = load_sbert_model()

In [None]:
base_model = "eval_results_Qwen3-0.6B"
names_06 = {
    f"{base_model}" : f"{base_model}-think-cc64",
    f"{base_model}-G" : f"{base_model}-think-guid_decoding-cc64",
    f"{base_model}-Mimic" : f"{base_model}-lora-1mimic-think-cc64",
    f"{base_model}-Mimic-G" : f"{base_model}-lora-1mimic-guid_decoding-cc64",
    f"{base_model}-Lora" : f"{base_model}-lora-1epoch-think-cc64",
    f"{base_model}-Lora-G" : f"{base_model}-lora-1epoch-think-guid_decoding-cc64",
    f"{base_model}-2Mimic" : f"{base_model}-lora-2mimic-cc64",
    f"{base_model}-2Mimic-G" : f"{base_model}-lora-2mimic-guid_decoding-cc64",
    f"{base_model}-2Lora" : f"{base_model}-lora-2epoch-cc32",
    f"{base_model}-2Lora-G" : f"{base_model}-lora-2epoch-guid_decoding-cc64",
}

base_model = "eval_results_Qwen3-8B"
names_8 = {
    f"{base_model}" : f"{base_model}-think-cc32",  # Works!
    f"{base_model}-G" : f"{base_model}-think-guid_decoding-cc32",  # Works!
    f"{base_model}-Mimic" : f"{base_model}-lora-1mimic-cc32",  # Works!
    f"{base_model}-Mimic-G" : f"{base_model}-lora-1mimic-guid_decoding-cc32",  # Works!
    f"{base_model}-Lora" : f"{base_model}-lora-1epoch-cc32",
    f"{base_model}-Lora-G" : f"{base_model}-lora-1epoch-guid_decoding-cc32", # Works!
    f"{base_model}-2mimic" : f"{base_model}-lora-2mimic-cc32",  # Works!
    f"{base_model}-2mimic-G" : f"{base_model}-lora-2mimic-guid_decoding-cc32",  # Works!
    f"{base_model}-2-Lora" : f"{base_model}-lora-2epoch-cc32",  # Works!
    f"{base_model}-2-Lora-G" : f"{base_model}-lora-2epoch-guid_decoding-cc64",  # Works!
}
base_model = "eval_results_Qwen3-14B"
names_14 = {
    f"{base_model}" : [f"{base_model}-cc32", f"{base_model}-think-cc32"],  # Works!
    f"{base_model}-G" : f"{base_model}-guid_decoding-cc32",  # Works!
    f"{base_model}-Lora" : [f"{base_model}-lora-1epoch-cc32", f"{base_model}-lora-1epoch-think-cc32"],
    f"{base_model}-Lora-G" : f"{base_model}-lora-1epoch-guid_decoding-cc32",
    f"{base_model}-2Lora" : f"{base_model}-lora--2lora-cc32",
    f"{base_model}-2Lora-G" : f"{base_model}-lora--2lora-think-guid_decoding-cc32",  # Wrong name. Run no thinking
    f"{base_model}-3Lora" : f"{base_model}-lora-3epoch-think-cc32",
    f"{base_model}-3Lora-G" : f"{base_model}-lora-3epoch-guid_decoding-cc32",  
}
base_model = "eval_results_Qwen3-32B"
names_32 = {
    f"{base_model}" : f"{base_model}-cc32",  # Works!
    f"{base_model}-G" : f"{base_model}-guid_decoding-cc16",  # Works!
    f"{base_model}-Lora" : f"{base_model}-lora-1epoch-think-cc32",
    f"{base_model}-Lora-G" : f"{base_model}-lora-1epoch-guid_decoding-cc32",
    f"{base_model}-2Lora" : [f"{base_model}-lora-2epoch-cc32", f"{base_model}-lora-2epoch-cc64"],
    f"{base_model}-2Lora-G" : f"{base_model}-lora-2epoch-guid_decoding-cc64",
}

other_base_models = {
    "MedReason-8B": "eval_results_MedReason-8B-cc32",
    "MedReason-8B-G": "eval_results_MedReason-8B-guid_decoding-cc32",
    "Medgemma-27b": "eval_results_medgemma-27b-text-it-cc32",
    "Medgemma-27b-G": ["eval_results_medgemma-27b-text-it-guid_decoding-cc16", "eval_results_medgemma-27b-text-it-guid_decoding-cc32"],
    "Llama-3.3-70B": "eval_results_Llama-3.3-70B-Instruct-cc16",
    "Llama-3.3-70B-G": "eval_results_Llama-3.3-70B-Instruct-guid_decoding-cc32",
}
base_model = "eval_results_Qwen3-8B-OOD"
ood_models = {
    "Qwen3-8B-OOD": f"{base_model}-cc32",
    "Qwen3-8B-OOD-G": [f"{base_model}-guid_decoding-cc32", f"{base_model}-guid_decoding-cc32"],
    "Qwen3-8B-OOD-Mimic": f"{base_model}-lora-1mimic-cc32",
    "Qwen3-8B-OOD-Mimic-G": f"{base_model}-lora-1mimic-guid_decoding-cc32",
    "Qwen3-8B-OOD-Lora": f"{base_model}-lora-1epoch-cc32",
    "Qwen3-8B-OOD-Lora-G": f"{base_model}-lora-1epoch-guid_decoding-cc32",
    "Qwen3-8B-OOD-2Mimic": f"{base_model}-lora-2mimic-cc32",
    "Qwen3-8B-OOD-2Mimic-G": f"{base_model}-lora-2mimic-guid_decoding-cc32",
    "Qwen3-8B-OOD-2Lora": f"{base_model}-lora-2epoch-cc32",
    "Qwen3-8B-OOD-2Lora-G": f"{base_model}-lora-2epoch-guid_decoding-cc32",
}

In [None]:
for i in range(3):
    df = all_dfs['eval_results_Qwen3-8B'][i]
    df.to_parquet(f'data/results/qwen3_8b/base/qwen3_8b_base_{i}.pq')

In [None]:
for name, file in ood_models.items():
    if len(all_dfs.get(name, [])) >= 3:  # FixMe Ensure 3 of all that royu want to report
        print(f'Skip {name} already downloaded')
        continue
    try:
        dfs = load_cross_validation_patients(wandb_run, file)
    except Exception as e:
        print(e)
        continue
        
    for df in dfs:
        if 'OOD' not in name:
            df['Chief Complaint'] = 'abdominal_pain'
            extract_json_and_pred_from_text(df, prompt_args, mapping_model)
        else:
            extract_json_and_pred_from_text(df, ood_prompt_args, mapping_model)
            
        
    all_dfs[name] = dfs

In [None]:
def cosine_similarity_torch(prediction, label, scale_to_unit=True):
    # Convert to tensors on the correct device
    prediction = tensor(prediction, dtype=torch.float32, device='mps')
    label = tensor(label, dtype=torch.float32, device='mps')

    # Compute cosine similarity along the last dimension (vector dimension)
    cos_sim = F.cosine_similarity(prediction, label, dim=1)

    # Optionally scale from [-1, 1] to [0, 1]
    if scale_to_unit:
        cos_sim = (cos_sim + 1) / 2

    return cos_sim.mean().item()


def calculate_metrics(df: pd.DataFrame, generation=False) -> dict:

    metrics = {}
    if generation:
        metrics['old_V1'] = df['v1_score'].mean()
        v1_preds = df['v1_preds'].map(ast.literal_eval).to_list()
        v1_labels = df['disease_vector'].to_list()
        metrics['V1_CosSim'] = cosine_similarity_torch(v1_preds, v1_labels)
    
    metrics['V2_MRR'] = calculate_disease_metrics(df['v2_preds'], df['disease'])
    if generation:
        metrics['V3_VRR'] = calculate_disease_metrics(df['v3_preds'], df['disease'])

    v4_labels = df['ICD_CODES'].map(convert_codes_to_short_codes).to_list()
    metrics.update(calculate_icd_metrics(df['v4_preds'].to_list(), v4_labels))
    metrics.update({'Valid JSON': get_valid_json_pct(df)})
    # metrics.update({'In Domain ICDs': calculate_in_domain_score(df)})
    return metrics


def calculate_mean_metrics(experiments: dict, data: dict, std=False) -> pd.DataFrame:
    experiment_results = pd.DataFrame()
    experiments = [experiment for experiment in experiments if experiment in data.keys()]
    # labels = data[experiments[0]][0]['ICD_CODES'].map(convert_codes_to_short_codes).to_list()
    for experiment_name in experiments:
        results = []
        for df in data[experiment_name]:
            metrics = {}
            metrics.update(calculate_metrics(df))
            results.append(metrics)
    
        experiment_results = pd.concat([experiment_results, summarize_cv_results(results, experiment_name, std)], axis=0)
        
    return experiment_results.round(3)

In [None]:
calculate_mean_metrics(other_base_models, all_dfs, True)

In [None]:
calculate_mean_metrics(ood_models, all_dfs, True)

In [None]:
calculate_mean_metrics(names_06, all_dfs, True)

In [None]:
calculate_mean_metrics(names_8, all_dfs, True)

In [None]:
calculate_mean_metrics(names_14, all_dfs, True)

In [None]:
calculate_mean_metrics(names_32, all_dfs, True)

In [None]:
gen_dfs = {
    'Qwen3-32B': pd.read_parquet('data/results/generation/dataset_3055_qwen3_32b.pq'),
    'Llama-3.3-70B': pd.read_parquet('data/results/generation/dataset_3055_Llama_3.3_70B_Instruct.pq'),
    'Medgemma-27b': pd.read_parquet('data/results/generation/dataset_3055_medgemma_27b_text_it.pq'),
}
gen_prompt_args = PromptArgs()
load_schemes_and_labelspace(gen_dfs['Qwen3-32B'], gen_prompt_args, exp_args)

for name, df in gen_dfs.items():
    df['Chief Complaint'] = 'abdominal_pain' 
    extract_json_and_pred_from_text(df, gen_prompt_args, mapping_model, True)


In [None]:
gen_metrics = {}
for name, df in gen_dfs.items():
    gen_metrics[name] = calculate_metrics(df, True)
pd.DataFrame(gen_metrics).T.round(3)

In [None]:
gen_metrics = {}
for name, df in gen_dfs.items():
    gen_metrics[name] = calculate_metrics(df, True)
pd.DataFrame(gen_metrics).T.round(4)