[q] [SEP] [sent1 + sent2 + sent3 + ... + sentn]

In [4]:
# connect your Google Drive
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
root_dir = "/content/gdrive/My Drive"

Install the dependencies

*Hardware : GPU A100 on Google colab pro*

In [2]:
!pip install --quiet torch==2.2.2+cu121 torchvision==0.17.2+cu121 torchaudio==2.2.2+cu121 --index-url https://download.pytorch.org/whl/cu121
!pip install --quiet --upgrade transformers sentence-transformers
!pip install numpy==1.26.4

In [3]:
!pip install ninja packaging
!MAX_JOBS=8 pip install flash-attn --no-build-isolation

In [None]:
import torch
print("Torch version:", torch.__version__)

Torch version: 2.2.2+cu121


Load the augmented dataset

In [None]:
import pandas as pd
#Load the augmented dataset
df = pd.read_excel("'/content/gdrive/MyDrive/dataset_excel_bionlp/final_dataset/qa_train_dataset_structured_05-05.xlsx'")

In [15]:
# # Preview
# df

In [16]:
# print("Number of rows:", len(df))

Fine tuning

Using just train (no validation)

In [None]:
import os
os.environ["WANDB_DISABLED"] = "true"
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
from sklearn.metrics import classification_report
import numpy as np

from tqdm import tqdm

# -----------------------
# CONFIG
# -----------------------

MODEL_NAME = "jinaai/jina-embeddings-v3"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 1
EPOCHS = 6

# -----------------------
# DATA LOADING
# -----------------------
# df = pd.read_excel("/content/gdrive/MyDrive/dataset_excel_bionlp/final_dataset/qa_results_structured_03-05.xlsx")
df["binary_relevance"] = df["relevance"].apply(lambda x: 1 if x.strip().lower() == "essential" else 0)

# Group by case
grouped_data = []
for case_id, group in df.groupby("case_id"):
    question = group["question_generated"].iloc[0]
    sentences = group["ref_excerpt"].tolist()
    labels = group["binary_relevance"].tolist()
    grouped_data.append({
        "question": question,
        "sentences": [s.strip() for s in sentences],
        "labels": labels
    })

# -----------------------
# DATASET
# -----------------------
class SentenceClassificationDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item["question"]
        sentences = item["sentences"]
        labels = item["labels"]


        separator = "</s>"
        text = question + f" {separator} " + f" {separator} ".join(sentences)



        # 🔍 Check token count before truncation
        tokens_total = len(self.tokenizer.tokenize(text))
        if tokens_total > self.max_length:
            print(f"[!] Truncated from {tokens_total} → {self.max_length} tokens")


        encoding = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length)

        input_ids = encoding["input_ids"][0]
        sent_token_id = tokenizer.convert_tokens_to_ids("</s>")
        sep_positions = (input_ids == sent_token_id).nonzero(as_tuple=True)[0]


        # sep_positions = (input_ids == sep_token_id).nonzero(as_tuple=True)[0][1:]

        # If no [SEP] tokens are found (after question), fallback
        if len(sep_positions) == 0:
            # Insert dummy [SEP] and dummy label
            sep_positions = torch.tensor([1])
            labels = [0]

        # Also truncate labels to match sep_positions
        labels = labels[:len(sep_positions)]
        # DEBUG PRINT (only for a few examples)
        # if idx < 3:  # Only print for first 3 batches
        #     print(f"\nExample {idx}")
        #     print(f"Question: {question}")
        #     print(f"Sentences: {sentences}")
        #     print(f"Labels: {labels}")
        #     print(f"Number of SEP positions found: {len(sep_positions)}")
        #     print(f"Input IDs shape: {input_ids.shape}")

        return {
            "input_ids": input_ids,
            "attention_mask": encoding["attention_mask"][0],
            "sep_positions": sep_positions,
            "labels": torch.tensor(labels)
        }


# -----------------------
# MODEL WRAPPER
# -----------------------
class MultiSentenceClassifier(nn.Module):
    def __init__(self, base_model_name):
        super().__init__()
        self.encoder = SentenceTransformer(
            base_model_name,
            trust_remote_code=True,
            model_kwargs={"default_task": "classification","lora_main_params_trainable": True})
        self.encoder[0].default_task = "classification"
        #self.encoder = self.encoder.float()
        hidden_size = self.encoder.get_sentence_embedding_dimension()
        self.classifier = nn.Linear(hidden_size, 1)

    def forward(self, input_ids, attention_mask, sep_positions):
        output = self.encoder[0].auto_model(input_ids=input_ids, attention_mask=attention_mask)
        token_embeddings = output.last_hidden_state

        preds = []
        for i in range(input_ids.shape[0]):
            sep_pos = sep_positions[i]
            if len(sep_pos) == 0:
                continue  # skip if somehow empty
            sentence_embs = token_embeddings[i, sep_pos, :].float()
            logits = self.classifier(sentence_embs).squeeze(-1)
            preds.append(logits)

        return preds


