In [1]:
%cd ..

/home/lingze/embedding_fusion


In [2]:
import numpy as np
import pandas as pd
from relbench.datasets import get_dataset
from relbench.base import Table
from tqdm import tqdm
from typing import Any,Dict

import torch
import pickle
import os
from torch import Tensor
from torch_frame import stype
from torch_frame.config import TextEmbedderConfig
from torch_frame.data import Dataset
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.typing import NodeType
from torch_geometric.utils import sort_edge_index

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset = get_dataset(name = "rel-trial", download = True)
db = dataset.get_db()
cache_path = "data/rel-trial-tensor-frame"

Loading Database object from /home/lingze/.cache/relbench/rel-trial/db...
Done in 8.01 seconds.


In [4]:
# [NOTE]: the dataset has been materialized

# get infer_type in cache
type_path = os.path.join(cache_path,"col_type_dict.pkl")
col_type_dict = pickle.load(open(type_path, "rb"))
len(col_type_dict)

# add "compress_text" in each table in case 
for table_name, table in db.table_dict.items():
    table.df["text_compress"] = np.nan

In [5]:
from typing import List, Optional
from torch_frame.config.text_embedder import TextEmbedderConfig
from sentence_transformers import SentenceTransformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            # "all-MiniLM-L12-v2",
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=512
)

def remove_pkey_fkey(col_to_stype: Dict[str, Any], table:Table) -> dict:
    r"""Remove pkey, fkey columns since they will not be used as input feature."""
    if table.pkey_col is not None:
        if table.pkey_col in col_to_stype:
            col_to_stype.pop(table.pkey_col)
    for fkey in table.fkey_col_to_pkey_table.keys():
        if fkey in col_to_stype:
            col_to_stype.pop(fkey)

def to_unix_time(ser: pd.Series) -> np.ndarray:
    r"""Converts a :class:`pandas.Timestamp` series to UNIX timestamp (in seconds)."""
    assert ser.dtype in [np.dtype("datetime64[s]"), np.dtype("datetime64[ns]")]
    unix_time = ser.astype("int64").values
    if ser.dtype == np.dtype("datetime64[ns]"):
        unix_time //= 10**9
    return unix_time

  return self.fget.__get__(instance, owner)()


In [6]:
# build graph

# start build graph
cache_dir = "./data/rel-trial-tensor-frame"
if cache_dir is not None:
    os.makedirs(cache_dir, exist_ok=True)
data = HeteroData()
col_stats_dict = {}
for table_name, table in db.table_dict.items():
    df = table.df
    # (important for foreignKey value) Ensure the pkey is consecutive
    if table.pkey_col is not None:
        assert (df[table.pkey_col].values == np.arange(len(df))).all()
    
    col_to_stype = col_type_dict[table_name]
    
    # remove pkey, fkey
    remove_pkey_fkey(col_to_stype, table)
    
    if len(col_to_stype) == 0:
        # for example, relationship table which only contains pkey and fkey
        raise KeyError(f"{table_name} has no column to build graph")
    
    path = (
            None if cache_dir is None else os.path.join(cache_dir, f"{table_name}.pt")
    )
    
    print(f"-----> Materialize {table_name} Tensor Frame")
    dataset = Dataset(
        df = df,
        col_to_stype=col_to_stype,
        col_to_text_embedder_cfg=text_embedder_cfg,
    ).materialize(path=path)
    
    data[table_name].tf = dataset.tensor_frame
    col_stats_dict[table_name] = dataset.col_stats
    
    # Add time attribute
    if table.time_col is not None:
        data[table_name].time = torch.from_numpy(
            to_unix_time(df[table.time_col])
        )
    
    # Add edges normal edges
    for fkey_col_name, pkey_table_name in table.fkey_col_to_pkey_table.items():
        pkey_index = df[fkey_col_name]
        # Filter out dangling foreign keys
        mask = ~pkey_index.isna()
        fkey_index = torch.arange(len(pkey_index))
        
        # filter dangling foreign keys:
        pkey_index = torch.from_numpy(pkey_index[mask].astype(int).values)
        fkey_index = fkey_index[torch.from_numpy(mask.values)]
        
        # fkey -> pkey edges
        edge_index = torch.stack([fkey_index, pkey_index], dim=0)
        edge_type = (table_name, f"f2p_{fkey_col_name}", pkey_table_name)
        data[edge_type].edge_index = sort_edge_index(edge_index)

        # pkey -> fkey edges.
        # "rev_" is added so that PyG loader recognizes the reverse edges
        edge_index = torch.stack([pkey_index, fkey_index], dim=0)
        edge_type = (pkey_table_name, f"rev_f2p_{fkey_col_name}", table_name)
        data[edge_type].edge_index = sort_edge_index(edge_index)
    
