In [1]:
"""
File: Bi-LSTM.ipynb
Code to train and evaluate a bi-directional LSTM model on MIMIC-IV FHIR dataset.
"""


def Project():
    """
    __Objectives__
    0. Import data and tokenizer
    1. Train the tokenizer on all sequences of the dataset
    2. Tokenize different sequences and join them together
    3. Prepare actual labels for one, six, twelve month death after discharge
    4. Define the model architecture for bidrectional LSTM
    5. Train Bi-LSTM model and evaluate on test dataset
    6. Compare performance across new tasks to XGBoost
    """
    return ProjectObjectives.__doc__

In [2]:
import os


ROOT = "/fs01/home/afallah/odyssey/odyssey"
os.chdir(ROOT)
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from sklearn.metrics import (
    auc,
    average_precision_score,
    balanced_accuracy_score,
    f1_score,
    precision_recall_curve,
    precision_score,
    recall_score,
    roc_auc_score,
    roc_curve,
)
from torch import nn, optim
from torch.nn.functional import sigmoid
from torch.nn.utils.rnn import pack_padded_sequence
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from models.big_bird_cehr.data import FinetuneDataset
from models.big_bird_cehr.embeddings import Embeddings
from models.big_bird_cehr.tokenizer import ConceptTokenizer


%matplotlib inline

DATA_ROOT = f"{ROOT}/data/slurm_data/512/two_weeks"
DATA_PATH = f"{DATA_ROOT}/pretrain.parquet"
FINE_TUNE_PATH = f"{DATA_ROOT}/fine_tune.parquet"
TEST_DATA_PATH = f"{DATA_ROOT}/fine_test.parquet"

2024-02-08 23:50:17.949692: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-02-08 23:50:18.008367: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-08 23:50:18.008419: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-08 23:50:18.010630: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-08 23:50:18.022348: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-02-08 23:50:18.023340: I tensorflow/core/platform/cpu_feature_guard.cc:1

In [3]:
# save parameters and configurations
class config:
    seed = 23
    data_dir = DATA_ROOT
    test_size = 0.2
    batch_size = 64
    num_workers = 3
    vocab_size = None
    embedding_size = 768
    time_embeddings_size = 32
    type_vocab_size = 8
    max_len = 512
    padding_idx = None
    device = torch.device("cuda")


