In [1]:
%cd ..
from model.base import CompositeModel, FeatureEncodingPart, NodeRepresentationPart
from relbench.modeling.nn import HeteroTemporalEncoder

/home/lingze/embedding_fusion


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import relbench
from relbench.base import Table, Database, Dataset, EntityTask
from relbench.datasets import get_dataset
from relbench.tasks import get_task
from relbench.base import BaseTask
from torch_geometric.seed import seed_everything
from relbench.modeling.utils import get_stype_proposal
from relbench.modeling.graph import make_pkey_fkey_graph
from relbench.modeling.graph import get_node_train_table_input
from torch_geometric.loader import NeighborLoader

import os
import math
import numpy as np
from tqdm import tqdm
import copy

import torch
from torch import Tensor
import torch_geometric
import torch_frame

from torch_frame.config.text_embedder import TextEmbedderConfig
from typing import List, Optional
from sentence_transformers import SentenceTransformer

from torch.nn import L1Loss, BCEWithLogitsLoss
from sklearn.metrics import mean_absolute_error, r2_score, root_mean_squared_error
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

seed_everything(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
dataset = get_dataset(name="rel-trial", download=True)
db = dataset.get_db()
task_a = get_task("rel-trial", "study-outcome", download = True)
task_b = get_task("rel-trial", "site-success", download = True)

Loading Database object from /home/lingze/.cache/relbench/rel-trial/db...
Done in 8.01 seconds.


In [4]:
col_to_stype_dict = get_stype_proposal(db)

class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=256
)


root_dir = "/home/lingze/embedding_fusion/data"
data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,  # speficied column types
    text_embedder_cfg=text_embedder_cfg,  # our chosen text encoder
    cache_dir=os.path.join(
        root_dir, f"rel-trial_materialized_cache"
    ),  # store materialized graph for convenience
)

  return self.fget.__get__(instance, owner)()


### Multi-task joint train with shared feature encoder.


In [5]:
def generate_loader_dict(task: BaseTask) -> dict:
    loader_dict = {}
    for split, table in [
        ("train", task.get_table("train")),
        ("val",task.get_table("val")),
        ("test", task.get_table("test")),
    ]:
        table_input = get_node_train_table_input(
            table=table,
            task=task,
        )
        loader_dict[split] = NeighborLoader(
            data,
            num_neighbors=[
                128 for i in range(2)
            ],  # we sample subgraphs of depth 2, 128 neighbors per node.
            time_attr="time",
            input_nodes=table_input.nodes,
            input_time=table_input.time,
            transform=table_input.transform,
            batch_size=512,
            temporal_strategy="uniform",
            shuffle=split == "train",
            num_workers=0,
            persistent_workers=False,
        )
    return loader_dict

In [6]:
taska_loader_dict = generate_loader_dict(task_a)
taskb_loader_dict = generate_loader_dict(task_b)

In [8]:
channels = 128
temporal_encoder = HeteroTemporalEncoder(
    node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=channels,
)

feat_encoder = FeatureEncodingPart(
    data=data,
    node_to_col_stats=col_stats_dict,
    channels=channels
)

task_a_node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.2
)

task_b_node_encoder = NodeRepresentationPart(
    data=data,
    channels=channels,
    num_layers=1,
    normalization="layer_norm",
    dropout_prob=0.4
)

task_a_model =  CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    dropout=0.2,
    aggr="mean",
    norm="layer_norm",
    num_layer=2,
    feature_encoder=feat_encoder,
    node_encoder=task_a_node_encoder,
    temporal_encoder=temporal_encoder
)

task_b_model =  CompositeModel(
    data=data,
    channels=channels,
    out_channels=1,
    dropout=0.3,
    aggr="mean",
    norm="layer_norm",
    num_layer=2,
    feature_encoder=feat_encoder,
    node_encoder=task_b_node_encoder,
    temporal_encoder=temporal_encoder
)

