In [1]:
%cd ..

/home/lingze/embedding_fusion


In [2]:
import numpy as np
import pandas as pd
from relbench.datasets import get_dataset
from relbench.base import Table
from tqdm import tqdm
from typing import Any,Dict

import torch
import pickle
import os
from torch import Tensor
from torch_frame import stype
from torch_frame.config import TextEmbedderConfig
from torch_frame.data import Dataset
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.typing import NodeType
from torch_geometric.utils import sort_edge_index

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset = get_dataset(name = "rel-avito", download = True)
db = dataset.get_db()
cache_path = "./data/rel-avito-tensor-frame"

Loading Database object from /home/lingze/.cache/relbench/rel-avito/db...
Done in 4.75 seconds.


In [4]:
# [NOTE]: the dataset has been materialized

# get infer_type in cache
type_path = os.path.join(cache_path,"col_type_dict.pkl")
col_type_dict = pickle.load(open(type_path, "rb"))
len(col_type_dict)

# add "compress_text" in each table in case 
for table_name, table in db.table_dict.items():
    table.df["text_compress"] = np.nan

In [5]:
from typing import List, Optional
from torch_frame.config.text_embedder import TextEmbedderConfig
from sentence_transformers import SentenceTransformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            # "all-MiniLM-L12-v2",
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=512
)

def remove_pkey_fkey(col_to_stype: Dict[str, Any], table:Table) -> dict:
    r"""Remove pkey, fkey columns since they will not be used as input feature."""
    if table.pkey_col is not None:
        if table.pkey_col in col_to_stype:
            col_to_stype.pop(table.pkey_col)
    for fkey in table.fkey_col_to_pkey_table.keys():
        if fkey in col_to_stype:
            col_to_stype.pop(fkey)

def to_unix_time(ser: pd.Series) -> np.ndarray:
    r"""Converts a :class:`pandas.Timestamp` series to UNIX timestamp (in seconds)."""
    assert ser.dtype in [np.dtype("datetime64[s]"), np.dtype("datetime64[ns]")]
    unix_time = ser.astype("int64").values
    if ser.dtype == np.dtype("datetime64[ns]"):
        unix_time //= 10**9
    return unix_time

  return self.fget.__get__(instance, owner)()


In [6]:
# build graph
cache_dir = cache_path
# start build graph
if cache_dir is not None:
    os.makedirs(cache_dir, exist_ok=True)
data = HeteroData()
col_stats_dict = {}
for table_name, table in db.table_dict.items():
    df = table.df
    # (important for foreignKey value) Ensure the pkey is consecutive
    if table.pkey_col is not None:
        assert (df[table.pkey_col].values == np.arange(len(df))).all()
    
    col_to_stype = col_type_dict[table_name]
    
    # remove pkey, fkey
    remove_pkey_fkey(col_to_stype, table)
    
    if len(col_to_stype) == 0:
        # for example, relationship table which only contains pkey and fkey
        raise KeyError(f"{table_name} has no column to build graph")
    
    path = (
            None if cache_dir is None else os.path.join(cache_dir, f"{table_name}.pt")
    )
    
    print(f"-----> Materialize {table_name} Tensor Frame")
    dataset = Dataset(
        df = df,
        col_to_stype=col_to_stype,
        col_to_text_embedder_cfg=text_embedder_cfg,
    ).materialize(path=path)
    
    data[table_name].tf = dataset.tensor_frame
    col_stats_dict[table_name] = dataset.col_stats
    
    # Add time attribute
    if table.time_col is not None:
        data[table_name].time = torch.from_numpy(
            to_unix_time(df[table.time_col])
        )
    
    # Add edges normal edges
    for fkey_col_name, pkey_table_name in table.fkey_col_to_pkey_table.items():
        pkey_index = df[fkey_col_name]
        # Filter out dangling foreign keys
        mask = ~pkey_index.isna()
        fkey_index = torch.arange(len(pkey_index))
        
        # filter dangling foreign keys:
        pkey_index = torch.from_numpy(pkey_index[mask].astype(int).values)
        fkey_index = fkey_index[torch.from_numpy(mask.values)]
        
        # fkey -> pkey edges
        edge_index = torch.stack([fkey_index, pkey_index], dim=0)
        edge_type = (table_name, f"f2p_{fkey_col_name}", pkey_table_name)
        data[edge_type].edge_index = sort_edge_index(edge_index)

        # pkey -> fkey edges.
        # "rev_" is added so that PyG loader recognizes the reverse edges
        edge_index = torch.stack([pkey_index, fkey_index], dim=0)
        edge_type = (pkey_table_name, f"rev_f2p_{fkey_col_name}", table_name)
        data[edge_type].edge_index = sort_edge_index(edge_index)
    