class FocalLoss(nn.Module):
    def __init__(self, alpha=0.75, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        BCE_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        probs = torch.sigmoid(inputs)
        alpha_factor = targets * self.alpha + (1 - targets) * (1 - self.alpha)
        focal_weight = (targets * (1 - probs) + (1 - targets) * probs) ** self.gamma
        loss = alpha_factor * focal_weight * BCE_loss

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss




# -----------------------
# TRAINING
# -----------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)


model = MultiSentenceClassifier(MODEL_NAME).to(DEVICE)


dataset = SentenceClassificationDataset(grouped_data, tokenizer, max_length=4096)

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)



pos = df["binary_relevance"].value_counts()[1]
neg = df["binary_relevance"].value_counts()[0]
pos_weight_value = (neg / pos)


pos_weight = torch.tensor([pos_weight_value], device=DEVICE)

print("Using pos weight value:", pos_weight_value)


criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)


# criterion = FocalLoss(alpha=0.65, gamma=2)



model.train()
for epoch in range(EPOCHS):
    total_loss = 0
    for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        sep_positions = [x.to(DEVICE) for x in batch["sep_positions"]]
        labels = [x.to(DEVICE).float() for x in batch["labels"]]

        optimizer.zero_grad()

        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            outputs = model(input_ids, attention_mask, sep_positions)

            sample_losses = []

            for pred, target in zip(outputs, labels):
                if pred.numel() == 0 or target.numel() == 0:
                    continue
                if pred.shape != target.shape:
                    min_len = min(pred.shape[0], target.shape[0])
                    pred = pred[:min_len]
                    target = target[:min_len]
                sample_losses.append(criterion(pred, target))


            if sample_losses:
                loss = torch.stack(sample_losses).mean()
            else:
                continue  # skip if nothing valid

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch + 1}/{EPOCHS} - Loss: {total_loss:.4f}")

    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for batch in dataloader:  # replace with val_dataloader when ready
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            sep_positions = [x.to(DEVICE) for x in batch["sep_positions"]]
            labels = [x.to(DEVICE).float() for x in batch["labels"]]

            outputs = model(input_ids, attention_mask, sep_positions)

            for pred, target in zip(outputs, labels):
                if pred.numel() == 0 or target.numel() == 0:
                    continue
                if pred.shape != target.shape:
                    min_len = min(pred.shape[0], target.shape[0])
                    pred = pred[:min_len]
                    target = target[:min_len]

                probs = torch.sigmoid(pred).cpu().numpy()
                binarized = (probs > 0.5).astype(int)

                all_preds.extend(binarized.tolist())
                all_targets.extend(target.cpu().numpy().tolist())

    print("\nEvaluation Metrics:")
    print(classification_report(all_targets, all_preds, digits=3))
    model.train()


# -----------------------
# SAVE
# -----------------------
torch.save(model.state_dict(), "/content/gdrive/MyDrive/qlora_outputs/fine_tuned_jina_v3_multisent.pt")


Use train, validation split


In [None]:
# import os
# os.environ["WANDB_DISABLED"] = "true"

# import pandas as pd
# import torch
# import torch.nn as nn
# from torch.utils.data import Dataset, DataLoader
# from transformers import AutoTokenizer
# from sentence_transformers import SentenceTransformer
# from sklearn.metrics import classification_report
# from sklearn.model_selection import train_test_split
# import numpy as np
# from tqdm import tqdm

# # -----------------------
# # CONFIG
# # -----------------------
# MODEL_NAME = "jinaai/jina-embeddings-v3"
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# BATCH_SIZE = 1
# EPOCHS = 6
# PATIENCE = 2  # for early stopping

# # -----------------------
# # DATA LOADING
# # -----------------------


# df["binary_relevance"] = df["relevance"].apply(lambda x: 1 if x.strip().lower() == "essential" else 0)

# grouped_data = []
# for case_id, group in df.groupby("case_id"):
#     question = group["question_generated"].iloc[0]
#     sentences = group["ref_excerpt"].tolist()
#     labels = group["binary_relevance"].tolist()
#     grouped_data.append({
#         "question": question,
#         "sentences": [s.strip() for s in sentences],
#         "labels": labels
#     })

# # -----------------------
# # DATASET CLASS
# # -----------------------
# class SentenceClassificationDataset(Dataset):
#     def __init__(self, data, tokenizer, max_length=512):
#         self.data = data
#         self.tokenizer = tokenizer
#         self.max_length = max_length

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         item = self.data[idx]
#         question = item["question"]
#         sentences = item["sentences"]
#         labels = item["labels"]
#         separator = "</s>"
#         text = question + f" {separator} " + f" {separator} ".join(sentences)

#         encoding = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length)
#         input_ids = encoding["input_ids"][0]
#         sep_token_id = tokenizer.convert_tokens_to_ids("</s>")
#         sep_positions = (input_ids == sep_token_id).nonzero(as_tuple=True)[0]

