In [1]:
!pip install -q transformers

In [2]:
import torch
import sklearn
import numpy as np
import pandas as pd
from torch import nn
from sklearn import metrics
from tqdm.notebook import tqdm
from transformers import AdamW
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.utils.class_weight import compute_class_weight

In [3]:
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dev

device(type='cpu')

In [4]:
class PatDataset(Dataset):
    def __init__(self, px, dx, rx, static_features, y):
        self.px = px
        self.dx = dx
        self.rx = rx
        self.s = static_features
        self.y = y
            
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.px[idx], self.dx[idx], self.rx[idx], self.s[idx], self.y[idx]


class MyModel(nn.Module):
    def __init__(self, path=None):
        super().__init__()
    
        self.px_seq = nn.Sequential(
            nn.Linear(768, 128),
            nn.ReLU(),
            nn.LayerNorm(128, eps=1e-12, elementwise_affine=True)
            )
        self.dx_seq = nn.Sequential(
            nn.Linear(768, 128),
            nn.ReLU(),
            nn.LayerNorm(128, eps=1e-12, elementwise_affine=True)
            )
        self.rx_seq = nn.Sequential(
            nn.Linear(806, 128),
            nn.ReLU(),
            nn.LayerNorm(128, eps=1e-12, elementwise_affine=True)
            )
        self.static_seq = nn.Sequential(
            nn.Linear(182, 64),
            nn.ReLU(),
            nn.LayerNorm(64, eps=1e-12, elementwise_affine=True)
            )
        self.concat_seq = nn.Sequential(
            nn.Linear(448, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 2),  
            )

    def forward(self, px_embed, dx_embed, rx_embed, static_embed):
        px_features = self.px_seq(px_embed)
        dx_features = self.px_seq(dx_embed)
        rx_features = self.rx_seq(rx_embed)
        static_features = self.static_seq(static_embed)
        concat = torch.cat((px_features, dx_features, rx_features, static_features), dim=1)
        return self.concat_seq(concat)


