In [1]:
import pickle
import numpy as np
import os
import gzip
from tqdm import tqdm
from data import load_mortality_dataset, pad_text_data, normalise_time
from torch.utils.data import DataLoader
from strats_text_model import load_Bert
import torch

# Data Check

In [2]:
with open('./mortality_mimic_3_benchmark/train_texts.pkl', 'rb') as f:
    text = pickle.load(f)
with open('./mortality_mimic_3_benchmark/train_text_times.pkl', 'rb') as f:
    time = pickle.load(f)

In [3]:
len(text)

14681

In [4]:
len(time)

14681

# Text Dataset

In [5]:
_, _, tokenizer = load_Bert(
    text_encoder_model = 'bioLongformer'
)

Some weights of the model checkpoint at yikuan8/Clinical-Longformer were not used when initializing LongformerModel: ['lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.decoder.bias', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing LongformerModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LongformerModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LongformerModel were not initialized from the model checkpoint at yikuan8/Clinical-Longformer and are newly initialized: ['longformer.pooler.dense.bias', 'longformer.pooler.dense.weight']
You should probably TRAIN this model on a dow

In [6]:
train, val, test, V, D = load_mortality_dataset(
    data_dir='./mortality_mimic_3_benchmark', 
    with_text=True, 
    tokenizer=tokenizer, 
    text_padding=True, 
    text_max_len=1024, 
    text_model='bioLongformer', 
    period_length=48, 
    num_notes=3,
    debug=True
)

100it [00:00, 179.60it/s]
100it [00:00, 208.71it/s]
100it [00:00, 222.33it/s]


In [7]:
train_dataloader = DataLoader(train, batch_size=2, collate_fn=pad_text_data)

In [8]:
X_demos, X_times, X_values, X_varis, Y, X_text_tokens, X_text_attention_mask, X_text_times, X_text_time_mask, X_text_feature_varis = iter(train_dataloader).next()

  X_text_times = pad_sequence([torch.tensor(time, dtype=torch.float) for time in X_text_times],batch_first=True,padding_value=0)
  X_text_time_mask = pad_sequence([torch.tensor(time_mask, dtype=torch.long) for time_mask in X_text_time_mask],batch_first=True,padding_value=0)


In [9]:
X_times.shape

torch.Size([2, 500])

In [10]:
X_times

tensor([[ 0.0667,  0.0667,  0.0667,  0.0667,  0.0667,  0.0667,  0.1500,  0.1500,
          0.1500,  0.1500,  0.1500,  0.1500,  0.2333,  0.2333,  0.2333,  0.2333,
          0.2333,  0.2333,  0.3167,  0.3167,  0.3167,  0.3167,  0.3167,  0.3167,
          0.4000,  0.4000,  0.4000,  0.4000,  0.4000,  0.4000,  0.4833,  0.4833,
          0.4833,  0.4833,  0.4833,  0.4833,  0.5167,  0.5167,  0.5167,  0.5167,
          0.5167,  0.5167,  0.5667,  0.5667,  0.5667,  0.5667,  0.5667,  0.5667,
          0.6500,  0.6500,  0.6500,  0.6500,  0.6500,  0.6500,  0.7333,  0.7333,
          0.7333,  0.7333,  0.7333,  0.7333,  0.8167,  0.8167,  0.8167,  0.8167,
          0.8167,  0.8167,  0.9000,  0.9000,  0.9000,  0.9000,  0.9000,  0.9000,
          0.9833,  0.9833,  0.9833,  0.9833,  0.9833,  0.9833,  1.0667,  1.0667,
          1.0667,  1.0667,  1.0667,  1.0667,  1.1500,  1.1500,  1.1500,  1.1500,
          1.1500,  1.1500,  1.2333,  1.2333,  1.2333,  1.2333,  1.2333,  1.2333,
          1.2333,  1.2333,  

In [11]:
X_text_times

tensor([[ 0.0000,  0.0000,  0.0000],
        [22.6000, 34.8167, 46.1333]])

In [12]:
X_text_times.shape

torch.Size([2, 3])

In [13]:
X_text_tokens.shape

torch.Size([2, 3, 1024])

In [14]:
X_text_feature_varis.shape

torch.Size([2, 3])

In [15]:
X_varis.shape

torch.Size([2, 500])

In [16]:
Y

tensor([1., 0.])

In [17]:
X_text_tokens.shape

torch.Size([2, 3, 1024])

In [18]:
X_text_attention_mask.shape

torch.Size([2, 3, 1024])

In [19]:
X_text_tokens

tensor([[[    0,     0,     0,  ...,     0,     0,     0],
         [    0,     0,     0,  ...,     0,     0,     0],
         [    0,     0,     0,  ...,     0,     0,     0]],

        [[    0,   282, 35857,  ...,     1,     1,     1],
         [    0,    90,    73,  ...,     1,     1,     1],
         [    0,  7048,   575,  ...,     1,     1,     1]]])

# Numerical Dataset

In [2]:
train, val, test, V, D = load_mortality_dataset(
    data_dir='./mortality_datasets', 
    with_text=False, 
    period_length=48, 
    debug=False
)

train_dataloader = DataLoader(train, batch_size=4, collate_fn=normalise_time)

# Model Check

In [3]:
from orig_model import STraTS
from new_model import custom_STraTS

# text_model, config, tokenizer = load_Bert(
#     text_encoder_model = 'bioLongformer'
# )

# model = STraTS(
#     D=D, # No. of static variables
#     V=V+1, # No. of variables / features
#     d=64, # Input size of attention layer
#     N=2, # No. of Encoder blocks
#     he=4, # No. of heads in multi headed encoder blocks
#     dropout=0, 
#     with_text=False,
#     forecast=False, 
#     return_embeddings=False
# )

model = custom_STraTS(
    D=D,
    V=V,
    d=64,
    N=2,
    he=4,
    dropout=0.1,
    time_2_vec=False,
)

print(model)

total_params = sum(p.numel() for p in model.parameters())
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad==True)
print(f'Total number of parameters: {total_params}, Total trainable parameters: {total_trainable_params}')

custom_STraTS(
  (values_stack): CVE(
    (stack): Sequential(
      (0): Linear(in_features=1, out_features=8, bias=True)
      (1): Tanh()
      (2): Linear(in_features=8, out_features=64, bias=False)
    )
  )
  (times_stack): CVE(
    (stack): Sequential(
      (0): Linear(in_features=1, out_features=8, bias=True)
      (1): Tanh()
      (2): Linear(in_features=8, out_features=64, bias=False)
    )
  )
  (varis_stack): Embedding(130, 64)
  (mTAND): MultiTimeAttention(
    (linears): ModuleList(
      (0-1): 2 x Linear(in_features=64, out_features=64, bias=True)
      (2): Linear(in_features=256, out_features=64, bias=True)
    )
  )
  (CTE): STraTS_Transformer(
    (dropout_layer): Dropout(p=0.1, inplace=False)
    (identity): Identity()
  )
  (atten_stack): Attention(
    (stack): Sequential(
      (0): Linear(in_features=64, out_features=128, bias=True)
      (1): Tanh()
      (2): Linear(in_features=128, out_features=1, bias=False)
    )
    (softmax): Softmax(dim=-2)
  )
  (dem

In [4]:
# from accelerate import Accelerator

# accelerator = Accelerator()
# accelerator.device

In [5]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

from utils import mortality_loss
loss_fn = mortality_loss

In [6]:
# model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)

In [7]:
batch = next(iter(train_dataloader))

# X_demos, X_times, X_values, X_varis, Y, X_text_tokens, X_text_attention_mask, X_text_times, X_text_time_mask, X_text_feature_varis = iter(train_dataloader).next()
# Y_pred = model(X_demos, X_times, X_values, X_varis, X_text_tokens, X_text_attention_mask, X_text_times, X_text_feature_varis)

X_demos, X_times, X_values, X_varis, Y = batch
Y_pred = model(X_demos, X_times, X_values, X_varis)

4
tensor([[ 1.9333, 13.9667,  8.8333,  ...,  0.0000,  0.0000,  0.0000],
        [16.9000, 10.9000, 11.9000,  ...,  0.0000,  0.0000,  0.0000],
        [ 6.9000,  7.9000,  6.4000,  ...,  0.0000,  0.0000,  0.0000],
        [17.7833, 20.7833, 19.7833,  ...,  0.0000,  0.0000,  0.0000]])
ts_varis_emb: torch.Size([4, 880, 64])
ts_values_emb: torch.Size([4, 880, 64])
ts_times_emb: torch.Size([4, 880, 64])
mask: torch.Size([4, 880])
 MultiTime Attention query: torch.Size([4, 4, 880, 16])
 MultiTime Attention key: torch.Size([4, 4, 880, 16])
 MultiTime Attention value: torch.Size([4, 1, 880, 64])
 MultiTime Attention scores: torch.Size([4, 4, 880, 880])
 MultiTime Attention scores: torch.Size([4, 4, 880, 880, 1])
 MultiTime Attention mask: torch.Size([4, 1, 880, 1])
 MultiTime Attention p_attn: torch.Size([4, 4, 880, 880, 1])
 MultiTime Attention value.unsqueeze(-3): torch.Size([4, 1, 1, 880, 64])
 MultiTime Attention output: torch.Size([4, 4, 880, 64])
 MultiTime Attention output: torch.Size([4

In [8]:
loss = loss_fn(Y, Y_pred)
loss

tensor(0.7143, grad_fn=<BinaryCrossEntropyBackward0>)

In [9]:
optimizer.step()
optimizer.zero_grad()

In [10]:
Y_pred

tensor([0.5504, 0.5166, 0.5275, 0.5432], grad_fn=<ViewBackward0>)

In [11]:
Y

tensor([1., 0., 0., 0.])