#         if len(sep_positions) == 0:
#             sep_positions = torch.tensor([1])
#             labels = [0]

#         labels = labels[:len(sep_positions)]

#         return {
#             "input_ids": input_ids,
#             "attention_mask": encoding["attention_mask"][0],
#             "sep_positions": sep_positions,
#             "labels": torch.tensor(labels)
#         }

# # -----------------------
# # MODEL
# # -----------------------
# class MultiSentenceClassifier(nn.Module):
#     def __init__(self, base_model_name):
#         super().__init__()
#         self.encoder = SentenceTransformer(
#             base_model_name,
#             trust_remote_code=True,
#             model_kwargs={"default_task": "classification", "lora_main_params_trainable": True})
#         self.encoder[0].default_task = "classification"
#         hidden_size = self.encoder.get_sentence_embedding_dimension()
#         self.classifier = nn.Linear(hidden_size, 1)

#     def forward(self, input_ids, attention_mask, sep_positions):
#         output = self.encoder[0].auto_model(input_ids=input_ids, attention_mask=attention_mask)
#         token_embeddings = output.last_hidden_state
#         preds = []
#         for i in range(input_ids.shape[0]):
#             sep_pos = sep_positions[i]
#             if len(sep_pos) == 0:
#                 continue
#             sentence_embs = token_embeddings[i, sep_pos, :].float()
#             logits = self.classifier(sentence_embs).squeeze(-1)
#             preds.append(logits)
#         return preds


# class FocalLoss(nn.Module):
#     def __init__(self, alpha=0.75, gamma=2.0, reduction='mean'):
#         super(FocalLoss, self).__init__()
#         self.alpha = alpha
#         self.gamma = gamma
#         self.reduction = reduction

#     def forward(self, inputs, targets):
#         BCE_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
#         probs = torch.sigmoid(inputs)
#         alpha_factor = targets * self.alpha + (1 - targets) * (1 - self.alpha)
#         focal_weight = (targets * (1 - probs) + (1 - targets) * probs) ** self.gamma
#         loss = alpha_factor * focal_weight * BCE_loss

#         if self.reduction == 'mean':
#             return loss.mean()
#         elif self.reduction == 'sum':
#             return loss.sum()
#         else:
#             return loss

# # -----------------------
# # TRAINING SETUP
# # -----------------------
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
# model = MultiSentenceClassifier(MODEL_NAME).to(DEVICE)

# train_data, val_data = train_test_split(grouped_data, test_size=0.2, random_state=42)

# train_dataset = SentenceClassificationDataset(train_data, tokenizer, max_length=4096)
# val_dataset = SentenceClassificationDataset(val_data, tokenizer, max_length=4096)

# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

# pos = df["binary_relevance"].value_counts()[1]
# neg = df["binary_relevance"].value_counts()[0]
# # pos_weight_value = (neg / pos)
# # pos_weight = torch.tensor([pos_weight_value], device=DEVICE)

# # print("Using pos weight value:", pos_weight_value)
# # criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

# criterion = FocalLoss(alpha=0.75, gamma=2)

# # -----------------------
# # TRAINING LOOP W/ EARLY STOPPING
# # -----------------------
# best_f1 = 0.0
# no_improve_epochs = 0

# for epoch in range(EPOCHS):
#     model.train()
#     total_loss = 0
#     for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
#         input_ids = batch["input_ids"].to(DEVICE)
#         attention_mask = batch["attention_mask"].to(DEVICE)
#         sep_positions = [x.to(DEVICE) for x in batch["sep_positions"]]
#         labels = [x.to(DEVICE).float() for x in batch["labels"]]

#         optimizer.zero_grad()
#         with torch.cuda.amp.autocast(dtype=torch.bfloat16):
#             outputs = model(input_ids, attention_mask, sep_positions)
#             sample_losses = []
#             for pred, target in zip(outputs, labels):
#                 if pred.numel() == 0 or target.numel() == 0:
#                     continue
#                 min_len = min(pred.shape[0], target.shape[0])
#                 sample_losses.append(criterion(pred[:min_len], target[:min_len]))
#             if sample_losses:
#                 loss = torch.stack(sample_losses).mean()
#             else:
#                 continue
#         loss.backward()
#         optimizer.step()
#         total_loss += loss.item()

#     print(f"Epoch {epoch + 1}/{EPOCHS} - Loss: {total_loss:.4f}")

#     # -----------------------
#     # VALIDATION
#     # -----------------------
#     model.eval()
#     all_preds, all_targets = [], []

#     with torch.no_grad():
#         for batch in val_loader:
#             input_ids = batch["input_ids"].to(DEVICE)
#             attention_mask = batch["attention_mask"].to(DEVICE)
#             sep_positions = [x.to(DEVICE) for x in batch["sep_positions"]]
#             labels = [x.to(DEVICE).float() for x in batch["labels"]]

#             outputs = model(input_ids, attention_mask, sep_positions)

