In [1]:
%cd ..
from model.base import CompositeModel, FeatureEncodingPart, NodeRepresentationPart
from relbench.modeling.nn import HeteroTemporalEncoder

/home/lingze/embedding_fusion


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import relbench
from relbench.base import Table, Database, Dataset, EntityTask
from relbench.datasets import get_dataset
from relbench.tasks import get_task
from relbench.base import BaseTask
from torch_geometric.seed import seed_everything
from relbench.modeling.utils import get_stype_proposal
from relbench.modeling.graph import make_pkey_fkey_graph
from relbench.modeling.graph import get_node_train_table_input
from torch_geometric.loader import NeighborLoader

import os
import math
import numpy as np
from tqdm import tqdm
import copy

import torch
from torch import Tensor
import torch_geometric
import torch_frame

from torch_frame.config.text_embedder import TextEmbedderConfig
from typing import List, Optional
from sentence_transformers import SentenceTransformer

from torch.nn import L1Loss, BCEWithLogitsLoss
from sklearn.metrics import mean_absolute_error, r2_score, root_mean_squared_error
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

seed_everything(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
dataset = get_dataset(name="rel-trial", download=True)
db = dataset.get_db()
task_a = get_task("rel-trial", "study-outcome", download = True)
task_b = get_task("rel-trial", "site-success", download = True)

Loading Database object from /home/lingze/.cache/relbench/rel-trial/db...
Done in 7.85 seconds.


In [4]:
col_to_stype_dict = get_stype_proposal(db)

class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=256
)


root_dir = "/home/lingze/embedding_fusion/data"
data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,  # speficied column types
    text_embedder_cfg=text_embedder_cfg,  # our chosen text encoder
    cache_dir=os.path.join(
        root_dir, f"rel-trial_materialized_cache"
    ),  # store materialized graph for convenience
)

  return self.fget.__get__(instance, owner)()


In [5]:
def generate_loader_dict(task: BaseTask) -> dict:
    loader_dict = {}
    for split, table in [
        ("train", task.get_table("train")),
        ("val",task.get_table("val")),
        ("test", task.get_table("test")),
    ]:
        table_input = get_node_train_table_input(
            table=table,
            task=task,
        )
        loader_dict[split] = NeighborLoader(
            data,
            num_neighbors=[
                128 for i in range(2)
            ],  # we sample subgraphs of depth 2, 128 neighbors per node.
            time_attr="time",
            input_nodes=table_input.nodes,
            input_time=table_input.time,
            transform=table_input.transform,
            batch_size=512,
            temporal_strategy="uniform",
            shuffle=split == "train",
            num_workers=0,
            persistent_workers=False,
        )
    return loader_dict

taska_loader_dict = generate_loader_dict(task_a)
taskb_loader_dict = generate_loader_dict(task_b)

In [6]:
channels = 128
temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=channels,
)

feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=channels
)

task_a_node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.2
)

task_b_node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.3
)

task_a_model =  CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    dropout=0.2,
    aggr="mean",
    norm="layer_norm",
    num_layer=2,
    feature_encoder=feat_encoder,
    node_encoder=task_a_node_encoder,
    temporal_encoder=temporal_encoder
)

task_b_model =  CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    dropout=0.3,
    aggr="mean",
    norm="layer_norm",
    num_layer=2,
    feature_encoder=feat_encoder,
    node_encoder=task_b_node_encoder,
    temporal_encoder=temporal_encoder
)