In [9]:
@torch.no_grad()
def valid(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    model.eval()
    pred_list = []
    pred_hat_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
        pred_hat_list.append(batch[task.entity_table].y.detach().cpu())
    return torch.cat(pred_list, dim=0), torch.cat(pred_hat_list, dim=0)

@torch.no_grad()
def test(loader: NeighborLoader, model: torch.nn.Module, task: BaseTask)-> np.ndarray:
    model.eval()
    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0)

In [27]:
params = set(list(task_a_model.parameters()) + list(task_b_model.parameters()))
optimizer = torch.optim.Adam(params, lr = 0.005)
epochs = 20
task_a_loss_fn = BCEWithLogitsLoss()
task_b_loss_fn = L1Loss()

task_a_model.to(device)
task_b_model.to(device)

batch_size = 512
higher_is_better = False
state_dict_a = state_dict_b = None
best_val_metric = float("-inf") if higher_is_better else float("inf")
best_epoch = None
early_stop = 20

task_a_model.reset_parameters()
task_b_model.reset_parameters()


for epoch in range(1, epochs + 1):
    task_a_model.train(), task_b_model.train()
    loss_a_accum = loss_b_accum = count_accum = loss = 0
    freshape = lambda x: x.view(-1) if x.size(1) == 1 else x
    cnt = 0
    for batch_a, batch_b in tqdm(zip(taska_loader_dict["train"], taskb_loader_dict["train"])):
        cnt += 1
        if cnt > early_stop:
            break
        
        batch_a, batch_b = batch_a.to(device), batch_b.to(device)
        pred_a = task_a_model(
            batch_a,
            task_a.entity_table,
        )
        pred_b = task_b_model(
            batch_b,
            task_b.entity_table,
        )
        pred_a = freshape(pred_a)
        pred_b = freshape(pred_b)
         
        optimizer.zero_grad()
        loss_a = task_a_loss_fn(
            pred_a.float(),
            batch_a[task_a.entity_table].y.float(),
        )
        loss_b = task_b_loss_fn(
            pred_b.float(),
            batch_b[task_b.entity_table].y.float(),
        )
        loss = loss_a + loss_b
            
        loss.backward()
        optimizer.step() 

        loss_a_accum += loss_a.detach().item()* batch_size
        loss_b_accum += loss_b.detach().item()* batch_size
        loss = loss.detach().item()* batch_size
        count_accum += batch_size
    
    train_loss = loss / count_accum
    train_loss_a = loss_a_accum / count_accum
    train_loss_b = loss_b_accum / count_accum
    
    val_a_logits, val_a_hat = valid(taska_loader_dict["val"], task_a_model, task_a)
    val_a_loss = task_a_loss_fn(
        val_a_logits.float(),
        val_a_hat.float()
    )

    val_a_logits = torch.sigmoid(val_a_logits).numpy()
    val_a_pred = (val_a_logits > 0.5).astype(int)
    val_a_pred_hat = task_a.get_table("val").df[task_a.target_col].to_numpy()
    val_a_metrics = {
        "auroc": roc_auc_score(val_a_pred_hat, val_a_logits),
        "accuracy": accuracy_score(val_a_pred_hat, val_a_pred),
        "precision": precision_score(val_a_pred_hat, val_a_pred),
        "recall": recall_score(val_a_pred_hat, val_a_pred),
        "f1score": f1_score(val_a_pred_hat, val_a_pred),
        "loss": val_a_loss.item()
    }
    
    val_b_logits, val_b_hat = valid(taskb_loader_dict["val"], task_b_model, task_b)
    val_b_loss = task_b_loss_fn(
        val_b_logits.float(),
        val_b_hat.float()
    )
    
    val_b_logits = val_b_logits.numpy()
    val_b_pred_hat = task_b.get_table("val").df[task_b.target_col].to_numpy()
    
    val_b_metrics = {
        "mae": mean_absolute_error(val_b_pred_hat, val_b_logits),
        "r2": r2_score(val_b_pred_hat, val_b_logits),
        "rmse": root_mean_squared_error(val_b_pred_hat, val_b_logits),
        "loss": val_b_loss.item()
    }
    
    eval_metric = val_a_loss + val_b_loss
    
    
    print("*"*30 + f"<Epoch: {epoch:02d}>" + "*"*30)
    print(f", Train loss: {train_loss}, Val metrics: {eval_metric}")
    print(f"Task A: Train loss: {train_loss_a}, Val metrics: {val_a_metrics}")
    print(f"Task B: Train loss: {train_loss_b}, Val metrics: {val_b_metrics}")
    
    
    if (higher_is_better and eval_metric > best_val_metric) or (
        not higher_is_better and eval_metric < best_val_metric
    ):
        best_val_metric = eval_metric
        best_epoch = epoch
        state_dict_a = copy.deepcopy(task_a_model.state_dict())
        state_dict_b = copy.deepcopy(task_b_model.state_dict())