class Trainer:
    def __init__(self, train_data, val_data):
        self.train_data = train_data
        self.val_data = val_data
        if torch.cuda.is_available():
            self.dev = torch.device('cuda')
        else:
            self.dev = torch.device('cpu')
        print("Using: ", self.dev)

    @staticmethod
    def _roc_auc(y_true, y_raw_logits, class_idx=1):
        y_score = F.softmax(y_raw_logits, dim=1).cpu().detach().numpy()[:, class_idx]
        return metrics.roc_auc_score(y_true.cpu().detach().numpy(), y_score)

    @staticmethod
    def _accuracy(true, pred_proba, class_idx=None):
        preds = torch.argmax(pred_proba, dim=1)
        if class_idx is None:
            score = (true == preds).float().mean()
        else:
            score = (true==preds)[true==class_idx].float().mean()
        return score

    @staticmethod
    def _get_optimizer_with_decay(model, lr):
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': 0.01, 'lr':lr},
            {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, 'lr':lr}
             ]
        return AdamW(optimizer_grouped_parameters, lr=lr)

    def _get_loss_func(self, class_wts=None):
        if class_wts is not None:
            loss_fn = torch.nn.CrossEntropyLoss(weight=torch.tensor(class_wts).type(torch.float).to(self.dev))
        else:
            loss_fn = torch.nn.CrossEntropyLoss()
        return loss_fn
    
    @staticmethod
    def get_class_wts(y):
        return compute_class_weight('balanced', np.unique(y), y)

    def train_model(self, model, n_epochs=1, lr=1e-5, class_wts=None,
                    path_to_save_model='./model.tar', logging_step=5,
                    patience=3):
        loss_fn = self._get_loss_func(class_wts=class_wts)
        optimizer = Trainer._get_optimizer_with_decay(model, lr)

        model.to(self.dev)
        trn_loss = []
        val_loss = []
        val_accu = []
        max_val_accu = 0
        count = 0
        for epoch in range(n_epochs):
            print(f"Epoch: {epoch}")
            trn_loss_per_epoch = []
            model.train()
            for i, batch in enumerate(tqdm(self.train_data)):
                batch = [x.to(self.dev) for x in batch]
                outputs = model(*batch[:-1])
                loss = loss_fn(outputs, batch[-1])
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                del batch
                torch.cuda.empty_cache()
                trn_loss_per_epoch.append(float(loss))
                if i % logging_step == 0:
                    print("train_loss: ", np.mean(trn_loss_per_epoch))
            trn_loss.append(np.mean(trn_loss_per_epoch))
            
            model.eval()
            valid_loss_per_epoch = []
            yhat_lst = []
            y_lst = []
            with torch.no_grad():
                for batch in tqdm(self.val_data):
                    batch = [x.to(self.dev) for x in batch]
                    outputs = model(*batch[:-1])
                    valid_loss_per_epoch.append(float(loss_fn(outputs, batch[-1])))
                    yhat_lst.append(outputs)
                    y_lst.append(batch[-1])   
                tmp = Trainer._roc_auc(torch.cat(y_lst, dim=0), torch.cat(yhat_lst, dim=0))
            val_loss.append(np.mean(valid_loss_per_epoch))
            val_accu.append(tmp)
            
            print(f"Training Loss for epoch {epoch}: ", trn_loss[-1])
            print("Validation Loss: ", val_loss[-1], "| Validation roc_auc: ", val_accu[-1])

            if val_accu[-1] > max_val_accu:
                count = 0
                max_val_accu = val_accu[-1]
                torch.save({
                            'epoch': epoch,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'training_loss': trn_loss[-1],
                            'val_loss': val_loss[-1],
                            'val_accuracy': val_accu[-1]
                            }, path_to_save_model)
                print(f"Checkpoint saved!")
            else:
                count = count + 1

            if count > patience:
                print("Stopping early, restoring best checkpoint..")
                model_parameters = torch.load(path_to_save_model, map_location = self.dev)
                model.load_state_dict(model_parameters['model_state_dict'])
                optimizer.load_state_dict(model_parameters['optimizer_state_dict'])
                break 
            print("Max roc_auc till now: ", max_val_accu)
        return model.eval(), optimizer

In [5]:
px_trn = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/trn_px_journey.pkl")
dx_trn = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/trn_dx_journey.pkl")
rx_trn = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/imp_rx_trn.pkl")
y_trn = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/y_trn.pkl")
rx_trn.shape

(23868, 806)

In [6]:
px_val = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/val_px_journey.pkl")
dx_val = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/val_dx_journey.pkl")
rx_val = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/imp_rx_val.pkl")
y_val = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/y_val.pkl")
rx_val.shape

(4212, 806)

In [7]:
px_test = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/test_px_journey.pkl")
dx_test = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/test_dx_journey.pkl")
rx_test = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/imp_rx_test.pkl")
y_test = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/y_test.pkl")
rx_test.shape

(7020, 806)

In [8]:
trn_static = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/trn_static.pkl")
val_static = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/val_static.pkl")
test_static = pd.read_pickle("/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/test_static.pkl")
trn_static.shape, val_static.shape, test_static.shape

((23868, 182), (4212, 182), (7020, 182))

In [9]:
bs = 1024

trn_ds = PatDataset(px_trn, dx_trn, rx_trn.values, trn_static.values, y_trn['switch_flag'].tolist())
trn_dl = torch.utils.data.DataLoader(trn_ds, batch_size=bs, num_workers=0, pin_memory=False)

val_ds = PatDataset(px_val, dx_val, rx_val.values, val_static.values, y_val['switch_flag'].tolist())
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=bs, num_workers=0, pin_memory=False)

test_ds = PatDataset(px_test, dx_test, rx_test.values, test_static.values, y_test['switch_flag'].tolist())
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=bs, num_workers=0, pin_memory=False)

In [10]:
class_wts = compute_class_weight('balanced', np.unique(y_trn['switch_flag'].values), y_trn['switch_flag'].values)
class_wts

array([ 0.51734004, 14.9175    ])

