## Initialization

In [1]:
import logging
import os
import textwrap
from dataclasses import asdict, dataclass
from datetime import date, datetime
from pathlib import Path

import altair as alt
import numpy as np
import polars as pl
import torch
from polars import col as c
from polars import lit
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    roc_auc_score,
    roc_curve,
)
from sklearn.preprocessing import label_binarize
from tqdm import tqdm

alt.data_transformers.enable("vegafusion")

# Configure logging to show timestamp, log level and message
logging.basicConfig(
    format="[%(asctime)s] [%(levelname)s] %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S %p",
)

# set the number of rows to display
pl.Config.set_tbl_rows(1000)

work_dir = os.getcwd()
logging.info(f"Working directory: {work_dir}")

DataTransformerRegistry.enable('vegafusion')

polars.config.Config

[2025-02-03 04:11:39 AM] [INFO] Working directory: /home/yu/chaoyang/projects/Call/call/code/v4/reproduce-finetune


## Benchmarking

In [2]:
# define a function to evaluate the classification performance

def evaluate_classification(yt, models, n_classes=5):
    """
    Evaluate classification performance with multiple metrics for multiple models.
    Returns results as a polars DataFrame.
    
    Args:
        ytx (pl.DataFrame): DataFrame containing predictions and ground truth
        models (list): List of model names to evaluate
        n_classes (int): Number of classes (default=5)
    """
    results = []
    
    for model in models:
        data = yt.filter(pl.col("model") == model)
        t = np.array(data['t'])
        y = np.array(data['y'])
        
        # Remove null values and corresponding targets
        valid_mask = ~np.isnan(y)
        y = y[valid_mask]
        t = t[valid_mask]
        
        # number of non-null predictions
        n_instances = len(y)

        # Calculate accuracy
        accuracy = np.mean(t == y)
        
        # Calculate per-class precision, recall, and f1
        f1 = []
        support = []
        auc_scores = []
        accuracies = []
        
        # For micro-AUC: binarize labels for all classes at once
        y_bin = label_binarize(y, classes=range(1, n_classes + 1))
        t_bin = label_binarize(t, classes=range(1, n_classes + 1))
        
        # Calculate micro-AUC
        micro_auc = roc_auc_score(t_bin, y_bin, average='micro')
        
        for class_idx in range(1, n_classes + 1):
            # Calculate per-class metrics directly using multi-class labels
            prec, rec, f1_score, sup = precision_recall_fscore_support(
                t, y, labels=[class_idx], average=None
            )
            
            # Calculate per-class accuracy
            class_mask = t == class_idx
            class_accuracy = np.mean(y[class_mask] == t[class_mask]) if any(class_mask) else 0.0
            accuracies.append(class_accuracy)
            
            f1.append(float(f1_score[0]))
            support.append(sup[0])
            
            # Calculate AUC using one-vs-rest approach
            try:
                # Get probability scores for this class
                y_scores = (y == class_idx).astype(float)
                t_binary = (t == class_idx).astype(int)
                auc = roc_auc_score(t_binary, y_scores)
                auc_scores.append(auc)
            except:
                auc_scores.append(np.nan)
        
        # Calculate macro metrics
        macro_f1 = np.mean(f1)
        micro_f1 = np.mean(t == y)
        macro_auc = np.mean([x for x in auc_scores if not np.isnan(x)])
        
        # Add overall metrics
        results.append({
            "model": model,
            "metric_type": "overall",
            "metric": "n_instances",
            "value": n_instances,
            "class": None
        })
        results.append({
            "model": model,
            "metric_type": "overall",
            "metric": "accuracy",
            "value": accuracy,
            "class": None
        })
        results.append({
            "model": model,
            "metric_type": "overall", 
            "metric": "macro_f1",
            "value": macro_f1,
            "class": None
        })
        results.append({
            "model": model,
            "metric_type": "overall",
            "metric": "micro_f1",
            "value": micro_f1,
            "class": None
        })
        results.append({
            "model": model,
            "metric_type": "overall",
            "metric": "macro_auc",
            "value": macro_auc,
            "class": None
        })
        results.append({
            "model": model,
            "metric_type": "overall",
            "metric": "micro_auc",
            "value": micro_auc,
            "class": None
        })
        
        # Add per-class metrics
        for i in range(n_classes):
            class_num = i + 1
            results.append({
                "model": model,
                "metric_type": "per_class",
                "metric": "f1",
                "value": f1[i],
                "class": class_num
            })
            results.append({
                "model": model,
                "metric_type": "per_class",
                "metric": "auc",
                "value": auc_scores[i],
                "class": class_num
            })
            results.append({
                "model": model,
                "metric_type": "per_class",
                "metric": "accuracy",
                "value": accuracies[i],
                "class": class_num
            })
            results.append({
                "model": model,
                "metric_type": "per_class",
                "metric": "support",
                "value": float(support[i]),
                "class": class_num
            })
    
    # Convert to polars DataFrame
    results_df = pl.DataFrame(results)
    return results_df

