In [86]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [136]:
import re
import os
import wandb
import string
import numpy as np
import unicodedata
import pandas as pd
from pathlib import Path
from datasets import load_dataset
from collections import defaultdict
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import flash_attn
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, TrainingArguments, Trainer, AutoModelForSequenceClassification
from sklearn.model_selection import train_test_split

from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score, precision_score, recall_score

import sys
sys.path.append("../../src")
import util.preprocessing_util as util

# Setup

In [137]:
DATA_DIR = Path("../../data/dev/processed")
DATASET_NAME = "medical_data.csv"

In [138]:
data = pd.read_csv(DATA_DIR / DATASET_NAME)
data.head()

Unnamed: 0,case_id,patient_question,note_excerpt,sentence_id,sentence_text,relevance,start_char_index,length
0,1,my question is if the sludge was there does no...,brief hospital course during the ercp a pancre...,0,brief hospital course,not-relevant,0,22
1,1,my question is if the sludge was there does no...,brief hospital course during the ercp a pancre...,1,during the ercp a pancreatic stent was require...,essential,0,243
2,1,my question is if the sludge was there does no...,brief hospital course during the ercp a pancre...,2,however due to the patients elevated inr no sp...,not-relevant,244,93
3,1,my question is if the sludge was there does no...,brief hospital course during the ercp a pancre...,3,frank pus was noted to be draining from the co...,not-relevant,338,151
4,1,my question is if the sludge was there does no...,brief hospital course during the ercp a pancre...,4,the vancomycin was discontinued,not-relevant,490,32


In [139]:
# Group by case_id and build sentence list + label list
def aggregate_case(group):
    sentences = group["sentence_text"].tolist()
    labels = [1 if rel in ["essential", "relevant"] else 0 for rel in group["relevance"]]
    return pd.Series({
        "question": group["patient_question"].iloc[0],
        "sentences": sentences,
        "labels": labels
    })

# Select only needed columns before grouping to silence the warning
data = (
    data[["case_id", "patient_question", "sentence_text", "relevance"]]
    .groupby("case_id")
    .apply(aggregate_case)
    .reset_index()
)

  .apply(aggregate_case)


In [140]:
data.head()

Unnamed: 0,case_id,question,sentences,labels
0,1,my question is if the sludge was there does no...,"[brief hospital course, during the ercp a panc...","[0, 1, 0, 0, 0, 1, 1, 1, 0]"
1,2,dad given multiple shots of lasciks after he w...,"[brief hospital course, acute diastolic heart ...","[0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0]"
2,3,he is continously irritated and has headache w...,[discharge instructions you were admitted to t...,"[0, 0, 0, 0, 1, 1, 0, 0, 0, 0]"
3,4,my doctor performed a cardiac catherization,"[history of present illness, on the cardiology...","[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,5,i overdosed october 4th on trihexyphenidyl tho...,"[brief hospital course, bipolar do ptsd schiz...","[0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, ..."


In [141]:
data.iloc[0].sentences

['brief hospital course',
 'during the ercp a pancreatic stent was required to facilitate access to the biliary system removed at the end of the procedure and a common bile duct stent was placed to allow drainage of the biliary obstruction caused by stones and sludge',
 'however due to the patients elevated inr no sphincterotomy or stone removal was performed',
 'frank pus was noted to be draining from the common bile duct and postercp it was recommended that the patient remain on iv zosyn for at least a week',
 'the vancomycin was discontinued',
 'on hospital day 4 postprocedure day 3 the patient returned to ercp for reevaluation of her biliary stent as her lfts and bilirubin continued an upward trend',
 'on ercp the previous biliary stent was noted to be acutely obstructed by biliary sludge and stones',
 'as the patients inr was normalized to 12 a sphincterotomy was safely performed with removal of several biliary stones in addition to the common bile duct stent',
 'at the conclusion

# Masking

In [142]:
WINDOW_SIZE = 1

In [143]:
test_df = util.mask_on_sentence_level(data, window=WINDOW_SIZE)

In [144]:
test_df.head()

Unnamed: 0,question,context,target_sentence,target_index,label
0,my question is if the sludge was there does no...,[START] brief hospital course [END]. during th...,brief hospital course,0,0
1,my question is if the sludge was there does no...,brief hospital course. [START] during the ercp...,during the ercp a pancreatic stent was require...,1,1
2,my question is if the sludge was there does no...,during the ercp a pancreatic stent was require...,however due to the patients elevated inr no sp...,2,0
3,my question is if the sludge was there does no...,however due to the patients elevated inr no sp...,frank pus was noted to be draining from the co...,3,0
4,my question is if the sludge was there does no...,frank pus was noted to be draining from the co...,the vancomycin was discontinued,4,0


# Prepare Dataset

In [145]:
BATCH_SIZE = 64
CONTEXT_LENGTH = 512

In [146]:
dataset_test = Dataset.from_pandas(test_df)

In [147]:
progress_bar = tqdm(total=(len(dataset_test)),
                    desc="Tokenizing", position=0, leave=True)

Tokenizing:   0%|          | 0/428 [00:00<?, ?it/s]

In [148]:
def tokenize_batch(batch):
    encodings = tokenizer(
        batch["question"],
        batch["context"],
        padding="max_length",
        truncation=True,
        max_length=CONTEXT_LENGTH,
        return_tensors="pt"
    )
    return {
        "input_ids": encodings["input_ids"].tolist(),
        "attention_mask": encodings["attention_mask"].tolist(),
        "labels": batch["label"]
    }

def tokenize_with_progress(batch):
    out = tokenize_batch(batch)
    progress_bar.update(len(batch["question"]))
    return out

In [149]:
tokenized_dataset_test = dataset_test.map(tokenize_with_progress, batched=True, batch_size=BATCH_SIZE)

Map:   0%|          | 0/428 [00:00<?, ? examples/s]

In [150]:
progress_bar.close()

In [151]:
tokenized_dataset_test.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

In [158]:
test_dataloader = DataLoader(tokenized_dataset_test, batch_size=BATCH_SIZE)

In [153]:
# Check one batch
batch = next(iter(test_dataloader))
print({key: value.shape for key, value in batch.items()})

{'input_ids': torch.Size([64, 512]), 'attention_mask': torch.Size([64, 512]), 'labels': torch.Size([64])}


In [154]:
print("----- Test Set -----")
print(tokenized_dataset_test)
print(tokenized_dataset_test.column_names)

----- Test Set -----
Dataset({
    features: ['question', 'context', 'target_sentence', 'target_index', 'label', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 428
})
['question', 'context', 'target_sentence', 'target_index', 'label', 'input_ids', 'attention_mask', 'labels']


# Model

In [155]:
model_dir = "../../models"

model = AutoModelForSequenceClassification.from_pretrained(model_dir, num_labels=2)
tokenizer = AutoTokenizer.from_pretrained(model_dir)

In [156]:
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device);

In [159]:
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(test_dataloader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=1)

        # Avoid CUDA-related errors
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

report = pd.DataFrame(classification_report(all_labels, all_preds, digits=4, output_dict=True)).transpose()

  0%|          | 0/7 [00:00<?, ?it/s]

In [160]:
display(report)

Unnamed: 0,precision,recall,f1-score,support
0,0.700521,0.927586,0.79822,290.0
1,0.522727,0.166667,0.252747,138.0
accuracy,0.682243,0.682243,0.682243,0.682243
macro avg,0.611624,0.547126,0.525483,428.0
weighted avg,0.643195,0.682243,0.622343,428.0