data.validate()

-----> Materialize interventions Tensor Frame
-----> Materialize interventions_studies Tensor Frame
-----> Materialize facilities_studies Tensor Frame
-----> Materialize sponsors Tensor Frame
-----> Materialize eligibilities Tensor Frame
-----> Materialize reported_event_totals Tensor Frame
-----> Materialize designs Tensor Frame
-----> Materialize conditions_studies Tensor Frame
-----> Materialize drop_withdrawals Tensor Frame
-----> Materialize studies Tensor Frame
-----> Materialize outcome_analyses Tensor Frame
-----> Materialize sponsors_studies Tensor Frame
-----> Materialize outcomes Tensor Frame
-----> Materialize conditions Tensor Frame
-----> Materialize facilities Tensor Frame


True

In [7]:
from relbench.tasks import get_task
from relbench.modeling.graph import get_node_train_table_input
from torch_geometric.loader import NeighborLoader
from relbench.base import BaseTask
from model.base import CompositeModel, FeatureEncodingPart, NodeRepresentationPart
from relbench.modeling.nn import HeteroTemporalEncoder
# start to fine-train on the task a
from torch.nn import BCEWithLogitsLoss
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import math
import copy


In [8]:
task_a = get_task("rel-trial", "study-outcome", download = True)
entity_table = task_a.entity_table

In [9]:
def generate_loader_dict(task: BaseTask, data:HeteroData) -> dict:
    loader_dict = {}
    for split, table in [
        ("train", task.get_table("train")),
        ("val",task.get_table("val")),
        ("test", task.get_table("test")),
    ]:
        table_input = get_node_train_table_input(
            table=table,
            task=task,
        )
        loader_dict[split] = NeighborLoader(
            data,
            num_neighbors=[
                128 for i in range(2)
            ],  # we sample subgraphs of depth 2, 128 neighbors per node.
            time_attr="time",
            input_nodes=table_input.nodes,
            input_time=table_input.time,
            transform=table_input.transform,
            batch_size=512,
            temporal_strategy="uniform",
            shuffle=split == "train",
            num_workers=0,
            persistent_workers=False,
        )
    return loader_dict

