In [1]:
%cd ..

/home/lingze/embedding_fusion


In [2]:
import numpy as np
import pandas as pd
from relbench.datasets import get_dataset
from relbench.base import Table
from tqdm import tqdm
from typing import Any,Dict

import torch
import pickle
import os
from torch import Tensor
from torch_frame import stype
from torch_frame.config import TextEmbedderConfig
from torch_frame.data import Dataset
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.typing import NodeType
from torch_geometric.utils import sort_edge_index

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset = get_dataset(name = "rel-trial", download = True)
db = dataset.get_db()
cache_path = "data/rel-trial-tensor-frame"

Loading Database object from /home/lingze/.cache/relbench/rel-trial/db...
Done in 7.56 seconds.


In [4]:
# [NOTE]: the dataset has been materialized

# get infer_type in cache
type_path = os.path.join(cache_path,"col_type_dict.pkl")
col_type_dict = pickle.load(open(type_path, "rb"))
len(col_type_dict)

# add "compress_text" in each table in case 
for table_name, table in db.table_dict.items():
    table.df["text_compress"] = np.nan

In [5]:
from typing import List, Optional
from torch_frame.config.text_embedder import TextEmbedderConfig
from sentence_transformers import SentenceTransformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            # "all-MiniLM-L12-v2",
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=512
)

def remove_pkey_fkey(col_to_stype: Dict[str, Any], table:Table) -> dict:
    r"""Remove pkey, fkey columns since they will not be used as input feature."""
    if table.pkey_col is not None:
        if table.pkey_col in col_to_stype:
            col_to_stype.pop(table.pkey_col)
    for fkey in table.fkey_col_to_pkey_table.keys():
        if fkey in col_to_stype:
            col_to_stype.pop(fkey)

def to_unix_time(ser: pd.Series) -> np.ndarray:
    r"""Converts a :class:`pandas.Timestamp` series to UNIX timestamp (in seconds)."""
    assert ser.dtype in [np.dtype("datetime64[s]"), np.dtype("datetime64[ns]")]
    unix_time = ser.astype("int64").values
    if ser.dtype == np.dtype("datetime64[ns]"):
        unix_time //= 10**9
    return unix_time

  return self.fget.__get__(instance, owner)()


In [6]:
# build graph

# start build graph
cache_dir = "./data/rel-trial-tensor-frame"
if cache_dir is not None:
    os.makedirs(cache_dir, exist_ok=True)
data = HeteroData()
col_stats_dict = {}
for table_name, table in db.table_dict.items():
    df = table.df
    # (important for foreignKey value) Ensure the pkey is consecutive
    if table.pkey_col is not None:
        assert (df[table.pkey_col].values == np.arange(len(df))).all()
    
    col_to_stype = col_type_dict[table_name]
    
    # remove pkey, fkey
    remove_pkey_fkey(col_to_stype, table)
    
    if len(col_to_stype) == 0:
        # for example, relationship table which only contains pkey and fkey
        raise KeyError(f"{table_name} has no column to build graph")
    
    path = (
            None if cache_dir is None else os.path.join(cache_dir, f"{table_name}.pt")
    )
    
    print(f"-----> Materialize {table_name} Tensor Frame")
    dataset = Dataset(
        df = df,
        col_to_stype=col_to_stype,
        col_to_text_embedder_cfg=text_embedder_cfg,
    ).materialize(path=path)
    
    data[table_name].tf = dataset.tensor_frame
    col_stats_dict[table_name] = dataset.col_stats
    
    # Add time attribute
    if table.time_col is not None:
        data[table_name].time = torch.from_numpy(
            to_unix_time(df[table.time_col])
        )
    
    # Add edges normal edges
    for fkey_col_name, pkey_table_name in table.fkey_col_to_pkey_table.items():
        pkey_index = df[fkey_col_name]
        # Filter out dangling foreign keys
        mask = ~pkey_index.isna()
        fkey_index = torch.arange(len(pkey_index))
        
        # filter dangling foreign keys:
        pkey_index = torch.from_numpy(pkey_index[mask].astype(int).values)
        fkey_index = fkey_index[torch.from_numpy(mask.values)]
        
        # fkey -> pkey edges
        edge_index = torch.stack([fkey_index, pkey_index], dim=0)
        edge_type = (table_name, f"f2p_{fkey_col_name}", pkey_table_name)
        data[edge_type].edge_index = sort_edge_index(edge_index)

        # pkey -> fkey edges.
        # "rev_" is added so that PyG loader recognizes the reverse edges
        edge_index = torch.stack([pkey_index, fkey_index], dim=0)
        edge_type = (pkey_table_name, f"rev_f2p_{fkey_col_name}", table_name)
        data[edge_type].edge_index = sort_edge_index(edge_index)
    
