In [21]:
import os
# Force CPU inference entirely
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# 0)
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, classification_report
from transformers import (
    AutoTokenizer,
    BertModel,
    BertPreTrainedModel,
    Trainer,
    TrainingArguments
)
from transformers.modeling_outputs import SequenceClassifierOutput
from tqdm import tqdm

# 1) 
df = pd.read_csv("SDOH_MIMICIII_physio_release.csv")
label_sums = df.drop(columns=["provider_type","patient_id","note_id","sentence_index","text"]).sum()
top4_labels = label_sums.sort_values(ascending=False).head(4).index.tolist()

df = df.reset_index(drop=True)  # keep all rows
df["text"] = df["text"].fillna("").astype(str)
df[top4_labels] = df[top4_labels].fillna(0).astype(int)

# 2) Train/val/test split: 70/15/15
train_texts, temp_texts, train_labels, temp_labels = train_test_split(
    df["text"], df[top4_labels], test_size=0.3, random_state=42
)
val_texts, test_texts, val_labels, test_labels = train_test_split(
    temp_texts, temp_labels, test_size=0.5, random_state=42
)

# 3) Tokenize
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
def tokenize(texts):
    return tokenizer(texts, truncation=True, padding=True, max_length=512)
train_enc = tokenize(train_texts.tolist())
val_enc   = tokenize(val_texts.tolist())
test_enc  = tokenize(test_texts.tolist())

# 4) 
class SDOHDataset(Dataset):
    def __init__(self, enc, labels):
        self.encodings = enc
        self.labels = labels.values.tolist()
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]) for k,v in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx]).float()
        return item

train_ds = SDOHDataset(train_enc, train_labels)
val_ds   = SDOHDataset(val_enc,   val_labels)
test_ds  = SDOHDataset(test_enc,  test_labels)

# 5) Custom BERT + Focal Loss head
class BertWithFocal(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert       = BertModel(config)
        self.dropout    = nn.Dropout(0.3)
        self.classifier = nn.Linear(config.hidden_size, len(top4_labels))
        self.init_weights()
    def forward(self, input_ids=None, attention_mask=None, labels=None):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = self.dropout(out.pooler_output)
        logits = self.classifier(pooled)
        loss   = None
        if labels is not None:
            bce   = F.binary_cross_entropy_with_logits(logits, labels, reduction="none")
            pt    = torch.exp(-bce)
            focal = ((1 - pt)**2 * bce).mean()
            loss  = focal
        return SequenceClassifierOutput(loss=loss, logits=logits)

# 6) 
model = BertWithFocal.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model.to("cpu")

# 7)
training_args = TrainingArguments(
    output_dir="./results_top4",
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_dir="./logs_top4",
    logging_steps=50,
    num_train_epochs=5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=2e-5,
    weight_decay=0.01,
    load_best_model_at_end=True,
)

# 8) 
def compute_metrics(p):
    logits, labels = p
    probs = torch.sigmoid(torch.tensor(logits)).numpy()
    preds = (probs >= 0.25).astype(int)
    return {
        "accuracy": accuracy_score(labels, preds),
        "micro_f1": f1_score(labels, preds, average="micro", zero_division=0),
        "macro_f1": f1_score(labels, preds, average="macro", zero_division=0),
    }

# 9) 
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
)

# 10) Train
trainer.train()

# 11) 
preds_out = trainer.predict(test_ds)
logits = preds_out.predictions
trues  = preds_out.label_ids
probs  = torch.sigmoid(torch.tensor(logits)).numpy()
preds  = (probs >= 0.25).astype(int)

# 12) 
df_out = pd.DataFrame(probs, columns=[f"{lbl}_prob" for lbl in top4_labels])
for i, lbl in enumerate(top4_labels):
    df_out[f"{lbl}_pred"] = preds[:, i]
    df_out[f"{lbl}_true"] = trues[:, i]
df_out["text"] = test_texts.reset_index(drop=True)
cols = ["text"] + \
       [f"{lbl}_prob" for lbl in top4_labels] + \
       [f"{lbl}_pred" for lbl in top4_labels] + \
       [f"{lbl}_true" for lbl in top4_labels]
df_out = df_out[cols]
out_path = "sdoh_test_predictions_export.csv"
df_out.to_csv(out_path, index=False)
print("Predictions exported to", out_path)

# 13) Final test metrics
print("\nTest Set Metrics:")
print(classification_report(trues, preds, target_names=top4_labels, zero_division=0))


Some weights of BertWithFocal were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Micro F1,Macro F1
1,0.0147,0.011583,0.893617,0.301205,0.188537
2,0.01,0.008872,0.891114,0.360465,0.439289
3,0.0042,0.010686,0.951189,0.52,0.434066
4,0.0011,0.013474,0.95995,0.568182,0.524456
5,0.003,0.013234,0.961202,0.561798,0.509343


Predictions exported to sdoh_test_predictions_export.csv

Test Set Metrics:
                      precision    recall  f1-score   support

RELATIONSHIP_married       0.59      0.95      0.73        21
        SUPPORT_plus       0.14      0.69      0.23        16
 EMPLOYMENT_employed       0.57      0.57      0.57         7
       SUPPORT_minus       0.00      0.00      0.00         3

           micro avg       0.28      0.74      0.40        47
           macro avg       0.33      0.55      0.38        47
        weighted avg       0.40      0.74      0.49        47
         samples avg       0.03      0.04      0.03        47

