In [1]:
import relbench
from relbench.base import Table, Database, Dataset, EntityTask
from relbench.datasets import get_dataset
from relbench.tasks import get_task
from relbench.base import BaseTask
from torch_geometric.seed import seed_everything
from relbench.modeling.utils import get_stype_proposal
from relbench.modeling.graph import make_pkey_fkey_graph
from relbench.modeling.graph import get_node_train_table_input
from torch_geometric.loader import NeighborLoader

import os
import math
import numpy as np
from tqdm import tqdm

import torch
from torch import Tensor
import torch_geometric
import torch_frame

from torch_frame.config.text_embedder import TextEmbedderConfig
from typing import List, Optional
from sentence_transformers import SentenceTransformer

from torch.nn import L1Loss
from sklearn.metrics import mean_absolute_error, r2_score, root_mean_squared_error
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

seed_everything(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


  from .autonotebook import tqdm as notebook_tqdm


device(type='cuda')

In [2]:
dataset = get_dataset(name="rel-trial", download=True)
db = dataset.get_db()
task_a = get_task("rel-trial", "study-outcome", download = True)
task_b = get_task("rel-trial", "site-success", download = True)

Loading Database object from /home/lingze/.cache/relbench/rel-trial/db...
Done in 8.15 seconds.


In [3]:
col_to_stype_dict = get_stype_proposal(db)

class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=256
)


root_dir = "/home/lingze/embedding_fusion/data"
data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,  # speficied column types
    text_embedder_cfg=text_embedder_cfg,  # our chosen text encoder
    cache_dir=os.path.join(
        root_dir, f"rel-trial_materialized_cache"
    ),  # store materialized graph for convenience
)

  return self.fget.__get__(instance, owner)()


In [4]:
def generate_loader_dict(task: BaseTask) -> dict:
    loader_dict = {}
    for split, table in [
        ("train", task.get_table("train")),
        ("val",task.get_table("val")),
        ("test", task.get_table("test")),
    ]:
        table_input = get_node_train_table_input(
            table=table,
            task=task,
        )
        loader_dict[split] = NeighborLoader(
            data,
            num_neighbors=[
                128 for i in range(2)
            ],  # we sample subgraphs of depth 2, 128 neighbors per node.
            time_attr="time",
            input_nodes=table_input.nodes,
            input_time=table_input.time,
            transform=table_input.transform,
            batch_size=512,
            temporal_strategy="uniform",
            shuffle=split == "train",
            num_workers=0,
            persistent_workers=False,
        )
    return loader_dict

In [5]:
taska_loader_dict = generate_loader_dict(task_a)
taskb_loader_dict = generate_loader_dict(task_b)

In [6]:
# declare the model
from torch.nn import BCEWithLogitsLoss
import copy
from typing import Any, Dict, List, Optional

import torch
import torch_frame
from torch import Tensor
from torch.nn import Embedding, ModuleDict
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_frame.nn.models.resnet import FCResidualBlock
from torch_geometric.typing import NodeType
from torch_frame.nn.encoder import StypeWiseFeatureEncoder
from relbench.modeling.nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder

