In [1]:
import sys

sys.path.insert(0, '..')

In [2]:
from data.dataset import Dataset
from sklearn.metrics import mean_squared_error

In [3]:
TRN = Dataset.load_csv("ds/TRN_LARGE")
TST1 = Dataset.load_csv("ds/TST_1")
TST2 = Dataset.load_csv("ds/TST_2")

In [4]:
from data.featurization.dgl_Graph import DGL_Graph

featurizer = DGL_Graph(
        graph_type="BI_GRAPH",
        featurize_type="Canonical",
        self_loop=True
)
TRN.X = TRN.featurize(featurizer)
TST1.X = TST1.featurize(featurizer)
TST2.X = TST2.featurize(featurizer)

In [5]:
from model.dgl.Weave import Weave
import torch

MODEL = Weave(
        task_type="regression",
        # Weave Configuration
        n_tasks=1,
        node_in_feats=featurizer.get_node_feat_size(),
        edge_in_feats=featurizer.get_edge_feat_size(),
        num_gnn_layers=3,
        gnn_hidden_feats=50,
        gnn_activation=torch.nn.functional.relu,
        graph_feats=50,
        gaussian_expand=True,
        readout_activation=torch.nn.Tanh(),
        # Abstract DGL Configuration
        lr=0.01,
        y_name="LogS exp (mol/L)",
        weight_decay=0,
        batch_size=4096
)
MODEL.fit(
        dataset=TRN,
        epochs=80,
        extra_eval_set=TST1,
        cv=5
)

print(f"TST1 : RMSE {mean_squared_error(TST1.y, MODEL.predict(TST1).cpu())}")
print(f"TST2 : RMSE {mean_squared_error(TST2.y, MODEL.predict(TST2).cpu())}")
print(f"TST1^: RMSE {mean_squared_error(TST1.y, MODEL.predict(TST1, True).cpu())}")
print(f"TST2^: RMSE {mean_squared_error(TST2.y, MODEL.predict(TST2, True).cpu())}")

[INFO] Expect to use 'DGL_Graph' to featurize SMILES
[INFO] Device cuda


[CV 0]: 100%|██████████| 80/80 [01:48<00:00,  1.36s/it, loss: 4.341, val_rmse: 3.558, extra_rmse: 3.038]
[CV 1]: 100%|██████████| 80/80 [01:42<00:00,  1.28s/it, loss: 3.933, val_rmse: 6.939, extra_rmse: 2.700]
[CV 2]: 100%|██████████| 80/80 [01:36<00:00,  1.20s/it, loss: 3.887, val_rmse: 8.076, extra_rmse: 2.708]
[CV 3]: 100%|██████████| 80/80 [01:48<00:00,  1.35s/it, loss: 4.176, val_rmse: 5.393, extra_rmse: 2.949]
[CV 4]: 100%|██████████| 80/80 [01:46<00:00,  1.33s/it, loss: 4.295, val_rmse: 3.867, extra_rmse: 2.732]


TST1 : RMSE 2.7323691218529773
TST2 : RMSE 9.772260662464246
TST1^: RMSE 2.585474830674696
TST2^: RMSE 9.45174161053901