### Ours

In [10]:
# evaluate

def evaluate_one_window(window):

    # get the X (features) and t (target)
    test_data = pl.read_ipc(f"data/test_{window}.feather", memory_map=False)
    train_data = pl.read_ipc(f"data/train_{window}.feather", memory_map=False)

    # get the cutoff
    y_all = pl.read_ipc(
        "data/cutoff.feather",
        memory_map=False
    )

    cutoff = (
        y_all.select(c.docid_idx, y=c.y_car_c5_call_0_21_std, t=c.t_car_c5_call_0_21_std)
        .join(train_data, on="docid_idx", how="semi")["y"]
        .qcut(5)
        .unique()
        .to_list()
    )
    cutoff = sorted([float(cut.split(",")[1].strip("]")) for cut in cutoff[:-1]])

    # get the yt
    yt = (
        y_all.join(test_data, on="docid_idx")
        .select(
            c.docid_idx,
            # y=c.y_car_c5_call_0_21_std.cut(cutoff, labels=["1", "2", "3", "4", "5"]).cast(pl.Int32),
            y=c.y_car_c5_call_0_21_std.qcut(5, labels=["1", "2", "3", "4", "5"]).cast(pl.Int32),
            t=c.rank,
            y_num=c.y_car_c5_call_0_21_std,
            t_num=c.t_car_c5_call_0_21_std
        )
        .with_columns(model=lit("ours"), split_id=lit(window))
    )

    # evaluate the performance
    results_df = evaluate_classification(yt, ["ours"])

    return yt, results_df

def evaluate_all_windows(windows):
    yt_all = []
    results_df = []
    for split_id in windows:
        yt, results_df_window = evaluate_one_window(split_id)
        results_df.append(results_df_window)
        yt_all.append(yt)

    results_df = pl.concat(results_df)
    yt_all = pl.concat(yt_all)

    # average across windows
    results_df = (
        results_df
        .group_by(c.model, c('class'), c.metric)
        .agg(value=c.value.mean().round(3))
        .with_columns(value=pl.when(c.metric=="n_instances").then(c.value*4).otherwise(c.value))
        .with_columns(value=pl.when(c.metric=="support").then(c.value*4).otherwise(c.value))
        .sort(['model', 'class', 'metric'])
    )

    return yt_all, results_df


windows = ['22q1', '22q2', '22q3', '22q4']
yt_all, results_df = evaluate_all_windows(windows)

# print overall metrics
results_df.filter(c('class').is_null())

# print per-class metrics
results_df.filter(c('class').is_not_null())


model,class,metric,value
str,i64,str,f64
"""ours""",,"""accuracy""",0.262
"""ours""",,"""macro_auc""",0.539
"""ours""",,"""macro_f1""",0.261
"""ours""",,"""micro_auc""",0.539
"""ours""",,"""micro_f1""",0.262
"""ours""",,"""n_instances""",3867.0


model,class,metric,value
str,i64,str,f64
"""ours""",1,"""accuracy""",0.336
"""ours""",1,"""auc""",0.584
"""ours""",1,"""f1""",0.332
"""ours""",1,"""support""",755.0
"""ours""",2,"""accuracy""",0.22
"""ours""",2,"""auc""",0.513
"""ours""",2,"""f1""",0.212
"""ours""",2,"""support""",697.0
"""ours""",3,"""accuracy""",0.208
"""ours""",3,"""auc""",0.506