In [10]:
@torch.no_grad()
def valid(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    model.eval()
    pred_list = []
    pred_hat_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
        pred_hat_list.append(batch[task.entity_table].y.detach().cpu())
    return torch.cat(pred_list, dim=0), torch.cat(pred_hat_list, dim=0)

@torch.no_grad()
def test(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    model.eval()
    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0)

In [11]:
# add the additional edges
from utils.util import load_np_dict
edge_dict = load_np_dict("./edges/rel-trail-edges.npz")

In [12]:
for edge_name, edge_np in edge_dict.items():
    src_table, dst_table = edge_name.split('-')[0], edge_name.split('-')[1]
    edge_index = torch.from_numpy(edge_np.astype(int)).t()
    # [2, edge_num]
    edge_type = (src_table, f"appendix", dst_table)
    data[edge_type].edge_index = sort_edge_index(edge_index)
data.validate()

True

In [13]:
# read the pre-extracted sample
from utils.util import load_np_dict
sample_dict = load_np_dict("./samples/rel-trail-samples.npz")
sample_dict.keys()

dict_keys(['interventions', 'sponsors', 'eligibilities', 'designs', 'studies', 'conditions'])

In [14]:
from relbench.base import Database
def neighborsample_batch(
    db: Database,
    entity_table: str,
    node_idxs: np.ndarray,
    num_neighbors: List[int] = [128,128],
):
    # node_idxs: [n]
    nodes = (entity_table, torch.from_numpy(node_idxs))
    n = node_idxs.shape[0]
    input_time = torch.from_numpy(
        to_unix_time(pd.Series([db.max_timestamp] * n)))

    if db.table_dict[entity_table].time_col:
        time_col = db.table_dict[entity_table].time_col
        time_values = db.table_dict[entity_table].df[time_col].loc[node_idxs.tolist(
        )]
        input_time = torch.from_numpy(to_unix_time(time_values))

    loader = NeighborLoader(
        data,
        num_neighbors=num_neighbors,
        input_nodes=nodes,
        time_attr = "time",
        input_time=input_time,
        batch_size=n,
        temporal_strategy="uniform",
        shuffle=False,
        disjoint=True,
        num_workers=0,
        persistent_workers=False,
    )
    return next(iter(loader))
    

In [15]:
# construct bottom model
channels = 64
temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
    channels=channels,
)

feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=channels
)

node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.2
)


net = CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    dropout=0.2,
    aggr="mean",
    norm="layer_norm",
    num_layer=2,
    feature_encoder=feat_encoder,
    node_encoder=node_encoder,
    temporal_encoder=temporal_encoder
)

In [16]:
from model.utils import InfoNCE
lr = 5e-4
negative_num = 20
negative_sample_pool_size = 512
temprature = 0.01
net.reset_parameters()
optimizer = torch.optim.Adam(net.parameters(), lr = lr)
epoches = 20
early_restart_steps = 20
loss_fn = InfoNCE(temperature=temprature, negative_mode='paired')

In [17]:
net.to(device)
for epoch in range(1, epoches + 1):
    net.train()
    print("*"*30 + f"<Epoch: {epoch:02d}>" + "*"*30)
    for sample_table, sample_np in sample_dict.items():
        loss_accum = count_accum = 0
        shuffle_sample_np = sample_np[np.random.permutation(len(sample_np))]
        anchor_nodes_np = shuffle_sample_np[:, 0]
        positive_pool_np = shuffle_sample_np[:, 1:]
        # choose the positive samples
        n = sample_np.shape[0]
        batch_size = 256
        for batch_idx in tqdm(range(0, n, batch_size), leave=False):
            anchor_nodes = anchor_nodes_np[batch_idx:batch_idx+batch_size]
            positive_pool_batch_np = positive_pool_np[batch_idx:batch_idx+batch_size]

            positive_nodes = []
            # random select the positive samples
            for row in positive_pool_batch_np:
                valid = row[row != -1]
                random_choice = np.random.choice(valid, 1)[0]
                positive_nodes.append(random_choice)

            positive_nodes = np.array(positive_nodes)
            B = positive_nodes.size
            # random select the negative sample, negative ratio is 1:20
            # for one batch, we still extract batch_size negative samples
            # for each positive-negative pair, we extract 20 from this 256 batch as negative samples
            excluded = set(positive_nodes.tolist()).union(
                set(anchor_nodes.tolist()))
            negative_candidates = list(set(range(n)) - excluded)
            sample_size = min(negative_sample_pool_size, n - 2 * batch_size)
            negative_nodes = np.random.choice(
                negative_candidates, size=sample_size, replace=False)
            # [batch_size]

            # neighbor hood loader
            anchor_nodes_batch = neighborsample_batch(
                db, sample_table, anchor_nodes)
            positive_nodes_batch = neighborsample_batch(
                db, sample_table, positive_nodes)
            negative_nodes_batch = neighborsample_batch(
                db, sample_table, negative_nodes)

            optimizer.zero_grad()

            anchor_nodes_batch, positive_nodes_batch, negative_nodes_batch = \
                anchor_nodes_batch.to(device), positive_nodes_batch.to(
                    device), negative_nodes_batch.to(device)

            anchor_nodes_embedding = net.get_node_embedding(
                anchor_nodes_batch, sample_table)[sample_table][:B]
            positive_nodes_embedding = net.get_node_embedding(
                positive_nodes_batch, sample_table)[sample_table][:B]
            negative_nodes_embedding = net.get_node_embedding(
                negative_nodes_batch, sample_table)[sample_table][:negative_sample_pool_size]

            # negative_nodes_embedding = net.get_node_embedding(negative_nodes_batch, sample_table)[sample_table][:B]
            # [B, D]

            negative_indices = torch.stack([torch.randperm(negative_sample_pool_size)[
                                           :negative_num] for _ in range(B)]).to(device)
            negative_nodes_embedding = negative_nodes_embedding[negative_indices]
            # [B, negative_num, D]

            loss = loss_fn(anchor_nodes_embedding,
                           positive_nodes_embedding, negative_nodes_embedding)
            loss.backward()
            optimizer.step()
            loss_accum += loss.detach().item()
            count_accum += B

        train_loss = loss_accum / count_accum

        print(f"====> In {sample_table}, Train loss: {train_loss}")