#             for pred, target in zip(outputs, labels):
#                 if pred.numel() == 0 or target.numel() == 0:
#                     continue
#                 min_len = min(pred.shape[0], target.shape[0])
#                 probs = torch.sigmoid(pred[:min_len]).cpu().numpy()
#                 binarized = (probs > 0.5).astype(int)
#                 all_preds.extend(binarized.tolist())
#                 all_targets.extend(target[:min_len].cpu().numpy().tolist())

#     print("\nValidation Metrics:")
#     report = classification_report(all_targets, all_preds, digits=3, output_dict=True)
#     print(classification_report(all_targets, all_preds, digits=3))

#     f1 = report.get(1.0, report.get('1.0', {})).get('f1-score', 0.0)
#     if f1 > best_f1:
#         best_f1 = f1
#         no_improve_epochs = 0
#         torch.save(model.state_dict(), "/content/gdrive/MyDrive/qlora_outputs/best_model.pt")
#         print("New best model saved.")
#     else:
#         no_improve_epochs += 1
#         print(f"No improvement for {no_improve_epochs} epoch(s)")

#     if no_improve_epochs >= PATIENCE:
#         print("Early stopping triggered.")
#         break

Inference

In [None]:
# connect your Google Drive
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
root_dir = "/content/gdrive/My Drive"

Mounted at /content/gdrive


Evaluate the fine tuned model

Using Patient narrative

In [5]:
import torch
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from sklearn.metrics import classification_report
import pandas as pd
from tqdm import tqdm

# -----------------------
# CONFIG
# -----------------------
MODEL_PATH = "/content/gdrive/MyDrive/qlora_outputs/05-05_BCEWithLogitsLoss_5_epochs/fine_tuned_jina_v3_multisent.pt"
# MODEL_PATH = "/content/gdrive/MyDrive/qlora_outputs/fine_tuned_jina_v3_multisent.pt"
MODEL_NAME = "jinaai/jina-embeddings-v3"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LENGTH = 4096

# -----------------------
# LOAD TEST DATA
# -----------------------
df_test = pd.read_excel("/content/gdrive/MyDrive/dataset_excel_bionlp/merged_notes_cases.xlsx")

# question_column = 'clinician_question'
question_column = 'patient_narrative'
df_test = df_test.rename(columns={question_column: 'question_generated'})


df_test["ref_excerpt"] = df_test["ref_excerpt"].astype(str)
df_test["binary_relevance"] = df_test["relevance"].apply(lambda x: 1 if x.strip().lower() == "essential" else 0)

grouped_data = []
for case_id, group in df_test.groupby("case_id"):
    question = group["question_generated"].iloc[0]
    sentences = group["ref_excerpt"].tolist()
    labels = group["binary_relevance"].tolist()
    grouped_data.append({
        "question": question,
        "sentences": [s.strip() for s in sentences],
        "labels": labels
    })

# -----------------------
# TOKENIZER AND MODEL
# -----------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

class MultiSentenceClassifier(torch.nn.Module):
    def __init__(self, base_model_name):
        super().__init__()
        self.encoder = SentenceTransformer(
            base_model_name,
            trust_remote_code=True,
            model_kwargs={"default_task": "classification", "use_flash_attn": True}
        )
        hidden_size = self.encoder.get_sentence_embedding_dimension()
        self.classifier = torch.nn.Linear(hidden_size, 1)

    def forward(self, input_ids, attention_mask, sep_positions):
        output = self.encoder[0].auto_model(input_ids=input_ids, attention_mask=attention_mask)
        token_embeddings = output.last_hidden_state
        preds = []
        for i in range(input_ids.shape[0]):
            sep_pos = sep_positions[i]
            if len(sep_pos) == 0:
                continue
            sentence_embs = token_embeddings[i, sep_pos, :].float()  # ensure float32
            logits = self.classifier(sentence_embs).squeeze(-1)
            preds.append(logits)
        return preds

model = MultiSentenceClassifier(MODEL_NAME).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# -----------------------
# INFERENCE
# -----------------------


all_preds = []
all_targets = []

with torch.no_grad():
    for item in tqdm(grouped_data):
        question = item["question"]
        sentences = item["sentences"]
        true_labels = item["labels"]

        text = question + " </s> " + " </s> ".join(sentences)
        encoding = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=MAX_LENGTH)
        input_ids = encoding["input_ids"].to(DEVICE)
        attention_mask = encoding["attention_mask"].to(DEVICE)

        sent_token_id = tokenizer.convert_tokens_to_ids("</s>")
        sep_positions = (input_ids[0] == sent_token_id).nonzero(as_tuple=True)[0]
        if len(sep_positions) == 0:
            sep_positions = torch.tensor([1]).to(DEVICE)

        sep_positions = sep_positions.unsqueeze(0)

        outputs = model(input_ids, attention_mask, [sep_positions])
        if outputs:
            probs = torch.sigmoid(outputs[0]).cpu().numpy().reshape(-1)
            preds = (probs > 0.5).astype(int)
            # truncate to shortest length (safety check)
            min_len = min(len(probs), len(true_labels))
            probs = probs[:min_len]
            preds = preds[:min_len]
            labels = true_labels[:min_len]

            # accumulate
            all_preds.extend(preds.tolist())
            all_targets.extend(labels)

            # Optional: print per sentence
            for i, (sent, prob, pred, label) in enumerate(zip(sentences[:min_len], probs, preds, labels)):
                print(f"Q: {question}")
                print(f"SENT {i+1}: {sent}")
                #print(f" → Predicted: {pred}, Prob: {prob:.3f}, True: {label}\n")
                print(f" → Predicted: {pred}, True: {label}\n")

