In [1]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import math
print(torch.cuda.is_available())

  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(


True


In [2]:
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

In [4]:
static = pd.read_csv('static-pre-process.csv')

In [5]:
from sklearn.preprocessing import MinMaxScaler
static_1 = static.copy()
scaler = MinMaxScaler(feature_range=(0, 1))
static_1['race_encoded'] = scaler.fit_transform(static_1[['race_encoded']])
static_1['icu_outtime'] = pd.to_datetime(static_1['icu_outtime'])
static_1['icu_intime'] = pd.to_datetime(static_1['icu_intime'])


# Calculate ICU stay duration in hours and keep only the total number of hours
#static_1['icu_hours'] = (static_1['icu_outtime'] - static_1['icu_intime']).dt.total_seconds() / 3600

static_1.drop(['admission_type','first_careunit'], axis=1, inplace=True)

static_1.sort_values(by='id', ascending=True, inplace=True)


In [6]:
static_2 = static_1.loc[:,'admission_age':'gender_encoded']
numpy_array = static_2.to_numpy()
static_data = torch.tensor(numpy_array, dtype=torch.float32)
static_data.shape

torch.Size([20414, 16])

In [7]:
ts = pd.read_csv('dynamic-pre-process.csv')
ts1 = ts.copy()

In [8]:
grouped = ts1.groupby('id')
tensor_list = []

for name, group in grouped:
    values = group.loc[:,'albumin':].values
    tensor_list.append(torch.tensor(values, dtype=torch.float))

In [9]:
lst_ts = [t.shape[0] for t in tensor_list]
average_length = sum(lst_ts) / len(lst_ts)
math.ceil(average_length)

5

In [10]:
ave_length = 5
num_features = len(ts1.columns) - list(ts1.columns).index('albumin')
final_tensors = []
for t in tensor_list:
    if t.shape[0] > ave_length:
        t = t[:ave_length, :]
    elif t.shape[0] < ave_length:
        padding_needed = ave_length - t.shape[0]
        padding = torch.zeros(padding_needed, num_features, dtype=t.dtype)
        t = torch.cat([t, padding], dim=0)
    final_tensors.append(t)

In [11]:
ts_data = torch.stack(final_tensors)
ts_data.shape

torch.Size([20414, 5, 68])

In [12]:
label_1 = static_1['los_icu']
labels = torch.tensor(label_1.values.reshape(-1, 1), dtype=torch.float32)
labels.shape

torch.Size([20414, 1])

In [13]:
text = pd.read_csv('cleaned_notes.csv')
text = text[['id','text']]
text['text'] = text['text'].str.replace('\n', ' ', regex=False)
text.sort_values(by='id', ascending=True, inplace=True)
text

Unnamed: 0,id,text
49326,20001305,unilat lower ext veins right year old woman w...
49325,20001305,with copd in resp distress intubated evaluate...
37633,20001361,renal ultrasound portable man with acute rena...
37629,20001361,chest xray dated none male with seizure and ...
37630,20001361,ct of the head dated yo male with overdose a...
...,...,...
2741,29999625,ct head wo contrast with pfossa bleed seen on...
2744,29999625,chest xray pmh of htn presented with medial r...
2743,29999625,chest portable ap pmh of htn presented with m...
2742,29999625,ct head wo contrast pmh of htn presented with...


In [14]:
text_merged = text.groupby('id')['text'].agg(' '.join).reset_index()
text_merged

Unnamed: 0,id,text
0,20001305,unilat lower ext veins right year old woman w...
1,20001361,renal ultrasound portable man with acute rena...
2,20001770,dx chest portable picc line placement year ol...
3,20002506,us renal artery doppler year old man with iph...
4,20003425,ct chest wo contrast year old man with htn hl...
...,...,...
20409,29997500,year old woman with poorly differentiated epi...
20410,29997616,history with asthma exacerbation preceding pr...
20411,29998399,pelvis ap inlet and outlet left acetabular fra...
20412,29999498,male with metastatic melanoma now with hyperc...


In [15]:
text_data = text_merged['text'].tolist()

In [16]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, static_data, ts_data, texts, labels):
        self.static_data = static_data
        self.ts_data = ts_data
        self.texts = texts
        self.labels = labels
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        text_tokens = self.tokenizer(self.texts[idx], return_tensors='pt', padding=True, truncation=True, max_length=128, add_special_tokens=True)
        return (
            self.static_data[idx],
            self.ts_data[idx],
            text_tokens['input_ids'].squeeze(0),
            text_tokens['attention_mask'].squeeze(0),
            self.labels[idx]
        )

In [17]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    static_data, ts_data, input_ids, attention_masks, labels = zip(*batch)

    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
    attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0)

    static_data = torch.stack(static_data)
    ts_data = torch.stack(ts_data)
    labels = torch.tensor(labels)

    return static_data, ts_data, input_ids, attention_masks, labels

In [18]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

(train_static, test_static, 
 train_ts, test_ts, 
 train_texts, test_texts, 
 train_labels, test_labels) = train_test_split(static_data, ts_data, text_data, labels, test_size=0.2, random_state=42)