******************************<Epoch: 01>******************************


  0%|          | 0/4 [00:00<?, ?it/s]

                                             

====> In interventions, Train loss: 0.062238225488529426


                                               

====> In sponsors, Train loss: 0.03788315336574823


                                               

====> In eligibilities, Train loss: 0.04196471298903837


                                               

====> In designs, Train loss: 0.04364552577870004


                                               

====> In studies, Train loss: 0.02888180746511677


                                             

====> In conditions, Train loss: 0.06378715051147367
******************************<Epoch: 02>******************************


                                             

====> In interventions, Train loss: 0.0391655974042764


                                               

====> In sponsors, Train loss: 0.01864755450015985


                                               

====> In eligibilities, Train loss: 0.024557861732412468


                                               

====> In designs, Train loss: 0.023279800191021605


                                               

====> In studies, Train loss: 0.0156848581551436


                                             

====> In conditions, Train loss: 0.03460731021801814
******************************<Epoch: 03>******************************


                                             

====> In interventions, Train loss: 0.031730862705952934


                                               

====> In sponsors, Train loss: 0.01524015001570858


                                               

====> In eligibilities, Train loss: 0.01632812740045683


                                               

====> In designs, Train loss: 0.016565733627984986


                                               

====> In studies, Train loss: 0.00952724252632041


                                             

====> In conditions, Train loss: 0.027529434941198572
******************************<Epoch: 04>******************************


                                             

====> In interventions, Train loss: 0.025704143916788064


                                               

====> In sponsors, Train loss: 0.013695520836952553


                                               

====> In eligibilities, Train loss: 0.009656839608093286


                                               

====> In designs, Train loss: 0.0061946647479230125


                                               

====> In studies, Train loss: 0.0031405676216317837


                                             

====> In conditions, Train loss: 0.024628428884246333
******************************<Epoch: 05>******************************


                                             

====> In interventions, Train loss: 0.025394053138044617


                                               

====> In sponsors, Train loss: 0.012902770049677229


                                               

====> In eligibilities, Train loss: 0.0028193266886585165


                                               

====> In designs, Train loss: 0.0018383530262332634


                                               

====> In studies, Train loss: 0.0016183833492457778


                                             

====> In conditions, Train loss: 0.023859276599017723
******************************<Epoch: 06>******************************


                                             

====> In interventions, Train loss: 0.021717425071299454


                                               

====> In sponsors, Train loss: 0.01208125776835712


                                               

====> In eligibilities, Train loss: 0.0015788547895080065


                                               

====> In designs, Train loss: 0.0012537793800134786


                                               