data.validate()

-----> Materialize interventions Tensor Frame
-----> Materialize interventions_studies Tensor Frame
-----> Materialize facilities_studies Tensor Frame
-----> Materialize sponsors Tensor Frame
-----> Materialize eligibilities Tensor Frame
-----> Materialize reported_event_totals Tensor Frame
-----> Materialize designs Tensor Frame
-----> Materialize conditions_studies Tensor Frame
-----> Materialize drop_withdrawals Tensor Frame
-----> Materialize studies Tensor Frame
-----> Materialize outcome_analyses Tensor Frame
-----> Materialize sponsors_studies Tensor Frame
-----> Materialize outcomes Tensor Frame
-----> Materialize conditions Tensor Frame
-----> Materialize facilities Tensor Frame


True

In [7]:
from relbench.tasks import get_task
from relbench.modeling.graph import get_node_train_table_input
from torch_geometric.loader import NeighborLoader
from relbench.base import BaseTask
from model.base import CompositeModel, FeatureEncodingPart, NodeRepresentationPart
from relbench.modeling.nn import HeteroTemporalEncoder
# start to fine-train on the task a
from torch.nn import BCEWithLogitsLoss
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import math
import copy


In [8]:
task_a = get_task("rel-trial", "study-adverse", download = True)
entity_table = task_a.entity_table
entity_table

'studies'

In [9]:
def generate_loader_dict(task: BaseTask, data:HeteroData) -> dict:
    loader_dict = {}
    for split, table in [
        ("train", task.get_table("train")),
        ("val",task.get_table("val")),
        ("test", task.get_table("test")),
    ]:
        table_input = get_node_train_table_input(
            table=table,
            task=task,
        )
        loader_dict[split] = NeighborLoader(
            data,
            num_neighbors=[
                128 for i in range(2)
            ],  # we sample subgraphs of depth 2, 128 neighbors per node.
            time_attr="time",
            input_nodes=table_input.nodes,
            input_time=table_input.time,
            transform=table_input.transform,
            batch_size=512,
            temporal_strategy="uniform",
            shuffle=split == "train",
            num_workers=0,
            persistent_workers=False,
        )
    return loader_dict

In [10]:
@torch.no_grad()
def valid(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    # model.eval()
    pred_list = []
    pred_hat_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
        pred_hat_list.append(batch[task.entity_table].y.detach().cpu())
    return torch.cat(pred_list, dim=0), torch.cat(pred_hat_list, dim=0)

@torch.no_grad()
def test(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    # model.eval()
    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0)

In [20]:
# construct bottom model
channels = 128
temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
    channels=channels,
)

feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=channels
)

node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.2
)


net = CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    dropout=0.2,
    aggr="mean",
    norm="batch_norm",
    num_layer=2,
    feature_encoder=feat_encoder,
    node_encoder=node_encoder,
    temporal_encoder=temporal_encoder
)

In [21]:
# for regression task, we need to deactivate the normalization and dropout layer
task_a.task_type
# freeze_instances = (torch.nn.BatchNorm1d, torch.nn.LayerNorm, torch.nn.Dropout, torch.nn.BatchNorm2d)
deactive_nn_instances = (torch.nn.Dropout, torch.nn.Dropout2d, torch.nn.Dropout3d)
net.train()
for module in net.modules():
    if isinstance(module, deactive_nn_instances):
        module.eval()
        for param in module.parameters():
            param.requires_grad = False