In [11]:
trainer = Trainer(trn_dl, val_dl)

Using:  cpu


In [12]:
model = MyModel()

model, optimizer = trainer.train_model(
    model, n_epochs=200, lr=1e-5, class_wts=class_wts,
    path_to_save_model='/content/drive/MyDrive/ColabData/saved_models/PatientBERT/only_top/sepsis-readmission-only-top-with-static-features.tar',
    logging_step=46, patience=10
    )

Epoch: 0


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6937960386276245



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 0:  0.6951270923018456
Validation Loss:  0.6925402522087097 | Validation roc_auc:  0.5039511437934117
Checkpoint saved!
Max roc_auc till now:  0.5039511437934117
Epoch: 1


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6898974180221558



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 1:  0.691310927271843
Validation Loss:  0.6893953561782837 | Validation roc_auc:  0.550254263420039
Checkpoint saved!
Max roc_auc till now:  0.550254263420039
Epoch: 2


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6885543465614319



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 2:  0.6887704605857531
Validation Loss:  0.6873053669929504 | Validation roc_auc:  0.5754297391513403
Checkpoint saved!
Max roc_auc till now:  0.5754297391513403
Epoch: 3


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6871179938316345



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 3:  0.6864022836089134
Validation Loss:  0.685102915763855 | Validation roc_auc:  0.5909503476414215
Checkpoint saved!
Max roc_auc till now:  0.5909503476414215
Epoch: 4


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6857180595397949



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 4:  0.6840886250138283
Validation Loss:  0.6830718517303467 | Validation roc_auc:  0.6015407370242034
Checkpoint saved!
Max roc_auc till now:  0.6015407370242034
Epoch: 5


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6844781041145325



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 5:  0.6817747751871744
Validation Loss:  0.6809695482254028 | Validation roc_auc:  0.6103088616768668
Checkpoint saved!
Max roc_auc till now:  0.6103088616768668
Epoch: 6


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6833009719848633



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 6:  0.6794217179218928
Validation Loss:  0.6788039803504944 | Validation roc_auc:  0.6171458386685968
Checkpoint saved!
Max roc_auc till now:  0.6171458386685968
Epoch: 7


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.68184494972229



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 7:  0.6769945472478867
Validation Loss:  0.6764772653579711 | Validation roc_auc:  0.6227859744848094
Checkpoint saved!
Max roc_auc till now:  0.6227859744848094
Epoch: 8


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.680197536945343



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 8:  0.6744861056407293
Validation Loss:  0.6740912675857544 | Validation roc_auc:  0.6272789197419562
Checkpoint saved!
Max roc_auc till now:  0.6272789197419562
Epoch: 9


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.678756058216095



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 9:  0.6719295084476471
Validation Loss:  0.6715882539749145 | Validation roc_auc:  0.6317265696998838
Checkpoint saved!
Max roc_auc till now:  0.6317265696998838
Epoch: 10


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.677094042301178



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 10:  0.6692441826065382
Validation Loss:  0.6690178871154785 | Validation roc_auc:  0.6355357301515128
Checkpoint saved!
Max roc_auc till now:  0.6355357301515128
Epoch: 11


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6753959059715271



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 11:  0.6663969854513804
Validation Loss:  0.6663041472434997 | Validation roc_auc:  0.6396471496190839
Checkpoint saved!
Max roc_auc till now:  0.6396471496190839
Epoch: 12


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6738784313201904



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 12:  0.6634348034858704
Validation Loss:  0.6631531119346619 | Validation roc_auc:  0.6448308481893203
Checkpoint saved!
Max roc_auc till now:  0.6448308481893203
Epoch: 13


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.672589898109436



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 13:  0.6602950220306715
Validation Loss:  0.6597923636436462 | Validation roc_auc:  0.6503777802167555
Checkpoint saved!
Max roc_auc till now:  0.6503777802167555
Epoch: 14


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6712974905967712



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 14:  0.6569371571143469
Validation Loss:  0.6561268210411072 | Validation roc_auc:  0.6552714146593009
Checkpoint saved!
Max roc_auc till now:  0.6552714146593009
Epoch: 15


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6701982617378235



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 15:  0.6535290802518526
Validation Loss:  0.6528013586997986 | Validation roc_auc:  0.6594089660302677
Checkpoint saved!
Max roc_auc till now:  0.6594089660302677
Epoch: 16


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6689692139625549



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 16:  0.6499959131081899
Validation Loss:  0.64949471950531 | Validation roc_auc:  0.6624402668241549
Checkpoint saved!
Max roc_auc till now:  0.6624402668241549
Epoch: 17


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.667859673500061



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 17:  0.6463523333271345
Validation Loss:  0.6462290167808533 | Validation roc_auc:  0.6651266264932205
Checkpoint saved!
Max roc_auc till now:  0.6651266264932205
Epoch: 18


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6671286225318909



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 18:  0.6426794330279032
Validation Loss:  0.643043828010559 | Validation roc_auc:  0.66786960528631
Checkpoint saved!
Max roc_auc till now:  0.66786960528631
Epoch: 19


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6661220192909241



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 19:  0.6389425496260325
Validation Loss:  0.6399169087409973 | Validation roc_auc:  0.6708573529078711
Checkpoint saved!
Max roc_auc till now:  0.6708573529078711
Epoch: 20


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.665207028388977



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 20:  0.6354474127292633
Validation Loss:  0.637092673778534 | Validation roc_auc:  0.6731752527390589
Checkpoint saved!
Max roc_auc till now:  0.6731752527390589
Epoch: 21


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6639889478683472



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 21:  0.631942426164945
Validation Loss:  0.6344955325126648 | Validation roc_auc:  0.6754835708723351
Checkpoint saved!
Max roc_auc till now:  0.6754835708723351
Epoch: 22


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6626506447792053



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 22:  0.6286483531196912
Validation Loss:  0.6326361775398255 | Validation roc_auc:  0.6764469670441856
Checkpoint saved!
Max roc_auc till now:  0.6764469670441856
Epoch: 23


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6611722111701965



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 23:  0.6254981035987536
Validation Loss:  0.6308885812759399 | Validation roc_auc:  0.6783005900583786
Checkpoint saved!
Max roc_auc till now:  0.6783005900583786
Epoch: 24


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6596466302871704



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 24:  0.6222454483310381
Validation Loss:  0.6293566584587097 | Validation roc_auc:  0.6799207680689047
Checkpoint saved!
Max roc_auc till now:  0.6799207680689047
Epoch: 25


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6582753658294678



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 25:  0.6191010524829229
Validation Loss:  0.6277293205261231 | Validation roc_auc:  0.6813101142661029
Checkpoint saved!
Max roc_auc till now:  0.6813101142661029
Epoch: 26


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6562740206718445



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 26:  0.6160600756605467
Validation Loss:  0.6265491008758545 | Validation roc_auc:  0.6823301295619769
Checkpoint saved!
Max roc_auc till now:  0.6823301295619769
Epoch: 27


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6548274159431458



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 27:  0.6131700724363327
Validation Loss:  0.6252938628196716 | Validation roc_auc:  0.683370179317121
Checkpoint saved!
Max roc_auc till now:  0.683370179317121
Epoch: 28


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6532202959060669



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 28:  0.610299194852511
Validation Loss:  0.6244127035140992 | Validation roc_auc:  0.6844363609756607
Checkpoint saved!
Max roc_auc till now:  0.6844363609756607
Epoch: 29


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6513802409172058



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 29:  0.6075276136398315
Validation Loss:  0.6233052253723145 | Validation roc_auc:  0.6854816371114838
Checkpoint saved!
Max roc_auc till now:  0.6854816371114838
Epoch: 30


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6494227647781372



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 30:  0.6048060605923334
Validation Loss:  0.6223694682121277 | Validation roc_auc:  0.6866035668306009
Checkpoint saved!
Max roc_auc till now:  0.6866035668306009
Epoch: 31


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6472102403640747



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 31:  0.6020162403583527
Validation Loss:  0.6214535474777222 | Validation roc_auc:  0.6875495417335207
Checkpoint saved!
Max roc_auc till now:  0.6875495417335207
Epoch: 32


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6451950073242188



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 32:  0.5992129271229109
Validation Loss:  0.6208553791046143 | Validation roc_auc:  0.6881627703998704
Checkpoint saved!
Max roc_auc till now:  0.6881627703998704
Epoch: 33


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6429067254066467



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 33:  0.5964640478293101
Validation Loss:  0.6199907541275025 | Validation roc_auc:  0.6888143258578667
Checkpoint saved!
Max roc_auc till now:  0.6888143258578667
Epoch: 34


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6409158706665039



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 34:  0.5936672985553741
Validation Loss:  0.6191612243652344 | Validation roc_auc:  0.6896522888934183
Checkpoint saved!
Max roc_auc till now:  0.6896522888934183
Epoch: 35


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6385937929153442



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 35:  0.5908220137159029
Validation Loss:  0.6185246706008911 | Validation roc_auc:  0.6903961770767459
Checkpoint saved!
Max roc_auc till now:  0.6903961770767459
Epoch: 36


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6365326046943665



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 36:  0.5879839956760406
Validation Loss:  0.6178821325302124 | Validation roc_auc:  0.6912263005412788
Checkpoint saved!
Max roc_auc till now:  0.6912263005412788
Epoch: 37


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6338967084884644



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 37:  0.585089107354482
Validation Loss:  0.6174597144126892 | Validation roc_auc:  0.6917306462768136
Checkpoint saved!
Max roc_auc till now:  0.6917306462768136
Epoch: 38


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6316492557525635



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 38:  0.5822546258568764
Validation Loss:  0.6167985200881958 | Validation roc_auc:  0.6923952676865076
Checkpoint saved!
Max roc_auc till now:  0.6923952676865076
Epoch: 39


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6291640400886536



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 39:  0.5794280295570692
Validation Loss:  0.6162272930145264 | Validation roc_auc:  0.6930529205886299
Checkpoint saved!
Max roc_auc till now:  0.6930529205886299
Epoch: 40


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6266763210296631



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 40:  0.576561450958252
Validation Loss:  0.6156578183174133 | Validation roc_auc:  0.6936818283970169
Checkpoint saved!
Max roc_auc till now:  0.6936818283970169
Epoch: 41


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6240836977958679



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 41:  0.5736940751473109
Validation Loss:  0.6152122616767883 | Validation roc_auc:  0.694277635794436
Checkpoint saved!
Max roc_auc till now:  0.694277635794436
Epoch: 42


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6216534376144409



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 42:  0.57075293486317
Validation Loss:  0.6148988246917725 | Validation roc_auc:  0.6949535810289349
Checkpoint saved!
Max roc_auc till now:  0.6949535810289349
Epoch: 43


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6188102960586548



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 43:  0.567757590363423
Validation Loss:  0.6143231391906738 | Validation roc_auc:  0.6955110616347073
Checkpoint saved!
Max roc_auc till now:  0.6955110616347073
Epoch: 44


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6160098910331726



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 44:  0.564831322679917
Validation Loss:  0.6139925360679627 | Validation roc_auc:  0.6961120954128057
Checkpoint saved!
Max roc_auc till now:  0.6961120954128057
Epoch: 45


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6133149862289429



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 45:  0.5618168984850248
Validation Loss:  0.6136355757713318 | Validation roc_auc:  0.6967723615052673
Checkpoint saved!
Max roc_auc till now:  0.6967723615052673
Epoch: 46


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.610317051410675



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 46:  0.5586404117445151
Validation Loss:  0.6133479833602905 | Validation roc_auc:  0.6972026668478477
Checkpoint saved!
Max roc_auc till now:  0.6972026668478477
Epoch: 47


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6074705123901367



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 47:  0.5555092332263788
Validation Loss:  0.6130722880363464 | Validation roc_auc:  0.6976974308854708
Checkpoint saved!
Max roc_auc till now:  0.6976974308854708
Epoch: 48


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6040779948234558



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 48:  0.5523014217615128
Validation Loss:  0.6129795193672181 | Validation roc_auc:  0.6980040452186456
Checkpoint saved!
Max roc_auc till now:  0.6980040452186456
Epoch: 49


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.6015542149543762



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 49:  0.5491180010139942
Validation Loss:  0.612708842754364 | Validation roc_auc:  0.6984430611956914
Checkpoint saved!
Max roc_auc till now:  0.6984430611956914
Epoch: 50


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5980469584465027



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 50:  0.5458631801108519
Validation Loss:  0.6127484917640686 | Validation roc_auc:  0.6987845180667269
Checkpoint saved!
Max roc_auc till now:  0.6987845180667269
Epoch: 51


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5951128602027893



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 51:  0.5426112413406372
Validation Loss:  0.6125032544136048 | Validation roc_auc:  0.6991242328108694
Checkpoint saved!
Max roc_auc till now:  0.6991242328108694
Epoch: 52


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5915716290473938



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 52:  0.5391827325026194
Validation Loss:  0.6126161575317383 | Validation roc_auc:  0.6997061031931444
Checkpoint saved!
Max roc_auc till now:  0.6997061031931444
Epoch: 53


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5883401036262512



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 53:  0.5357726017634074
Validation Loss:  0.6125383734703064 | Validation roc_auc:  0.7001660246929066
Checkpoint saved!
Max roc_auc till now:  0.7001660246929066
Epoch: 54


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5847545862197876



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 54:  0.5322506489853064
Validation Loss:  0.6129452109336853 | Validation roc_auc:  0.700232225514842
Checkpoint saved!
Max roc_auc till now:  0.700232225514842
Epoch: 55


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5814189314842224



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 55:  0.5288016758859158
Validation Loss:  0.6128016591072083 | Validation roc_auc:  0.7007208921083394
Checkpoint saved!
Max roc_auc till now:  0.7007208921083394
Epoch: 56


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5775743126869202



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 56:  0.5251840961476167
Validation Loss:  0.6130951762199401 | Validation roc_auc:  0.700676467872567
Max roc_auc till now:  0.7007208921083394
Epoch: 57


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5739640593528748



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 57:  0.5216491855680943
Validation Loss:  0.6130768537521363 | Validation roc_auc:  0.7010545094083562
Checkpoint saved!
Max roc_auc till now:  0.7010545094083562
Epoch: 58


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5702645182609558



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 58:  0.5178672596812248
Validation Loss:  0.6136904835700989 | Validation roc_auc:  0.7011224523571847
Checkpoint saved!
Max roc_auc till now:  0.7011224523571847
Epoch: 59


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5662861466407776



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 59:  0.5141099840402603
Validation Loss:  0.6138909697532654 | Validation roc_auc:  0.7016207006485938
Checkpoint saved!
Max roc_auc till now:  0.7016207006485938
Epoch: 60


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5624226331710815



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 60:  0.5103986288110415
Validation Loss:  0.6143542408943177 | Validation roc_auc:  0.7015632104611235
Max roc_auc till now:  0.7016207006485938
Epoch: 61


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5585261583328247



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 61:  0.5065870707233747
Validation Loss:  0.614983081817627 | Validation roc_auc:  0.7019238307279825
Checkpoint saved!
Max roc_auc till now:  0.7019238307279825
Epoch: 62


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5544409155845642



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 62:  0.5027115816871325
Validation Loss:  0.6157829761505127 | Validation roc_auc:  0.7019151200935174
Max roc_auc till now:  0.7019238307279825
Epoch: 63


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5503010153770447



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 63:  0.4988335371017456
Validation Loss:  0.616418433189392 | Validation roc_auc:  0.7020091949457414
Checkpoint saved!
Max roc_auc till now:  0.7020091949457414
Epoch: 64


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5462738275527954



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 64:  0.49481184408068657
Validation Loss:  0.617592716217041 | Validation roc_auc:  0.7017443916579996
Max roc_auc till now:  0.7020091949457414
Epoch: 65


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5419653058052063



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 65:  0.4908744643131892
Validation Loss:  0.6182222723960876 | Validation roc_auc:  0.7018837618094426
Max roc_auc till now:  0.7020091949457414
Epoch: 66


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5376713871955872



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 66:  0.48683421686291695
Validation Loss:  0.6194225788116455 | Validation roc_auc:  0.7017252282621761
Max roc_auc till now:  0.7020091949457414
Epoch: 67


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5333271026611328



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 67:  0.48282158002257347
Validation Loss:  0.6203128576278687 | Validation roc_auc:  0.7017513601655717
Max roc_auc till now:  0.7020091949457414
Epoch: 68


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5289542078971863



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 68:  0.47870752215385437
Validation Loss:  0.6217910528182984 | Validation roc_auc:  0.701582373856947
Max roc_auc till now:  0.7020091949457414
Epoch: 69


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5246140956878662



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 69:  0.4746418495972951
Validation Loss:  0.6228973507881165 | Validation roc_auc:  0.70172174400839
Max roc_auc till now:  0.7020091949457414
Epoch: 70


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5201217532157898



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 70:  0.4704853432873885
Validation Loss:  0.6247684240341187 | Validation roc_auc:  0.701315828442312
Max roc_auc till now:  0.7020091949457414
Epoch: 71


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5153403282165527



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 71:  0.4663159114619096
Validation Loss:  0.6259745717048645 | Validation roc_auc:  0.7012025901942647
Max roc_auc till now:  0.7020091949457414
Epoch: 72


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5110105276107788



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 72:  0.46213118607799214
Validation Loss:  0.6277498245239258 | Validation roc_auc:  0.7010998047075753
Max roc_auc till now:  0.7020091949457414
Epoch: 73


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5063905715942383



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 73:  0.4579450649519761
Validation Loss:  0.6294668197631836 | Validation roc_auc:  0.7007600899634328
Max roc_auc till now:  0.7020091949457414
Epoch: 74


HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

train_loss:  0.5019139647483826



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


Training Loss for epoch 74:  0.4537044018507004
Validation Loss:  0.631260085105896 | Validation roc_auc:  0.7006451095884922
Stopping early, restoring best checkpoint..


In [13]:
model_parameters = torch.load(
    "/content/drive/MyDrive/ColabData/saved_models/PatientBERT/only_top/sepsis-readmission-only-top-with-static-features.tar",
    map_location = dev
    )
model.load_state_dict(model_parameters['model_state_dict'])
model = model.eval()

In [14]:
dl = trn_dl
yhat_lst = []
y_lst = []
with torch.no_grad():
    for batch in tqdm(dl):
        yhat = model(*batch[:-1])
        yhat_lst.append(yhat)
        y_lst += list(batch[-1].cpu().numpy())

print(Trainer._roc_auc(torch.tensor(y_lst), torch.cat(yhat_lst, dim=0)))

HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))


0.8553900966707126


In [15]:
dl = val_dl
yhat_lst = []
y_lst = []
with torch.no_grad():
    for batch in tqdm(dl):
        yhat = model(*batch[:-1])
        yhat_lst.append(yhat)
        y_lst += list(batch[-1].cpu().numpy())

print(Trainer._roc_auc(torch.tensor(y_lst), torch.cat(yhat_lst, dim=0)))

HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))


0.7020091949457414


In [16]:
dl = test_dl
yhat_lst = []
y_lst = []
with torch.no_grad():
    for batch in tqdm(dl):
        yhat = model(*batch[:-1])
        yhat_lst.append(yhat)
        y_lst += list(batch[-1].cpu().numpy())

print(Trainer._roc_auc(torch.tensor(y_lst), torch.cat(yhat_lst, dim=0)))

HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))


0.709803540350272


# **XGBoost** 

## Without Static

In [17]:
from xgboost import XGBClassifier

In [23]:
data_trn = pd.concat([pd.DataFrame(px_trn), pd.DataFrame(dx_trn), rx_trn.reset_index(drop=True)], axis=1)
data_val = pd.concat([pd.DataFrame(px_val), pd.DataFrame(dx_val), rx_val.reset_index(drop=True)], axis=1)
data_test = pd.concat([pd.DataFrame(px_test), pd.DataFrame(dx_test), rx_test.reset_index(drop=True)], axis=1)

In [24]:
clf = XGBClassifier(n_estimators=36,
                    scale_pos_weight=(y_trn.shape[0]-800)/y_trn['switch_flag'].sum(),
                    verbosity=2)
clf = clf.fit(data_trn.values, y_trn['switch_flag'].tolist(),
              early_stopping_rounds=10,
              eval_set=[(data_val.values, y_val['switch_flag'].tolist())],
              eval_metric='auc')