0it [00:00, ?it/s]

20it [00:04,  4.01it/s]


******************************<Epoch: 01>******************************
, Train loss: 0.05026651620864868, Val metrics: 1.1287024021148682
Task A: Train loss: 0.5926268965005874, Val metrics: {'auroc': 0.6309490303298353, 'accuracy': 0.6072916666666667, 'precision': 0.6762452107279694, 'recall': 0.6292335115864528, 'f1score': 0.6518928901200369, 'loss': 0.6830201745033264}
Task B: Train loss: 0.5859857603907586, Val metrics: {'mae': 0.44568227610100564, 'r2': -0.13410892681006503, 'rmse': 0.5087180586627122, 'loss': 0.44568225741386414}


20it [00:05,  3.71it/s]


******************************<Epoch: 02>******************************
, Train loss: 0.042929285764694215, Val metrics: 1.1156361103057861
Task A: Train loss: 0.5540477871894837, Val metrics: {'auroc': 0.6301672184025126, 'accuracy': 0.6010416666666667, 'precision': 0.6444805194805194, 'recall': 0.7076648841354723, 'f1score': 0.6745964316057774, 'loss': 0.7034967541694641}
Task B: Train loss: 0.33491129279136655, Val metrics: {'mae': 0.41213944470760594, 'r2': -0.21680619809192425, 'rmse': 0.52693916247157, 'loss': 0.4121394157409668}


20it [00:05,  3.65it/s]


******************************<Epoch: 03>******************************
, Train loss: 0.040868809819221495, Val metrics: 1.1149402856826782
Task A: Train loss: 0.5240947663784027, Val metrics: {'auroc': 0.6306407730556337, 'accuracy': 0.6083333333333333, 'precision': 0.653910149750416, 'recall': 0.7005347593582888, 'f1score': 0.6764199655765921, 'loss': 0.7045515179634094}
Task B: Train loss: 0.28786259666085245, Val metrics: {'mae': 0.41038876944678376, 'r2': -0.39710403674847505, 'rmse': 0.5646302529221623, 'loss': 0.4103887677192688}


20it [00:05,  3.76it/s]


******************************<Epoch: 04>******************************
, Train loss: 0.039092570543289185, Val metrics: 1.110368251800537
Task A: Train loss: 0.5121762081980705, Val metrics: {'auroc': 0.6341522254834948, 'accuracy': 0.6052083333333333, 'precision': 0.6791338582677166, 'recall': 0.6149732620320856, 'f1score': 0.6454630495790459, 'loss': 0.7007052302360535}
Task B: Train loss: 0.2694582425057888, Val metrics: {'mae': 0.4096630835599049, 'r2': -0.3848955009868793, 'rmse': 0.5621578407628203, 'loss': 0.4096630811691284}


20it [00:04,  4.08it/s]


******************************<Epoch: 05>******************************
, Train loss: 0.040565529465675355, Val metrics: 1.147159218788147
Task A: Train loss: 0.49893373250961304, Val metrics: {'auroc': 0.6176671625588034, 'accuracy': 0.6104166666666667, 'precision': 0.647244094488189, 'recall': 0.732620320855615, 'f1score': 0.6872909698996655, 'loss': 0.7396120429039001}
Task B: Train loss: 0.272846706956625, Val metrics: {'mae': 0.4075471571056013, 'r2': -0.36871768351646583, 'rmse': 0.5588647393055884, 'loss': 0.4075471758842468}