# -----------------------
# GLOBAL METRICS
# -----------------------
print("\nOverall Evaluation on Test Set:")
print(classification_report(all_targets, all_preds, digits=3))




# Raw metrics
precision, recall, f1, support = precision_recall_fscore_support(all_targets, all_preds, average='binary', pos_label=1)

print("\nBinary-Averaged Metrics (positive class = 1):")
print(f"Precision: {precision:.3f}")
print(f"Recall:    {recall:.3f}")
print(f"F1-score:  {f1:.3f}")
# print(f"Support:   {support}")

In [6]:
# -----------------------
# CREATE EXPORTABLE RESULTS
# -----------------------

rows = []

for case_id, item in tqdm(zip(df_test["case_id"].unique(), grouped_data), total=len(grouped_data), desc="Generating report"):
    question = item["question"]
    labels = item["labels"]
    sentences = item["sentences"]

    # Rerun encoding for the item
    text = question + " </s> " + " </s> ".join(sentences)
    encoding = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=MAX_LENGTH)
    input_ids = encoding["input_ids"].to(DEVICE)
    attention_mask = encoding["attention_mask"].to(DEVICE)

    sent_token_id = tokenizer.convert_tokens_to_ids("</s>")
    sep_positions = (input_ids[0] == sent_token_id).nonzero(as_tuple=True)[0]
    if len(sep_positions) == 0:
        sep_positions = torch.tensor([1]).to(DEVICE)
    sep_positions = sep_positions.unsqueeze(0)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask, [sep_positions])
        probs = torch.sigmoid(outputs[0]).cpu().numpy().reshape(-1)
        preds = (probs > 0.5).astype(int)

    min_len = min(len(preds), len(labels))
    cited_ids = [str(i) for i, p in enumerate(preds[:min_len]) if p == 1]
    gold_ids = [str(i) for i, l in enumerate(labels[:min_len]) if l == 1]

    rows.append({
        "case_id": case_id,
        "question": question,
        "cited_sentence_ids": ", ".join(cited_ids),
        "gold_essential_sentence_ids": ", ".join(gold_ids)
    })

# Create DataFrame
df_out = pd.DataFrame(rows)

# Save to Excel
output_path = "/content/gdrive/MyDrive/qlora_outputs/05-05_BCEWithLogitsLoss_5_epochs/prediction_outputs/jina_citation_results_fine_tuned_patient_narrative.xlsx"
df_out.to_excel(output_path, index=False)

print(f"\nSaved output to: {output_path}")

In [7]:
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support

# -----------------------
# CONFIG
# -----------------------
PREDICTION_FILE = "/content/gdrive/MyDrive/qlora_outputs/05-05_BCEWithLogitsLoss_5_epochs/prediction_outputs/jina_citation_results_fine_tuned_patient_narrative.xlsx"

# -----------------------
# LOAD
# -----------------------
df = pd.read_excel(PREDICTION_FILE)

y_true = []
y_pred = []

for _, row in df.iterrows():
    gold_ids = set(map(int, str(row["gold_essential_sentence_ids"]).split(","))) if pd.notna(row["gold_essential_sentence_ids"]) else set()
    pred_ids = set(map(int, str(row["cited_sentence_ids"]).split(","))) if pd.notna(row["cited_sentence_ids"]) else set()

    max_id = max(gold_ids.union(pred_ids)) if gold_ids or pred_ids else -1

    for i in range(max_id + 1):
        y_true.append(1 if i in gold_ids else 0)
        y_pred.append(1 if i in pred_ids else 0)

# -----------------------
# METRICS
# -----------------------
precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", pos_label=1)

print("Evaluation Results of fine tuned model on dev with patient narrative:")
print(f"Precision: {precision:.3f}")
print(f"Recall:    {recall:.3f}")
print(f"F1-score:  {f1:.3f}")

Using Clinician question

In [8]:
import torch
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from sklearn.metrics import classification_report
import pandas as pd
from tqdm import tqdm

# -----------------------
# CONFIG
# -----------------------
MODEL_PATH = "/content/gdrive/MyDrive/qlora_outputs/05-05_BCEWithLogitsLoss_5_epochs/fine_tuned_jina_v3_multisent.pt"
MODEL_NAME = "jinaai/jina-embeddings-v3"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LENGTH = 4096