====> In studies, Train loss: 0.0012949057162293248


                                             

====> In conditions, Train loss: 0.02080405134710557
******************************<Epoch: 07>******************************


                                             

====> In interventions, Train loss: 0.021524788615361255


                                               

====> In sponsors, Train loss: 0.011355800957682167


                                               

====> In eligibilities, Train loss: 0.0009664445636933801


                                               

====> In designs, Train loss: 0.001062937080860138


                                               

====> In studies, Train loss: 0.0014687424547590276


                                             

====> In conditions, Train loss: 0.019473901430031627
******************************<Epoch: 08>******************************


                                             

====> In interventions, Train loss: 0.019673824613248833


                                               

====> In sponsors, Train loss: 0.01078202402694524


                                               

====> In eligibilities, Train loss: 0.0011012016745850246


                                               

====> In designs, Train loss: 0.0008631700077937954


                                               

====> In studies, Train loss: 0.0013275404252925174


                                             

====> In conditions, Train loss: 0.01856587298380769
******************************<Epoch: 09>******************************


                                             

====> In interventions, Train loss: 0.018570683962811054


                                               

====> In sponsors, Train loss: 0.010406878622336656


                                               

====> In eligibilities, Train loss: 0.0009634974717328176


                                               

====> In designs, Train loss: 0.0007515948715826009


                                               

====> In studies, Train loss: 0.001319632879878461


                                             

====> In conditions, Train loss: 0.01804514680117988
******************************<Epoch: 10>******************************


                                             

====> In interventions, Train loss: 0.018710287886623204


                                               

====> In sponsors, Train loss: 0.009778061534837077


                                               

====> In eligibilities, Train loss: 0.0006986238071354743


                                               

====> In designs, Train loss: 0.000680764713293354


                                               

====> In studies, Train loss: 0.0012997585854908223


                                             

====> In conditions, Train loss: 0.018075867558186378
******************************<Epoch: 11>******************************


                                             

====> In interventions, Train loss: 0.017767976169513474


                                               

====> In sponsors, Train loss: 0.00941772469404928


                                               

====> In eligibilities, Train loss: 0.0007967145680567673


                                               

====> In designs, Train loss: 0.0006336165067033479


                                               

====> In studies, Train loss: 0.001151041461494916


                                             

====> In conditions, Train loss: 0.01767837441453941
******************************<Epoch: 12>******************************


                                             

====> In interventions, Train loss: 0.016658049526481193


                                               

====> In sponsors, Train loss: 0.008751378332764168


                                               

====> In eligibilities, Train loss: 0.0008700251016976699


                                               

====> In designs, Train loss: 0.0007002147152119835


                                               

====> In studies, Train loss: 0.0011366646830317467


                                             

====> In conditions, Train loss: 0.017512676438705292
******************************<Epoch: 13>******************************


                                             

====> In interventions, Train loss: 0.017623140972481598


                                               

====> In sponsors, Train loss: 0.008527663774988726


                                               

====> In eligibilities, Train loss: 0.000606059821697183


                                               

====> In designs, Train loss: 0.0007075952503505169


                                               

====> In studies, Train loss: 0.0009912110556893063


                                             

====> In conditions, Train loss: 0.016573648988695858
******************************<Epoch: 14>******************************


                                             

====> In interventions, Train loss: 0.016811940994165754


                                               

====> In sponsors, Train loss: 0.007981601975884673


                                               

====> In eligibilities, Train loss: 0.0005705205911328543


                                               

====> In designs, Train loss: 0.0007040010152647159


                                               

====> In studies, Train loss: 0.0012163012912848163


                                             

====> In conditions, Train loss: 0.016163392467072966
******************************<Epoch: 15>******************************


                                             

====> In interventions, Train loss: 0.016744648486262205


                                               

====> In sponsors, Train loss: 0.00790906959472726


                                               

====> In eligibilities, Train loss: 0.0005886197757141889


                                               

