In [1]:
# Run in a notebook cell (prefix with !)
!pip install -q tf-keras


In [2]:
import os, re, random
import numpy as np
import pandas as pd
import torch

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding
from transformers import EarlyStoppingCallback

PROJECT_ROOT = "/Users/anudeep/Documents/glaucoma_detection"
DATA_PATH = os.path.join(PROJECT_ROOT, "data", "clinical_notes.csv")
OUT_BASE = os.path.join(PROJECT_ROOT, "models")
os.makedirs(OUT_BASE, exist_ok=True)

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)


<torch._C.Generator at 0x129134e30>

In [3]:
df = pd.read_csv(DATA_PATH, low_memory=False)
# minimal checks
assert 'gpt4_summary' in df.columns and 'glaucoma' in df.columns

# clean summaries inplace
def _clean(s): return re.sub(r'\s+',' ', str(s).strip().lower())
df['gpt4_summary'] = df['gpt4_summary'].astype(str).apply(_clean)

# map labels yes/no -> 1/0 (assumes these values exist)
df['glaucoma'] = df['glaucoma'].map({'yes':1, 'no':0})
df = df.dropna(subset=['glaucoma']).reset_index(drop=True)
df['glaucoma'] = df['glaucoma'].astype(int)

print("Rows:", len(df), "Pos:", df['glaucoma'].sum())
df.head()

Rows: 10000 Pos: 5048


Unnamed: 0,age,gender,race,ethnicity,language,maritalstatus,note,gpt4_summary,glaucoma,use
0,56.56,female,black,non-hispanic,english,single,ms. PERSON is a 56 yo woman presenting to esta...,the 56 y/o female patient has optic nerve head...,1,training
1,53.91,female,white,non-hispanic,english,single,referred for evaluation of narrow angles ou #p...,patient was referred for narrow angle evaluati...,1,training
2,46.3,female,white,non-hispanic,english,single,1. left upper lid ptosis: occurred after botox...,"patient experienced ptosis, ear and eye pain, ...",0,training
3,66.52,male,white,non-hispanic,english,single,right plano +0.50 082 left LOCATION -0.50 83 a...,the patient has primary open angle glaucoma - ...,1,training
4,82.52,female,black,non-hispanic,english,divorced,in step. os with nonspecific peripheral defect...,the patient has nonspecific peripheral defects...,1,training


In [4]:
# === FIXED STRATIFIED SPLIT: ensure Asian, Black, White appear in test ===

# Create combined stratification key
df['strat_key'] = df['glaucoma'].astype(str) + "_" + df['race'].astype(str)

# First split: train_val vs test
train_val, test = train_test_split(
    df,
    test_size=0.15,
    stratify=df['strat_key'],
    random_state=SEED
)

# Second split: train vs val
train, val = train_test_split(
    train_val,
    test_size=0.1764706,  # â‰ˆ 0.15 overall val
    stratify=train_val['strat_key'],
    random_state=SEED
)

train = train.reset_index(drop=True)
val = val.reset_index(drop=True)
test = test.reset_index(drop=True)

print("Train/Val/Test:", len(train), len(val), len(test))
print("Races in test:", test['race'].value_counts())

Train/Val/Test: 6999 1501 1500
Races in test: race
white    1153
black     224
asian     123
Name: count, dtype: int64


In [5]:

MODEL_NAME = "nlpie/tiny-biobert"
print("MODEL_NAME =", MODEL_NAME)

MODEL_NAME = nlpie/tiny-biobert


In [6]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
MAX_LEN = 128


In [7]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding
import torch
from sklearn.metrics import roc_auc_score