# -----------------------
# LOAD TEST DATA
# -----------------------
df_test = pd.read_excel("/content/gdrive/MyDrive/dataset_excel_bionlp/merged_notes_cases.xlsx")

question_column = 'clinician_question'
# question_column = 'patient_narrative'
df_test = df_test.rename(columns={question_column: 'question_generated'})


df_test["ref_excerpt"] = df_test["ref_excerpt"].astype(str)
df_test["binary_relevance"] = df_test["relevance"].apply(lambda x: 1 if x.strip().lower() == "essential" else 0)

grouped_data = []
for case_id, group in df_test.groupby("case_id"):
    question = group["question_generated"].iloc[0]
    sentences = group["ref_excerpt"].tolist()
    labels = group["binary_relevance"].tolist()
    grouped_data.append({
        "question": question,
        "sentences": [s.strip() for s in sentences],
        "labels": labels
    })

# -----------------------
# TOKENIZER AND MODEL
# -----------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

class MultiSentenceClassifier(torch.nn.Module):
    def __init__(self, base_model_name):
        super().__init__()
        self.encoder = SentenceTransformer(
            base_model_name,
            trust_remote_code=True,
            model_kwargs={"default_task": "classification", "use_flash_attn": True}
        )
        hidden_size = self.encoder.get_sentence_embedding_dimension()
        self.classifier = torch.nn.Linear(hidden_size, 1)

    def forward(self, input_ids, attention_mask, sep_positions):
        output = self.encoder[0].auto_model(input_ids=input_ids, attention_mask=attention_mask)
        token_embeddings = output.last_hidden_state
        preds = []
        for i in range(input_ids.shape[0]):
            sep_pos = sep_positions[i]
            if len(sep_pos) == 0:
                continue
            sentence_embs = token_embeddings[i, sep_pos, :].float()  # ensure float32
            logits = self.classifier(sentence_embs).squeeze(-1)
            preds.append(logits)
        return preds

model = MultiSentenceClassifier(MODEL_NAME).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# -----------------------
# INFERENCE
# -----------------------


all_preds = []
all_targets = []

with torch.no_grad():
    for item in tqdm(grouped_data):
        question = item["question"]
        sentences = item["sentences"]
        true_labels = item["labels"]

        text = question + " </s> " + " </s> ".join(sentences)
        encoding = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=MAX_LENGTH)
        input_ids = encoding["input_ids"].to(DEVICE)
        attention_mask = encoding["attention_mask"].to(DEVICE)

        sent_token_id = tokenizer.convert_tokens_to_ids("</s>")
        sep_positions = (input_ids[0] == sent_token_id).nonzero(as_tuple=True)[0]
        if len(sep_positions) == 0:
            sep_positions = torch.tensor([1]).to(DEVICE)

        sep_positions = sep_positions.unsqueeze(0)

        outputs = model(input_ids, attention_mask, [sep_positions])
        if outputs:
            probs = torch.sigmoid(outputs[0]).cpu().numpy().reshape(-1)
            preds = (probs > 0.5).astype(int)
            # truncate to shortest length (safety check)
            min_len = min(len(probs), len(true_labels))
            probs = probs[:min_len]
            preds = preds[:min_len]
            labels = true_labels[:min_len]

            # accumulate
            all_preds.extend(preds.tolist())
            all_targets.extend(labels)

            # Optional: print per sentence
            for i, (sent, prob, pred, label) in enumerate(zip(sentences[:min_len], probs, preds, labels)):
                print(f"Q: {question}")
                print(f"SENT {i+1}: {sent}")
                #print(f" → Predicted: {pred}, Prob: {prob:.3f}, True: {label}\n")
                print(f" → Predicted: {pred}, True: {label}\n")

# -----------------------
# GLOBAL METRICS
# -----------------------
print("\nOverall Evaluation on Test Set:")
print(classification_report(all_targets, all_preds, digits=3))




# Raw metrics
precision, recall, f1, support = precision_recall_fscore_support(all_targets, all_preds, average='binary', pos_label=1)

print("\nBinary-Averaged Metrics")
print(f"Precision: {precision:.3f}")
print(f"Recall:    {recall:.3f}")
print(f"F1-score:  {f1:.3f}")
print(f"Support:   {support}")

In [9]:
# -----------------------
# CREATE EXPORTABLE RESULTS
# -----------------------

rows = []