[15:15:11] INFO: /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[0]	validation_0-auc:0.630452
Will train until validation_0-auc hasn't improved in 10 rounds.
[15:15:13] INFO: /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[1]	validation_0-auc:0.636731
[15:15:16] INFO: /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[2]	validation_0-auc:0.658897
[15:15:19] INFO: /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[3]	validation_0-auc:0.664356
[15:15:21] INFO: /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[4]	validation_0-auc:0.668267
[15:15:24] INFO: /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[5]	validati

In [25]:
yhat = clf.predict_proba(data_trn.values)
metrics.roc_auc_score(y_trn['switch_flag'].tolist(), yhat[:, 1])

0.7680008561643834

In [26]:
yhat = clf.predict_proba(data_val.values)
metrics.roc_auc_score(y_val['switch_flag'].tolist(), yhat[:, 1])

0.6757788613807052

In [27]:
yhat = clf.predict_proba(data_test.values)
metrics.roc_auc_score(y_test['switch_flag'].tolist(), yhat[:, 1])

0.6605879678264005

## With Static

In [None]:
data_trn = pd.concat([pd.DataFrame(px_trn), pd.DataFrame(dx_trn), rx_trn.reset_index(drop=True), trn_static.reset_index(drop=True)], axis=1)
data_val = pd.concat([pd.DataFrame(px_val), pd.DataFrame(dx_val), rx_val.reset_index(drop=True), val_static.reset_index(drop=True)], axis=1)
data_test = pd.concat([pd.DataFrame(px_test), pd.DataFrame(dx_test), rx_test.reset_index(drop=True), test_static.reset_index(drop=True)], axis=1)