In [8]:
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
device = "mps" if torch.backends.mps.is_available() else "cpu"
model.to(device)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpie/tiny-biobert and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 312, padding_idx=0)
      (position_embeddings): Embedding(512, 312)
      (token_type_embeddings): Embedding(2, 312)
      (LayerNorm): LayerNorm((312,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=312, out_features=312, bias=True)
              (key): Linear(in_features=312, out_features=312, bias=True)
              (value): Linear(in_features=312, out_features=312, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=312, out_features=312, bias=True)
              (LayerNorm): LayerNorm((312,), eps=1e-1

In [9]:
def tokenize(texts):
    return tokenizer(texts, truncation=True, padding='max_length', max_length=MAX_LEN)

train_enc = tokenize(train['gpt4_summary'].tolist())
val_enc   = tokenize(val['gpt4_summary'].tolist())
test_enc  = tokenize(test['gpt4_summary'].tolist())


In [10]:
class SimpleDS(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    def __len__(self): 
        return len(self.labels)
    def __getitem__(self, i):
        item = {k: torch.tensor(v[i]) for k,v in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[i], dtype=torch.long)  # Changed to long!
        return item

# Recreate datasets with new label type
train_dataset = SimpleDS(train_enc, train['glaucoma'].tolist())
val_dataset = SimpleDS(val_enc, val['glaucoma'].tolist())
test_dataset = SimpleDS(test_enc, test['glaucoma'].tolist())

In [11]:
#convert train_dataset to dataframe for inspection
inspection_df = pd.DataFrame({
    'input_ids': [item['input_ids'].tolist() for item in train_dataset],
    'attention_mask': [item['attention_mask'].tolist() for item in train_dataset],
    'labels': [item['labels'].item() for item in train_dataset]
})
inspection_df.head()

Unnamed: 0,input_ids,attention_mask,labels
0,"[101, 5351, 2786, 11534, 1112, 170, 176, 15554...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",0
1,"[101, 2623, 1884, 6602, 22259, 7589, 1114, 366...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",0
2,"[101, 5465, 194, 119, 184, 119, 2130, 1114, 22...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",1
3,"[101, 5351, 1110, 170, 176, 15554, 8178, 1161,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",0
4,"[101, 2588, 194, 119, 184, 119, 2130, 1114, 16...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",0


In [12]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
import numpy as np
import torch

def compute_metrics(pred):
    labels = pred.label_ids
    logits = pred.predictions

    # convert logits to a single prob-per-sample
    if getattr(logits, "ndim", None) == 2 and logits.shape[1] == 2:
        probs = torch.softmax(torch.tensor(logits), dim=1).numpy()[:, 1]
        preds = np.argmax(logits, axis=1)
    else:
        probs = torch.sigmoid(torch.tensor(np.asarray(logits).reshape(-1))).numpy()
        preds = (probs > 0.5).astype(int)

    acc = accuracy_score(labels, preds)
    p, r, f1, _ = precision_recall_fscore_support(labels, preds, average='binary', zero_division=0)

    try:
        auc = roc_auc_score(labels, probs)
    except Exception:
        auc = float("nan")

    return {"accuracy": acc, "precision": p, "recall": r, "f1": f1, "auc": auc}

In [16]:
MODEL_TAG = MODEL_NAME.split("/")[-1].replace(".", "_")
OUT_DIR_MODEL = os.path.join(OUT_BASE, f"transformer_{MODEL_TAG}")

In [23]:
# -------------------------
# Compute class weights and use WeightedTrainer
# -------------------------
from sklearn.utils.class_weight import compute_class_weight
import numpy as np, torch
from transformers import Trainer, TrainingArguments
from transformers import DataCollatorWithPadding as data_collator
data_collator = DataCollatorWithPadding(tokenizer)
# Compute weights
classes = np.unique(train['glaucoma'])
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=train['glaucoma'].values)
weight_tensor = torch.tensor(class_weights, dtype=torch.float).to(model.device)
print("classes:", classes, "class_weights:", class_weights)

# Safe WeightedTrainer
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fct = torch.nn.CrossEntropyLoss(weight=weight_tensor)
        loss = loss_fct(logits.view(-1, model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

# TrainingArgs (compatible)
training_args = TrainingArguments(
    output_dir=OUT_DIR_MODEL,
    num_train_epochs=6,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=32,
    save_total_limit=2,
    seed=SEED,
    logging_dir=os.path.join(OUT_DIR_MODEL, "logs"),
    report_to="none"
)

# Instantiate WeightedTrainer
trainer = WeightedTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

print("WeightedTrainer ready.")

classes: [0 1] class_weights: [1.00966532 0.99051797]
WeightedTrainer ready.


  trainer = WeightedTrainer(


In [24]:
import torch
print("train_dataset type:", type(train_dataset))
print("is torch Dataset:", isinstance(train_dataset, torch.utils.data.Dataset))

print("data_collator type:", type(data_collator))
print("data_collator callable?:", callable(data_collator))

# show a single item from train_dataset
try:
    item0 = train_dataset[0]
    print("train_dataset[0] type:", type(item0))
    if isinstance(item0, dict):
        print("keys:", list(item0.keys()))
    else:
        print(repr(item0)[:500])
except Exception as e:
    print("Error reading train_dataset[0]:", repr(e))

train_dataset type: <class '__main__.SimpleDS'>
is torch Dataset: True
data_collator type: <class 'transformers.data.data_collator.DataCollatorWithPadding'>
data_collator callable?: True
train_dataset[0] type: <class 'dict'>
keys: ['input_ids', 'token_type_ids', 'attention_mask', 'labels']


In [28]:
from torch.utils.data import DataLoader
dl = DataLoader(train_dataset, batch_size=training_args.per_device_train_batch_size, collate_fn=data_collator)
batch = next(iter(dl))
print(type(batch), list(batch.keys()) if isinstance(batch, dict) else repr(batch)[:200])

<class 'transformers.tokenization_utils_base.BatchEncoding'> {'input_ids': tensor([[  101,  5351,  2786,  ...,     0,     0,     0],
        [  101,  2623,  1884,  ...,     0,     0,     0],
        [  101,  5465,   194,  ...,     0,     0,     0],
        ...,


In [25]:
trainer.train()



Step,Training Loss
500,0.6328
1000,0.5749
1500,0.5299
2000,0.5118
2500,0.4824
3000,0.4496
3500,0.438
4000,0.3968
4500,0.3762
5000,0.3569




TrainOutput(global_step=5250, training_loss=0.46808145286923364, metrics={'train_runtime': 263.62, 'train_samples_per_second': 159.298, 'train_steps_per_second': 19.915, 'total_flos': 150537951857664.0, 'train_loss': 0.46808145286923364, 'epoch': 6.0})

In [29]:
pred_out = trainer.predict(test_dataset)
logits = pred_out.predictions

import numpy as np, torch
from sklearn.metrics import roc_auc_score, accuracy_score, precision_recall_fscore_support, confusion_matrix

# Convert logits
if logits.ndim == 2 and logits.shape[1] == 2:
    probs = torch.softmax(torch.tensor(logits), dim=1).numpy()[:, 1]
    preds = np.argmax(logits, axis=1)
else:
    probs = torch.sigmoid(torch.tensor(logits).reshape(-1)).numpy()
    preds = (probs > 0.5).astype(int)

y_true = np.array(test_dataset.labels)

auc = roc_auc_score(y_true, probs)
acc = accuracy_score(y_true, preds)
p, r, f1, _ = precision_recall_fscore_support(y_true, preds, average='binary')
tn, fp, fn, tp = confusion_matrix(y_true, preds).ravel()
sens = tp / (tp + fn)
spec = tn / (tn + fp)

print(f"AUC: {auc:.4f}")
print(f"Accuracy: {acc:.4f}")
print(f"Precision: {p:.4f}")
print(f"Recall/Sensitivity: {r:.4f}")
print(f"F1: {f1:.4f}")
print(f"Sensitivity: {sens:.4f}")
print(f"Specificity: {spec:.4f}")



AUC: 0.8156
Accuracy: 0.7347
Precision: 0.7442
Recall/Sensitivity: 0.7226
F1: 0.7332
Sensitivity: 0.7226
Specificity: 0.7470


In [26]:
# === PREDICT ON TEST SET & COMPUTE METRICS ===
import numpy as np
import torch
from sklearn.metrics import (
    roc_auc_score,
    accuracy_score,
    precision_recall_fscore_support,
    confusion_matrix,
)

# Run predictions (fast if already computed)
pred_out = trainer.predict(test_dataset)
logits = pred_out.predictions
print("pred_out.predictions.shape =", getattr(logits, "shape", None))

# Convert logits -> probabilities for positive class
if getattr(logits, "ndim", None) == 2 and logits.shape[1] == 2:
    probs = torch.softmax(torch.tensor(logits), dim=1).numpy()[:, 1]
    preds = np.argmax(logits, axis=1)
else:
    probs = torch.sigmoid(torch.tensor(np.asarray(logits).reshape(-1))).numpy()
    preds = (probs > 0.5).astype(int)

# Ground-truth labels from SimpleDS (or fallback to pandas test DataFrame)
if hasattr(test_dataset, "labels"):
    y_true = np.asarray(test_dataset.labels).reshape(-1)
else:
    y_true = np.asarray(test['glaucoma']).reshape(-1)

# Basic metrics
auc = roc_auc_score(y_true, probs) if len(np.unique(y_true)) > 1 else float('nan')
acc = accuracy_score(y_true, preds)
p, r, f1, _ = precision_recall_fscore_support(y_true, preds, average='binary', zero_division=0)

# Sensitivity (recall for positive class) and Specificity (recall for negative class)
tn, fp, fn, tp = confusion_matrix(y_true, preds).ravel()
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else float('nan')
specificity = tn / (tn + fp) if (tn + fp) > 0 else float('nan')

print("\n=== TEST METRICS ===")
print(f"AUC:         {auc:.4f}")
print(f"Accuracy:    {acc:.4f}")
print(f"Precision:   {p:.4f}")
print(f"Recall:      {r:.4f}")
print(f"F1:          {f1:.4f}")
print(f"Sensitivity: {sensitivity:.4f}")
print(f"Specificity: {specificity:.4f}")
print("Confusion matrix (tn, fp, fn, tp):", (tn, fp, fn, tp))



pred_out.predictions.shape = (1500, 2)

=== TEST METRICS ===
AUC:         0.8156
Accuracy:    0.7347
Precision:   0.7442
Recall:      0.7226
F1:          0.7332
Sensitivity: 0.7226
Specificity: 0.7470
Confusion matrix (tn, fp, fn, tp): (np.int64(555), np.int64(188), np.int64(210), np.int64(547))


In [31]:
import pandas as pd
from sklearn.metrics import roc_auc_score

df_test = test.reset_index(drop=True).copy()
df_test['prob_pos'] = probs
df_test['pred'] = preds

# Normalize to lowercase so comparisons always work
df_test['race_norm'] = df_test['race'].astype(str).str.strip().str.lower()

groups = ["asian", "black", "white"]   # use lowercase

print("\n=== AUC by Race Group ===")
for g in groups:
    mask = df_test['race_norm'] == g
    n = mask.sum()
    if n == 0:
        print(f"{g}: no samples in test set (n=0)")
        continue

    y_true_g = df_test.loc[mask, 'glaucoma'].astype(int).values
    probs_g = df_test.loc[mask, 'prob_pos'].values
    
    auc_g = roc_auc_score(y_true_g, probs_g) if len(set(y_true_g)) > 1 else float('nan')
    print(f"{g.capitalize()}: AUC = {auc_g:.4f}  (n={n})")


=== AUC by Race Group ===
Asian: AUC = 0.8267  (n=123)
Black: AUC = 0.7766  (n=224)
White: AUC = 0.8236  (n=1153)


In [32]:
import os, json
# OUT_DIR_MODEL should already be set (e.g. /.../models/transformer_tiny-biobert)
os.makedirs(OUT_DIR_MODEL, exist_ok=True)

# 1) Save model & tokenizer
trainer.save_model(OUT_DIR_MODEL)            # saves model + config
tokenizer.save_pretrained(OUT_DIR_MODEL)    # saves tokenizer files
print("Saved model & tokenizer to:", OUT_DIR_MODEL)

# 2) Save test predictions dataframe (df_test must have prob_pos, pred, glaucoma, race)
preds_csv = os.path.join(OUT_DIR_MODEL, "test_predictions_transformer.csv")
df_test.to_csv(preds_csv, index=False)
print("Saved test predictions to:", preds_csv)

# 3) Save metrics JSON (fill in with your computed values)
metrics = {
    "auc": float(auc),
    "accuracy": float(acc),
    "precision": float(p),
    "recall": float(r),
    "f1": float(f1),
    "sensitivity": float(sensitivity),
    "specificity": float(specificity),
    "auc_asian": float(df_test.loc[df_test['race_norm']=='asian','prob_pos'].pipe(lambda p: float('nan') if p.size==0 else roc_auc_score(df_test.loc[df_test['race_norm']=='asian','glaucoma'], p))),
    "auc_black": float(df_test.loc[df_test['race_norm']=='black','prob_pos'].pipe(lambda p: float('nan') if p.size==0 else roc_auc_score(df_test.loc[df_test['race_norm']=='black','glaucoma'], p))),
    "auc_white": float(df_test.loc[df_test['race_norm']=='white','prob_pos'].pipe(lambda p: float('nan') if p.size==0 else roc_auc_score(df_test.loc[df_test['race_norm']=='white','glaucoma'], p))),
}
with open(os.path.join(OUT_DIR_MODEL, "transformer_metrics.json"), "w") as f:
    json.dump(metrics, f, indent=2)
print("Saved metrics:", os.path.join(OUT_DIR_MODEL,"transformer_metrics.json"))

Saved model & tokenizer to: /Users/anudeep/Documents/glaucoma_detection/models/transformer_tiny-biobert
Saved test predictions to: /Users/anudeep/Documents/glaucoma_detection/models/transformer_tiny-biobert/test_predictions_transformer.csv
Saved metrics: /Users/anudeep/Documents/glaucoma_detection/models/transformer_tiny-biobert/transformer_metrics.json


In [33]:
# --- Minimal Inference Program for Glaucoma Detection ---

import torch
import numpy as np

def predict_text(text):
    """
    Input: any natural language text (string)
    Output: predicted_label (0/1), probability_glaucoma (0.0-1.0)
    Works with your trained Transformer model.
    """

    model.eval()  # set model to eval mode

    # tokenize
    inputs = tokenizer(
        text,
        truncation=True,
        padding='max_length',
        max_length=128,
        return_tensors='pt'
    )

    # move tensors to the right device
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    # 2-class softmax
    if logits.shape[-1] == 2:
        probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
        pred = int(np.argmax(probs))
        prob_glaucoma = float(probs[1])
    else:
        # 1-logit sigmoid case
        prob_glaucoma = float(torch.sigmoid(logits).cpu().numpy().reshape(-1)[0])
        pred = 1 if prob_glaucoma > 0.5 else 0

    return pred, prob_glaucoma


# --- Try some examples ---
examples = [
    "The patient has no issues. Vision normal. No glaucoma symptoms.",   # should be 0
    "Optic disc cupping noted with elevated IOP. Possible glaucoma.",     # should be 1
    "Patient complains of stomach pain only. No eye complaints.",         # should be 0
    "Severe optic nerve damage and high intraocular pressure.",           # should be 1
]

for text in examples:
    pred, prob = predict_text(text)
    print(f"\nInput: {text}")
    print(f"Prediction: {'Glaucoma' if pred==1 else 'No Glaucoma'}")
    print(f"Probability: {prob:.4f}")


Input: The patient has no issues. Vision normal. No glaucoma symptoms.
Prediction: No Glaucoma
Probability: 0.0094

Input: Optic disc cupping noted with elevated IOP. Possible glaucoma.
Prediction: Glaucoma
Probability: 0.8778

Input: Patient complains of stomach pain only. No eye complaints.
Prediction: No Glaucoma
Probability: 0.0113

Input: Severe optic nerve damage and high intraocular pressure.
Prediction: Glaucoma
Probability: 0.7794