20it [00:05,  3.68it/s]


******************************<Epoch: 06>******************************
, Train loss: 0.039993378520011905, Val metrics: 1.1748358011245728
Task A: Train loss: 0.49088720232248306, Val metrics: {'auroc': 0.6338529032027485, 'accuracy': 0.615625, 'precision': 0.6672473867595818, 'recall': 0.6827094474153298, 'f1score': 0.6748898678414097, 'loss': 0.7781368494033813}
Task B: Train loss: 0.27077203094959257, Val metrics: {'mae': 0.39669893150439955, 'r2': -0.27865813275431184, 'rmse': 0.5401656868087185, 'loss': 0.396698921918869}


20it [00:05,  3.66it/s]


******************************<Epoch: 07>******************************
, Train loss: 0.037548774480819704, Val metrics: 1.1010288000106812
Task A: Train loss: 0.4867525905370712, Val metrics: {'auroc': 0.6381640375448425, 'accuracy': 0.6135416666666667, 'precision': 0.6557377049180327, 'recall': 0.7130124777183601, 'f1score': 0.6831767719897524, 'loss': 0.7047721147537231}
Task B: Train loss: 0.2599672593176365, Val metrics: {'mae': 0.3962566658711485, 'r2': -0.3864025388044101, 'rmse': 0.5624636265683598, 'loss': 0.3962566554546356}


20it [00:05,  3.64it/s]


******************************<Epoch: 08>******************************
, Train loss: 0.034996092319488525, Val metrics: 1.1824977397918701
Task A: Train loss: 0.46662483662366866, Val metrics: {'auroc': 0.6308194729247362, 'accuracy': 0.6197916666666666, 'precision': 0.6585760517799353, 'recall': 0.7254901960784313, 'f1score': 0.6904156064461408, 'loss': 0.7664511203765869}
Task B: Train loss: 0.25697497352957727, Val metrics: {'mae': 0.4160467043419221, 'r2': -0.4746560564652458, 'rmse': 0.5800896804548824, 'loss': 0.416046679019928}


20it [00:05,  3.69it/s]


******************************<Epoch: 09>******************************
, Train loss: 0.035824379324913024, Val metrics: 1.1833040714263916
Task A: Train loss: 0.453975573182106, Val metrics: {'auroc': 0.6214287948034078, 'accuracy': 0.5854166666666667, 'precision': 0.6607495069033531, 'recall': 0.5971479500891266, 'f1score': 0.6273408239700374, 'loss': 0.7792614698410034}
Task B: Train loss: 0.2524088375270367, Val metrics: {'mae': 0.40404267772634433, 'r2': -0.3963246119522945, 'rmse': 0.5644727312899427, 'loss': 0.40404266119003296}


20it [00:05,  3.68it/s]


******************************<Epoch: 10>******************************
, Train loss: 0.034086483716964724, Val metrics: 1.2176690101623535
Task A: Train loss: 0.43890853971242905, Val metrics: {'auroc': 0.61968647107966, 'accuracy': 0.6010416666666667, 'precision': 0.6508474576271186, 'recall': 0.6844919786096256, 'f1score': 0.6672458731537794, 'loss': 0.8135988712310791}
Task B: Train loss: 0.2522283993661404, Val metrics: {'mae': 0.4040701948192769, 'r2': -0.3477721508128624, 'rmse': 0.5545720904110825, 'loss': 0.4040701985359192}


20it [00:05,  3.72it/s]


******************************<Epoch: 11>******************************
, Train loss: 0.03440073728561401, Val metrics: 1.281085729598999
Task A: Train loss: 0.42387436926364896, Val metrics: {'auroc': 0.6323652267924713, 'accuracy': 0.6260416666666667, 'precision': 0.6430594900849859, 'recall': 0.8092691622103387, 'f1score': 0.7166535122336227, 'loss': 0.8480035066604614}
Task B: Train loss: 0.23571831583976746, Val metrics: {'mae': 0.4330823007371085, 'r2': -0.5236658213618688, 'rmse': 0.5896504482539889, 'loss': 0.43308225274086}