data.validate()

-----> Materialize PhoneRequestsStream Tensor Frame
-----> Materialize Location Tensor Frame
-----> Materialize SearchInfo Tensor Frame


-----> Materialize UserInfo Tensor Frame
-----> Materialize SearchStream Tensor Frame
-----> Materialize VisitStream Tensor Frame
-----> Materialize AdsInfo Tensor Frame
-----> Materialize Category Tensor Frame


True

In [7]:
# add new edge
# add the additional edges
from utils.util import load_np_dict
edge_dict = load_np_dict("./edges/rel-avito-edges.npz")

for edge_name, edge_np in edge_dict.items():
    src_table, dst_table = edge_name.split('-')[0], edge_name.split('-')[1]
    edge_index = torch.from_numpy(edge_np.astype(int)).t()
    # [2, edge_num]
    edge_type = (src_table, f"appendix", dst_table)
    data[edge_type].edge_index = sort_edge_index(edge_index)
data.validate()

True

In [8]:
from relbench.tasks import get_task
from relbench.modeling.graph import get_node_train_table_input
from torch_geometric.loader import NeighborLoader
from relbench.base import BaseTask
from model.base import CompositeModel, FeatureEncodingPart, NodeRepresentationPart
from relbench.modeling.nn import HeteroTemporalEncoder
# start to fine-train on the task a
from torch.nn import BCEWithLogitsLoss
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import math
import copy


In [9]:
task_a = get_task("rel-avito", "ad-ctr", download = True)
entity_table = task_a.entity_table

In [10]:
def generate_loader_dict(task: BaseTask, data:HeteroData) -> dict:
    loader_dict = {}
    for split, table in [
        ("train", task.get_table("train")),
        ("val",task.get_table("val")),
        ("test", task.get_table("test")),
    ]:
        table_input = get_node_train_table_input(
            table=table,
            task=task,
        )
        loader_dict[split] = NeighborLoader(
            data,
            num_neighbors=[
                128 for i in range(2)
            ],  # we sample subgraphs of depth 2, 128 neighbors per node.
            time_attr="time",
            input_nodes=table_input.nodes,
            input_time=table_input.time,
            transform=table_input.transform,
            batch_size=512,
            temporal_strategy="uniform",
            shuffle=split == "train",
            num_workers=0,
            persistent_workers=False,
        )
    return loader_dict

In [11]:
@torch.no_grad()
def valid(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    # model.eval()
    pred_list = []
    pred_hat_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
        pred_hat_list.append(batch[task.entity_table].y.detach().cpu())
    return torch.cat(pred_list, dim=0), torch.cat(pred_hat_list, dim=0)

@torch.no_grad()
def test(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    # model.eval()
    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0)

In [12]:
# construct bottom model
channels = 128
temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
    channels=channels,
)

feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=channels
)

node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.2
)


net = CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    dropout=0.2,
    aggr="sum",
    norm="batch_norm",
    num_layer=2,
    feature_encoder=feat_encoder,
    node_encoder=node_encoder,
    temporal_encoder=temporal_encoder
)