def seed_all(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # pl.seed_everything(seed)


seed_all(config.seed)
print(f"Cuda: {torch.cuda.get_device_name(torch.cuda.current_device())}")

Cuda: NVIDIA A40


In [4]:
# Load data
pretrain_data = pd.read_parquet(DATA_PATH)
pretrain_data = pretrain_data[pretrain_data["event_tokens_512"].notnull()]

finetune_data = pd.read_parquet(FINE_TUNE_PATH)
finetune_data = finetune_data[finetune_data["event_tokens_512"].notnull()]

test_data = pd.read_parquet(TEST_DATA_PATH)
test_data = test_data[test_data["event_tokens_512"].notnull()]
test_length = len(test_data)

train_data = pd.concat((pretrain_data, finetune_data)).reset_index().drop_duplicates(subset="index", keep="first").set_index("index")
del pretrain_data, finetune_data

train_data

Unnamed: 0_level_0,patient_id,num_visits,deceased,death_after_start,death_after_end,length,token_length,event_tokens_512,type_tokens_512,age_tokens_512,time_tokens_512,visit_tokens_512,position_tokens_512,label
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
0,f8f3289a-057f-5fcc-a714-5f6109ca16c4,2,0,,,1,4,"[[CLS], [VS], 8938, [VE], [PAD], [PAD], [PAD],...","[1, 2, 7, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 18, 18, 18, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 8262, 8262, 8262, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 2, 2, 2, 513, 513, 513, 513, 513, 513, 513...",0
1,9b62c9f4-3fdc-5020-82b5-ae5b8292445a,4,0,,,43,52,"[[CLS], [VS], 7569, 66689036430, 00904224461, ...","[1, 2, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 3, 4, ...","[0, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28...","[0, 5963, 5963, 5963, 5963, 5963, 5963, 5963, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",0
2,2ca522eb-dd89-5f79-8155-9599ea46b0b2,2,1,244.0,242.0,51,54,"[[CLS], [VS], 00904629261, 00904642281, 009046...","[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86...","[0, 8016, 8016, 8016, 8016, 8016, 8016, 8016, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",0
4,02adf8a6-8bc0-55d3-81ae-4d8582094896,9,1,20.0,11.0,640,664,"[[CLS], [VS], 52007_3, 51476_2, 50861_1, 50862...","[1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...","[0, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73...","[0, 8426, 8426, 8426, 8426, 8426, 8426, 8426, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, ...",1
5,744fe3c4-9b03-55ae-ac9f-6bc4e967cde7,3,0,,,80,86,"[[CLS], [VS], 7813, 7813, 7902, 7902, 9604, 00...","[1, 2, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29...","[0, 7582, 7582, 7582, 7582, 7582, 7582, 7582, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
65266,433525b9-c7d8-53b7-9fcd-2a682eaf900e,2,1,11.0,0.0,746,749,"[[CLS], [VS], 50931_4, 50960_4, 50970_2, 50971...","[1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...","[0, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91...","[0, 8327, 8327, 8327, 8327, 8327, 8327, 8327, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...",1
143535,2bd932ab-2c87-529f-b2c2-e038ee18d935,25,1,18.0,0.0,4311,4383,"[[CLS], [VS], 51104_3, 51093_2, 51097_3, 51104...","[1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...","[0, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86...","[0, 6788, 6788, 6788, 6788, 6788, 6788, 6788, ...","[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26...",1
18343,b72aa2f9-33ca-58d0-acaa-3e3586da9ba0,2,1,7.0,0.0,286,289,"[[CLS], [VS], 966, 9671, 9671, 9604, 9604, 000...","[1, 2, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, ...","[0, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71...","[0, 7521, 7521, 7521, 7521, 7521, 7521, 7521, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",1
85586,48aad725-227a-5c75-92c6-531a65c84f89,2,1,0.0,0.0,16,19,"[[CLS], [VS], 9671, 9604, 00409490234, 5107902...","[1, 2, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, ...","[0, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91...","[0, 8402, 8402, 8402, 8402, 8402, 8402, 8402, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",1


In [5]:
# Load data
# data = pd.read_parquet(DATA_PATH)
# data.rename(
#     columns={
#         "event_tokens": "event_tokens_untruncated",
#         "event_tokens_updated": "event_tokens",
#     },
#     inplace=True,
# )
# data

In [6]:
# Define custom labels, here death in 12 M
# data["label"] = (
#     (data["death_after_end"] >= 0) & (data["death_after_end"] < 365)
# ).astype(int)

In [7]:
# Fit tokenizer on .json vocab files
tokenizer = ConceptTokenizer(data_dir=config.data_dir)
tokenizer.fit_on_vocab()
config.vocab_size = tokenizer.get_vocab_size()
config.padding_idx = tokenizer.get_pad_token_id()
tokenizer

<models.big_bird_cehr.tokenizer.ConceptTokenizer at 0x7f37363d0490>

In [8]:
# Define dataset with token lengths
class DatasetWithTokenLength(Dataset):
    def __init__(self, tokenized_data, length_data):
        super(Dataset, self).__init__()

        self.tokenized_data = tokenized_data
        self.length_data = length_data
        assert len(tokenized_data) == len(
            length_data,
        ), "Datasets have different lengths"

        self.sorted_indices = sorted(
            range(len(length_data)), key=lambda x: length_data[x], reverse=True,
        )
        # self.tokenized_data = [tokenized_data[i] for i in self.sorted_indices]
        # self.length_data = [min(length_data[i], ) for i in self.sorted_indices]

    def __len__(self):
        return len(self.tokenized_data)

    def __getitem__(self, index):
        index = self.sorted_indices[index]
        return self.tokenized_data[index], min(config.max_len, self.length_data[index])


# Get training and test datasets
# train_data, test_data = train_test_split(
#     data, test_size=config.test_size, random_state=config.seed, stratify=data["label"]
# )

train_dataset = FinetuneDataset(
    data=train_data,
    tokenizer=tokenizer,
    max_len=config.max_len,
)

test_dataset = FinetuneDataset(
    data=test_data,
    tokenizer=tokenizer,
    max_len=config.max_len,
)

train_dataset_with_lengths = DatasetWithTokenLength(
    train_dataset, train_data["token_length"].values,
)
test_dataset_with_lengths = DatasetWithTokenLength(
    test_dataset, test_data["token_length"].values,
)

train_loader = DataLoader(
    train_dataset_with_lengths,
    batch_size=config.batch_size,
    num_workers=config.num_workers,
    shuffle=False,
    pin_memory=True,
)

test_loader = DataLoader(
    test_dataset_with_lengths,
    batch_size=config.batch_size,
    num_workers=config.num_workers,
    shuffle=False,
    pin_memory=True,
)

print("Data is ready to go!\n")

Data is ready to go!



In [9]:
# Define model architecture
class BiLSTMModel(nn.Module):
    def __init__(
        self, embedding_dim, hidden_size, num_layers, output_size, dropout_rate,
    ):
        super(BiLSTMModel, self).__init__()

        self.embeddings = Embeddings(
            vocab_size=config.vocab_size,
            embedding_size=config.embedding_size,
            time_embedding_size=config.time_embeddings_size,
            type_vocab_size=config.type_vocab_size,
            max_len=config.max_len,
            padding_idx=config.padding_idx,
        )

        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout_rate,
        )

        self.batch_norm = nn.BatchNorm1d(hidden_size * 2)
        self.dropout = nn.Dropout(dropout_rate)
        self.linear = nn.Linear(hidden_size * 2, output_size)

    def forward(self, inputs, lengths):
        embed = self.embeddings(*inputs)
        packed_embed = pack_padded_sequence(embed, lengths.cpu(), batch_first=True)

        lstm_out, (hidden_state, cell_state) = self.lstm(packed_embed)
        # output = lstm_out[:, -1, :]
        output = torch.cat((hidden_state[-2, :, :], hidden_state[-1, :, :]), dim=1)

        output = self.dropout(self.batch_norm(output))
        output = self.linear(output)
        return output

    @staticmethod
    def get_inputs_labels(sequences):
        labels = sequences["labels"].view(-1, 1).to(config.device)
        inputs = (
            sequences["concept_ids"].to(config.device),
            sequences["type_ids"].to(config.device),
            sequences["time_stamps"].to(config.device),
            sequences["ages"].to(config.device),
            sequences["visit_orders"].to(config.device),
            sequences["visit_segments"].to(config.device),
        )

        return inputs, labels.float()

    @staticmethod
    def get_balanced_accuracy(outputs, labels):
        predictions = torch.round(sigmoid(outputs))
        predictions = predictions.detach().cpu().numpy()
        labels = labels.detach().cpu().numpy()

        return balanced_accuracy_score(labels, predictions)