20it [00:05,  3.71it/s]


******************************<Epoch: 12>******************************
, Train loss: 0.031147369742393495, Val metrics: 1.2660472393035889
Task A: Train loss: 0.41925454288721087, Val metrics: {'auroc': 0.6179352123624569, 'accuracy': 0.5989583333333334, 'precision': 0.6329305135951662, 'recall': 0.7468805704099821, 'f1score': 0.6852003270645952, 'loss': 0.8510846495628357}
Task B: Train loss: 0.23641354069113732, Val metrics: {'mae': 0.41496252860189964, 'r2': -0.49464717682916626, 'rmse': 0.5840084262442448, 'loss': 0.4149625301361084}


20it [00:05,  3.52it/s]


******************************<Epoch: 13>******************************
, Train loss: 0.03120954632759094, Val metrics: 1.272217035293579
Task A: Train loss: 0.398478202521801, Val metrics: {'auroc': 0.6066726531122817, 'accuracy': 0.5979166666666667, 'precision': 0.655417406749556, 'recall': 0.6577540106951871, 'f1score': 0.6565836298932385, 'loss': 0.8565502166748047}
Task B: Train loss: 0.23834980353713037, Val metrics: {'mae': 0.41566685359734323, 'r2': -0.45937524082106695, 'rmse': 0.5770763246806077, 'loss': 0.4156668484210968}


20it [00:05,  3.72it/s]


******************************<Epoch: 14>******************************
, Train loss: 0.03304958939552307, Val metrics: 1.3390533924102783
Task A: Train loss: 0.38001088351011275, Val metrics: {'auroc': 0.6220676468354487, 'accuracy': 0.6010416666666667, 'precision': 0.6377708978328174, 'recall': 0.7344028520499108, 'f1score': 0.6826843413421707, 'loss': 0.9324828386306763}
Task B: Train loss: 0.2341544009745121, Val metrics: {'mae': 0.40657058279273334, 'r2': -0.49332763699461646, 'rmse': 0.583750575243022, 'loss': 0.40657058358192444}


20it [00:05,  3.66it/s]


******************************<Epoch: 15>******************************
, Train loss: 0.029594209790229798, Val metrics: 1.3662042617797852
Task A: Train loss: 0.36324301213026045, Val metrics: {'auroc': 0.6203968030593418, 'accuracy': 0.6135416666666667, 'precision': 0.6334269662921348, 'recall': 0.803921568627451, 'f1score': 0.7085624509033779, 'loss': 0.9641737341880798}
Task B: Train loss: 0.2387396514415741, Val metrics: {'mae': 0.40203044950752115, 'r2': -0.4471958719281599, 'rmse': 0.574663254050361, 'loss': 0.40203046798706055}


20it [00:05,  3.68it/s]


******************************<Epoch: 16>******************************
, Train loss: 0.030931225419044493, Val metrics: 1.3480823040008545
Task A: Train loss: 0.3579610735177994, Val metrics: {'auroc': 0.62806749494056, 'accuracy': 0.6177083333333333, 'precision': 0.653968253968254, 'recall': 0.7344028520499108, 'f1score': 0.691855583543241, 'loss': 0.9401235580444336}
Task B: Train loss: 0.24827125668525696, Val metrics: {'mae': 0.40795870671471, 'r2': -0.4070242717804833, 'rmse': 0.5666313053740923, 'loss': 0.4079587161540985}


20it [00:05,  3.68it/s]