In [7]:
class FeatureEncodingPart(torch.nn.Module):
    def __init__(
        self,
        data:HeteroData,
        node_to_col_stats: Dict[str, Dict[str, Dict[StatType, Tensor]]],
        channels: int
    ):
        super().__init__()
        self.encoders = torch.nn.ModuleDict()
        # node_type : StypeWiseFeatureEncoder
        
        node_to_col_names_dict = {
            node_type: data[node_type].tf.col_names_dict
            for node_type in data.node_types
        }
        # node_type:  {stype: [col_name]}
        
        default_stype_encoder_cls_kwargs: Dict[torch_frame.stype, Any] = {
            torch_frame.categorical: (torch_frame.nn.EmbeddingEncoder, {}),
            torch_frame.numerical: (torch_frame.nn.LinearEncoder, {}),
            torch_frame.multicategorical: (
                torch_frame.nn.MultiCategoricalEmbeddingEncoder,
                {},
            ),
            torch_frame.embedding: (torch_frame.nn.LinearEmbeddingEncoder, {}),
            torch_frame.timestamp: (torch_frame.nn.TimestampEncoder, {}),
        }
                
        for node_type in node_to_col_names_dict.keys():
            stype_encoder_dict = {
                stype: default_stype_encoder_cls_kwargs[stype][0](
                    **default_stype_encoder_cls_kwargs[stype][1]
                )
                for stype in node_to_col_names_dict[node_type].keys()
            }
            self.encoders.update({node_type: StypeWiseFeatureEncoder(
                out_channels=channels,
                col_stats = node_to_col_stats[node_type],
                col_names_dict=node_to_col_names_dict[node_type],
                stype_encoder_dict=stype_encoder_dict
            )})
    
    def reset_parameters(self):
        for encoder in self.encoders.values():
            encoder.reset_parameters()
    
    def forward(
        self,
        tf_dict: Dict[NodeType, torch_frame.TensorFrame],
    )-> Dict[NodeType, Tensor]:
        x_dict = {}
        for node_type, tf in tf_dict.items():
            x, _ = self.encoders[node_type](tf)
            x_dict[node_type] = x
        return x_dict

In [8]:
class NodeRepresentationPart(torch.nn.Module):
    def __init__(
        self,
        data: HeteroData,
        channels:int,
        num_layers: int,
        normalization: Optional[str] = "layer_norm",
        dropout_prob: float = 0.0
    ):
        super().__init__()
        
        self.mappers = torch.nn.ModuleDict()
        
        node_to_col_names_dict = {
            node_type: data[node_type].tf.col_names_dict
            for node_type in data.node_types
        } 
        # node_type:  {stype: [col_name]}
        
        for node_type, type_to_col_names in node_to_col_names_dict.items():
            col_cnt = 0
            for cols in type_to_col_names.values():
                col_cnt += len(cols)
            in_channels = col_cnt * channels
            backbone = torch.nn.Sequential(*[
                FCResidualBlock(
                    in_channels if i == 0 else channels,
                    channels,
                    normalization=normalization,
                    dropout_prob=dropout_prob
                )
            for i in range(num_layers)], 
            torch.nn.LayerNorm(channels),
            torch.nn.ReLU(),
            torch.nn.Linear(channels, channels)
            )
            self.mappers.update(
                {
                    node_type: backbone
                }
            )
    def reset_parameters(self):
        for mapper in self.mappers.values():
            for layer in mapper:
                if hasattr(layer, "reset_parameters"):
                    layer.reset_parameters()
    
    def forward(self, 
                x_dict: Dict[NodeType, Tensor]
        ) -> Dict[NodeType, Tensor]:
        out_dict = {}
        for node_type, x in x_dict.items():
            # Flattening the encoder output
            x = x.view(x.size(0), math.prod(x.shape[1:]))
            out_dict[node_type] = self.mappers[node_type](x)
        return out_dict

In [9]:
class CompositeModel(torch.nn.Module):
    
    def __init__(
        self,
        data: HeteroData,
        channels: int,
        out_channels:int,
        aggr:str,
        norm:str,
        num_layer:int,
        feature_encoder: torch.nn.Module,
        node_encoder: torch.nn.Module,
        temporal_encoder: torch.nn.Module,
    ):
        super().__init__()
        self.gnn = HeteroGraphSAGE(
            node_types = data.node_types,
            edge_types= data.edge_types,
            channels=channels,
            aggr = aggr,
            num_layers=num_layer
        )
        
        self.head = MLP(
            channels,
            out_channels=out_channels,
            norm=norm,
            num_layers=1
        )
        
        self.feature_encoder = feature_encoder
        self.node_encoder = node_encoder
        self.temporal_encoder = temporal_encoder
    
    
    def reset_parameters(self):
        self.gnn.reset_parameters()
        self.head.reset_parameters()
    
    def forward(
        self,
        batch: HeteroData,
        entity_table: NodeType
    ) -> Tensor:
        seed_time = batch[entity_table].seed_time
        x_dict = self.feature_encoder(batch.tf_dict)
        x_dict = self.node_encoder(x_dict)
        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )
        
        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time
        
        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
            batch.num_sampled_nodes_dict,
            batch.num_sampled_edges_dict,
        )
        return self.head(x_dict[entity_table][: seed_time.size(0)]) 

