In [1]:
import sys

sys.path.insert(0, '..')

In [2]:
from data.dataset import Dataset
from sklearn.metrics import mean_squared_error

In [3]:
TRN = Dataset.load_csv("ds/TRN_LARGE")
TST1 = Dataset.load_csv("ds/TST_1")
TST2 = Dataset.load_csv("ds/TST_2")

In [4]:
from data.featurization.dgl_Graph import DGL_Graph

featurizer = DGL_Graph(
        graph_type="BI_GRAPH",
        featurize_type="Canonical",
        self_loop=True
)
TRN.X = TRN.featurize(featurizer)
TST1.X = TST1.featurize(featurizer)
TST2.X = TST2.featurize(featurizer)

In [6]:
from model.dgl.PAGTN import PAGTN
import torch

MODEL = PAGTN(
        task_type="regression",
        # PAGTN Configuration
        n_tasks=1,
        node_in_feats=featurizer.get_node_feat_size(),
        node_out_feats=featurizer.get_node_feat_size(),
        node_hid_feats=200,
        edge_feats=featurizer.get_edge_feat_size(),
        depth=5,
        nheads=1,
        dropout=0.1,
        activation=torch.nn.functional.leaky_relu,
        mode="mean",
        # Abstract DGL Configuration
        lr=0.01,
        y_name="LogS exp (mol/L)",
        weight_decay=0,
        batch_size=4096
)
MODEL.fit(
        dataset=TRN,
        epochs=80,
        extra_eval_set=TST1,
        cv=5
)

print(f"TST1 : RMSE {mean_squared_error(TST1.y, MODEL.predict(TST1).cpu())}")
print(f"TST2 : RMSE {mean_squared_error(TST2.y, MODEL.predict(TST2).cpu())}")
print(f"TST1^: RMSE {mean_squared_error(TST1.y, MODEL.predict(TST1, True).cpu())}")
print(f"TST2^: RMSE {mean_squared_error(TST2.y, MODEL.predict(TST2, True).cpu())}")

[INFO] Expect to use 'DGL_Graph' to featurize SMILES
[INFO] Device cuda


[CV 0]: 100%|██████████| 80/80 [02:08<00:00,  1.60s/it, loss: 1.709, val_rmse: 1.934, extra_rmse: 1.571]   
[CV 1]: 100%|██████████| 80/80 [02:14<00:00,  1.68s/it, loss: 1.108, val_rmse: 3.533, extra_rmse: 1.988]
[CV 2]: 100%|██████████| 80/80 [02:24<00:00,  1.81s/it, loss: 0.969, val_rmse: 2.342, extra_rmse: 1.750]
[CV 3]: 100%|██████████| 80/80 [02:22<00:00,  1.78s/it, loss: 1.080, val_rmse: 0.803, extra_rmse: 1.609]
[CV 4]: 100%|██████████| 80/80 [02:23<00:00,  1.80s/it, loss: 0.875, val_rmse: 0.963, extra_rmse: 1.186]


TST1 : RMSE 1.1859362750577056
TST2 : RMSE 3.2553124059535175
TST1^: RMSE 1.3620400190203243
TST2^: RMSE 4.221227331709656