In [7]:
@torch.no_grad()
def valid(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    model.eval()
    pred_list = []
    pred_hat_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
        pred_hat_list.append(batch[task.entity_table].y.detach().cpu())
    return torch.cat(pred_list, dim=0), torch.cat(pred_hat_list, dim=0)

@torch.no_grad()
def test(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    model.eval()
    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0)

In [10]:
state_dict = torch.load("task_a_model.pth")
task_a_model.load_state_dict(state_dict)
task_a_model.to(device)

CompositeModel(
  (gnn): HeteroGraphSAGE(
    (convs): ModuleList(
      (0-1): 2 x HeteroConv(num_relations=30)
    )
    (norms): ModuleList(
      (0-1): 2 x ModuleDict(
        (interventions): LayerNorm(128, affine=True, mode=node)
        (interventions_studies): LayerNorm(128, affine=True, mode=node)
        (facilities_studies): LayerNorm(128, affine=True, mode=node)
        (sponsors): LayerNorm(128, affine=True, mode=node)
        (eligibilities): LayerNorm(128, affine=True, mode=node)
        (reported_event_totals): LayerNorm(128, affine=True, mode=node)
        (designs): LayerNorm(128, affine=True, mode=node)
        (conditions_studies): LayerNorm(128, affine=True, mode=node)
        (drop_withdrawals): LayerNorm(128, affine=True, mode=node)
        (studies): LayerNorm(128, affine=True, mode=node)
        (outcome_analyses): LayerNorm(128, affine=True, mode=node)
        (sponsors_studies): LayerNorm(128, affine=True, mode=node)
        (outcomes): LayerNorm(128, affine

In [11]:
# test task A
test_logits = test(taska_loader_dict["test"], task_a_model, task_a)
test_logits =  torch.sigmoid(test_logits).numpy()

test_pred = (test_logits > 0.5).astype(int)
test_pred_hat = task_a.get_table("test", mask_input_cols = False).df[task_a.target_col].to_numpy()
test_metrics = {
    "auroc": roc_auc_score(test_pred_hat, test_logits),
    "accuracy": accuracy_score(test_pred_hat, test_pred),
    "precision": precision_score(test_pred_hat, test_pred),
    "recall": recall_score(test_pred_hat, test_pred),
    "f1score": f1_score(test_pred_hat, test_pred),
}
test_metrics

{'auroc': 0.6789013596793916,
 'accuracy': 0.5878787878787879,
 'precision': 0.5868772782503038,
 'recall': 1.0,
 'f1score': 0.7396630934150077}

transfer to next task, we freeze the feat_encoder and temporal_encoder

In [12]:
for param in feat_encoder.parameters():
    param.requires_grad = False
for param in temporal_encoder.parameters():
    param.requires_grad = False

In [15]:
# check whether the model is frozen in task_b_model
# for name, param in task_b_model.named_parameters():
#     if not param.requires_grad:
#         print(name)

train task b

In [16]:
optimizer = torch.optim.Adam(task_b_model.parameters(), lr=0.002)
epochs = 20
loss_fn = L1Loss()
tune_metric = "mae"
higher_is_better = False

In [17]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
task_b_model.to(device)
best_epoch = 0
early_stop = 20
# train
for epoch in range(1, epochs + 1):
    task_b_model.train()
    
    cnt = 0
    loss_accum = count_accum = 0
    for batch in tqdm(taskb_loader_dict["train"]):
        cnt += 1
        if cnt > early_stop:
            break
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = task_b_model(
            batch,
            task_b.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        loss = loss_fn(pred, batch[task_b.entity_table].y.float())
        
        loss.backward()
        optimizer.step()
        
        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    train_loss = loss_accum / count_accum
    val_pred_hat = task_b.get_table("val").df[task_b.target_col].to_numpy()
    val_logits = test(taskb_loader_dict["val"], task_b_model, task_b)
    val_logits = val_logits.numpy()
    
    val_metrics = {
        "mae": mean_absolute_error(val_pred_hat, val_logits),
        "r2": r2_score(val_pred_hat, val_logits),
        "rmse": root_mean_squared_error(val_pred_hat, val_logits),
    }
    
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

    
    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_epoch = epoch
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(task_b_model.state_dict())

best_epoch

  0%|          | 0/296 [00:00<?, ?it/s]

  7%|▋         | 20/296 [00:02<00:32,  8.61it/s]


Epoch: 01, Train loss: 0.5536862790584565, Val metrics: {'mae': 0.4670303391773532, 'r2': -0.38079538254706535, 'rmse': 0.5613250623192587}


  7%|▋         | 20/296 [00:02<00:31,  8.78it/s]


Epoch: 02, Train loss: 0.45024473965168, Val metrics: {'mae': 0.4661111583558287, 'r2': -0.26958255503306106, 'rmse': 0.538245296443496}


  7%|▋         | 20/296 [00:02<00:33,  8.25it/s]


Epoch: 03, Train loss: 0.4451788246631622, Val metrics: {'mae': 0.4669315650630681, 'r2': -0.48348847423170827, 'rmse': 0.5818243037099271}


  7%|▋         | 20/296 [00:02<00:32,  8.56it/s]


Epoch: 04, Train loss: 0.41880103051662443, Val metrics: {'mae': 0.41400271050683846, 'r2': -0.1293648983073823, 'rmse': 0.5076529484174325}


  7%|▋         | 20/296 [00:01<00:27, 10.01it/s]


Epoch: 05, Train loss: 0.3810374766588211, Val metrics: {'mae': 0.42254258938368433, 'r2': -0.3294033846577986, 'rmse': 0.5507799981981856}


  7%|▋         | 20/296 [00:02<00:32,  8.56it/s]


Epoch: 06, Train loss: 0.3579500004649162, Val metrics: {'mae': 0.41120197855205604, 'r2': -0.30060448276381346, 'rmse': 0.5447815531243743}


  7%|▋         | 20/296 [00:01<00:27, 10.04it/s]


Epoch: 07, Train loss: 0.35562915056943895, Val metrics: {'mae': 0.42772684085532253, 'r2': -0.4230109835497484, 'rmse': 0.5698412661589188}


  7%|▋         | 20/296 [00:01<00:27, 10.15it/s]


Epoch: 08, Train loss: 0.3374470114707947, Val metrics: {'mae': 0.44092946729540966, 'r2': -0.5681434475730074, 'rmse': 0.5981948421486004}


  7%|▋         | 20/296 [00:01<00:27, 10.10it/s]


Epoch: 09, Train loss: 0.32877634167671205, Val metrics: {'mae': 0.43324444231886056, 'r2': -0.46738179516942546, 'rmse': 0.5786571631618113}


  7%|▋         | 20/296 [00:02<00:28,  9.81it/s]


Epoch: 10, Train loss: 0.32383287996053695, Val metrics: {'mae': 0.43045193042991653, 'r2': -0.4798654928277708, 'rmse': 0.5811134025983531}


  7%|▋         | 20/296 [00:01<00:26, 10.27it/s]


Epoch: 11, Train loss: 0.32107689082622526, Val metrics: {'mae': 0.44113570277508013, 'r2': -0.5535773375700734, 'rmse': 0.5954101158791022}


  7%|▋         | 20/296 [00:01<00:27, 10.12it/s]


Epoch: 12, Train loss: 0.31160512268543245, Val metrics: {'mae': 0.453359223537947, 'r2': -0.6399061731582747, 'rmse': 0.6117292833285612}


  7%|▋         | 20/296 [00:02<00:31,  8.66it/s]


Epoch: 13, Train loss: 0.3099339172244072, Val metrics: {'mae': 0.4528615220232609, 'r2': -0.6003115625508668, 'rmse': 0.6042992312488685}


  7%|▋         | 20/296 [00:02<00:31,  8.65it/s]


Epoch: 14, Train loss: 0.29516699761152265, Val metrics: {'mae': 0.43444881092666976, 'r2': -0.5241348213868124, 'rmse': 0.5897411915106154}


  7%|▋         | 20/296 [00:01<00:26, 10.52it/s]


Epoch: 15, Train loss: 0.2993515580892563, Val metrics: {'mae': 0.4458842256669376, 'r2': -0.5833317776610727, 'rmse': 0.601084784048828}


  7%|▋         | 20/296 [00:01<00:25, 10.69it/s]


Epoch: 16, Train loss: 0.28868870362639426, Val metrics: {'mae': 0.45690438143991075, 'r2': -0.6729927606309496, 'rmse': 0.6178695618413682}


  7%|▋         | 20/296 [00:01<00:25, 10.62it/s]


Epoch: 17, Train loss: 0.2758818969130516, Val metrics: {'mae': 0.4592558245716642, 'r2': -0.7207069873651701, 'rmse': 0.626618526981728}


  7%|▋         | 20/296 [00:01<00:26, 10.61it/s]


Epoch: 18, Train loss: 0.27912004888057707, Val metrics: {'mae': 0.44715835399006526, 'r2': -0.5982601712242293, 'rmse': 0.6039117905296998}


  7%|▋         | 20/296 [00:01<00:26, 10.32it/s]


Epoch: 19, Train loss: 0.2834742233157158, Val metrics: {'mae': 0.44384483381062534, 'r2': -0.5849991529565042, 'rmse': 0.6014011960018354}


  7%|▋         | 20/296 [00:01<00:25, 10.74it/s]


Epoch: 20, Train loss: 0.2772685319185257, Val metrics: {'mae': 0.44986738136243254, 'r2': -0.5758220379729724, 'rmse': 0.5996576177671246}


6

In [18]:
# test
task_b_model.load_state_dict(state_dict)
logits = test(taskb_loader_dict["test"], task_b_model, task_b)
logits = logits.numpy()
pred_hat = task_b.get_table("test", mask_input_cols=False).df[task_b.target_col].to_numpy()
test_metrics = {
        "mae": mean_absolute_error(pred_hat, logits),
        "r2": r2_score(pred_hat, logits),
        "rmse": root_mean_squared_error(pred_hat, logits),
}
test_metrics

{'mae': 0.4060025686323995,
 'r2': -0.32183112975279005,
 'rmse': 0.5532634466844648}

task_b_model load, freeze and fine-tune in task A

In [31]:
task_a_model.reset_parameters()
task_b_model.reset_parameters()

In [32]:
state_dict = torch.load("task_b_model.pth")
task_b_model.load_state_dict(state_dict)

<All keys matched successfully>

In [33]:
# test
logits = test(taskb_loader_dict["test"], task_b_model, task_b)
logits = logits.numpy()
pred_hat = task_b.get_table("test", mask_input_cols=False).df[task_b.target_col].to_numpy()
test_metrics = {
        "mae": mean_absolute_error(pred_hat, logits),
        "r2": r2_score(pred_hat, logits),
        "rmse": root_mean_squared_error(pred_hat, logits),
}
test_metrics

{'mae': 0.4100478495823158,
 'r2': -0.5215441797977363,
 'rmse': 0.5935895870078363}

In [34]:
# check whether the model is frozen in task_b_model
# for name, param in task_a_model.named_parameters():
#     if not param.requires_grad:
#         print(name)

In [35]:
# Train Task A
optimizer = torch.optim.Adam(task_a_model.parameters(), lr=0.005)
epochs = 15
loss_fn = BCEWithLogitsLoss()
tune_metric = "auroc"
higher_is_better = True

In [36]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
task_a_model.to(device)
best_epoch = 0

for epoch in range(1, epochs + 1):
    task_a_model.train()
    loss_accum = count_accum = 0
    for batch in tqdm(taska_loader_dict["train"]):
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = task_a_model(
            batch,
            task_a.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        loss = loss_fn(pred, batch[task_a.entity_table].y.float())
        
        loss.backward()
        optimizer.step()
        
        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)
    
    train_loss = loss_accum / count_accum
    val_logits = test(taska_loader_dict["val"], task_a_model, task_a)
    val_logits = torch.sigmoid(val_logits).numpy()
    
    val_pred = (val_logits > 0.5).astype(int)
    val_pred_hat = task_a.get_table("val").df[task_a.target_col].to_numpy()
    val_metrics = {
            "auroc": roc_auc_score(val_pred_hat, val_logits),
        "accuracy": accuracy_score(val_pred_hat, val_pred),
        "precision": precision_score(val_pred_hat, val_pred),
        "recall": recall_score(val_pred_hat, val_pred),
        "f1": f1_score(val_pred_hat, val_pred),
    }
    
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_epoch = epoch
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(task_a_model.state_dict())

best_epoch

  4%|▍         | 1/24 [00:00<00:03,  6.17it/s]

100%|██████████| 24/24 [00:03<00:00,  6.64it/s]


Epoch: 01, Train loss: 0.6657383552546976, Val metrics: {'auroc': 0.6363636363636364, 'accuracy': 0.6229166666666667, 'precision': 0.6478454680534919, 'recall': 0.7771836007130125, 'f1': 0.706645056726094}


100%|██████████| 24/24 [00:03<00:00,  6.68it/s]


Epoch: 02, Train loss: 0.6018843909929609, Val metrics: {'auroc': 0.6720008577593717, 'accuracy': 0.6385416666666667, 'precision': 0.6573529411764706, 'recall': 0.7967914438502673, 'f1': 0.7203867848509267}


100%|██████████| 24/24 [00:03<00:00,  7.02it/s]


Epoch: 03, Train loss: 0.5897108795882583, Val metrics: {'auroc': 0.6694499171279358, 'accuracy': 0.6447916666666667, 'precision': 0.6498637602179836, 'recall': 0.8502673796791443, 'f1': 0.7366795366795367}


100%|██████████| 24/24 [00:03<00:00,  7.64it/s]


Epoch: 04, Train loss: 0.5781220834772448, Val metrics: {'auroc': 0.6626771920889568, 'accuracy': 0.615625, 'precision': 0.6371428571428571, 'recall': 0.7950089126559715, 'f1': 0.7073750991276765}


100%|██████████| 24/24 [00:03<00:00,  7.66it/s]


Epoch: 05, Train loss: 0.5663542633496346, Val metrics: {'auroc': 0.665022627870925, 'accuracy': 0.6395833333333333, 'precision': 0.6431424766977364, 'recall': 0.8609625668449198, 'f1': 0.7362804878048781}


100%|██████████| 24/24 [00:03<00:00,  7.44it/s]


Epoch: 06, Train loss: 0.557119338407464, Val metrics: {'auroc': 0.663767261290481, 'accuracy': 0.6145833333333334, 'precision': 0.6423248882265276, 'recall': 0.768270944741533, 'f1': 0.6996753246753247}


100%|██████████| 24/24 [00:03<00:00,  7.59it/s]


Epoch: 07, Train loss: 0.5531355584981701, Val metrics: {'auroc': 0.6667292116208525, 'accuracy': 0.63125, 'precision': 0.6716417910447762, 'recall': 0.7219251336898396, 'f1': 0.6958762886597938}


100%|██████████| 24/24 [00:03<00:00,  6.98it/s]


Epoch: 08, Train loss: 0.5408712819713105, Val metrics: {'auroc': 0.670146846617435, 'accuracy': 0.6416666666666667, 'precision': 0.6472184531886025, 'recall': 0.8502673796791443, 'f1': 0.7349768875192604}


100%|██████████| 24/24 [00:03<00:00,  7.61it/s]


Epoch: 09, Train loss: 0.5303818412089082, Val metrics: {'auroc': 0.659992226555694, 'accuracy': 0.6322916666666667, 'precision': 0.676271186440678, 'recall': 0.7112299465240641, 'f1': 0.6933101650738488}


100%|██████████| 24/24 [00:03<00:00,  7.52it/s]


Epoch: 10, Train loss: 0.5214789944409092, Val metrics: {'auroc': 0.6278664575878199, 'accuracy': 0.603125, 'precision': 0.6505016722408027, 'recall': 0.6934046345811051, 'f1': 0.6712683347713546}


100%|██████████| 24/24 [00:03<00:00,  7.42it/s]


Epoch: 11, Train loss: 0.5181294286352447, Val metrics: {'auroc': 0.6630882017878922, 'accuracy': 0.6239583333333333, 'precision': 0.6773049645390071, 'recall': 0.6809269162210339, 'f1': 0.6791111111111111}


100%|██████████| 24/24 [00:03<00:00,  7.57it/s]


Epoch: 12, Train loss: 0.5037265925262697, Val metrics: {'auroc': 0.6608276484437474, 'accuracy': 0.6197916666666666, 'precision': 0.6380281690140845, 'recall': 0.8074866310160428, 'f1': 0.7128245476003147}


100%|██████████| 24/24 [00:03<00:00,  7.75it/s]


Epoch: 13, Train loss: 0.49312296955510343, Val metrics: {'auroc': 0.6622304424162009, 'accuracy': 0.615625, 'precision': 0.6415929203539823, 'recall': 0.7754010695187166, 'f1': 0.7021791767554479}


100%|██████████| 24/24 [00:03<00:00,  6.17it/s]


Epoch: 14, Train loss: 0.49578280023123517, Val metrics: {'auroc': 0.6600100965426043, 'accuracy': 0.6135416666666667, 'precision': 0.6409495548961425, 'recall': 0.7700534759358288, 'f1': 0.6995951417004048}


100%|██████████| 24/24 [00:03<00:00,  6.68it/s]


Epoch: 15, Train loss: 0.47899952708770377, Val metrics: {'auroc': 0.6566237340231149, 'accuracy': 0.6208333333333333, 'precision': 0.6508422664624809, 'recall': 0.7575757575757576, 'f1': 0.700164744645799}


2

In [37]:
# test task A
task_a_model.load_state_dict(state_dict)
test_logits = test(taska_loader_dict["test"], task_a_model, task_a)
test_logits =  torch.sigmoid(test_logits).numpy()

test_pred = (test_logits > 0.5).astype(int)
test_pred_hat = task_a.get_table("test", mask_input_cols = False).df[task_a.target_col].to_numpy()
test_metrics = {
    "auroc": roc_auc_score(test_pred_hat, test_logits),
    "accuracy": accuracy_score(test_pred_hat, test_pred),
    "precision": precision_score(test_pred_hat, test_pred),
    "recall": recall_score(test_pred_hat, test_pred),
    "f1score": f1_score(test_pred_hat, test_pred),
}
test_metrics

{'auroc': 0.7135955831608005,
 'accuracy': 0.6593939393939394,
 'precision': 0.6778169014084507,
 'recall': 0.7971014492753623,
 'f1score': 0.7326355851569933}