******************************<Epoch: 17>******************************
, Train loss: 0.032305216789245604, Val metrics: 1.3214051723480225
Task A: Train loss: 0.3680561125278473, Val metrics: {'auroc': 0.6225724739656627, 'accuracy': 0.5989583333333334, 'precision': 0.6433224755700325, 'recall': 0.7040998217468806, 'f1score': 0.6723404255319149, 'loss': 0.9050995111465454}
Task B: Train loss: 0.2316126123070717, Val metrics: {'mae': 0.4163056144827851, 'r2': -0.5202454263971841, 'rmse': 0.5889882392117016, 'loss': 0.41630563139915466}


20it [00:05,  3.69it/s]


******************************<Epoch: 18>******************************
, Train loss: 0.02933640480041504, Val metrics: 1.3044493198394775
Task A: Train loss: 0.37473154366016387, Val metrics: {'auroc': 0.6247838848458043, 'accuracy': 0.6041666666666666, 'precision': 0.6579406631762653, 'recall': 0.6720142602495544, 'f1score': 0.6649029982363316, 'loss': 0.8964056372642517}
Task B: Train loss: 0.22928790375590324, Val metrics: {'mae': 0.4080436767650708, 'r2': -0.509689026452276, 'rmse': 0.5869397454280564, 'loss': 0.40804368257522583}


20it [00:05,  3.68it/s]


******************************<Epoch: 19>******************************
, Train loss: 0.030465489625930785, Val metrics: 1.4638752937316895
Task A: Train loss: 0.3177923306822777, Val metrics: {'auroc': 0.6136910904712762, 'accuracy': 0.6, 'precision': 0.6512820512820513, 'recall': 0.679144385026738, 'f1score': 0.6649214659685864, 'loss': 1.0492538213729858}
Task B: Train loss: 0.2278740756213665, Val metrics: {'mae': 0.41462141720746887, 'r2': -0.4832272814058509, 'rmse': 0.5817730815311216, 'loss': 0.41462141275405884}


20it [00:05,  3.72it/s]


******************************<Epoch: 20>******************************
, Train loss: 0.02811640202999115, Val metrics: 1.4493356943130493
Task A: Train loss: 0.29670246243476867, Val metrics: {'auroc': 0.6069898453799383, 'accuracy': 0.6041666666666666, 'precision': 0.6596119929453262, 'recall': 0.6666666666666666, 'f1score': 0.6631205673758865, 'loss': 1.0379319190979004}
Task B: Train loss: 0.22655694410204888, Val metrics: {'mae': 0.4114037894582114, 'r2': -0.4728432313728941, 'rmse': 0.5797330127046392, 'loss': 0.4114038050174713}


In [28]:
best_epoch

7

In [29]:
# test the task a 
task_a_model.load_state_dict(state_dict_a)
test_logits = test(taska_loader_dict["test"], task_a_model, task_a)
test_logits = torch.sigmoid(test_logits).numpy()

test_pred = (test_logits > 0.5).astype(int)
test_pred_hat = task_a.get_table("test", mask_input_cols=False).df[task_a.target_col].to_numpy()
test_metrics = {
    "auroc": roc_auc_score(test_pred_hat, test_logits),
    "accuracy": accuracy_score(test_pred_hat, test_pred),
    "precision": precision_score(test_pred_hat, test_pred),
    "recall": recall_score(test_pred_hat, test_pred),
    "f1score": f1_score(test_pred_hat, test_pred),
}
test_metrics

{'auroc': 0.640629351155667,
 'accuracy': 0.6181818181818182,
 'precision': 0.6532846715328468,
 'recall': 0.7412008281573499,
 'f1score': 0.6944713870029098}

In [30]:
# test the task b
task_b_model.load_state_dict(state_dict_b)
logits = test(taskb_loader_dict["test"], task_b_model, task_b)
logits = logits.numpy()
pred_hat = task_b.get_table("test", mask_input_cols=False).df[task_b.target_col].to_numpy()
test_metrics = {
        "mae": mean_absolute_error(pred_hat, logits),
        "r2": r2_score(pred_hat, logits),
        "rmse": root_mean_squared_error(pred_hat, logits),
}
test_metrics

{'mae': 0.3995978656478217,
 'r2': -0.41550429739073613,
 'rmse': 0.5725317640168031}

: 