# Train Classifier Head for Span Model

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm

import ast
from tqdm.notebook import tqdm
from transformers import AutoModel, AutoTokenizer
import pandas as pd
from sklearn.metrics import classification_report
from collections import defaultdict
from pathlib import Path
import pandas as pd

# Evaluate Sentence-Level Span Model

# Setup

Check versions of important packages

In [10]:
print("PyTorch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)

PyTorch version: 2.6.0+cu124
CUDA version: 12.4


In [11]:
print(sys.executable)

/share/miniforge3/bin/python3.10


## 1. Load sentence-level test dataframe

In [12]:
input_data_path = Path("../../data/dev/processed")
train_df = pd.read_csv(input_data_path / "pubmedqa_train.csv")
test_df = pd.read_csv(input_data_path / "pubmedqa_test.csv")

Sentences and labels are still stringified lists --> back to actual Python lists

In [13]:
# Apply literal_eval to parse strings into actual lists
train_df["sentences"] = train_df["sentences"].apply(ast.literal_eval)
train_df["labels"] = train_df["labels"].apply(ast.literal_eval)

test_df["sentences"] = test_df["sentences"].apply(ast.literal_eval)
test_df["labels"] = test_df["labels"].apply(ast.literal_eval)

In [14]:
train_df

Unnamed: 0,question,sentences,labels
0,is there a functional neural correlate of indi...,[the present study tested whether individuals ...,"[1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0]"
1,can we use the omron t9p automated blood press...,[recent events in our hospital combined with i...,"[1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ..."
2,intraabdominal vascular injury are we getting ...,[intraabdominal vascular injury iavi as a resu...,"[1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, ..."
3,hand grip and pinch strength in patients with ...,[the hand grip strength test and pinch was sig...,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, ..."
4,is canada ready for patient accessible electro...,[access to personal health information through...,"[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, ..."
...,...,...,...
9791,does postmastectomy radiotherapy affect the ou...,[the decision to perform immediate deep inferi...,"[1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, ..."
9792,risk of reoperation within 90 days of liver tr...,[overall 90day reoperation rate after lt was 2...,"[1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, ..."
9793,is the menopause rating scale accurate for dia...,[to evaluate the accuracy of the menopause rat...,"[1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, ..."
9794,can snow depth be used to predict the distribu...,[the svalbard endemic aphid acyrthosiphon sval...,"[0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, ..."


## 2. Load the pretrained model and tokenizer

In [15]:
MODEL_NAME = "KRLabsOrg/chiliground-base-modernbert-v1"
model = AutoModel.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

classifier = torch.nn.Linear(model.config.hidden_size, 2)  # attach simple head
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
classifier.to(device)
# switch them into evaluation/inference mode
model.eval()
classifier.eval()


You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Some weights of ModernBertModel were not initialized from the model checkpoint at KRLabsOrg/chiliground-base-modernbert-v1 and are newly initialized: ['embeddings.norm.weight', 'embeddings.tok_embeddings.weight', 'final_norm.weight', 'layers.0.attn.Wo.weight', 'layers.0.attn.Wqkv.weight', 'layers.0.mlp.Wi.weight', 'layers.0.mlp.Wo.weight', 'layers.0.mlp_norm.weight', 'layers.1.attn.Wo.weight', 'layers.1.attn.Wqkv.weight', 'layers.1.attn_norm.weight', 'layers.1.mlp.Wi.weight', 'layers.1.mlp.Wo.weight', 'layers.1.mlp_norm.weight', 'layers.10.attn.Wo.weight', 'layers.10.attn.Wqkv.weight', 'layers.10.attn_norm.weight', 'layers.10.mlp.Wi.weight', 'layers.10.mlp.Wo.weight', 'layers.10.mlp_norm.weight', 'layers.11.attn.Wo.weight', 'layers.11.attn.Wqkv.weight', 'layers.11.attn_norm.weight', 'layers.11.mlp.Wi.weight', 'layers

Linear(in_features=768, out_features=2, bias=True)

## 3. Encode sentences with boundaries

In [16]:
def encode_sentences(question, sentences, tokenizer, max_length=512):
    input_ids = []
    attention_mask = []
    sentence_boundaries = []

    q_ids = tokenizer.encode(question, add_special_tokens=True, truncation=True, max_length=max_length)
    input_ids.extend(q_ids[:-1])
    attention_mask.extend([1]*len(q_ids[:-1]))

    for sent in sentences:
        sent_ids = tokenizer.encode(sent, add_special_tokens=False, truncation=True, max_length=max_length)
        if len(input_ids) + len(sent_ids) + 1 > max_length:
            break
        sentence_boundaries.append((len(input_ids), len(input_ids) + len(sent_ids) - 1))
        input_ids.append(tokenizer.sep_token_id)
        attention_mask.append(1)
        input_ids.extend(sent_ids)
        attention_mask.extend([1]*len(sent_ids))

    input_ids.append(tokenizer.sep_token_id)
    attention_mask.append(1)

    return {
        "input_ids": torch.tensor(input_ids).unsqueeze(0).to(device),
        "attention_mask": torch.tensor(attention_mask).unsqueeze(0).to(device),
        "sentence_boundaries": [sentence_boundaries]
    }


## 4. Evaluate the model

In [17]:
all_preds, all_labels = [], []

with torch.no_grad():
    for _, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Evaluating"):
        question = row["question"]
        sentences = row["sentences"]
        labels = row["labels"]

        encoded = encode_sentences(question, sentences, tokenizer)
        outputs = model(input_ids=encoded["input_ids"], attention_mask=encoded["attention_mask"])
        last_hidden = outputs.last_hidden_state

        preds = []
        for start, end in encoded["sentence_boundaries"][0]:
            span_repr = last_hidden[0, start:end+1].mean(dim=0)
            logit = classifier(span_repr)
            pred = torch.argmax(logit).item()
            preds.append(pred)

        all_preds.extend(preds)
        all_labels.extend(labels[:len(preds)])  # account for truncation

report = pd.DataFrame(classification_report(all_labels, all_preds, digits=4, output_dict=True)).transpose()

Evaluating:   0%|          | 0/1225 [00:00<?, ?it/s]

In [18]:
report = pd.DataFrame(classification_report(all_labels, all_preds, digits=4, output_dict=True)).transpose()

In [19]:
display(report)

Unnamed: 0,precision,recall,f1-score,support
0,0.566072,0.51388,0.538715,12104.0
1,0.430507,0.482639,0.455085,9216.0
accuracy,0.500375,0.500375,0.500375,0.500375
macro avg,0.49829,0.498259,0.4969,21320.0
weighted avg,0.507471,0.500375,0.502564,21320.0