In [13]:
# for regression task, we need to deactivate the normalization and dropout layer
task_a.task_type
# freeze_instances = (torch.nn.BatchNorm1d, torch.nn.LayerNorm, torch.nn.Dropout, torch.nn.BatchNorm2d)
deactive_nn_instances = (torch.nn.Dropout, torch.nn.Dropout2d, torch.nn.Dropout3d)
net.train()
for module in net.modules():
    if isinstance(module, deactive_nn_instances):
        module.eval()
        for param in module.parameters():
            param.requires_grad = False


In [14]:
# read the pre-trained model
pre_trained_model_param_path = './static/rel-avito-pre-trained-channel128.pth'
pre_trained_state_dict = torch.load(pre_trained_model_param_path)
net.load_state_dict(pre_trained_state_dict)

<All keys matched successfully>

In [15]:
# training for fine-tune
from torch.nn import L1Loss
from sklearn.metrics import mean_absolute_error, r2_score, root_mean_squared_error

task_loader_dict = generate_loader_dict(task_a,data)
lr = 0.005
epoches = 80
loss_fn = L1Loss()
tune_metric = "mae"
higher_is_better = False
early_stop = 10
max_round_epoch = 30
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr = lr)

In [16]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
net.to(device)
best_epoch = 0
patience = 0
# train
for epoch in range(1, epoches + 1):
    cnt = 0
    loss_accum = count_accum = 0
    # net.train()
    for batch in tqdm(task_loader_dict["train"], leave = False):
        cnt += 1
        if cnt > max_round_epoch:
            break
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = net(
            batch,
            task_a.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        loss = loss_fn(pred, batch[task_a.entity_table].y.float())
        
        loss.backward()
        optimizer.step()
        
        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    train_loss = loss_accum / count_accum
    val_pred_hat = task_a.get_table("val").df[task_a.target_col].to_numpy()
    val_logits = test(task_loader_dict["val"], net, task_a)
    val_logits = val_logits.numpy()
    
    val_metrics = {
        "mae": mean_absolute_error(val_pred_hat, val_logits),
        "r2": r2_score(val_pred_hat, val_logits),
        "rmse": root_mean_squared_error(val_pred_hat, val_logits),
    }
    logits = test(task_loader_dict["test"], net, task_a)
    logits = logits.numpy()
    pred_hat = task_a.get_table("test", mask_input_cols=False).df[task_a.target_col].to_numpy()
    test_metrics = {
            "mae": mean_absolute_error(pred_hat, logits),
            "r2": r2_score(pred_hat, logits),
            "rmse": root_mean_squared_error(pred_hat, logits),
    }

    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")
    print(f"Test metrics: {test_metrics}")

    
    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        patience = 0
        best_epoch = epoch
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(net.state_dict())
    else:
        patience += 1
    
    if patience >= early_stop:
        print(f"Early stop at epoch {epoch}")
        break   
    
best_epoch

                                               

Epoch: 01, Train loss: 0.47573836901608635, Val metrics: {'mae': 0.25028142399815656, 'r2': -6.304137296402032, 'rmse': 0.2681026358903974}
Test metrics: {'mae': 0.2566136865265749, 'r2': -5.479734269843668, 'rmse': 0.27937984714126707}


                                               

Epoch: 02, Train loss: 0.14885231155975193, Val metrics: {'mae': 0.17189945869636908, 'r2': -2.908757634043558, 'rmse': 0.19612637033784672}
Test metrics: {'mae': 0.17780601318339248, 'r2': -2.548813918877427, 'rmse': 0.2067559856065674}


                                               

Epoch: 03, Train loss: 0.08712891533094294, Val metrics: {'mae': 0.037760414654561716, 'r2': 0.08559120980391888, 'rmse': 0.09486080721272039}
Test metrics: {'mae': 0.041580664493269234, 'r2': 0.048155535046696896, 'rmse': 0.10707780791021888}


                                               

Epoch: 04, Train loss: 0.04809360148859959, Val metrics: {'mae': 0.06055096164774562, 'r2': -0.17568260264106406, 'rmse': 0.10756269268347753}
Test metrics: {'mae': 0.0676102863907361, 'r2': -0.2390723495790581, 'rmse': 0.12217006647153807}


                                               

Epoch: 05, Train loss: 0.04071297124320385, Val metrics: {'mae': 0.05161158436962772, 'r2': -0.1803699813706725, 'rmse': 0.10777690251135301}
Test metrics: {'mae': 0.05701106529219341, 'r2': -0.20292766896995396, 'rmse': 0.12037498182025023}


                                               

Epoch: 06, Train loss: 0.03621089461095193, Val metrics: {'mae': 0.03745037293284587, 'r2': 0.013277704005745128, 'rmse': 0.09854034776077297}
Test metrics: {'mae': 0.04135135622308889, 'r2': -0.010714531979928221, 'rmse': 0.11033942904630104}


                                               

Epoch: 07, Train loss: 0.03387122192219192, Val metrics: {'mae': 0.03679166310759556, 'r2': 0.009447575539491493, 'rmse': 0.09873141298995014}
Test metrics: {'mae': 0.04108739911803004, 'r2': -0.019066333913007272, 'rmse': 0.11079437311548966}


                                               

Epoch: 08, Train loss: 0.03304056388198161, Val metrics: {'mae': 0.03660284107049402, 'r2': 0.01737129539951232, 'rmse': 0.09833572931411876}
Test metrics: {'mae': 0.041118777799730194, 'r2': -0.016526232238107896, 'rmse': 0.11065620519004452}


                                               

Epoch: 09, Train loss: 0.03307254550503749, Val metrics: {'mae': 0.03666927982170407, 'r2': 0.011341632488388487, 'rmse': 0.09863697457715744}
Test metrics: {'mae': 0.04141873566159237, 'r2': -0.025870897801365533, 'rmse': 0.11116365872619882}


                                               

Epoch: 10, Train loss: 0.03257967543368246, Val metrics: {'mae': 0.038934460859233085, 'r2': -0.04022797351390217, 'rmse': 0.10117678708199117}
Test metrics: {'mae': 0.04420562360357148, 'r2': -0.07442389530831983, 'rmse': 0.11376385670121263}


                                               

Epoch: 11, Train loss: 0.034037454902541404, Val metrics: {'mae': 0.045533457959398636, 'r2': -0.11716273506223529, 'rmse': 0.1048515464185373}
Test metrics: {'mae': 0.05088351050496674, 'r2': -0.14069310066876461, 'rmse': 0.11721977535687808}


                                               

Epoch: 12, Train loss: 0.034545053304994806, Val metrics: {'mae': 0.037439503567449205, 'r2': 0.0592154286194474, 'rmse': 0.09621919364844697}
Test metrics: {'mae': 0.04132883592709528, 'r2': 0.026115512122818596, 'rmse': 0.10831041043435626}


                                               

Epoch: 13, Train loss: 0.03429093444756433, Val metrics: {'mae': 0.03637306706013337, 'r2': 0.004223897580114877, 'rmse': 0.09899140072003475}
Test metrics: {'mae': 0.04124372627558909, 'r2': -0.030468552870811516, 'rmse': 0.11141248184721088}


                                               

Epoch: 14, Train loss: 0.03576565713128623, Val metrics: {'mae': 0.03805454024027814, 'r2': 0.08268020402324383, 'rmse': 0.09501168116435721}
Test metrics: {'mae': 0.04175887640971783, 'r2': 0.048496594728517306, 'rmse': 0.1070586224253884}


                                               

Epoch: 15, Train loss: 0.03708952447655154, Val metrics: {'mae': 0.03572434148274806, 'r2': 0.04204524720976399, 'rmse': 0.09709326779044432}
Test metrics: {'mae': 0.040354263510594006, 'r2': 0.0025622086965854107, 'rmse': 0.10961232426731443}


                                               

Epoch: 16, Train loss: 0.034176560885765976, Val metrics: {'mae': 0.03591234574186462, 'r2': 0.08577394281310435, 'rmse': 0.09485132837402524}
Test metrics: {'mae': 0.040356528112263486, 'r2': 0.03937797878235305, 'rmse': 0.10757039080697468}


                                               

Epoch: 17, Train loss: 0.03265136337747761, Val metrics: {'mae': 0.03532133763139946, 'r2': 0.039703214466660586, 'rmse': 0.09721188341169222}
Test metrics: {'mae': 0.04017639967079705, 'r2': -0.001275445902615191, 'rmse': 0.10982298923429355}


                                               

Epoch: 18, Train loss: 0.031217896219562082, Val metrics: {'mae': 0.034528026083505685, 'r2': 0.08713212288337235, 'rmse': 0.094780846317495}
Test metrics: {'mae': 0.03957065203088825, 'r2': 0.03498972974701586, 'rmse': 0.10781580877016088}


                                               

Epoch: 19, Train loss: 0.031022809480919556, Val metrics: {'mae': 0.034325557284476664, 'r2': 0.05366311497003862, 'rmse': 0.09650270872407102}
Test metrics: {'mae': 0.03943978926764654, 'r2': 0.012886650968034985, 'rmse': 0.10904355201420018}


                                               

Epoch: 20, Train loss: 0.03130897231546103, Val metrics: {'mae': 0.03346221030469157, 'r2': 0.12626354671010964, 'rmse': 0.09272713626594202}
Test metrics: {'mae': 0.03858695646399065, 'r2': 0.06869750092751259, 'rmse': 0.10591607087156941}


                                               

Epoch: 21, Train loss: 0.032220972498842315, Val metrics: {'mae': 0.03342635488540516, 'r2': 0.09202894975851894, 'rmse': 0.09452629168796355}
Test metrics: {'mae': 0.039056724905089005, 'r2': 0.03836245226349, 'rmse': 0.10762723508406057}


                                               

Epoch: 22, Train loss: 0.04074946342437875, Val metrics: {'mae': 0.042867892120243806, 'r2': 0.17127752233554938, 'rmse': 0.09030695088414759}
Test metrics: {'mae': 0.045986504816571044, 'r2': 0.13412081817897958, 'rmse': 0.10212807038947758}


                                               

Epoch: 23, Train loss: 0.03566379632435593, Val metrics: {'mae': 0.0353428277831466, 'r2': 0.1625766841822902, 'rmse': 0.09077978378940006}
Test metrics: {'mae': 0.03978967317492475, 'r2': 0.10659101879409927, 'rmse': 0.10373889898321838}


                                               

Epoch: 24, Train loss: 0.03384426147037861, Val metrics: {'mae': 0.0329563273253996, 'r2': 0.1556543570205886, 'rmse': 0.09115421452971084}
Test metrics: {'mae': 0.03824623935873669, 'r2': 0.09261362002817897, 'rmse': 0.10454724801782579}


                                               

Epoch: 25, Train loss: 0.03289856454905342, Val metrics: {'mae': 0.03300945764944866, 'r2': 0.14870376447695255, 'rmse': 0.09152863302809995}
Test metrics: {'mae': 0.03788493204269869, 'r2': 0.0975095455851075, 'rmse': 0.10426481714262557}


                                               

Epoch: 26, Train loss: 0.031498379619682534, Val metrics: {'mae': 0.03259208093110141, 'r2': 0.16673890349228904, 'rmse': 0.09055390278596781}
Test metrics: {'mae': 0.037684309815156775, 'r2': 0.11608376397209874, 'rmse': 0.10318629832277194}


                                               

Epoch: 27, Train loss: 0.03369457587599754, Val metrics: {'mae': 0.03359239896352716, 'r2': 0.16489420479331018, 'rmse': 0.09065408286002492}
Test metrics: {'mae': 0.03777251734063394, 'r2': 0.13708744517139337, 'rmse': 0.10195296752757872}


                                               

Epoch: 28, Train loss: 0.03029138790480062, Val metrics: {'mae': 0.03386902113485288, 'r2': 0.17878895414895846, 'rmse': 0.08989675409838355}
Test metrics: {'mae': 0.037847615453577184, 'r2': 0.15047075067547222, 'rmse': 0.10115926040684523}


                                               

Epoch: 29, Train loss: 0.03088045498027521, Val metrics: {'mae': 0.0360836450188762, 'r2': 0.20685448656300975, 'rmse': 0.08834725408128141}
Test metrics: {'mae': 0.039534460259355676, 'r2': 0.1699066478654777, 'rmse': 0.0999953823317788}


                                               

Epoch: 30, Train loss: 0.0361207197285166, Val metrics: {'mae': 0.03332254383393115, 'r2': 0.13038863501883724, 'rmse': 0.09250798541825832}
Test metrics: {'mae': 0.038815104006399925, 'r2': 0.08463243258626219, 'rmse': 0.10500602963421259}


                                               

Epoch: 31, Train loss: 0.038154437647146336, Val metrics: {'mae': 0.054254864476533916, 'r2': 0.16155272580677227, 'rmse': 0.0908352672709762}
Test metrics: {'mae': 0.05528643223023615, 'r2': 0.17544557261064275, 'rmse': 0.0996612067044024}


                                               

Epoch: 32, Train loss: 0.053525577687165315, Val metrics: {'mae': 0.05862980436377108, 'r2': -0.18659847828339649, 'rmse': 0.10806088333596361}
Test metrics: {'mae': 0.06447187521713735, 'r2': -0.18813166788263658, 'rmse': 0.11963238559724163}


                                               

Epoch: 33, Train loss: 0.03404914962894776, Val metrics: {'mae': 0.03346959707908678, 'r2': 0.17997548406494668, 'rmse': 0.08983178678858686}
Test metrics: {'mae': 0.037152954469010846, 'r2': 0.1554810079417237, 'rmse': 0.10086051647389808}


                                               

Epoch: 34, Train loss: 0.025923537712763338, Val metrics: {'mae': 0.035324595443470094, 'r2': 0.0950367798558227, 'rmse': 0.09436959347531572}
Test metrics: {'mae': 0.0406127803126531, 'r2': 0.06381464264667269, 'rmse': 0.1061933690885116}


                                               

Epoch: 35, Train loss: 0.026431215515031536, Val metrics: {'mae': 0.036981215407199855, 'r2': 0.07720980754765305, 'rmse': 0.09529455906093627}
Test metrics: {'mae': 0.04213141070100604, 'r2': 0.05033806100815974, 'rmse': 0.10695497575466169}


                                               

Epoch: 36, Train loss: 0.028359655743720485, Val metrics: {'mae': 0.03773515786613063, 'r2': 0.18649520894552263, 'rmse': 0.08947396375669657}
Test metrics: {'mae': 0.03929539086384168, 'r2': 0.19195440875542236, 'rmse': 0.09865847767254511}
Early stop at epoch 36


26

In [17]:
# test
net.load_state_dict(state_dict)
logits = test(task_loader_dict["test"], net, task_a)
logits = logits.numpy()
pred_hat = task_a.get_table("test", mask_input_cols=False).df[task_a.target_col].to_numpy()
test_metrics = {
        "mae": mean_absolute_error(pred_hat, logits),
        "r2": r2_score(pred_hat, logits),
        "rmse": root_mean_squared_error(pred_hat, logits),
}
test_metrics

{'mae': 0.03767852758049445,
 'r2': 0.1161066020326007,
 'rmse': 0.10318496528348037}