In [1]:
!pip install -q transformers

In [2]:
import torch
import sklearn
import numpy as np
import pandas as pd
from torch import nn
from sklearn import metrics
from tqdm.notebook import tqdm
from transformers import AdamW
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.utils.class_weight import compute_class_weight

In [3]:
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dev

device(type='cpu')

In [4]:
class PatDataset(Dataset):
    def __init__(self, px, dx, rx, y):
        self.px = px
        self.dx = dx
        self.rx = rx
        self.y = y
            
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.px[idx], self.dx[idx], self.rx[idx], self.y[idx]


class MyModel(nn.Module):
    def __init__(self, path=None):
        super().__init__()
        self.px_seq = nn.Sequential(
            nn.Linear(768, 128),
            nn.ReLU(),
            nn.LayerNorm(128, eps=1e-12, elementwise_affine=True)
            )
        self.dx_seq = nn.Sequential(
            nn.Linear(768, 128),
            nn.ReLU(),
            nn.LayerNorm(128, eps=1e-12, elementwise_affine=True)
            )
        self.rx_seq = nn.Sequential(
            nn.Linear(806, 128),
            nn.ReLU(),
            nn.LayerNorm(128, eps=1e-12, elementwise_affine=True)
            )
        self.concat_seq = nn.Sequential(
            nn.Linear(384, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 2),  
            )

    def forward(self, px_embed, dx_embed, rx_embed):
        px_features = self.px_seq(px_embed)
        dx_features = self.px_seq(dx_embed)
        rx_features = self.rx_seq(rx_embed)
        concat = torch.cat((px_features, dx_features, rx_features), dim=1)
        return self.concat_seq(concat)


class Trainer:
    def __init__(self, train_data, val_data):
        self.train_data = train_data
        self.val_data = val_data
        if torch.cuda.is_available():
            self.dev = torch.device('cuda')
        else:
            self.dev = torch.device('cpu')
        print("Using: ", self.dev)

    @staticmethod
    def _roc_auc(y_true, y_raw_logits, class_idx=1):
        y_score = F.softmax(y_raw_logits, dim=1).cpu().detach().numpy()[:, class_idx]
        return metrics.roc_auc_score(y_true.cpu().detach().numpy(), y_score)

    @staticmethod
    def _accuracy(true, pred_proba, class_idx=None):
        preds = torch.argmax(pred_proba, dim=1)
        if class_idx is None:
            score = (true == preds).float().mean()
        else:
            score = (true==preds)[true==class_idx].float().mean()
        return score

    @staticmethod
    def _get_optimizer_with_decay(model, lr):
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': 0.01, 'lr':lr},
            {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, 'lr':lr}
             ]
        return AdamW(optimizer_grouped_parameters, lr=lr)

    def _get_loss_func(self, class_wts=None):
        if class_wts is not None:
            loss_fn = torch.nn.CrossEntropyLoss(weight=torch.tensor(class_wts).type(torch.float).to(self.dev))
        else:
            loss_fn = torch.nn.CrossEntropyLoss()
        return loss_fn
    
    @staticmethod
    def get_class_wts(y):
        return compute_class_weight('balanced', np.unique(y), y)

    def train_model(self, model, n_epochs=1, lr=1e-5, class_wts=None,
                    path_to_save_model='./model.tar', logging_step=5,
                    patience=3):
        loss_fn = self._get_loss_func(class_wts=class_wts)
        optimizer = Trainer._get_optimizer_with_decay(model, lr)

        model.to(self.dev)
        trn_loss = []
        val_loss = []
        val_accu = []
        max_val_accu = 0
        count = 0
        for epoch in range(n_epochs):
            print(f"Epoch: {epoch}")
            trn_loss_per_epoch = []
            model.train()
            for i, batch in enumerate(tqdm(self.train_data)):
                batch = [x.to(self.dev) for x in batch]
                outputs = model(*batch[:-1])
                loss = loss_fn(outputs, batch[-1])
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                del batch
                torch.cuda.empty_cache()
                trn_loss_per_epoch.append(float(loss))
                if i % logging_step == 0:
                    print("train_loss: ", np.mean(trn_loss_per_epoch))
            trn_loss.append(np.mean(trn_loss_per_epoch))

            model.eval()
            valid_loss_per_epoch = []
            yhat_lst = []
            y_lst = []
            with torch.no_grad():
                for batch in tqdm(self.val_data):
                    batch = [x.to(self.dev) for x in batch]
                    outputs = model(*batch[:-1])
                    valid_loss_per_epoch.append(float(loss_fn(outputs, batch[-1])))
                    yhat_lst.append(outputs)
                    y_lst.append(batch[-1])   
                tmp = Trainer._roc_auc(torch.cat(y_lst, dim=0), torch.cat(yhat_lst, dim=0))
            val_loss.append(np.mean(valid_loss_per_epoch))
            val_accu.append(tmp)
            
            print(f"Training Loss for epoch {epoch}: ", trn_loss[-1])
            print("Validation Loss: ", val_loss[-1], "| Validation roc_auc: ", val_accu[-1])

            if val_accu[-1] > max_val_accu:
                count = 0
                max_val_accu = val_accu[-1]
                torch.save({
                            'epoch': epoch,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'training_loss': trn_loss[-1],
                            'val_loss': val_loss[-1],
                            'val_accuracy': val_accu[-1]
                            }, path_to_save_model)
                print(f"Checkpoint saved!")
            else:
                count = count + 1

            if count > patience:
                print("Stopping early, restoring best checkpoint..")
                model_parameters = torch.load(path_to_save_model, map_location = self.dev)
                model.load_state_dict(model_parameters['model_state_dict'])
                optimizer.load_state_dict(model_parameters['optimizer_state_dict'])
                break 
            print("Max roc_auc till now: ", max_val_accu)
        return model.eval(), optimizer