### LLM Results

Collect and parse prediction results from all finetuned or non-finetuned LLM models

In [13]:
# load all the LLM results
def load_llm_results():
    yt = []
    metrics = []

    for model in ['llama-3.1', 'mistral']:
        for ft in ['ft', 'noft']:
            for split in ['22q1', '22q2', '22q3', '22q4']:
                t = (
                    pl.read_ipc(f"data/test_{split}.feather", columns=['docid_idx', 'rank'])
                    .rename({"rank": "t"})
                    .with_columns(c.docid_idx.cast(pl.Int64))
                )
                
                # get yt for one split
                yt_one_split = (
                    pl.read_ipc(f"saved_results/results_{model}_{ft}_frtxt_{split}.feather")
                    .with_columns(model=lit(f"{model}_{ft}_frtxt"), split_id=lit(split))
                    .unique(subset=['docid_idx'])
                    # parse the generated text into a score
                    .with_columns(y=c.generated_text.str.extract(r"(?i)Score:?\s*([\d\.\d]+)").cast(pl.Float32).round(0).cast(pl.Int32))
                    # if the score doesn't fall into the range of 1-5, set it to null
                    .with_columns(y=pl.when(c.y.is_in(range(1, 6))).then(c.y).otherwise(None))
                    # join with the ground truth
                    .join(t, on='docid_idx', how='inner')
                )
                yt.append(yt_one_split)

                # get metrics for one split
                metrics_one_split = evaluate_classification(yt_one_split, [f"{model}_{ft}_frtxt"])
                metrics.append(metrics_one_split)

    metrics = pl.concat(metrics)
    yt = pl.concat(yt)

    return yt, metrics

# get yt and metrics
yt, metrics = load_llm_results()

# average across splits
avg_metrics = (
    metrics
    .group_by(c.model, c('class'), c.metric)
    .agg(value=c.value.mean().round(3))
    .with_columns(value=pl.when(c.metric=="n_instances").then(c.value*4).otherwise(c.value))
    .with_columns(value=pl.when(c.metric=="support").then(c.value*4).otherwise(c.value))
    .sort(['model', 'class', 'metric'])
)

# print overall metrics
avg_metrics.filter(c('class').is_null())

# print per-class metrics
avg_metrics.filter(c('class').is_not_null())

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


model,class,metric,value
str,i64,str,f64
"""llama-3.1_ft_frtxt""",,"""accuracy""",0.209
"""llama-3.1_ft_frtxt""",,"""macro_auc""",0.506
"""llama-3.1_ft_frtxt""",,"""macro_f1""",0.208
"""llama-3.1_ft_frtxt""",,"""micro_auc""",0.506
"""llama-3.1_ft_frtxt""",,"""micro_f1""",0.209
"""llama-3.1_ft_frtxt""",,"""n_instances""",3864.0
"""llama-3.1_noft_frtxt""",,"""accuracy""",0.206
"""llama-3.1_noft_frtxt""",,"""macro_auc""",0.501
"""llama-3.1_noft_frtxt""",,"""macro_f1""",0.129
"""llama-3.1_noft_frtxt""",,"""micro_auc""",0.504


model,class,metric,value
str,i64,str,f64
"""llama-3.1_ft_frtxt""",1,"""accuracy""",0.237
"""llama-3.1_ft_frtxt""",1,"""auc""",0.52
"""llama-3.1_ft_frtxt""",1,"""f1""",0.23
"""llama-3.1_ft_frtxt""",1,"""support""",754.0
"""llama-3.1_ft_frtxt""",2,"""accuracy""",0.207
"""llama-3.1_ft_frtxt""",2,"""auc""",0.498
"""llama-3.1_ft_frtxt""",2,"""f1""",0.193
"""llama-3.1_ft_frtxt""",2,"""support""",696.0
"""llama-3.1_ft_frtxt""",3,"""accuracy""",0.215
"""llama-3.1_ft_frtxt""",3,"""auc""",0.492
