In [None]:
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
from torch import nn
from transformers import AutoTokenizer, AutoModel

import sys
sys.path.append('./src')

from models import (
    Classifier,
    ClassifierConfig,
    ClinicalEncoder,
    ClinicalEncoderConfig,
    TextEncoder,
    TextEncoderConfig,
    MultimodalEncoder,
    MultimodalConfig,
    freeze_model,
    unfreeze_model,
)

In [None]:
cln_enc_cfg = ClinicalEncoderConfig(
    feature_names=['HR', 'RR', 'SBP', 'DBP', 'SpO2', 'GCS'],
    hidden_size=768,
    num_hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=3072,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=512,
    initializer_range=0.02,
    type_vocab_size=1, # no multi-sentence type for clinical encoder
)

# txt_enc_cfg = TextEncoderConfig(
#     vocab_size=30522,
#     hidden_size=768,
#     num_hidden_layers=12,
#     num_attention_heads=12,
#     intermediate_size=3072,
#     hidden_act="gelu",
#     hidden_dropout_prob=0.1,
#     attention_probs_dropout_prob=0.1,
#     max_position_embeddings=512,
#     initializer_range=0.02,
#     type_vocab_size=1, # no multi-sentence type for clinical encoder
# )

mm_enc_cfg = MultimodalConfig(
    hidden_size=768,
    initializer_range=0.02,
    cln_enc_cfg=cln_enc_cfg,
    # txt_enc_cfg=txt_enc_cfg,
)

cls_cfg = ClassifierConfig(
    num_labels=2,
    hidden_size=768,
    initializer_range=0.02,
)

In [None]:
all_records = pd.read_csv('/mnt/data1/mimic/iii/aligned/all_records.csv')
batch_records = all_records.sample(n=64, random_state=42)
clinical_features = []
notes = []
for i, row in tqdm(batch_records.iterrows()):
    subj = row['SUBJECT_ID']
    hadm = row['HADM_ID']
    note = row['NOTE_NUM']
    x = pd.read_csv(f'/mnt/data1/mimic/iii/aligned/feats/{subj}-{hadm}-{note}.csv')[cln_enc_cfg.feature_names].to_numpy()
    note = pd.read_csv(f'/mnt/data1/mimic/iii/aligned/notes/{subj}-{hadm}-{note}.csv').iloc[0, 0]
    clinical_features.append(x)
    notes.append(note)
clinical_features = np.stack(clinical_features)
mortality = batch_records['60D_MORTALITY'].to_numpy()

In [None]:
mpath = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
tok = AutoTokenizer.from_pretrained(mpath)

In [None]:
enc = MultimodalEncoder(mm_enc_cfg)
enc.text_from_pretrained(mpath)
cls = Classifier(cls_cfg)

enc.to("cuda")
cls.to("cuda")
clinical_features = torch.as_tensor(clinical_features, dtype=torch.float32, device='cuda')
mortality = torch.as_tensor(mortality, device='cuda')
notes_tok = tok(notes, return_tensors="pt", padding="max_length", truncation=True, max_length=512).to("cuda")

opt = torch.optim.Adam(nn.ModuleList([enc, cls]).parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

In [None]:
for i in range(4):
    enc_out = enc(
        clinical_features=clinical_features,
        input_ids=notes_tok["input_ids"],
        attention_mask=notes_tok["attention_mask"],
    )
    logits = cls(enc_out)
    loss = loss_fn(logits, mortality)
    loss.backward()
    opt.step()
    opt.zero_grad()
    print(loss.detach().cpu())

In [None]:
# def train(X, y):
#     cln_model = ClinicalEncoder(cln_enc_cfg, feat_names)
#     opt = torch.optim.Adam(cln_model.parameters(), lr=0.01)
#     loss_fn = nn.CrossEntropyLoss()