====> In designs, Train loss: 0.0006747069656098849


                                               

====> In studies, Train loss: 0.0009869442682346248


                                             

====> In conditions, Train loss: 0.016376922642661572
******************************<Epoch: 16>******************************


                                             

====> In interventions, Train loss: 0.01571742643092154


                                               

====> In sponsors, Train loss: 0.0074534168224054086


                                               

====> In eligibilities, Train loss: 0.00048562384833776535


                                               

====> In designs, Train loss: 0.0005909282626621675


                                               

====> In studies, Train loss: 0.0010269254824809802


                                             

====> In conditions, Train loss: 0.016226033048504954
******************************<Epoch: 17>******************************


                                             

====> In interventions, Train loss: 0.016449072763910742


                                               

====> In sponsors, Train loss: 0.007185626066257239


                                               

====> In eligibilities, Train loss: 0.0004889796305162176


                                               

====> In designs, Train loss: 0.0005459495763850692


                                               

====> In studies, Train loss: 0.0011367124922649074


                                             

====> In conditions, Train loss: 0.015860129357118436
******************************<Epoch: 18>******************************


                                             

====> In interventions, Train loss: 0.014715813621026587


                                               

====> In sponsors, Train loss: 0.007021475691047796


                                               

====> In eligibilities, Train loss: 0.0006199043733151932


                                               

====> In designs, Train loss: 0.0005360900643187882


                                               

====> In studies, Train loss: 0.0008975079654161865


                                             

====> In conditions, Train loss: 0.0157164600466287
******************************<Epoch: 19>******************************


                                             

====> In interventions, Train loss: 0.01574979107140586


                                               

====> In sponsors, Train loss: 0.006867866235478526


                                               

====> In eligibilities, Train loss: 0.000461334376833325


                                               

====> In designs, Train loss: 0.0005054765562183105


                                               

====> In studies, Train loss: 0.0008103459934832586


                                             

====> In conditions, Train loss: 0.015730470946607816
******************************<Epoch: 20>******************************


                                             

====> In interventions, Train loss: 0.01580954658333804


                                               

====> In sponsors, Train loss: 0.006476289242551388


                                               

====> In eligibilities, Train loss: 0.0005597110466814002


                                               

====> In designs, Train loss: 0.00043464797555200206


                                               

====> In studies, Train loss: 0.0007309331328870055


                                             

====> In conditions, Train loss: 0.01573642315178123




In [18]:
# pre-trained state
# record
pre_trained_state = copy.deepcopy(net.state_dict())

In [19]:
# construct a new bottom model
channels = 64
temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
    channels=channels,
)

feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=channels
)

node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.2
)


net = CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    dropout=0.2,
    aggr="mean",
    norm="layer_norm",
    num_layer=2,
    feature_encoder=feat_encoder,
    node_encoder=node_encoder,
    temporal_encoder=temporal_encoder
)

In [20]:
# training
task_loader_dict = generate_loader_dict(task_a,data)
lr = 0.0005
epoches = 20
loss_fn = BCEWithLogitsLoss()
tune_metric = "auroc"
higher_is_better = True
early_stop = 10

In [21]:


# reload the pre-trained state
net.load_state_dict(pre_trained_state)

flag = True
# first freeze the bottom model
for param in net.feature_encoder.parameters():
    param.requires_grad = flag
for param in net.node_encoder.parameters():
    param.requires_grad = flag
for param in net.temporal_encoder.parameters():
    param.requires_grad = flag
for param in net.gnn.parameters():
    param.requires_grad = flag
    
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr = lr)