In [27]:
@torch.no_grad()
def valid(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    model.eval()
    pred_list = []
    pred_hat_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
        pred_hat_list.append(batch[task.entity_table].y.detach().cpu())
    return torch.cat(pred_list, dim=0), torch.cat(pred_hat_list, dim=0)

@torch.no_grad()
def test(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    model.eval()
    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0)

In [26]:
channels = 128

task_a_temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=channels,
)

task_a_feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=channels
)


task_a_node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.2
)



task_a_model =  CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    aggr="mean",
    norm="layer_norm",
    num_layer=2,
    feature_encoder=task_a_feat_encoder,
    node_encoder=task_a_node_encoder,
    temporal_encoder=task_a_temporal_encoder
)


In [28]:
# Train Task A
optimizer = torch.optim.Adam(task_a_model.parameters(), lr=0.005)
epochs = 15
loss_fn = BCEWithLogitsLoss()
tune_metric = "auroc"
higher_is_better = True


In [29]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
task_a_model.to(device)
best_epoch = 0
task_a_model.reset_parameters()
for epoch in range(1, epochs + 1):
    task_a_model.train()
    loss_accum = count_accum = 0
    for batch in tqdm(taska_loader_dict["train"]):
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = task_a_model(
            batch,
            task_a.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        loss = loss_fn(pred, batch[task_a.entity_table].y.float())
        
        loss.backward()
        optimizer.step()
        
        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)
    
    train_loss = loss_accum / count_accum
    val_logits = test(taska_loader_dict["val"], task_a_model, task_a)
    val_logits = torch.sigmoid(val_logits).numpy()
    
    val_pred = (val_logits > 0.5).astype(int)
    val_pred_hat = task_a.get_table("val").df[task_a.target_col].to_numpy()
    val_metrics = {
            "auroc": roc_auc_score(val_pred_hat, val_logits),
        "accuracy": accuracy_score(val_pred_hat, val_pred),
        "precision": precision_score(val_pred_hat, val_pred),
        "recall": recall_score(val_pred_hat, val_pred),
        "f1": f1_score(val_pred_hat, val_pred),
    }
    
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_epoch = epoch
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(task_a_model.state_dict())

best_epoch

  0%|          | 0/24 [00:00<?, ?it/s]

100%|██████████| 24/24 [00:03<00:00,  6.05it/s]


Epoch: 01, Train loss: 0.6837852612244639, Val metrics: {'auroc': 0.5895911793744611, 'accuracy': 0.584375, 'precision': 0.584375, 'recall': 1.0, 'f1': 0.73767258382643}


100%|██████████| 24/24 [00:03<00:00,  6.19it/s]


Epoch: 02, Train loss: 0.6419359451336564, Val metrics: {'auroc': 0.6175353714053404, 'accuracy': 0.584375, 'precision': 0.584375, 'recall': 1.0, 'f1': 0.73767258382643}


100%|██████████| 24/24 [00:03<00:00,  6.54it/s]


Epoch: 03, Train loss: 0.618132569490919, Val metrics: {'auroc': 0.6466076063599283, 'accuracy': 0.628125, 'precision': 0.6588785046728972, 'recall': 0.7540106951871658, 'f1': 0.7032418952618454}


100%|██████████| 24/24 [00:03<00:00,  6.30it/s]


Epoch: 04, Train loss: 0.6067917583305596, Val metrics: {'auroc': 0.655538132318318, 'accuracy': 0.63125, 'precision': 0.6396761133603239, 'recall': 0.8449197860962567, 'f1': 0.728110599078341}


100%|██████████| 24/24 [00:03<00:00,  6.37it/s]


Epoch: 05, Train loss: 0.590242871467682, Val metrics: {'auroc': 0.6488815621942556, 'accuracy': 0.63125, 'precision': 0.6308470290771175, 'recall': 0.8894830659536542, 'f1': 0.7381656804733728}