In [None]:
clf = XGBClassifier(n_estimators=36, scale_pos_weight=(y_trn.shape[0]-800)/y_trn['switch_flag'].sum(), verbosity=2)
clf = clf.fit(data_trn.values, y_trn['switch_flag'].tolist(), early_stopping_rounds=10, eval_set=[(data_val.values, y_val['switch_flag'].tolist())], eval_metric='auc')

[15:02:34] INFO: /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[0]	validation_0-auc:0.671095
Will train until validation_0-auc hasn't improved in 10 rounds.
[15:02:37] INFO: /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[1]	validation_0-auc:0.685139
[15:02:39] INFO: /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[2]	validation_0-auc:0.701385
[15:02:42] INFO: /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[3]	validation_0-auc:0.695175
[15:02:45] INFO: /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[4]	validation_0-auc:0.702331
[15:02:47] INFO: /workspace/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 0 pruned nodes, max_depth=3
[5]	validati

In [None]:
yhat = clf.predict_proba(data_trn.values)
metrics.roc_auc_score(y_trn['switch_flag'].tolist(), yhat[:, 1])

0.7862989585139588

In [None]:
yhat = clf.predict_proba(data_val.values)
metrics.roc_auc_score(y_val['switch_flag'].tolist(), yhat[:, 1])

0.7188599173186576

In [None]:
yhat = clf.predict_proba(data_test.values)
metrics.roc_auc_score(y_test['switch_flag'].tolist(), yhat[:, 1])

0.7246382037974881