In [22]:
best_val_metric = -math.inf if higher_is_better else math.inf
net.to(device)
best_epoch = 0
for epoch in range(1, epoches + 1):
    net.train()
    cnt = 0
    loss_accum = count_accum = 0
    for batch in tqdm(task_loader_dict["train"], leave=False):
        cnt += 1
        if cnt > early_stop:
            break
        
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = net(
            batch,
            entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        loss = loss_fn(pred, batch[entity_table].y.float())
        
        loss.backward()
        optimizer.step()
        
        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)
    
    train_loss = loss_accum / count_accum
    val_logits = test(task_loader_dict["val"], net, task_a)
    val_logits = torch.sigmoid(val_logits).numpy()
    
    val_pred = (val_logits > 0.5).astype(int)
    val_pred_hat = task_a.get_table("val").df[task_a.target_col].to_numpy()
    val_metrics = {
            "auroc": roc_auc_score(val_pred_hat, val_logits),
        "accuracy": accuracy_score(val_pred_hat, val_pred),
        "precision": precision_score(val_pred_hat, val_pred),
        "recall": recall_score(val_pred_hat, val_pred),
        "f1": f1_score(val_pred_hat, val_pred),
    }
    
    print("*"*30 + f"<Epoch: {epoch:02d}>" + "*"*30)
    print(f", Train loss: {train_loss}, Val metrics: {val_metrics}")

    
    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_epoch = epoch
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(net.state_dict())

# print the best epoch
best_epoch

                                               

******************************<Epoch: 01>******************************
, Train loss: 0.6708765923976898, Val metrics: {'auroc': 0.5534245596165102, 'accuracy': 0.584375, 'precision': 0.584375, 'recall': 1.0, 'f1': 0.73767258382643}


                                               

******************************<Epoch: 02>******************************
, Train loss: 0.6522306442260742, Val metrics: {'auroc': 0.6025759586131103, 'accuracy': 0.584375, 'precision': 0.584375, 'recall': 1.0, 'f1': 0.73767258382643}


                                               

******************************<Epoch: 03>******************************
, Train loss: 0.6387920558452607, Val metrics: {'auroc': 0.6201555582360536, 'accuracy': 0.5854166666666667, 'precision': 0.5851619644723093, 'recall': 0.9982174688057041, 'f1': 0.7378129117259552}


                                               

******************************<Epoch: 04>******************************
, Train loss: 0.6254864037036896, Val metrics: {'auroc': 0.6421222396454596, 'accuracy': 0.6, 'precision': 0.5956756756756757, 'recall': 0.982174688057041, 'f1': 0.7415881561238223}


                                               

******************************<Epoch: 05>******************************
, Train loss: 0.6237622559070587, Val metrics: {'auroc': 0.643221243840439, 'accuracy': 0.63125, 'precision': 0.6439499304589708, 'recall': 0.8253119429590018, 'f1': 0.7234375}


                                               

******************************<Epoch: 06>******************************
, Train loss: 0.6236659288406372, Val metrics: {'auroc': 0.6556230147561417, 'accuracy': 0.6364583333333333, 'precision': 0.6464088397790055, 'recall': 0.8342245989304813, 'f1': 0.7284046692607004}


                                               

******************************<Epoch: 07>******************************
, Train loss: 0.6069568693637848, Val metrics: {'auroc': 0.6590272472625414, 'accuracy': 0.634375, 'precision': 0.6381578947368421, 'recall': 0.8645276292335116, 'f1': 0.7342922028766087}


                                               

******************************<Epoch: 08>******************************
, Train loss: 0.6083886563777924, Val metrics: {'auroc': 0.6611091007375838, 'accuracy': 0.6260416666666667, 'precision': 0.6304909560723514, 'recall': 0.8698752228163993, 'f1': 0.7310861423220973}


                                               

******************************<Epoch: 09>******************************
, Train loss: 0.6138168692588806, Val metrics: {'auroc': 0.6721572201448363, 'accuracy': 0.6395833333333333, 'precision': 0.6482758620689655, 'recall': 0.8377896613190731, 'f1': 0.7309486780715396}


                                               

******************************<Epoch: 10>******************************
, Train loss: 0.6016106963157654, Val metrics: {'auroc': 0.6767855467545869, 'accuracy': 0.646875, 'precision': 0.6661676646706587, 'recall': 0.7932263814616756, 'f1': 0.7241659886086249}


                                               