100%|██████████| 24/24 [00:03<00:00,  6.38it/s]


Epoch: 06, Train loss: 0.5707895852574433, Val metrics: {'auroc': 0.6368371910167576, 'accuracy': 0.5947916666666667, 'precision': 0.6365079365079365, 'recall': 0.714795008912656, 'f1': 0.6733837111670865}


100%|██████████| 24/24 [00:03<00:00,  6.36it/s]


Epoch: 07, Train loss: 0.5549517495047356, Val metrics: {'auroc': 0.631954217093536, 'accuracy': 0.6052083333333333, 'precision': 0.6307471264367817, 'recall': 0.7825311942959001, 'f1': 0.6984884645982498}


100%|██████████| 24/24 [00:03<00:00,  6.35it/s]


Epoch: 08, Train loss: 0.5455251353292321, Val metrics: {'auroc': 0.641662087482521, 'accuracy': 0.5895833333333333, 'precision': 0.6602687140115163, 'recall': 0.6131907308377896, 'f1': 0.6358595194085028}


100%|██████████| 24/24 [00:03<00:00,  6.33it/s]


Epoch: 09, Train loss: 0.526058482362446, Val metrics: {'auroc': 0.6344292102806035, 'accuracy': 0.61875, 'precision': 0.6644182124789207, 'recall': 0.7023172905525846, 'f1': 0.682842287694974}


100%|██████████| 24/24 [00:04<00:00,  5.74it/s]


Epoch: 10, Train loss: 0.5063523848702436, Val metrics: {'auroc': 0.6382087125121182, 'accuracy': 0.6135416666666667, 'precision': 0.6655052264808362, 'recall': 0.6809269162210339, 'f1': 0.6731277533039648}


100%|██████████| 24/24 [00:04<00:00,  5.82it/s]


Epoch: 11, Train loss: 0.48740576590043777, Val metrics: {'auroc': 0.6311321976956652, 'accuracy': 0.609375, 'precision': 0.6709558823529411, 'recall': 0.6506238859180036, 'f1': 0.6606334841628959}


100%|██████████| 24/24 [00:03<00:00,  6.13it/s]


Epoch: 12, Train loss: 0.47294296029807925, Val metrics: {'auroc': 0.6230147561416911, 'accuracy': 0.5927083333333333, 'precision': 0.6349206349206349, 'recall': 0.7130124777183601, 'f1': 0.6717044500419815}


100%|██████████| 24/24 [00:03<00:00,  6.24it/s]


Epoch: 13, Train loss: 0.4575477092936374, Val metrics: {'auroc': 0.6239886704282989, 'accuracy': 0.6125, 'precision': 0.6666666666666666, 'recall': 0.6737967914438503, 'f1': 0.6702127659574468}


100%|██████████| 24/24 [00:03<00:00,  6.28it/s]


Epoch: 14, Train loss: 0.43261125400023837, Val metrics: {'auroc': 0.6264457936284562, 'accuracy': 0.60625, 'precision': 0.6414219474497682, 'recall': 0.7397504456327986, 'f1': 0.6870860927152318}


100%|██████████| 24/24 [00:03<00:00,  6.37it/s]


Epoch: 15, Train loss: 0.42252855523439253, Val metrics: {'auroc': 0.6376458079244457, 'accuracy': 0.6145833333333334, 'precision': 0.6547811993517018, 'recall': 0.7201426024955436, 'f1': 0.6859083191850595}


4

In [31]:
# test task A
task_a_model.load_state_dict(state_dict)
test_logits = test(taska_loader_dict["test"], task_a_model, task_a)
test_logits =  torch.sigmoid(test_logits).numpy()

test_pred = (test_logits > 0.5).astype(int)
test_pred_hat = task_a.get_table("test", mask_input_cols = False).df[task_a.target_col].to_numpy()
test_metrics = {
    "auroc": roc_auc_score(test_pred_hat, test_logits),
    "accuracy": accuracy_score(test_pred_hat, test_pred),
    "precision": precision_score(test_pred_hat, test_pred),
    "recall": recall_score(test_pred_hat, test_pred),
    "f1score": f1_score(test_pred_hat, test_pred),
}
test_metrics