In [5]:
px_trn = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/trn_px_journey.pkl")
dx_trn = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/trn_dx_journey.pkl")
rx_trn = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/imp_rx_trn.pkl")
y_trn = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/y_trn.pkl")
rx_trn.shape

(23868, 806)

In [6]:
px_val = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/val_px_journey.pkl")
dx_val = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/val_dx_journey.pkl")
rx_val = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/imp_rx_val.pkl")
y_val = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/y_val.pkl")
rx_val.shape

(4212, 806)

In [7]:
px_test = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/test_px_journey.pkl")
dx_test = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/test_dx_journey.pkl")
rx_test = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/imp_rx_test.pkl")
y_test = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/y_test.pkl")
rx_test.shape

(7020, 806)

In [8]:
trn_ds = PatDataset(px_trn, dx_trn, rx_trn.values.astype(np.float32), y_trn['switch_flag'].tolist())
trn_dl = torch.utils.data.DataLoader(trn_ds, batch_size=512, num_workers=0, pin_memory=False)

In [9]:
val_ds = PatDataset(px_val, dx_val, rx_val.values.astype(np.float32), y_val['switch_flag'].tolist())
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=512, num_workers=0, pin_memory=False)

In [10]:
test_ds = PatDataset(px_test, dx_test, rx_test.values.astype(np.float32), y_test['switch_flag'].tolist())
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=512, num_workers=0, pin_memory=False)

In [11]:
class_wts = compute_class_weight('balanced', np.unique(y_trn['switch_flag'].values), y_trn['switch_flag'].values)
class_wts

array([ 0.51734004, 14.9175    ])

In [12]:
trainer = Trainer(trn_dl, val_dl)

Using:  cpu


In [13]:
model = MyModel()

model, optimizer = trainer.train_model(
    model, n_epochs=200, lr=1e-5, class_wts=class_wts,
    path_to_save_model='/content/drive/MyDrive/ColabData/saved_models/PatientBERT/only_top/sepsis-readmission-only-top.tar',
    logging_step=46, patience=10
    )