train_dataset = CustomDataset(train_static, train_ts, train_texts, train_labels)
test_dataset = CustomDataset(test_static, test_ts, test_texts, test_labels)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

In [19]:
class MultiInputModel(nn.Module):
    def __init__(self):
        super(MultiInputModel, self).__init__()
        self.static_layer = nn.Sequential(
            nn.Linear(16, 16), # 16 features
            nn.ReLU()
        )
        self.lstm = nn.LSTM(input_size=68, hidden_size=16, batch_first=True) #68 features
        self.bert = bert_model
        # self.bert_to_hidden = nn.Linear(768, 256)
        self.fc = nn.Sequential(
            nn.Linear(16 + 16 + 768, 16),  # LSTM and static 32 dim，BERT output 768 dim
            # nn.Linear(32 + 64 + 256, 64)
            nn.ReLU(),
            nn.Linear(16, 1)
        )
    
    def forward(self, static_data, ts_data, input_ids, attention_mask):
        static_features = self.static_layer(static_data)
        _, (hidden, _) = self.lstm(ts_data)
        lstm_features = hidden[-1]
        text_features = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output
        # text_features = self.bert_to_hidden(text_features)
        combined_features = torch.cat([static_features, lstm_features, text_features], dim=1)
        output = self.fc(combined_features)
        return output

In [20]:
torch.cuda.empty_cache()

In [23]:
for i in [(16,0.0001),(32,0.0001)]:  
    class MultiInputModel(nn.Module):
        def __init__(self):
            super(MultiInputModel, self).__init__()
            self.static_layer = nn.Sequential(
                nn.Linear(16, i[0]), # 16 features
                nn.ReLU()
            )
            self.lstm = nn.LSTM(input_size=68, hidden_size=i[0], batch_first=True) #68 features
            self.bert = bert_model
            # self.bert_to_hidden = nn.Linear(768, 256)
            self.fc = nn.Sequential(
                nn.Linear(i[0] + i[0] + 768, i[0]),  # LSTM and static 32 dim，BERT output 768 dim
                # nn.Linear(32 + 64 + 256, 64)
                nn.ReLU(),
                nn.Linear(i[0], 1)
            )
        
        def forward(self, static_data, ts_data, input_ids, attention_mask):
            static_features = self.static_layer(static_data)
            _, (hidden, _) = self.lstm(ts_data)
            lstm_features = hidden[-1]
            text_features = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output
            # text_features = self.bert_to_hidden(text_features)
            combined_features = torch.cat([static_features, lstm_features, text_features], dim=1)
            output = self.fc(combined_features)
            return output

    torch.cuda.empty_cache()
    model = MultiInputModel()

    loss_function = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=i[1])
    num_epochs = 5

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    for epoch in range(num_epochs):
        for data in train_loader:
            static_data, ts_data, input_ids, attention_mask, labels = data

            static_data = static_data.to(device)
            ts_data = ts_data.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            output = model(static_data, ts_data, input_ids, attention_mask)
            labels = labels.unsqueeze(1)
            loss = loss_function(output, labels.float())
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f'Epoch {epoch+1}, Loss: {loss.item()}',i)
    model.eval()

    total_mse = 0
    num_samples = 0

    with torch.no_grad():
        for data in test_loader:
            static_data, ts_data, input_ids, attention_mask, labels = data

            static_data = static_data.to(device)
            ts_data = ts_data.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            output = model(static_data, ts_data, input_ids, attention_mask)
            labels = labels.unsqueeze(1)
            loss = loss_function(output, labels.float())

            total_mse += loss.item() * labels.size(0)
            num_samples += labels.size(0)

    average_mse = total_mse / num_samples
    print(f'Test MSE: {average_mse}',i)

Epoch 1, Loss: 0.002147998893633485 (16, 0.0001)
Epoch 2, Loss: 0.0015150908147916198 (16, 0.0001)
Epoch 3, Loss: 0.0005369987338781357 (16, 0.0001)
Epoch 4, Loss: 0.0033159013837575912 (16, 0.0001)
Epoch 5, Loss: 0.0020625744946300983 (16, 0.0001)
Test MSE: 0.004019141474485757 (16, 0.0001)
Epoch 1, Loss: 0.00019445612269919366 (32, 0.0001)
Epoch 2, Loss: 0.0010307560442015529 (32, 0.0001)
Epoch 3, Loss: 0.0049176146276295185 (32, 0.0001)
Epoch 4, Loss: 0.013914750888943672 (32, 0.0001)
Epoch 5, Loss: 0.0030324032995849848 (32, 0.0001)
Test MSE: 0.004314338567876787 (32, 0.0001)


In [22]:
model.eval()

total_mse = 0
num_samples = 0

with torch.no_grad():
    for data in test_loader:
        static_data, ts_data, input_ids, attention_mask, labels = data

        static_data = static_data.to(device)
        ts_data = ts_data.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        output = model(static_data, ts_data, input_ids, attention_mask)
        labels = labels.unsqueeze(1)
        loss = loss_function(output, labels.float())

        total_mse += loss.item() * labels.size(0)
        num_samples += labels.size(0)

average_mse = total_mse / num_samples
print(f'Test MSE: {average_mse}')

Test MSE: 0.003921828034032568