for case_id, item in tqdm(zip(df_test["case_id"].unique(), grouped_data), total=len(grouped_data), desc="Generating report"):
    question = item["question"]
    labels = item["labels"]
    sentences = item["sentences"]

    # Rerun encoding for the item
    text = question + " </s> " + " </s> ".join(sentences)
    encoding = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=MAX_LENGTH)
    input_ids = encoding["input_ids"].to(DEVICE)
    attention_mask = encoding["attention_mask"].to(DEVICE)

    sent_token_id = tokenizer.convert_tokens_to_ids("</s>")
    sep_positions = (input_ids[0] == sent_token_id).nonzero(as_tuple=True)[0]
    if len(sep_positions) == 0:
        sep_positions = torch.tensor([1]).to(DEVICE)
    sep_positions = sep_positions.unsqueeze(0)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask, [sep_positions])
        probs = torch.sigmoid(outputs[0]).cpu().numpy().reshape(-1)
        preds = (probs > 0.5).astype(int)

    min_len = min(len(preds), len(labels))
    cited_ids = [str(i) for i, p in enumerate(preds[:min_len]) if p == 1]
    gold_ids = [str(i) for i, l in enumerate(labels[:min_len]) if l == 1]

    rows.append({
        "case_id": case_id,
        "question": question,
        "cited_sentence_ids": ", ".join(cited_ids),
        "gold_essential_sentence_ids": ", ".join(gold_ids)
    })

# Create DataFrame
df_out = pd.DataFrame(rows)

# Save to Excel
output_path = "/content/gdrive/MyDrive/qlora_outputs/05-05_BCEWithLogitsLoss_5_epochs/prediction_outputs/jina_citation_results_fine_tuned_clinician_question.xlsx"
df_out.to_excel(output_path, index=False)

print(f"\nSaved output to: {output_path}")

In [10]:
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support

# -----------------------
# CONFIG
# -----------------------
PREDICTION_FILE = "/content/gdrive/MyDrive/qlora_outputs/05-05_BCEWithLogitsLoss_5_epochs/prediction_outputs/jina_citation_results_fine_tuned_clinician_question.xlsx"

# -----------------------
# LOAD
# -----------------------
df = pd.read_excel(PREDICTION_FILE)

y_true = []
y_pred = []

for _, row in df.iterrows():
    gold_ids = set(map(int, str(row["gold_essential_sentence_ids"]).split(","))) if pd.notna(row["gold_essential_sentence_ids"]) else set()
    pred_ids = set(map(int, str(row["cited_sentence_ids"]).split(","))) if pd.notna(row["cited_sentence_ids"]) else set()

    max_id = max(gold_ids.union(pred_ids)) if gold_ids or pred_ids else -1

    for i in range(max_id + 1):
        y_true.append(1 if i in gold_ids else 0)
        y_pred.append(1 if i in pred_ids else 0)

# -----------------------
# METRICS
# -----------------------
precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", pos_label=1)

print("Evaluation Results of fine tuned model on dev with clinician question:")
print(f"Precision: {precision:.3f}")
print(f"Recall:    {recall:.3f}")
print(f"F1-score:  {f1:.3f}")

Evaluate the Untuned Model

Using Patient narrative

In [11]:
from sentence_transformers import SentenceTransformer, util
import torch
from sklearn.metrics import classification_report
import pandas as pd
from tqdm import tqdm

# Load pretrained model
MODEL_NAME = "jinaai/jina-embeddings-v3"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = SentenceTransformer(MODEL_NAME, trust_remote_code=True).to(DEVICE)

# Load test data
df_test = pd.read_excel("/content/gdrive/MyDrive/dataset_excel_bionlp/merged_notes_cases.xlsx")
# question_column = 'clinician_question'
question_column = 'patient_narrative'
df_test = df_test.rename(columns={question_column: 'question_generated'})
df_test["ref_excerpt"] = df_test["ref_excerpt"].astype(str)
df_test["binary_relevance"] = df_test["relevance"].apply(lambda x: 1 if x.strip().lower() == "essential" else 0)

# Group test cases
grouped_data = []
for case_id, group in df_test.groupby("case_id"):
    question = group["question_generated"].iloc[0]
    sentences = group["ref_excerpt"].tolist()
    labels = group["binary_relevance"].tolist()
    grouped_data.append({
        "question": question,
        "sentences": [s.strip() for s in sentences],
        "labels": labels
    })

# Inference
all_preds = []
all_targets = []

for item in tqdm(grouped_data):
    question = item["question"]
    sentences = item["sentences"]
    labels = item["labels"]

    # Encode question and sentences
    # q_emb = model.encode(question, convert_to_tensor=True, task="classification")
    # sent_embs = model.encode(sentences, convert_to_tensor=True, task="classification")
    q_emb = model.encode(question, convert_to_tensor=True)
    sent_embs = model.encode(sentences, convert_to_tensor=True)

    # Cosine similarity
    sims = util.cos_sim(q_emb, sent_embs)[0]  # shape: [#sentences]

    # Threshold (you may want to tune this threshold based on val set)
    threshold = 0.5
    preds = (sims > threshold).int().tolist()

    all_preds.extend(preds)
    all_targets.extend(labels)

    # Optional: print per sample
    for i, (sent, pred, label) in enumerate(zip(sentences, preds, labels)):
        print(f"Q: {question}")
        print(f"SENT {i+1}: {sent}")
        print(f" → Predicted: {pred}, True: {label}\n")

# Evaluation
print("\nBaseline (untuned) model evaluation:")
print(classification_report(all_targets, all_preds, digits=3))