******************************<Epoch: 11>******************************
, Train loss: 0.5997978866100311, Val metrics: {'auroc': 0.681735533128722, 'accuracy': 0.65625, 'precision': 0.6928213689482471, 'recall': 0.7397504456327986, 'f1': 0.7155172413793104}


                                               

******************************<Epoch: 12>******************************
, Train loss: 0.5925839900970459, Val metrics: {'auroc': 0.6797653670718686, 'accuracy': 0.64375, 'precision': 0.6729857819905213, 'recall': 0.7593582887700535, 'f1': 0.7135678391959799}


                                               

******************************<Epoch: 13>******************************
, Train loss: 0.5841094017028808, Val metrics: {'auroc': 0.6847064184525485, 'accuracy': 0.6604166666666667, 'precision': 0.6874003189792663, 'recall': 0.768270944741533, 'f1': 0.7255892255892256}


                                               

******************************<Epoch: 14>******************************
, Train loss: 0.5776736319065094, Val metrics: {'auroc': 0.6850727531842082, 'accuracy': 0.653125, 'precision': 0.6748466257668712, 'recall': 0.7843137254901961, 'f1': 0.7254740313272877}


                                               

******************************<Epoch: 15>******************************
, Train loss: 0.5891318023204803, Val metrics: {'auroc': 0.6781570682499476, 'accuracy': 0.6489583333333333, 'precision': 0.6777777777777778, 'recall': 0.7611408199643493, 'f1': 0.7170445004198153}


                                               

******************************<Epoch: 16>******************************
, Train loss: 0.5723324716091156, Val metrics: {'auroc': 0.6800780918427977, 'accuracy': 0.6489583333333333, 'precision': 0.680064308681672, 'recall': 0.7540106951871658, 'f1': 0.7151310228233305}


                                               

******************************<Epoch: 17>******************************
, Train loss: 0.5741039216518402, Val metrics: {'auroc': 0.6822671652393015, 'accuracy': 0.6479166666666667, 'precision': 0.6925734024179621, 'recall': 0.714795008912656, 'f1': 0.7035087719298245}


                                               

******************************<Epoch: 18>******************************
, Train loss: 0.582523399591446, Val metrics: {'auroc': 0.6843043437470682, 'accuracy': 0.6416666666666667, 'precision': 0.6947935368043088, 'recall': 0.6898395721925134, 'f1': 0.6923076923076923}


                                               

******************************<Epoch: 19>******************************
, Train loss: 0.565915846824646, Val metrics: {'auroc': 0.6792069299809238, 'accuracy': 0.6447916666666667, 'precision': 0.6883561643835616, 'recall': 0.7165775401069518, 'f1': 0.7021834061135371}


                                               

******************************<Epoch: 20>******************************
, Train loss: 0.5477791368961334, Val metrics: {'auroc': 0.672219765099022, 'accuracy': 0.6364583333333333, 'precision': 0.677257525083612, 'recall': 0.7219251336898396, 'f1': 0.6988783433994823}


14

In [23]:
net.load_state_dict(state_dict)
test_logits = test(task_loader_dict["test"], net, task_a)
test_logits =  torch.sigmoid(test_logits).numpy()

test_pred = (test_logits > 0.5).astype(int)
test_pred_hat = task_a.get_table("test", mask_input_cols = False).df[task_a.target_col].to_numpy()
test_metrics = {
    "auroc": roc_auc_score(test_pred_hat, test_logits),
    "accuracy": accuracy_score(test_pred_hat, test_pred),
    "precision": precision_score(test_pred_hat, test_pred),
    "recall": recall_score(test_pred_hat, test_pred),
    "f1score": f1_score(test_pred_hat, test_pred),
}
test_metrics

{'auroc': 0.7292990931434866,
 'accuracy': 0.6896969696969697,
 'precision': 0.705989110707804,
 'recall': 0.8053830227743272,
 'f1score': 0.7524177949709865}