In [1]:
import torch
import torch.nn as nn
import transformers
import shap
import numpy as np
import pickle

device = "cpu"


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# Load Resources
with open('../models/mlb_classes.pkl', 'rb') as f:
    classes = pickle.load(f)

tokenizer = transformers.AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

# Model Architecture
class PMSIModel(nn.Module):
    def __init__(self, n_classes):
        super(PMSIModel, self).__init__()
        self.bert = transformers.AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
        self.drop = nn.Dropout(0.3)
        self.out = nn.Linear(768, n_classes)

    def forward(self, ids, mask, token_type_ids):
        _, pooled_output = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids, return_dict=False)
        output = self.drop(pooled_output)
        return self.out(output)

model = PMSIModel(len(classes))
model.load_state_dict(torch.load('../models/pmsi_model_high_conf.bin', map_location=device))
model.to(device)
model.eval()

PMSIModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_aff

In [3]:
# Prediction Function
def predict_pmsi(text, top_k=3):
    inputs = tokenizer.encode_plus(
        text, None, add_special_tokens=True, max_length=128,
        padding='max_length', truncation=True, return_token_type_ids=True
    )

    ids = torch.tensor(inputs['input_ids'], dtype=torch.long).unsqueeze(0).to(device)
    mask = torch.tensor(inputs['attention_mask'], dtype=torch.long).unsqueeze(0).to(device)
    token_type_ids = torch.tensor(inputs['token_type_ids'], dtype=torch.long).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(ids, mask, token_type_ids)
    
    probs = torch.sigmoid(outputs).cpu().numpy()[0]
    
    predictions = list(zip(classes, probs))
    predictions.sort(key=lambda x: x[1], reverse=True)
    
    return predictions[:top_k]

# Test Prediction
sample_text = "Patient has chronic migraines and numbness in the fingers. MRI reveals a pinched nerve in the cervical spine."
print(f"Input: {sample_text}")
print(f"Prediction: {predict_pmsi(sample_text)}")

Input: Patient has chronic migraines and numbness in the fingers. MRI reveals a pinched nerve in the cervical spine.
Prediction: [('neurology', np.float32(0.9560129)), ('mri', np.float32(0.6861209)), ('neurosurgery', np.float32(0.67500013))]


  return forward_call(*args, **kwargs)


In [4]:

# SHAP Visualization
def predict_wrapper(texts):
    if isinstance(texts, np.ndarray): texts = texts.tolist()
    if isinstance(texts, str): texts = [texts]
        
    inputs = tokenizer.batch_encode_plus(
        texts, padding=True, truncation=True, max_length=128, return_tensors='pt'
    )
    ids = inputs['input_ids'].to(device)
    mask = inputs['attention_mask'].to(device)
    token_type_ids = inputs['token_type_ids'].to(device)

    with torch.no_grad():
        outputs = model(ids, mask, token_type_ids)
    return torch.sigmoid(outputs).cpu().numpy()

masker = shap.maskers.Text(tokenizer)
explainer = shap.Explainer(predict_wrapper, masker)

In [8]:
explain_text = ["Patient is a 45-year-old male presenting with severe abdominal pain and bloating. He has a history of hernia and is scheduled for surgery tomorrow."]
shap_values = explainer(explain_text)

# Plot for specific label
try:
    target_idx = np.where(classes == 'hernia')[0][0]
    print("Explaining class: hernia")
    shap.plots.text(shap_values[0, :, target_idx])
except:
    print("Label not found, showing top prediction.")
    shap.plots.text(shap_values[0, :, 0])

PartitionExplainer explainer: 2it [00:15, 15.69s/it]               

Explaining class: hernia