In [12]:
# -----------------------
# CREATE EXPORTABLE RESULTS
# -----------------------

results = []

for case_id, item in zip(df_test["case_id"].unique(), tqdm(grouped_data, desc="Running inference", total=len(grouped_data))):
    question = item["question"]
    sentences = item["sentences"]
    labels = item["labels"]

    # Encode question and sentences
    q_emb = model.encode(question, convert_to_tensor=True)
    sent_embs = model.encode(sentences, convert_to_tensor=True)

    # Cosine similarity
    sims = util.cos_sim(q_emb, sent_embs)[0]
    threshold = 0.5
    preds = (sims > threshold).int().tolist()

    # Record sentence IDs (indices) that were predicted or truly essential
    cited_ids = [str(i) for i, p in enumerate(preds) if p == 1]
    gold_ids = [str(i) for i, g in enumerate(labels) if g == 1]

    results.append({
        "case_id": case_id,
        "question": question,
        "cited_sentence_ids": ",".join(cited_ids),
        "gold_essential_sentence_ids": ",".join(gold_ids),
    })

# Save to Excel
df_results = pd.DataFrame(results)
df_results.to_excel("/content/gdrive/MyDrive/qlora_outputs/05-05_BCEWithLogitsLoss_5_epochs/prediction_outputs/jina_citation_results_untuned_base_patient_narrative.xlsx", index=False)

In [13]:
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support

# -----------------------
# CONFIG
# -----------------------
PREDICTION_FILE = "/content/gdrive/MyDrive/qlora_outputs/05-05_BCEWithLogitsLoss_5_epochs/prediction_outputs/jina_citation_results_untuned_base_patient_narrative.xlsx"

# -----------------------
# LOAD
# -----------------------
df = pd.read_excel(PREDICTION_FILE)

y_true = []
y_pred = []

for _, row in df.iterrows():
    gold_ids = set(map(int, str(row["gold_essential_sentence_ids"]).split(","))) if pd.notna(row["gold_essential_sentence_ids"]) else set()
    pred_ids = set(map(int, str(row["cited_sentence_ids"]).split(","))) if pd.notna(row["cited_sentence_ids"]) else set()

    max_id = max(gold_ids.union(pred_ids)) if gold_ids or pred_ids else -1

    for i in range(max_id + 1):
        y_true.append(1 if i in gold_ids else 0)
        y_pred.append(1 if i in pred_ids else 0)

# -----------------------
# METRICS
# -----------------------
precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", pos_label=1)

print("Evaluation Results of fine tuned model on dev with patient narrative question:")
print(f"Precision: {precision:.3f}")
print(f"Recall:    {recall:.3f}")
print(f"F1-score:  {f1:.3f}")

Using Clinician question

In [14]:
from sentence_transformers import SentenceTransformer, util
import torch
from sklearn.metrics import classification_report
import pandas as pd
from tqdm import tqdm

# Load pretrained model
MODEL_NAME = "jinaai/jina-embeddings-v3"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = SentenceTransformer(MODEL_NAME, trust_remote_code=True).to(DEVICE)

# Load test data
df_test = pd.read_excel("/content/gdrive/MyDrive/dataset_excel_bionlp/merged_notes_cases.xlsx")
question_column = 'clinician_question'
# question_column = 'patient_narrative'
df_test = df_test.rename(columns={question_column: 'question_generated'})
df_test["ref_excerpt"] = df_test["ref_excerpt"].astype(str)
df_test["binary_relevance"] = df_test["relevance"].apply(lambda x: 1 if x.strip().lower() == "essential" else 0)

# Group test cases
grouped_data = []
for case_id, group in df_test.groupby("case_id"):
    question = group["question_generated"].iloc[0]
    sentences = group["ref_excerpt"].tolist()
    labels = group["binary_relevance"].tolist()
    grouped_data.append({
        "question": question,
        "sentences": [s.strip() for s in sentences],
        "labels": labels
    })

# Inference
all_preds = []
all_targets = []

for item in tqdm(grouped_data):
    question = item["question"]
    sentences = item["sentences"]
    labels = item["labels"]

    # Encode question and sentences
    q_emb = model.encode(question, convert_to_tensor=True)
    sent_embs = model.encode(sentences, convert_to_tensor=True)

    # Cosine similarity
    sims = util.cos_sim(q_emb, sent_embs)[0]  # shape: [#sentences]

    # Threshold (you may want to tune this threshold based on val set)
    threshold = 0.5
    preds = (sims > threshold).int().tolist()

    all_preds.extend(preds)
    all_targets.extend(labels)

    # Optional: print per sample
    for i, (sent, pred, label) in enumerate(zip(sentences, preds, labels)):
        print(f"Q: {question}")
        print(f"SENT {i+1}: {sent}")
        print(f" → Predicted: {pred}, True: {label}\n")

# Evaluation
print("\nBaseline (untuned) model evaluation:")
print(classification_report(all_targets, all_preds, digits=3))