In [10]:
# Set hyperparameters for Bi-LSTM model adn training loop
input_size = config.embedding_size  # embedding_dim
hidden_size = config.embedding_size // 2  # output hidden size
num_layers = 5  # number of LSTM layers
output_size = 1  # Binary classification, so output size is 1
dropout_rate = 0.2  # Dropout rate for regularization

epochs = 6
learning_rate = 0.001

In [11]:
# Training Loop
model = BiLSTMModel(input_size, hidden_size, num_layers, output_size, dropout_rate).to(
    config.device,
)
class_weights = torch.tensor([8]).to(config.device)
loss_fcn = nn.BCEWithLogitsLoss(weight=class_weights)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = ExponentialLR(optimizer, gamma=0.75, verbose=True)

for epoch in range(epochs):
    train_total_loss = 0
    train_accuracy = 0
    test_accuracy = 0

    model.train()
    for batch_no, (sequences, lengths) in tqdm(
        enumerate(train_loader),
        file=sys.stdout,
        total=len(train_loader),
        desc=f"Epoch {epoch + 1}/{epochs}",
        unit=" batch",
    ):
        inputs, labels = model.get_inputs_labels(sequences)
        outputs = model(inputs, lengths)
        loss = loss_fcn(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_total_loss += loss.item()
        # tqdm.write(f'Batch Loss: {loss.item():.4f}', file=sys.stdout, end='\r')
        # print(f'\nBatch Loss: {loss.item():.4f}', end='\r')

    # model.eval()
    # with torch.no_grad():
    #     for batch_no, (sequences, lengths) in tqdm(
    #         enumerate(train_loader),
    #         file=sys.stdout,
    #         total=len(train_loader),
    #         desc=f"Train Evaluation {epoch + 1}/{epochs}",
    #         unit=" batch",
    #     ):
    #         inputs, labels = model.get_inputs_labels(sequences)
    #         outputs = model(inputs, lengths)
    #         train_accuracy += model.get_balanced_accuracy(outputs, labels)
    #
    #     for batch_no, (sequences, lengths) in tqdm(
    #         enumerate(test_loader),
    #         file=sys.stdout,
    #         total=len(test_loader),
    #         desc=f"Test Evaluation {epoch + 1}/{epochs}",
    #         unit=" batch",
    #     ):
    #         inputs, labels = model.get_inputs_labels(sequences)
    #         outputs = model(inputs, lengths)
    #         test_accuracy += model.get_balanced_accuracy(outputs, labels)

    print(
        f"\nEpoch {epoch + 1}/{epochs}  |  "
        f"Average Train Loss: {train_total_loss / len(train_loader):.5f}  |  "
        f"Train Accuracy: {train_accuracy / len(train_loader):.5f}  |  ",
        # f"Test Accuracy: {test_accuracy / len(test_loader):.5f}\n\n"
    )
    scheduler.step()

# torch.save(model, 'LSTM_V2.pt')

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch 1/6: 100%|██████████| 2666/2666 [11:47<00:00,  3.77 batch/s]

Epoch 1/6  |  Average Train Loss: 1.10915  |  Train Accuracy: 0.00000  |  
Adjusting learning rate of group 0 to 7.5000e-04.
Epoch 2/6: 100%|██████████| 2666/2666 [11:44<00:00,  3.79 batch/s]

Epoch 2/6  |  Average Train Loss: 0.84270  |  Train Accuracy: 0.00000  |  
Adjusting learning rate of group 0 to 5.6250e-04.
Epoch 3/6: 100%|██████████| 2666/2666 [11:41<00:00,  3.80 batch/s]

Epoch 3/6  |  Average Train Loss: 0.78582  |  Train Accuracy: 0.00000  |  
Adjusting learning rate of group 0 to 4.2188e-04.
Epoch 4/6: 100%|██████████| 2666/2666 [11:32<00:00,  3.85 batch/s]

Epoch 4/6  |  Average Train Loss: 0.75239  |  Train Accuracy: 0.00000  |  
Adjusting learning rate of group 0 to 3.1641e-04.
Epoch 5/6: 100%|██████████| 2666/2666 [11:27<00:00,  3.88 batch/s]

Epoch 5/6  |  Average Train Loss: 0.72128  |  Train Accuracy: 0.00000  |  
Adjusting learning rate of group 0 

In [None]:
state_dict = torch.load(
    "/fs01/home/afallah/odyssey/slurm/LSTM_V4_Weighted1.pt",
).state_dict()
model = BiLSTMModel(input_size, hidden_size, num_layers, output_size, dropout_rate).to(
    config.device,
)
model.load_state_dict(state_dict)

In [12]:
# y_train_pred = np.array([])
# y_train_labels = np.array([])
y_test_pred = np.array([])
y_test_labels = np.array([])
y_test_prob = np.array([])


model.eval()
with torch.no_grad():
    # for batch_no, (sequences, lengths) in tqdm(
    #     enumerate(train_loader),
    #     file=sys.stdout,
    #     total=len(train_loader),
    #     desc="Train Evaluation",
    #     unit=" batch",
    # ):
    #     inputs, labels = model.get_inputs_labels(sequences)
    #     outputs = model(inputs, lengths)
    #
    #     predictions = torch.round(sigmoid(outputs))
    #     predictions = predictions.detach().cpu().numpy()
    #     labels = labels.detach().cpu().numpy()
    #
    #     y_train_pred = np.append(y_train_pred, predictions)
    #     y_train_labels = np.append(y_train_labels, labels)

    for batch_no, (sequences, lengths) in tqdm(
        enumerate(test_loader),
        file=sys.stdout,
        total=len(test_loader),
        desc="Test Evaluation",
        unit=" batch",
    ):
        inputs, labels = model.get_inputs_labels(sequences)
        outputs = model(inputs, lengths).detach().cpu()

        predictions = torch.round(sigmoid(outputs))
        predictions = predictions.detach().cpu().numpy()
        labels = labels.detach().cpu().numpy()

        y_test_pred = np.append(y_test_pred, predictions)
        y_test_labels = np.append(y_test_labels, labels)
        y_test_prob = np.append(y_test_prob, sigmoid(outputs))


# all_data_pred = np.append(y_train_pred, y_test_pred)
# all_data_labels = np.append(y_train_labels, y_test_labels)

Test Evaluation: 100%|██████████| 47/47 [00:07<00:00,  6.44 batch/s]


In [13]:
np.save(f"{ROOT}/lstm_y_test_pred_two_weeks.npy", y_test_pred)
np.save(f"{ROOT}/lstm_y_test_pred_two_weeks_labels.npy", y_test_labels)
np.save(f"{ROOT}/lstm_y_test_pred_two_weeks_prob.npy", y_test_prob)

In [None]:
### ASSESS MODEL PERFORMANCE ###

# Balanced Accuracy
y_train_accuracy = balanced_accuracy_score(y_train_labels, y_train_pred)
y_test_accuracy = balanced_accuracy_score(y_test_labels, y_test_pred)
all_data_accuracy = balanced_accuracy_score(all_data_labels, all_data_pred)

# F1 Score
y_train_f1 = f1_score(y_train_labels, y_train_pred)
y_test_f1 = f1_score(y_test_labels, y_test_pred)
all_data_f1 = f1_score(all_data_labels, all_data_pred)

# Precision
y_train_precision = precision_score(y_train_labels, y_train_pred)
y_test_precision = precision_score(y_test_labels, y_test_pred)
all_data_precision = precision_score(all_data_labels, all_data_pred)

# Recall
y_train_recall = recall_score(y_train_labels, y_train_pred)
y_test_recall = recall_score(y_test_labels, y_test_pred)
all_data_recall = recall_score(all_data_labels, all_data_pred)

# AUROC
y_train_auroc = roc_auc_score(y_train_labels, y_train_pred)
y_test_auroc = roc_auc_score(y_test_labels, y_test_pred)
all_data_auroc = roc_auc_score(all_data_labels, all_data_pred)

# AUC-PR (Area Under the Precision-Recall Curve)
y_train_p, y_train_r, _ = precision_recall_curve(y_train_labels, y_train_pred)
y_test_p, y_test_r, _ = precision_recall_curve(y_test_labels, y_test_pred)
all_data_p, all_data_r, _ = precision_recall_curve(all_data_labels, all_data_pred)

y_train_auc_pr = auc(y_train_r, y_train_p)
y_test_auc_pr = auc(y_test_r, y_test_p)
all_data_auc_pr = auc(all_data_r, all_data_p)

# Average Precision Score (APS)
y_train_aps = average_precision_score(y_train_labels, y_train_pred)
y_test_aps = average_precision_score(y_test_labels, y_test_pred)
all_data_aps = average_precision_score(all_data_labels, all_data_pred)

# Print Metrics
print(
    f"Balanced Accuracy\nTrain: {y_train_accuracy:.5f}  |  Test: {y_test_accuracy:.5f}  |  All Data: {all_data_accuracy:.5f}\n",
)
print(
    f"F1 Score\nTrain: {y_train_f1:.5f}  |  Test: {y_test_f1:.5f}  |  All Data: {all_data_f1:.5f}\n",
)
print(
    f"Precision\nTrain: {y_train_precision:.5f}  |  Test: {y_test_precision:.5f}  |  All Data: {all_data_precision:.5f}\n",
)
print(
    f"Recall\nTrain: {y_train_recall:.5f}  |  Test: {y_test_recall:.5f}  |  All Data: {all_data_recall:.5f}\n",
)
print(
    f"AUROC\nTrain: {y_train_auroc:.5f}  |  Test: {y_test_auroc:.5f}  |  All Data: {all_data_auroc:.5f}\n",
)
print(
    f"AUC-PR\nTrain: {y_train_auc_pr:.5f}  |  Test: {y_test_auc_pr:.5f}  |  All Data: {all_data_auc_pr:.5f}\n",
)
print(
    f"Average Precision Score\nTrain: {y_train_aps:.5f}  |  Test: {y_test_aps:.5f}  |  All Data: {all_data_aps:.5f}\n",
)

# Plot ROC Curve
fpr_train, tpr_train, _ = roc_curve(y_train_labels, y_train_pred)
fpr_test, tpr_test, _ = roc_curve(y_test_labels, y_test_pred)
fpr_all_data, tpr_all_data, _ = roc_curve(all_data_labels, all_data_pred)

# Plot Information
plt.figure(figsize=(10, 7))
plt.plot(fpr_train, tpr_train, label=f"Train AUROC={y_train_auroc:.2f}")
plt.plot(fpr_test, tpr_test, label=f"Test AUROC={y_test_auroc:.2f}")
plt.plot(fpr_all_data, tpr_all_data, label=f"All Data AUROC={all_data_auroc:.2f}")
plt.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Random")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic (ROC) Curve")
plt.legend()
plt.show()

In [None]:
"""
1. Evaluate 0.82 model, V4_Weighted_1 -> Still imbalanced
2. Train + Evaluate high lr model, V4_Weighted
3. Train + Evaluate medium lr model with changes, V4_Weighted_2
4. Train + Evaluate low lr model with changes, V4_Weighted_3
"""