In [22]:
# training for fine-tune
from torch.nn import L1Loss
from sklearn.metrics import mean_absolute_error, r2_score, root_mean_squared_error

task_loader_dict = generate_loader_dict(task_a,data)
lr = 0.005
epoches = 80
loss_fn = L1Loss()
tune_metric = "mae"
higher_is_better = False
early_stop = 5
max_round_epoch = 30
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr = lr)

In [23]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
net.to(device)
best_epoch = 0
patience = 0
# train
for epoch in range(1, epoches + 1):
    cnt = 0
    loss_accum = count_accum = 0
    # net.train()
    for batch in tqdm(task_loader_dict["train"], leave = False):
        cnt += 1
        if cnt > max_round_epoch:
            break
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = net(
            batch,
            task_a.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        loss = loss_fn(pred, batch[task_a.entity_table].y.float())
        
        loss.backward()
        optimizer.step()
        
        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    train_loss = loss_accum / count_accum
    val_pred_hat = task_a.get_table("val").df[task_a.target_col].to_numpy()
    val_logits = test(task_loader_dict["val"], net, task_a)
    val_logits = val_logits.numpy()
    
    val_metrics = {
        "mae": mean_absolute_error(val_pred_hat, val_logits),
        "r2": r2_score(val_pred_hat, val_logits),
        "rmse": root_mean_squared_error(val_pred_hat, val_logits),
    }
    logits = test(task_loader_dict["test"], net, task_a)
    logits = logits.numpy()
    pred_hat = task_a.get_table("test", mask_input_cols=False).df[task_a.target_col].to_numpy()
    test_metrics = {
            "mae": mean_absolute_error(pred_hat, logits),
            "r2": r2_score(pred_hat, logits),
            "rmse": root_mean_squared_error(pred_hat, logits),
    }

    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")
    print(f"Test metrics: {test_metrics}")

    
    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        patience = 0
        best_epoch = epoch
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(net.state_dict())
    else:
        patience += 1
    
    if patience >= early_stop:
        print(f"Early stop at epoch {epoch}")
        break   
    
best_epoch

  0%|          | 0/85 [00:00<?, ?it/s]

                                               

Epoch: 01, Train loss: 36.52751242319743, Val metrics: {'mae': 54.81963467603348, 'r2': -0.015472323831242196, 'rmse': 394.1298211464677}
Test metrics: {'mae': 55.377625785109075, 'r2': -0.03757472990522137, 'rmse': 254.33392240630874}


                                               

Epoch: 02, Train loss: 36.80703983306885, Val metrics: {'mae': 53.49667310225337, 'r2': -0.010696926381572158, 'rmse': 393.2020043778222}
Test metrics: {'mae': 53.79675200911512, 'r2': -0.022507604227854694, 'rmse': 252.48051615231248}


                                               

Epoch: 03, Train loss: 36.45296026865641, Val metrics: {'mae': 53.021189700754086, 'r2': -0.0070519531518362655, 'rmse': 392.49234292586857}
Test metrics: {'mae': 53.202689644509626, 'r2': -0.011625364950215156, 'rmse': 251.13338539019512}


                                               

Epoch: 04, Train loss: 36.82204252878825, Val metrics: {'mae': 52.187234179840914, 'r2': -0.0019183413895920154, 'rmse': 391.49066784224834}
Test metrics: {'mae': 52.10877393304176, 'r2': 0.0019789789217978804, 'rmse': 249.43904812109068}


                                               

Epoch: 05, Train loss: 33.59614013036092, Val metrics: {'mae': 51.434010753171876, 'r2': 0.003109574891460465, 'rmse': 390.50712561373393}
Test metrics: {'mae': 51.161593693287486, 'r2': 0.016668768839899983, 'rmse': 247.5965064409129}


                                               

Epoch: 06, Train loss: 32.40218346913655, Val metrics: {'mae': 51.20356222311733, 'r2': 0.005844832691389468, 'rmse': 389.97102288491476}
Test metrics: {'mae': 50.87201073043303, 'r2': 0.026684581607548497, 'rmse': 246.33232032485134}


                                               

Epoch: 07, Train loss: 35.38134657541911, Val metrics: {'mae': 50.89220805650782, 'r2': 0.009682754933350135, 'rmse': 389.21755613079796}
Test metrics: {'mae': 50.41912302712738, 'r2': 0.03763601704032482, 'rmse': 244.94257356955256}


                                               

Epoch: 08, Train loss: 36.80251242319743, Val metrics: {'mae': 52.11477145307943, 'r2': 0.0059689370897940686, 'rmse': 389.946681297566}
Test metrics: {'mae': 51.48221923331769, 'r2': 0.03609663179287159, 'rmse': 245.13839881129184}


                                               

Epoch: 09, Train loss: 38.417340469360354, Val metrics: {'mae': 50.153470070409, 'r2': 0.019681409214644563, 'rmse': 387.24772038371987}
Test metrics: {'mae': 49.35214248568474, 'r2': 0.06358862520459785, 'rmse': 241.61724964444534}


                                               

Epoch: 10, Train loss: 30.920375887552897, Val metrics: {'mae': 49.870541123459915, 'r2': 0.020757424959648274, 'rmse': 387.0351369178716}
Test metrics: {'mae': 48.84544227987749, 'r2': 0.06824002989029576, 'rmse': 241.01641397438144}


                                               

Epoch: 11, Train loss: 36.784453296661376, Val metrics: {'mae': 50.40344346561847, 'r2': 0.01759402756116457, 'rmse': 387.6597823184335}
Test metrics: {'mae': 49.05585107741922, 'r2': 0.07396123281229972, 'rmse': 240.2753287862326}


                                               

Epoch: 12, Train loss: 34.96241785685221, Val metrics: {'mae': 49.68966293508173, 'r2': 0.024489253078042283, 'rmse': 386.2969504097559}
Test metrics: {'mae': 48.22727500911372, 'r2': 0.08781001345490025, 'rmse': 238.47191893238548}


                                               

Epoch: 13, Train loss: 33.42746903101603, Val metrics: {'mae': 49.49762688909924, 'r2': 0.023520529688118574, 'rmse': 386.488707412843}
Test metrics: {'mae': 48.03292776021825, 'r2': 0.09321709900754027, 'rmse': 237.7640869099766}


                                               

Epoch: 14, Train loss: 38.681722259521486, Val metrics: {'mae': 48.984297405261216, 'r2': 0.031886377613770334, 'rmse': 384.82955281929895}
Test metrics: {'mae': 47.253737094937755, 'r2': 0.10746900525347425, 'rmse': 235.8882179636859}


                                               

Epoch: 15, Train loss: 33.70811545054118, Val metrics: {'mae': 48.87743074027148, 'r2': 0.035965037440058856, 'rmse': 384.0180543233608}
Test metrics: {'mae': 47.0141780762335, 'r2': 0.1239259441339211, 'rmse': 233.70338658483476}


                                               

Epoch: 16, Train loss: 30.4956174214681, Val metrics: {'mae': 48.91704567640738, 'r2': 0.04081748964551857, 'rmse': 383.0503611735308}
Test metrics: {'mae': 47.26053646381456, 'r2': 0.12748057436282434, 'rmse': 233.22878429968998}


                                               

Epoch: 17, Train loss: 34.4558318456014, Val metrics: {'mae': 48.81418725609514, 'r2': 0.03697484134015072, 'rmse': 383.8168766972297}
Test metrics: {'mae': 46.37232580762224, 'r2': 0.13192054685055143, 'rmse': 232.63461407996533}


                                               

Epoch: 18, Train loss: 34.92251834869385, Val metrics: {'mae': 48.79966334631465, 'r2': 0.038304313802331236, 'rmse': 383.55185236447005}
Test metrics: {'mae': 46.43573045335695, 'r2': 0.13506645433750963, 'rmse': 232.2126992273514}


                                               

Epoch: 19, Train loss: 34.754250081380206, Val metrics: {'mae': 48.91010620019858, 'r2': 0.038702927296113354, 'rmse': 383.47235487262276}
Test metrics: {'mae': 46.65552643574462, 'r2': 0.13276244058845432, 'rmse': 232.5217781330004}


                                               

Epoch: 20, Train loss: 37.02654244105021, Val metrics: {'mae': 48.50403081684128, 'r2': 0.0515788170783511, 'rmse': 380.89552752380655}
Test metrics: {'mae': 45.812777869961316, 'r2': 0.16793758761389344, 'rmse': 227.7574270255513}


                                               

Epoch: 21, Train loss: 39.30879201889038, Val metrics: {'mae': 48.15285102084478, 'r2': 0.051852710194415286, 'rmse': 380.8405244313856}
Test metrics: {'mae': 45.487426359954384, 'r2': 0.1714446310930129, 'rmse': 227.276935039492}


                                               

Epoch: 22, Train loss: 31.03713715871175, Val metrics: {'mae': 48.499229139920836, 'r2': 0.052260989677452074, 'rmse': 380.758519179438}
Test metrics: {'mae': 46.111890869252406, 'r2': 0.16714191035330528, 'rmse': 227.86629992944665}


                                               

Epoch: 23, Train loss: 29.814351431528728, Val metrics: {'mae': 48.134354519292096, 'r2': 0.04930636544297773, 'rmse': 381.3515742770467}
Test metrics: {'mae': 45.85725999738494, 'r2': 0.1606033455802356, 'rmse': 228.75901247948772}


                                               

Epoch: 24, Train loss: 30.10293156305949, Val metrics: {'mae': 48.49776167971174, 'r2': 0.047676355686423366, 'rmse': 381.6783570326499}
Test metrics: {'mae': 46.60298798758996, 'r2': 0.15105613044957922, 'rmse': 230.05627544722554}


                                               

Epoch: 25, Train loss: 33.29369748433431, Val metrics: {'mae': 48.68759988995064, 'r2': 0.05085815561675788, 'rmse': 381.04021250552995}
Test metrics: {'mae': 46.692808450441966, 'r2': 0.15881315261498974, 'rmse': 229.00282132511217}


                                               

Epoch: 26, Train loss: 32.427274004618326, Val metrics: {'mae': 47.64143203572682, 'r2': 0.06103844555825921, 'rmse': 378.99122603349457}
Test metrics: {'mae': 45.289035223934285, 'r2': 0.19351639049868485, 'rmse': 224.22930398799434}


                                               

Epoch: 27, Train loss: 29.95704765319824, Val metrics: {'mae': 47.735714598557806, 'r2': 0.07685032298812577, 'rmse': 375.7866188903986}
Test metrics: {'mae': 45.392656533972264, 'r2': 0.22271048923854786, 'rmse': 220.13342954471685}


                                               

Epoch: 28, Train loss: 27.86974500020345, Val metrics: {'mae': 46.98750203780341, 'r2': 0.07291527689629118, 'rmse': 376.5866869748922}
Test metrics: {'mae': 44.21811310875273, 'r2': 0.2159934058538009, 'rmse': 221.08254425507428}


                                               

Epoch: 29, Train loss: 31.46937459309896, Val metrics: {'mae': 46.771845502416305, 'r2': 0.07633705302609783, 'rmse': 375.89107279728364}
Test metrics: {'mae': 44.806638630279004, 'r2': 0.21499190444763905, 'rmse': 221.22370621481488}


                                               

Epoch: 30, Train loss: 30.705997880299886, Val metrics: {'mae': 46.40199460095008, 'r2': 0.08154229475527375, 'rmse': 374.83042168152485}
Test metrics: {'mae': 44.0994550843078, 'r2': 0.23562814248573316, 'rmse': 218.29658443078304}


                                               

Epoch: 31, Train loss: 33.184669431050615, Val metrics: {'mae': 46.43384914067848, 'r2': 0.0889307627959357, 'rmse': 373.3197289752245}
Test metrics: {'mae': 44.14507017205749, 'r2': 0.2404718726463997, 'rmse': 217.60382594316764}


                                               

Epoch: 32, Train loss: 27.388908704121906, Val metrics: {'mae': 46.035465397363375, 'r2': 0.0906494088403138, 'rmse': 372.9674465599213}
Test metrics: {'mae': 43.43619166365587, 'r2': 0.25635037936000205, 'rmse': 215.3172262336928}


                                               

Epoch: 33, Train loss: 30.707509485880532, Val metrics: {'mae': 46.90998697490975, 'r2': 0.07608874029191925, 'rmse': 375.9415957007013}
Test metrics: {'mae': 45.614006136807994, 'r2': 0.19928073000277002, 'rmse': 223.4265277714507}


                                               

Epoch: 34, Train loss: 30.785243701934814, Val metrics: {'mae': 46.440444929361576, 'r2': 0.09147412792854803, 'rmse': 372.7982801522315}
Test metrics: {'mae': 45.24512443300445, 'r2': 0.24174611149852077, 'rmse': 217.42121542952071}


                                               

Epoch: 35, Train loss: 26.90305296579997, Val metrics: {'mae': 46.01446523640027, 'r2': 0.0906854138206381, 'rmse': 372.96006281891283}
Test metrics: {'mae': 43.60146401876919, 'r2': 0.24413917776159932, 'rmse': 217.07785118517876}


                                               

Epoch: 36, Train loss: 27.331729952494303, Val metrics: {'mae': 45.90940722063581, 'r2': 0.10477014516581795, 'rmse': 370.0603276455672}
Test metrics: {'mae': 44.5487483536441, 'r2': 0.27878239090797374, 'rmse': 212.04486365468975}


                                               

Epoch: 37, Train loss: 29.327657604217528, Val metrics: {'mae': 45.678094157578215, 'r2': 0.10479180111923525, 'rmse': 370.0558516678952}
Test metrics: {'mae': 43.183185701421415, 'r2': 0.2882720016007948, 'rmse': 210.64522630743042}


                                               

Epoch: 38, Train loss: 26.346074708302815, Val metrics: {'mae': 46.576294631288036, 'r2': 0.09503448907241052, 'rmse': 372.0670960057263}
Test metrics: {'mae': 44.266974666542474, 'r2': 0.2456152271918861, 'rmse': 216.8657921589526}


                                               

Epoch: 39, Train loss: 26.320768070220947, Val metrics: {'mae': 45.43349600779732, 'r2': 0.11266271067241596, 'rmse': 368.42544507691434}
Test metrics: {'mae': 43.12778926111657, 'r2': 0.29757356876432983, 'rmse': 209.26423910450694}


                                               

Epoch: 40, Train loss: 27.471531867980957, Val metrics: {'mae': 45.46162614777883, 'r2': 0.11717728571009156, 'rmse': 367.4870165098119}
Test metrics: {'mae': 43.21103844690161, 'r2': 0.31816737730709643, 'rmse': 206.1738042086084}


39

In [24]:
# test
net.load_state_dict(state_dict)
logits = test(task_loader_dict["test"], net, task_a)
logits = logits.numpy()
pred_hat = task_a.get_table("test", mask_input_cols=False).df[task_a.target_col].to_numpy()
test_metrics = {
        "mae": mean_absolute_error(pred_hat, logits),
        "r2": r2_score(pred_hat, logits),
        "rmse": root_mean_squared_error(pred_hat, logits),
}
test_metrics

{'mae': 43.129404742667525,
 'r2': 0.29743311119899796,
 'rmse': 209.28516035335826}