Epoch: 0


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6936256885528564
train_loss:  0.6909998870910482



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 0:  0.6909998870910482
Validation Loss:  0.686121490266588 | Validation roc_auc:  0.5982080482778205
Checkpoint saved!
Max roc_auc till now:  0.5982080482778205
Epoch: 1


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.684659481048584
train_loss:  0.6863901348824196



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 1:  0.6863901348824196
Validation Loss:  0.6826364133093092 | Validation roc_auc:  0.6201152939577814
Checkpoint saved!
Max roc_auc till now:  0.6201152939577814
Epoch: 2


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6825225949287415
train_loss:  0.6827570593103449



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 2:  0.6827570593103449
Validation Loss:  0.6797810329331292 | Validation roc_auc:  0.6333345528221583
Checkpoint saved!
Max roc_auc till now:  0.6333345528221583
Epoch: 3


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6807092428207397
train_loss:  0.6793078450446434



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 3:  0.6793078450446434
Validation Loss:  0.6765983700752258 | Validation roc_auc:  0.6447925213976735
Checkpoint saved!
Max roc_auc till now:  0.6447925213976735
Epoch: 4


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6786019802093506
train_loss:  0.6757895325092559



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 4:  0.6757895325092559
Validation Loss:  0.673493762811025 | Validation roc_auc:  0.6513429185154987
Checkpoint saved!
Max roc_auc till now:  0.6513429185154987
Epoch: 5


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6770586371421814
train_loss:  0.6721160551334949



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 5:  0.6721160551334949
Validation Loss:  0.6701368358400133 | Validation roc_auc:  0.6566651161737319
Checkpoint saved!
Max roc_auc till now:  0.6566651161737319
Epoch: 6


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.675102174282074
train_loss:  0.6681913999800987



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 6:  0.6681913999800987
Validation Loss:  0.6665861275460985 | Validation roc_auc:  0.6608479628439177
Checkpoint saved!
Max roc_auc till now:  0.6608479628439177
Epoch: 7


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6725721955299377
train_loss:  0.6640641841482608



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 7:  0.6640641841482608
Validation Loss:  0.6626862817340426 | Validation roc_auc:  0.6659088414681948
Checkpoint saved!
Max roc_auc till now:  0.6659088414681948
Epoch: 8


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6697610020637512
train_loss:  0.6596337277838524



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 8:  0.6596337277838524
Validation Loss:  0.6586917042732239 | Validation roc_auc:  0.6698869882284486
Checkpoint saved!
Max roc_auc till now:  0.6698869882284486
Epoch: 9


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6672568321228027
train_loss:  0.654855674885689



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 9:  0.654855674885689
Validation Loss:  0.6545776261223687 | Validation roc_auc:  0.6732815224795344
Checkpoint saved!
Max roc_auc till now:  0.6732815224795344
Epoch: 10


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6658682227134705
train_loss:  0.6498817040565166



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 10:  0.6498817040565166
Validation Loss:  0.6506586339738634 | Validation roc_auc:  0.6750375863877174
Checkpoint saved!
Max roc_auc till now:  0.6750375863877174
Epoch: 11


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6650028228759766
train_loss:  0.6447151937383286



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 11:  0.6447151937383286
Validation Loss:  0.6472687986161973 | Validation roc_auc:  0.6768964357825895
Checkpoint saved!
Max roc_auc till now:  0.6768964357825895
Epoch: 12


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6648613214492798
train_loss:  0.6398224082398922



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 12:  0.6398224082398922
Validation Loss:  0.6442424853642782 | Validation roc_auc:  0.6781742858586333
Checkpoint saved!
Max roc_auc till now:  0.6781742858586333
Epoch: 13


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6638906002044678
train_loss:  0.6349527899255144



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 13:  0.6349527899255144
Validation Loss:  0.6414221194055345 | Validation roc_auc:  0.6792134645503309
Checkpoint saved!
Max roc_auc till now:  0.6792134645503309
Epoch: 14


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6631170511245728
train_loss:  0.6303176296518204



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 14:  0.6303176296518204
Validation Loss:  0.6392895910474989 | Validation roc_auc:  0.6803562997921643
Checkpoint saved!
Max roc_auc till now:  0.6803562997921643
Epoch: 15


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6628766059875488
train_loss:  0.6261327634466455



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 15:  0.6261327634466455
Validation Loss:  0.6377530958917406 | Validation roc_auc:  0.6815548830945749
Checkpoint saved!
Max roc_auc till now:  0.6815548830945749
Epoch: 16


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6621907949447632
train_loss:  0.6221818327903748



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 16:  0.6221818327903748
Validation Loss:  0.636695663134257 | Validation roc_auc:  0.6827447557625201
Checkpoint saved!
Max roc_auc till now:  0.6827447557625201
Epoch: 17


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.660966157913208
train_loss:  0.6184694551407023



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 17:  0.6184694551407023
Validation Loss:  0.6356489194764031 | Validation roc_auc:  0.6836558881275795
Checkpoint saved!
Max roc_auc till now:  0.6836558881275795
Epoch: 18


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.660115659236908
train_loss:  0.6150659918785095



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 18:  0.6150659918785095
Validation Loss:  0.6351889769236246 | Validation roc_auc:  0.6844851405286658
Checkpoint saved!
Max roc_auc till now:  0.6844851405286658
Epoch: 19


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.659465491771698
train_loss:  0.6118239177034256



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 19:  0.6118239177034256
Validation Loss:  0.6345254778862 | Validation roc_auc:  0.6856610761814668
Checkpoint saved!
Max roc_auc till now:  0.6856610761814668
Epoch: 20


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6579400300979614
train_loss:  0.6084139131485148



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 20:  0.6084139131485148
Validation Loss:  0.6339683400260078 | Validation roc_auc:  0.6863962537303292
Checkpoint saved!
Max roc_auc till now:  0.6863962537303292
Epoch: 21


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6565806865692139
train_loss:  0.6052516914428548



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 21:  0.6052516914428548
Validation Loss:  0.6335206031799316 | Validation roc_auc:  0.6865547872775958
Checkpoint saved!
Max roc_auc till now:  0.6865547872775958
Epoch: 22


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6544660329818726
train_loss:  0.6019313614419166



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 22:  0.6019313614419166
Validation Loss:  0.6332721379068162 | Validation roc_auc:  0.6874345613585803
Checkpoint saved!
Max roc_auc till now:  0.6874345613585803
Epoch: 23


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6537126898765564
train_loss:  0.5988627355149452



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 23:  0.5988627355149452
Validation Loss:  0.6328087647755941 | Validation roc_auc:  0.6880355951366786
Checkpoint saved!
Max roc_auc till now:  0.6880355951366786
Epoch: 24


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6517958045005798
train_loss:  0.5955839499514154



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 24:  0.5955839499514154
Validation Loss:  0.632489456070794 | Validation roc_auc:  0.6885303591743015
Checkpoint saved!
Max roc_auc till now:  0.6885303591743015
Epoch: 25


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6496278047561646
train_loss:  0.5924460444044559



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 25:  0.5924460444044559
Validation Loss:  0.6321948568026224 | Validation roc_auc:  0.6892533418349126
Checkpoint saved!
Max roc_auc till now:  0.6892533418349126
Epoch: 26


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6476517915725708
train_loss:  0.5891230068308242



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 26:  0.5891230068308242
Validation Loss:  0.6320434676276313 | Validation roc_auc:  0.6897115212077817
Checkpoint saved!
Max roc_auc till now:  0.6897115212077817
Epoch: 27


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6456090211868286
train_loss:  0.5859269969006802



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 27:  0.5859269969006802
Validation Loss:  0.6320970522032844 | Validation roc_auc:  0.6899379977038768
Checkpoint saved!
Max roc_auc till now:  0.6899379977038768
Epoch: 28


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6420598030090332
train_loss:  0.582686648723927



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 28:  0.582686648723927
Validation Loss:  0.6317232847213745 | Validation roc_auc:  0.6902498384177307
Checkpoint saved!
Max roc_auc till now:  0.6902498384177307
Epoch: 29


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6394121050834656
train_loss:  0.5794325699197486



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 29:  0.5794325699197486
Validation Loss:  0.6310580902629428 | Validation roc_auc:  0.6912742090308374
Checkpoint saved!
Max roc_auc till now:  0.6912742090308374
Epoch: 30


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6355912685394287
train_loss:  0.576116848499217



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 30:  0.576116848499217
Validation Loss:  0.6315648025936551 | Validation roc_auc:  0.6907951241352518
Max roc_auc till now:  0.6912742090308374
Epoch: 31


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6327390074729919
train_loss:  0.572803092763779



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 31:  0.572803092763779
Validation Loss:  0.6311613056394789 | Validation roc_auc:  0.691615665901873
Checkpoint saved!
Max roc_auc till now:  0.691615665901873
Epoch: 32


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6300587058067322
train_loss:  0.5694560154955438



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 32:  0.5694560154955438
Validation Loss:  0.6310137112935384 | Validation roc_auc:  0.6920895244167795
Checkpoint saved!
Max roc_auc till now:  0.6920895244167795
Epoch: 33


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6269778609275818
train_loss:  0.5661578248155877



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 33:  0.5661578248155877
Validation Loss:  0.6308118436071608 | Validation roc_auc:  0.6925668671854721
Checkpoint saved!
Max roc_auc till now:  0.6925668671854721
Epoch: 34


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6235759258270264
train_loss:  0.5626212979884858



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 34:  0.5626212979884858
Validation Loss:  0.6310625010066562 | Validation roc_auc:  0.692876965772433
Checkpoint saved!
Max roc_auc till now:  0.692876965772433
Epoch: 35


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6208276748657227
train_loss:  0.5591192163051443



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 35:  0.5591192163051443
Validation Loss:  0.6307134495841132 | Validation roc_auc:  0.693737576457594
Checkpoint saved!
Max roc_auc till now:  0.693737576457594
Epoch: 36


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6161743402481079
train_loss:  0.5555380119922313



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 36:  0.5555380119922313
Validation Loss:  0.6313497755262587 | Validation roc_auc:  0.6942131770993936
Checkpoint saved!
Max roc_auc till now:  0.6942131770993936
Epoch: 37


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6138911843299866
train_loss:  0.5519047956517402



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 37:  0.5519047956517402
Validation Loss:  0.6309665375285678 | Validation roc_auc:  0.6949413861406837
Checkpoint saved!
Max roc_auc till now:  0.6949413861406837
Epoch: 38


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6084875464439392
train_loss:  0.5482926787214076



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 38:  0.5482926787214076
Validation Loss:  0.6322174668312073 | Validation roc_auc:  0.6947845947203104
Max roc_auc till now:  0.6949413861406837
Epoch: 39


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6058400869369507
train_loss:  0.5444498188952183



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 39:  0.5444498188952183
Validation Loss:  0.6315699153476291 | Validation roc_auc:  0.6958716819015663
Checkpoint saved!
Max roc_auc till now:  0.6958716819015663
Epoch: 40


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.6014986634254456
train_loss:  0.54064009798334



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 40:  0.54064009798334
Validation Loss:  0.6322428849008348 | Validation roc_auc:  0.6962828238483234
Checkpoint saved!
Max roc_auc till now:  0.6962828238483234
Epoch: 41


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.5971665382385254
train_loss:  0.5368218941891447



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 41:  0.5368218941891447
Validation Loss:  0.6328723960452609 | Validation roc_auc:  0.6962479813104627
Max roc_auc till now:  0.6962828238483234
Epoch: 42


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.5927321314811707
train_loss:  0.5329807996749878



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 42:  0.5329807996749878
Validation Loss:  0.6333517034848531 | Validation roc_auc:  0.6968612099768123
Checkpoint saved!
Max roc_auc till now:  0.6968612099768123
Epoch: 43


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.589512288570404
train_loss:  0.5291017629998795



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 43:  0.5291017629998795
Validation Loss:  0.6335777706570096 | Validation roc_auc:  0.6970963971073725
Checkpoint saved!
Max roc_auc till now:  0.6970963971073725
Epoch: 44


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.584200382232666
train_loss:  0.5250339102237782



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 44:  0.5250339102237782
Validation Loss:  0.6346431175867716 | Validation roc_auc:  0.6971364660259124
Checkpoint saved!
Max roc_auc till now:  0.6971364660259124
Epoch: 45


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.5804104804992676
train_loss:  0.5210051993106274



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 45:  0.5210051993106274
Validation Loss:  0.6355874008602567 | Validation roc_auc:  0.6977636317074063
Checkpoint saved!
Max roc_auc till now:  0.6977636317074063
Epoch: 46


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.5771164894104004
train_loss:  0.5169697454635133



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 46:  0.5169697454635133
Validation Loss:  0.6369869642787509 | Validation roc_auc:  0.6976294879366423
Max roc_auc till now:  0.6977636317074063
Epoch: 47


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.5714839696884155
train_loss:  0.5128760496352581



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 47:  0.5128760496352581
Validation Loss:  0.6381494601567587 | Validation roc_auc:  0.6980040452186456
Checkpoint saved!
Max roc_auc till now:  0.6980040452186456
Epoch: 48


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.5683454275131226
train_loss:  0.5086565582042045



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 48:  0.5086565582042045
Validation Loss:  0.6396566695637174 | Validation roc_auc:  0.6980040452186456
Max roc_auc till now:  0.6980040452186456
Epoch: 49


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.5638217926025391
train_loss:  0.5045857106117492



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 49:  0.5045857106117492
Validation Loss:  0.6408577362696329 | Validation roc_auc:  0.6983611812317185
Checkpoint saved!
Max roc_auc till now:  0.6983611812317185
Epoch: 50


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.5582558512687683
train_loss:  0.500363752562949



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 50:  0.500363752562949
Validation Loss:  0.643409808476766 | Validation roc_auc:  0.6984448033225844
Checkpoint saved!
Max roc_auc till now:  0.6984448033225844
Epoch: 51


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.5534646511077881
train_loss:  0.4961348359889172



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 51:  0.4961348359889172
Validation Loss:  0.6441397931840684 | Validation roc_auc:  0.6990684847502923
Checkpoint saved!
Max roc_auc till now:  0.6990684847502923
Epoch: 52


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.5487342476844788
train_loss:  0.4917113520997636



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 52:  0.4917113520997636
Validation Loss:  0.6469658878114488 | Validation roc_auc:  0.6990493213544688
Max roc_auc till now:  0.6990684847502923
Epoch: 53


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.5451478362083435
train_loss:  0.48742679585801796



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 53:  0.48742679585801796
Validation Loss:  0.648522244559394 | Validation roc_auc:  0.6991834651252328
Checkpoint saved!
Max roc_auc till now:  0.6991834651252328
Epoch: 54


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.5393233299255371
train_loss:  0.4831064318088775



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 54:  0.4831064318088775
Validation Loss:  0.6503896713256836 | Validation roc_auc:  0.6992688293429916
Checkpoint saved!
Max roc_auc till now:  0.6992688293429916
Epoch: 55


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.5345194339752197
train_loss:  0.4786018475573114



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 55:  0.4786018475573114
Validation Loss:  0.6534696420033773 | Validation roc_auc:  0.6990205762607337
Max roc_auc till now:  0.6992688293429916
Epoch: 56


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.5304856896400452
train_loss:  0.47426852900931177



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 56:  0.47426852900931177
Validation Loss:  0.6552195615238614 | Validation roc_auc:  0.6989831205325333
Max roc_auc till now:  0.6992688293429916
Epoch: 57


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.5250397324562073
train_loss:  0.4697988749818599



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 57:  0.4697988749818599
Validation Loss:  0.6583165725072225 | Validation roc_auc:  0.6993785833372531
Checkpoint saved!
Max roc_auc till now:  0.6993785833372531
Epoch: 58


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.520837664604187
train_loss:  0.4652393098841322



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 58:  0.4652393098841322
Validation Loss:  0.6608170337147183 | Validation roc_auc:  0.6990231894510732
Max roc_auc till now:  0.6993785833372531
Epoch: 59


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.5153020620346069
train_loss:  0.460659458916238



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 59:  0.460659458916238
Validation Loss:  0.6634984811147054 | Validation roc_auc:  0.699146880460479
Max roc_auc till now:  0.6993785833372531
Epoch: 60


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.5099965333938599
train_loss:  0.45615362677168336



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 60:  0.45615362677168336
Validation Loss:  0.6665212445788913 | Validation roc_auc:  0.6986887010876098
Max roc_auc till now:  0.6993785833372531
Epoch: 61


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.504601776599884
train_loss:  0.45164390827747103



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 61:  0.45164390827747103
Validation Loss:  0.6695487764146593 | Validation roc_auc:  0.6985040356369477
Max roc_auc till now:  0.6993785833372531
Epoch: 62


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.49942705035209656
train_loss:  0.44689582637015807



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 62:  0.44689582637015807
Validation Loss:  0.6727604468663534 | Validation roc_auc:  0.6982218110802754
Max roc_auc till now:  0.6993785833372531
Epoch: 63


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.49405980110168457
train_loss:  0.4423176024822479



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 63:  0.4423176024822479
Validation Loss:  0.6765694287088182 | Validation roc_auc:  0.6975336709575253
Max roc_auc till now:  0.6993785833372531
Epoch: 64


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.48898741602897644
train_loss:  0.4376910449342525



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 64:  0.4376910449342525
Validation Loss:  0.6810165908601549 | Validation roc_auc:  0.6975005705465575
Max roc_auc till now:  0.6993785833372531
Epoch: 65


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.4841696619987488
train_loss:  0.43304983986184953



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 65:  0.43304983986184953
Validation Loss:  0.6850747598542107 | Validation roc_auc:  0.6966469283689685
Max roc_auc till now:  0.6993785833372531
Epoch: 66


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.4795304238796234
train_loss:  0.42851033870210037



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 66:  0.42851033870210037
Validation Loss:  0.6880112820201449 | Validation roc_auc:  0.6965755011663539
Max roc_auc till now:  0.6993785833372531
Epoch: 67


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.4732394516468048
train_loss:  0.4237874207344461



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 67:  0.4237874207344461
Validation Loss:  0.6927308572663201 | Validation roc_auc:  0.6960389260832981
Max roc_auc till now:  0.6993785833372531
Epoch: 68


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