{'auroc': 0.7006465438959718,
 'accuracy': 0.6472727272727272,
 'precision': 0.6486068111455109,
 'recall': 0.8674948240165632,
 'f1score': 0.7422497785651019}

In [None]:
# train Task B

In [50]:
task_b_temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=channels,
)

task_b_feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=channels
)


task_b_node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.3
)

task_b_model =  CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    aggr="mean",
    norm="layer_norm",
    num_layer=2,
    feature_encoder=task_b_feat_encoder,
    node_encoder=task_b_node_encoder,
    temporal_encoder=task_b_temporal_encoder
)

In [53]:
optimizer = torch.optim.Adam(task_b_model.parameters(), lr=0.002)
epochs = 20
loss_fn = L1Loss()
tune_metric = "mae"
higher_is_better = False

In [54]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
task_b_model.to(device)
best_epoch = 0
early_stop = 20
# train
for epoch in range(1, epochs + 1):
    task_b_model.train()
    
    cnt = 0
    loss_accum = count_accum = 0
    for batch in tqdm(taskb_loader_dict["train"]):
        cnt += 1
        if cnt > early_stop:
            break
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = task_b_model(
            batch,
            task_b.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        loss = loss_fn(pred, batch[task_b.entity_table].y.float())
        
        loss.backward()
        optimizer.step()
        
        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    train_loss = loss_accum / count_accum
    val_pred_hat = task_b.get_table("val").df[task_b.target_col].to_numpy()
    val_logits = test(taskb_loader_dict["val"], task_b_model, task_b)
    val_logits = val_logits.numpy()
    
    val_metrics = {
        "mae": mean_absolute_error(val_pred_hat, val_logits),
        "r2": r2_score(val_pred_hat, val_logits),
        "rmse": root_mean_squared_error(val_pred_hat, val_logits),
    }
    
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

    
    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_epoch = epoch
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(task_b_model.state_dict())

best_epoch

  7%|▋         | 20/296 [00:02<00:32,  8.39it/s]


Epoch: 01, Train loss: 0.41037669628858564, Val metrics: {'mae': 0.4475658469513964, 'r2': -0.4558490914968616, 'rmse': 0.5763787357807583}


  7%|▋         | 20/296 [00:02<00:29,  9.35it/s]


Epoch: 02, Train loss: 0.3218929409980774, Val metrics: {'mae': 0.43447137506095423, 'r2': -0.4276799351072853, 'rmse': 0.5707753357201678}


  7%|▋         | 20/296 [00:02<00:31,  8.69it/s]


Epoch: 03, Train loss: 0.28878341168165206, Val metrics: {'mae': 0.4268082009795004, 'r2': -0.3815311949320759, 'rmse': 0.5614746047235288}


  7%|▋         | 20/296 [00:02<00:33,  8.35it/s]


Epoch: 04, Train loss: 0.2742943085730076, Val metrics: {'mae': 0.42401254670573446, 'r2': -0.4522040582154665, 'rmse': 0.5756567391251746}


  7%|▋         | 20/296 [00:02<00:32,  8.54it/s]


Epoch: 05, Train loss: 0.2792063236236572, Val metrics: {'mae': 0.4285064555773019, 'r2': -0.4072439802881451, 'rmse': 0.5666755437221486}


  7%|▋         | 20/296 [00:02<00:32,  8.59it/s]


Epoch: 06, Train loss: 0.25812099650502207, Val metrics: {'mae': 0.4335978127386829, 'r2': -0.47608721061830117, 'rmse': 0.5803711007978084}


  7%|▋         | 20/296 [00:02<00:37,  7.31it/s]


Epoch: 07, Train loss: 0.2523935079574585, Val metrics: {'mae': 0.42417674058638466, 'r2': -0.482422613477671, 'rmse': 0.581615250814505}


  7%|▋         | 20/296 [00:02<00:33,  8.28it/s]


Epoch: 08, Train loss: 0.24488811194896698, Val metrics: {'mae': 0.43002036074424815, 'r2': -0.5023935676136224, 'rmse': 0.5855198568549347}


  7%|▋         | 20/296 [00:02<00:29,  9.35it/s]


Epoch: 09, Train loss: 0.25629867538809775, Val metrics: {'mae': 0.4144787055295133, 'r2': -0.4362478551514435, 'rmse': 0.5724854677116512}


  7%|▋         | 20/296 [00:02<00:31,  8.66it/s]


Epoch: 10, Train loss: 0.24291283264756203, Val metrics: {'mae': 0.40582686573382387, 'r2': -0.41682869668256406, 'rmse': 0.5686020779534545}


  7%|▋         | 20/296 [00:02<00:29,  9.23it/s]


Epoch: 11, Train loss: 0.24397690892219542, Val metrics: {'mae': 0.41573333441555393, 'r2': -0.40494048115120607, 'rmse': 0.5662115619059531}


  7%|▋         | 20/296 [00:02<00:29,  9.34it/s]


Epoch: 12, Train loss: 0.23071654587984086, Val metrics: {'mae': 0.40817946900933894, 'r2': -0.4379565315229117, 'rmse': 0.5728259039486684}


  7%|▋         | 20/296 [00:02<00:29,  9.30it/s]


Epoch: 13, Train loss: 0.22962623834609985, Val metrics: {'mae': 0.40849792874504287, 'r2': -0.3449821935057402, 'rmse': 0.5539977967526727}


  7%|▋         | 20/296 [00:02<00:30,  9.17it/s]


Epoch: 14, Train loss: 0.2282417193055153, Val metrics: {'mae': 0.4050899251179319, 'r2': -0.40803072336446844, 'rmse': 0.5668339262671477}


  7%|▋         | 20/296 [00:02<00:31,  8.71it/s]


Epoch: 15, Train loss: 0.22397611513733864, Val metrics: {'mae': 0.3990583078126809, 'r2': -0.3794551417348908, 'rmse': 0.5610525768339572}


  7%|▋         | 20/296 [00:02<00:32,  8.41it/s]


Epoch: 16, Train loss: 0.2209170401096344, Val metrics: {'mae': 0.4019329465990645, 'r2': -0.3392639645544897, 'rmse': 0.5528188738935781}


  7%|▋         | 20/296 [00:02<00:32,  8.44it/s]


Epoch: 17, Train loss: 0.21357216611504554, Val metrics: {'mae': 0.4003906602231768, 'r2': -0.36502269153562517, 'rmse': 0.5581098735253793}


  7%|▋         | 20/296 [00:02<00:32,  8.40it/s]


Epoch: 18, Train loss: 0.22007163912057875, Val metrics: {'mae': 0.41247513787363294, 'r2': -0.43537786880026563, 'rmse': 0.5723120540537806}


  7%|▋         | 20/296 [00:02<00:32,  8.41it/s]


Epoch: 19, Train loss: 0.2103615455329418, Val metrics: {'mae': 0.40700702387336735, 'r2': -0.4448564994177324, 'rmse': 0.574198598516727}


  7%|▋         | 20/296 [00:02<00:32,  8.39it/s]


Epoch: 20, Train loss: 0.20522889345884324, Val metrics: {'mae': 0.4036564826480426, 'r2': -0.3692540406407825, 'rmse': 0.558974229278604}


15

In [55]:
# test
task_b_model.load_state_dict(state_dict)
logits = test(taskb_loader_dict["test"], task_b_model, task_b)
logits = logits.numpy()
pred_hat = task_b.get_table("test", mask_input_cols=False).df[task_b.target_col].to_numpy()
test_metrics = {
        "mae": mean_absolute_error(pred_hat, logits),
        "r2": r2_score(pred_hat, logits),
        "rmse": root_mean_squared_error(pred_hat, logits),
}
test_metrics

{'mae': 0.4027192653319429,
 'r2': -0.43432623996859543,
 'rmse': 0.5763256675733064}

: 