In [1]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

model_path = "/Users/anudeep/Documents/glaucoma_detection/models/transformer_tiny-biobert"  # or the folder where you saved it
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
model.eval()


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 312, padding_idx=0)
      (position_embeddings): Embedding(512, 312)
      (token_type_embeddings): Embedding(2, 312)
      (LayerNorm): LayerNorm((312,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=312, out_features=312, bias=True)
              (key): Linear(in_features=312, out_features=312, bias=True)
              (value): Linear(in_features=312, out_features=312, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=312, out_features=312, bias=True)
              (LayerNorm): LayerNorm((312,), eps=1e-1

In [2]:
def predict_glaucoma(note: str):
    # Tokenize the input
    inputs = tokenizer(note, return_tensors="pt", truncation=True, padding=True, max_length=256)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Run inference
    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.softmax(outputs.logits, dim=-1).cpu().numpy()[0]

    pred_label = "Glaucoma Detected" if probs[1] > 0.5 else "No Glaucoma"
    confidence = probs[1] if probs[1] > 0.5 else probs[0]
    return pred_label, float(confidence)


In [4]:
sample1 = "Patient shows optic nerve thinning and increased cup-to-disc ratio with visual field loss."
sample2 = "no disease."

for text in [sample1, sample2]:
    label, conf = predict_glaucoma(text)
    print(f"Text: {text}\n→ {label} (Confidence: {conf:.3f})\n")


Text: Patient shows optic nerve thinning and increased cup-to-disc ratio with visual field loss.
→ Glaucoma Detected (Confidence: 0.723)

Text: no disease.
→ No Glaucoma (Confidence: 0.874)