train_loss:  0.46867990493774414
train_loss:  0.4191458891046808



HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


Training Loss for epoch 68:  0.4191458891046808
Validation Loss:  0.6973021957609389 | Validation roc_auc:  0.6954396344320928
Stopping early, restoring best checkpoint..


In [14]:
model_parameters = torch.load(
    "/content/drive/MyDrive/ColabData/saved_models/PatientBERT/only_top/sepsis-readmission-only-top.tar",
    map_location = dev
    )
model.load_state_dict(model_parameters['model_state_dict'])
model = model.eval()

In [15]:
dl = trn_dl
yhat_lst = []
y_lst = []
with torch.no_grad():
    for batch in tqdm(dl):
        yhat = model(*batch[:-1])
        yhat_lst.append(yhat)
        y_lst += list(batch[-1].cpu().numpy())

print(Trainer._roc_auc(torch.tensor(y_lst), torch.cat(yhat_lst, dim=0)))

HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))


0.8742168805271372


In [16]:
dl = val_dl
yhat_lst = []
y_lst = []
with torch.no_grad():
    for batch in tqdm(dl):
        yhat = model(*batch[:-1])
        yhat_lst.append(yhat)
        y_lst += list(batch[-1].cpu().numpy())

print(Trainer._roc_auc(torch.tensor(y_lst), torch.cat(yhat_lst, dim=0)))

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))


0.6993785833372531


In [17]:
dl = test_dl
yhat_lst = []
y_lst = []
with torch.no_grad():
    for batch in tqdm(dl):
        yhat = model(*batch[:-1])
        yhat_lst.append(yhat)
        y_lst += list(batch[-1].cpu().numpy())

print(Trainer._roc_auc(torch.tensor(y_lst), torch.cat(yhat_lst, dim=0)))

HBox(children=(FloatProgress(value=0.0, max=14.0), HTML(value='')))


0.679872371783816
