In [61]:
import numpy as np
import deepchem as dc
import numpy as np
from deepchem.feat import GraphData
import pandas as pd
from deepchem.feat import DMPNNFeaturizer
import deepchem as dc


In [62]:
def get_graphs(weighted_adj_matrix):
    node_features = weighted_adj_matrix.diagonal().reshape(-1, 1)
    edges = np.array(np.nonzero(weighted_adj_matrix))
    edge_weights = weighted_adj_matrix[edges.T[:, 0], edges.T[:, 1]]
    graph = GraphData(node_features=node_features, edge_index=edges, edge_features=edge_weights.reshape(-1, 1))
    return graph

In [63]:
TASKS = [
'alcoholic', 'aldehydic', 'alliaceous', 'almond', 'amber', 'animal',
'anisic', 'apple', 'apricot', 'aromatic', 'balsamic', 'banana', 'beefy',
'bergamot', 'berry', 'bitter', 'black currant', 'brandy', 'burnt',
'buttery', 'cabbage', 'camphoreous', 'caramellic', 'cedar', 'celery',
'chamomile', 'cheesy', 'cherry', 'chocolate', 'cinnamon', 'citrus', 'clean',
'clove', 'cocoa', 'coconut', 'coffee', 'cognac', 'cooked', 'cooling',
'cortex', 'coumarinic', 'creamy', 'cucumber', 'dairy', 'dry', 'earthy',
'ethereal', 'fatty', 'fermented', 'fishy', 'floral', 'fresh', 'fruit skin',
'fruity', 'garlic', 'gassy', 'geranium', 'grape', 'grapefruit', 'grassy',
'green', 'hawthorn', 'hay', 'hazelnut', 'herbal', 'honey', 'hyacinth',
'jasmin', 'juicy', 'ketonic', 'lactonic', 'lavender', 'leafy', 'leathery',
'lemon', 'lily', 'malty', 'meaty', 'medicinal', 'melon', 'metallic',
'milky', 'mint', 'muguet', 'mushroom', 'musk', 'musty', 'natural', 'nutty',
'odorless', 'oily', 'onion', 'orange', 'orangeflower', 'orris', 'ozone',
'peach', 'pear', 'phenolic', 'pine', 'pineapple', 'plum', 'popcorn',
'potato', 'powdery', 'pungent', 'radish', 'raspberry', 'ripe', 'roasted',
'rose', 'rummy', 'sandalwood', 'savory', 'sharp', 'smoky', 'soapy',
'solvent', 'sour', 'spicy', 'strawberry', 'sulfurous', 'sweaty', 'sweet',
'tea', 'terpenic', 'tobacco', 'tomato', 'tropical', 'vanilla', 'vegetable',
'vetiver', 'violet', 'warm', 'waxy', 'weedy', 'winey', 'woody'
]
print("No of tasks: ", len(TASKS))

No of tasks:  138


In [64]:
import deepchem as dc
from openpom.feat.graph_featurizer import GraphFeaturizer, GraphConvConstants
from openpom.utils.data_utils import get_class_imbalance_ratio

In [None]:
from deepchem.feat import DMPNNFeaturizer
import deepchem as dc


featurizer = DMPNNFeaturizer()

smiles_field = 'nonStereoSMILES'


loader = dc.data.CSVLoader(
    tasks=TASKS,  
    feature_field=smiles_field,
    featurizer=featurizer
)


In [67]:

dataset = loader.create_dataset(inputs=["DMPNNpruned_without hydrogen_curated_GS_LF_merged_4812_QM_cleaned.csv"])
n_tasks = len(dataset.tasks)
dataset.X
len(dataset.X)


4812

In [68]:
data = np.load("DMPNNpruned_graph_data_cleaned.npz",allow_pickle=True)
mtx_list = data['mtx']
qm_graphs = []
for mtx in mtx_list:
    qm_graphs.append(get_graphs(mtx))
QM_X = np.asarray(qm_graphs)
QM_X
len(QM_X)


4812

In [69]:
X = np.asarray([i for i in zip(dataset.X, QM_X)])
new_dataset = dc.data.NumpyDataset(X=X, y=dataset.y, n_tasks=138, ids=dataset.ids)

In [70]:
new_dataset

<NumpyDataset X.shape: (4812, 2), y.shape: (4812, 138), w.shape: (4812, 1), task_names: [  0   1   2 ... 135 136 137]>

In [74]:
new_dataset

<NumpyDataset X.shape: (4812, 2), y.shape: (4812, 138), w.shape: (4812, 1), task_names: [  0   1   2 ... 135 136 137]>

In [None]:


randomstratifiedsplitter = dc.splits.RandomStratifiedSplitter()
train_dataset, test_dataset, valid_dataset = randomstratifiedsplitter.train_valid_test_split(new_dataset, frac_train = 0.8, frac_valid = 0.1, frac_test = 0.1, seed = 1)

In [76]:
print(len(train_dataset))
print(len(valid_dataset))
print(len(test_dataset))

3839
485
488


In [77]:
dc.data.DiskDataset.from_numpy(X=train_dataset.X,
                    y= train_dataset.y,
                    w= train_dataset.w,
                    ids=train_dataset.ids,
                    tasks=TASKS,
                    data_dir="./train_split")

<DiskDataset X.shape: (3839, 2), y.shape: (3839, 138), w.shape: (3839, 1), task_names: ['alcoholic' 'aldehydic' 'alliaceous' ... 'weedy' 'winey' 'woody']>

In [78]:
dc.data.DiskDataset.from_numpy(X=valid_dataset.X,
                    y= valid_dataset.y,
                    w= valid_dataset.w,
                    ids=valid_dataset.ids,
                    tasks=TASKS,
                    data_dir="./valid_split")

<DiskDataset X.shape: (485, 2), y.shape: (485, 138), w.shape: (485, 1), ids: ['CCC(=O)C(=O)O' 'O=C(O)CCCCC(=O)O' 'O=C(O)c1ccccc1' ...
 'CC(C)C(NC(=O)CCC(N)C(=O)O)C(=O)NCC(=O)O' 'CCC1SC(CC(C)C)=NC1C'
 'O=c1[nH]cnc2c1ncn2C1OC(COP(=O)(O)O)C(O)C1O'], task_names: ['alcoholic' 'aldehydic' 'alliaceous' ... 'weedy' 'winey' 'woody']>

In [79]:
dc.data.DiskDataset.from_numpy(X=test_dataset.X,
                    y= test_dataset.y,
                    w= test_dataset.w,
                    ids=test_dataset.ids,
                    tasks=TASKS,
                    data_dir="./test_split")

<DiskDataset X.shape: (488, 2), y.shape: (488, 138), w.shape: (488, 1), ids: ['CC(O)CN' 'O=Cc1ccccc1' 'CC(=O)C(C)=O' ... 'CCCCCCCC(C)=NO'
 'CCCC=CC1OCC(O)CO1' 'CC(C)C1CCC2CC(=O)CC(C)C2(C)C1'], task_names: ['alcoholic' 'aldehydic' 'alliaceous' ... 'weedy' 'winey' 'woody']>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Union, Optional, Callable, Dict

from deepchem.models.losses import Loss, L2Loss
from deepchem.models.torch_models.torch_model import TorchModel
from deepchem.models.optimizers import Optimizer, LearningRateSchedule

from openpom.layers.pom_ffn import CustomPositionwiseFeedForward
from openpom.utils.loss import CustomMultiLabelLoss
from openpom.utils.optimizer import get_optimizer

try:
    import dgl
    from dgl import DGLGraph
    from dgl.nn.pytorch import Set2Set
    from openpom.layers.pom_mpnn_gnn import CustomMPNNGNN
except (ImportError, ModuleNotFoundError):
    raise ImportError('This module requires dgl and dgllife')


class MPNNPOM(nn.Module):

    def __init__(self,
                 n_tasks: int,
                 node_out_feats: int = 64,
                 edge_hidden_feats: int = 128,
                 edge_out_feats: int = 64,
                 num_step_message_passing: int = 3,
                 mpnn_residual: bool = True,
                 message_aggregator_type: str = 'sum',
                 mode: str = 'classification',
                 number_atom_features: int = 133,
                 number_bond_features: int = 14,
                 n_classes: int = 1,
                 nfeat_name: str = 'x',
                 efeat_name: str = 'edge_attr',
                 readout_type: str = 'set2set',
                 num_step_set2set: int = 6,
                 num_layer_set2set: int = 3,
                 ffn_hidden_list: List = [300],
                 ffn_embeddings: int = 256,
                 ffn_activation: str = 'relu',
                 ffn_dropout_p: float = 0.0,
                 ffn_dropout_at_input_no_act: bool = True):

        if mode not in ['classification', 'regression']:
            raise ValueError(
                "mode must be either 'classification' or 'regression'")

        super(MPNNPOM, self).__init__()

        self.n_tasks: int = n_tasks
        self.mode: str = mode
        self.n_classes: int = n_classes
        self.nfeat_name: str = nfeat_name
        self.efeat_name: str = efeat_name
        self.readout_type: str = readout_type
        self.ffn_embeddings: int = ffn_embeddings
        self.ffn_activation: str = ffn_activation
        self.ffn_dropout_p: float = ffn_dropout_p

        if mode == 'classification':
            self.ffn_output: int = n_tasks * n_classes
        else:
            self.ffn_output = n_tasks

        self.mpnn: nn.Module = CustomMPNNGNN(
            node_in_feats=number_atom_features,
            node_out_feats=node_out_feats,
            edge_in_feats=number_bond_features,
            edge_hidden_feats=edge_hidden_feats,
            num_step_message_passing=num_step_message_passing,
            residual=mpnn_residual,
            message_aggregator_type=message_aggregator_type)
        
        QM_node_out_features = 50
        
        self.QM_mpnn: nn.Module = CustomMPNNGNN(
            node_in_feats=1,
            node_out_feats=QM_node_out_features,
            edge_in_feats=1,
            edge_hidden_feats=16,
            num_step_message_passing=2,
            residual=True,
            message_aggregator_type='sum')

        self.project_edge_feats: nn.Module = nn.Sequential(
            nn.Linear(number_bond_features, edge_out_feats), nn.ReLU())

        if self.readout_type == 'set2set':
            self.readout_set2set: nn.Module = Set2Set(
                input_dim=node_out_feats + edge_out_feats,
                n_iters=num_step_set2set,
                n_layers=num_layer_set2set)
            ffn_input: int = 2 * (node_out_feats + edge_out_feats)
        elif self.readout_type == 'global_sum_pooling':
            ffn_input = node_out_feats + edge_out_feats
        else:
            raise Exception("readout_type invalid")

        if ffn_embeddings is not None:
            d_hidden_list: List = ffn_hidden_list + [ffn_embeddings]

        self.ffn: nn.Module = CustomPositionwiseFeedForward(
            d_input=ffn_input + QM_node_out_features,
            
            d_hidden_list=d_hidden_list,
            d_output=self.ffn_output,
            activation=ffn_activation,
            dropout_p=ffn_dropout_p,
            dropout_at_input_no_act=ffn_dropout_at_input_no_act)

    def _readout(self, g: DGLGraph, node_encodings: torch.Tensor,
                 edge_feats: torch.Tensor) -> torch.Tensor:

        g.ndata['node_emb'] = node_encodings
        g.edata['edge_emb'] = self.project_edge_feats(edge_feats)

        def message_func(edges) -> Dict:
            """
            The message function to generate messages
            along the edges for DGLGraph.send_and_recv()
            """
            src_msg: torch.Tensor = torch.cat(
                (edges.src['node_emb'], edges.data['edge_emb']), dim=1)
            return {'src_msg': src_msg}

        def reduce_func(nodes) -> Dict:
            """
            The reduce function to aggregate the messages
            for DGLGraph.send_and_recv()
            """
            src_msg_sum: torch.Tensor = torch.sum(nodes.mailbox['src_msg'],
                                                  dim=1)
            return {'src_msg_sum': src_msg_sum}

        
        g.send_and_recv(g.edges(),
                        message_func=message_func,
                        reduce_func=reduce_func)

        if self.readout_type == 'set2set':
            batch_mol_hidden_states: torch.Tensor = self.readout_set2set(
                g, g.ndata['src_msg_sum'])
        elif self.readout_type == 'global_sum_pooling':
            batch_mol_hidden_states = dgl.sum_nodes(g, 'src_msg_sum')

        
        return batch_mol_hidden_states

    def forward(
        self, graphs: tuple[DGLGraph]
    ) -> Union[tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
        g, qm_g = graphs
        node_feats: torch.Tensor = g.ndata[self.nfeat_name]
        edge_feats: torch.Tensor = g.edata[self.efeat_name]

        qm_node_feats: torch.Tensor = qm_g.ndata[self.nfeat_name]
        qm_edge_feats: torch.Tensor = qm_g.edata[self.efeat_name]

        node_encodings: torch.Tensor = self.mpnn(g, node_feats, edge_feats)

        QM_encodings: torch.Tensor = self.QM_mpnn(qm_g, qm_node_feats, qm_edge_feats)
        qm_g.ndata['node_emb'] = QM_encodings
        molecular_QM_encodings = dgl.sum_nodes(qm_g, 'node_emb')
        molecular_encodings: torch.Tensor = self._readout(
            g, node_encodings, edge_feats)
        if self.readout_type == 'global_sum_pooling':
            molecular_encodings = F.softmax(molecular_encodings, dim=1)

        embeddings: torch.Tensor
        out: torch.Tensor
        embeddings, out = self.ffn(torch.concat((molecular_encodings, molecular_QM_encodings), dim=1))


        if self.mode == 'classification':
            if self.n_tasks == 1:
                logits: torch.Tensor = out.view(-1, self.n_classes)
            else:
                logits = out.view(-1, self.n_tasks, self.n_classes)
            proba: torch.Tensor = F.sigmoid(
                logits)  
            if self.n_classes == 1:
                proba = proba.squeeze(-1)  
            return proba, logits, embeddings
        else:
            return out


class MPNNPOMModel(TorchModel):

    def __init__(self,
                 n_tasks: int,
                 class_imbalance_ratio: Optional[List] = None,
                 loss_aggr_type: str = 'sum',
                 learning_rate: Union[float, LearningRateSchedule] = 0.001,
                 batch_size: int = 25,
                 node_out_feats: int = 64,
                 edge_hidden_feats: int = 128,
                 edge_out_feats: int = 64,
                 num_step_message_passing: int = 3,
                 mpnn_residual: bool = True,
                 message_aggregator_type: str = 'sum',
                 mode: str = 'regression',
                 number_atom_features: int = 133,
                 number_bond_features: int = 14,
                 n_classes: int = 1,
                 readout_type: str = 'set2set',
                 num_step_set2set: int = 6,
                 num_layer_set2set: int = 3,
                 ffn_hidden_list: List = [300],
                 ffn_embeddings: int = 256,
                 ffn_activation: str = 'relu',
                 ffn_dropout_p: float = 0.0,
                 ffn_dropout_at_input_no_act: bool = True,
                 weight_decay: float = 1e-5,
                 self_loop: bool = False,
                 optimizer_name: str = 'adam',
                 device_name: Optional[str] = None,
                 **kwargs):
        model: nn.Module = MPNNPOM(
            n_tasks=n_tasks,
            node_out_feats=node_out_feats,
            edge_hidden_feats=edge_hidden_feats,
            edge_out_feats=edge_out_feats,
            num_step_message_passing=num_step_message_passing,
            mpnn_residual=mpnn_residual,
            message_aggregator_type=message_aggregator_type,
            mode=mode,
            number_atom_features=number_atom_features,
            number_bond_features=number_bond_features,
            n_classes=n_classes,
            readout_type=readout_type,
            num_step_set2set=num_step_set2set,
            num_layer_set2set=num_layer_set2set,
            ffn_hidden_list=ffn_hidden_list,
            ffn_embeddings=ffn_embeddings,
            ffn_activation=ffn_activation,
            ffn_dropout_p=ffn_dropout_p,
            ffn_dropout_at_input_no_act=ffn_dropout_at_input_no_act)

        if class_imbalance_ratio and (len(class_imbalance_ratio) != n_tasks):
            raise Exception("size of class_imbalance_ratio \
                            should be equal to n_tasks")

        if mode == 'regression':
            loss: Loss = L2Loss()
            output_types: List = ['prediction']
        else:
            loss = CustomMultiLabelLoss(
                class_imbalance_ratio=class_imbalance_ratio,
                loss_aggr_type=loss_aggr_type,
                device=device_name)
            output_types = ['prediction', 'loss', 'embedding']

        optimizer: Optimizer = get_optimizer(optimizer_name)
        optimizer.learning_rate = learning_rate
        if device_name is not None:
            device: Optional[torch.device] = torch.device(device_name)
        else:
            device = None
        super(MPNNPOMModel, self).__init__(model,
                                           loss=loss,
                                           output_types=output_types,
                                           optimizer=optimizer,
                                           learning_rate=learning_rate,
                                           batch_size=batch_size,
                                           device=device,
                                           **kwargs)

        self.weight_decay: float = weight_decay
        self._self_loop: bool = self_loop
        self.regularization_loss: Callable = self._regularization_loss

    def _regularization_loss(self) -> torch.Tensor:
        l1_regularization: torch.Tensor = torch.tensor(0., requires_grad=True)
        l2_regularization: torch.Tensor = torch.tensor(0., requires_grad=True)
        for name, param in self.model.named_parameters():
            if 'bias' not in name:
                l1_regularization = l1_regularization + torch.norm(param, p=1)
                l2_regularization = l2_regularization + torch.norm(param, p=2)
        l1_norm: torch.Tensor = self.weight_decay * l1_regularization
        l2_norm: torch.Tensor = self.weight_decay * l2_regularization
        return l1_norm + l2_norm

    def _prepare_batch(
        self, batch: Tuple[List, List, List]
    ) -> Tuple[DGLGraph, List[torch.Tensor], List[torch.Tensor]]:
        inputs: List
        labels: List
        weights: List

        inputs, labels, weights = batch
        dgl_graphs1: List[DGLGraph] = [
            graphs[0].to_dgl_graph(self_loop=self._self_loop)
            for graphs in inputs[0]
        ]
        g1: DGLGraph = dgl.batch(dgl_graphs1).to(self.device)

        dgl_graphs2: List[DGLGraph] = [
            graphs[1].to_dgl_graph(self_loop=self._self_loop)
            for graphs in inputs[0]
        ]
        g2: DGLGraph = dgl.batch(dgl_graphs2).to(self.device)
        _, labels, weights = super(MPNNPOMModel, self)._prepare_batch(
            ([], labels, weights))
        return (g1, g2), labels, weights


In [81]:
type(train_dataset)
import os
import deepchem as dc

def convert_to_disk_dataset(train_dataset, output_dir):
    """
    Converts a deepchem.data.datasets.NumpyDataset to a deepchem.data.datasets.DiskDataset.
    
    Parameters:
    train_dataset (deepchem.data.datasets.NumpyDataset): The input NumpyDataset to convert.
    output_dir (str): The directory to save the DiskDataset.
    
    Returns:
    deepchem.data.datasets.DiskDataset: The converted DiskDataset.
    """
    os.makedirs(output_dir, exist_ok=True)
    
    X, y, w, ids = train_dataset.X, train_dataset.y, train_dataset.w, train_dataset.ids
    
    disk_dataset = dc.data.DiskDataset.from_numpy(X, y, w, ids, data_dir=output_dir)
    
    return disk_dataset
output_dir = './train_dataset'
train_dataset = convert_to_disk_dataset(train_dataset, output_dir)

In [82]:
train_ratios = get_class_imbalance_ratio(train_dataset)
assert len(train_ratios) == n_tasks

In [None]:
from deepchem.models.optimizers import ExponentialDecay

learning_rate = ExponentialDecay(initial_rate=0.005390333, decay_rate=0.777099289, decay_steps=764, staircase=True)


In [84]:
model = MPNNPOMModel(n_tasks = n_tasks,
                            batch_size=25,
                            learning_rate=learning_rate,
                            class_imbalance_ratio = train_ratios,
                            loss_aggr_type = 'sum',
                            node_out_feats = 25,
                            edge_hidden_feats = 75,
                            edge_out_feats = 25,
                            num_step_message_passing = 5,
                            mpnn_residual = True,
                            message_aggregator_type = 'max',
                            mode = 'classification',
                            number_atom_features = 133,
                            number_bond_features = 14,
                            n_classes = 1,
                            readout_type = 'set2set',
                            num_step_set2set = 3,
                            num_layer_set2set = 2,
                            ffn_hidden_list= [512, 512],
                            ffn_embeddings = 256,
                            ffn_activation = 'relu',
                            ffn_dropout_p = 0.154493949,
                            ffn_dropout_at_input_no_act = False,
                            weight_decay = 1.39e-6,
                            self_loop = False,
                            optimizer_name = 'adam',
                            log_frequency = 32,
                            model_dir = './DMPNN+QM_epoch_40/experiments',
                            device_name='cpu')

In [85]:
nb_epoch = 40

In [86]:
metric = dc.metrics.Metric(dc.metrics.roc_auc_score)

In [None]:

for epoch in range(1, nb_epoch+1):
    loss = model.fit(
            train_dataset,
            nb_epoch=1,
            max_checkpoints_to_keep=1,
            deterministic=False,
            restore=epoch>1)
    train_scores = model.evaluate(train_dataset, [metric])['roc_auc_score']
    valid_scores = model.evaluate(valid_dataset, [metric])['roc_auc_score']
    print(f"epoch {epoch}/{nb_epoch} ; loss = {loss}; train_scores = {train_scores}; valid_scores = {valid_scores}")
model.save_checkpoint()


epoch 1/40 ; loss = 2.900954319880559; train_scores = 0.7539474303461571; valid_scores = 0.7362672047370089


  data = torch.load(checkpoint, map_location=self.device)


epoch 2/40 ; loss = 2.698422431945801; train_scores = 0.7954052594302162; valid_scores = 0.7713126430354335


  data = torch.load(checkpoint, map_location=self.device)


epoch 3/40 ; loss = 2.739802769252232; train_scores = 0.8057719726774482; valid_scores = 0.7813898567049644


  data = torch.load(checkpoint, map_location=self.device)


epoch 4/40 ; loss = 2.669133186340332; train_scores = 0.8199852193633254; valid_scores = 0.7943622148007361


  data = torch.load(checkpoint, map_location=self.device)


epoch 5/40 ; loss = 2.2223150730133057; train_scores = 0.8423599059667038; valid_scores = 0.8097824009741043


  data = torch.load(checkpoint, map_location=self.device)


epoch 6/40 ; loss = 2.5236546652657643; train_scores = 0.854976187336123; valid_scores = 0.8252225462694773


  data = torch.load(checkpoint, map_location=self.device)


epoch 7/40 ; loss = 2.502341877330433; train_scores = 0.8599872506218045; valid_scores = 0.8286931100939381


  data = torch.load(checkpoint, map_location=self.device)


epoch 8/40 ; loss = 2.4916465282440186; train_scores = 0.8500847189676805; valid_scores = 0.8165651166154924


  data = torch.load(checkpoint, map_location=self.device)


epoch 9/40 ; loss = 2.551502227783203; train_scores = 0.865592860490926; valid_scores = 0.8332162255092019


  data = torch.load(checkpoint, map_location=self.device)


epoch 10/40 ; loss = 2.471067428588867; train_scores = 0.8649580642992507; valid_scores = 0.8314217443568149


  data = torch.load(checkpoint, map_location=self.device)


epoch 11/40 ; loss = 2.5040690104166665; train_scores = 0.8739620264857665; valid_scores = 0.838806890817829


  data = torch.load(checkpoint, map_location=self.device)


epoch 12/40 ; loss = 2.5616025924682617; train_scores = 0.869773756113923; valid_scores = 0.8381347661838139


  data = torch.load(checkpoint, map_location=self.device)


epoch 13/40 ; loss = 2.524509006076389; train_scores = 0.871300609836201; valid_scores = 0.8399960839668092


  data = torch.load(checkpoint, map_location=self.device)


epoch 14/40 ; loss = 2.4775209426879883; train_scores = 0.8747747463346388; valid_scores = 0.8419622073419


  data = torch.load(checkpoint, map_location=self.device)


epoch 15/40 ; loss = 2.433234532674154; train_scores = 0.8822471426108863; valid_scores = 0.8483204463891296


  data = torch.load(checkpoint, map_location=self.device)


epoch 16/40 ; loss = 2.4840872287750244; train_scores = 0.8868204418766342; valid_scores = 0.8512749486728954


  data = torch.load(checkpoint, map_location=self.device)


epoch 17/40 ; loss = 2.339239267202524; train_scores = 0.8856300339365939; valid_scores = 0.8490919518850322


  data = torch.load(checkpoint, map_location=self.device)


epoch 18/40 ; loss = 2.4693275451660157; train_scores = 0.8824105266361124; valid_scores = 0.8524532952773201


  data = torch.load(checkpoint, map_location=self.device)


epoch 19/40 ; loss = 2.405521665300642; train_scores = 0.8885968274412459; valid_scores = 0.8505647465205631


  data = torch.load(checkpoint, map_location=self.device)


epoch 20/40 ; loss = 2.4441003799438477; train_scores = 0.8928678952252203; valid_scores = 0.8538003927340209


  data = torch.load(checkpoint, map_location=self.device)


epoch 21/40 ; loss = 1.921987771987915; train_scores = 0.8891639314753493; valid_scores = 0.8554485368692248


  data = torch.load(checkpoint, map_location=self.device)


epoch 22/40 ; loss = 2.39697265625; train_scores = 0.8963000367469661; valid_scores = 0.856494254808292


  data = torch.load(checkpoint, map_location=self.device)


epoch 23/40 ; loss = 2.3600113608620386; train_scores = 0.8951220261775317; valid_scores = 0.8483574838524403


  data = torch.load(checkpoint, map_location=self.device)


epoch 24/40 ; loss = 2.292901039123535; train_scores = 0.8909912843816331; valid_scores = 0.851210211267118


  data = torch.load(checkpoint, map_location=self.device)


epoch 25/40 ; loss = 2.2383358001708986; train_scores = 0.8998161228657832; valid_scores = 0.8574083113884006


  data = torch.load(checkpoint, map_location=self.device)


epoch 26/40 ; loss = 1.958880066871643; train_scores = 0.9006303346752984; valid_scores = 0.8559992206499963


  data = torch.load(checkpoint, map_location=self.device)


epoch 27/40 ; loss = 2.3872111002604166; train_scores = 0.9051033684557085; valid_scores = 0.8594088279180916


  data = torch.load(checkpoint, map_location=self.device)


epoch 28/40 ; loss = 2.2936836878458657; train_scores = 0.9070418247820836; valid_scores = 0.8596510055128336


  data = torch.load(checkpoint, map_location=self.device)


epoch 29/40 ; loss = 2.3477174970838757; train_scores = 0.906792878474337; valid_scores = 0.8576269004854371


  data = torch.load(checkpoint, map_location=self.device)


epoch 30/40 ; loss = 2.3321873346964517; train_scores = 0.9077010721257328; valid_scores = 0.858845779333635


  data = torch.load(checkpoint, map_location=self.device)


epoch 31/40 ; loss = 2.0959439277648926; train_scores = 0.912403857166389; valid_scores = 0.8615374396743033


  data = torch.load(checkpoint, map_location=self.device)


epoch 32/40 ; loss = 2.3055334091186523; train_scores = 0.9129542531961137; valid_scores = 0.8611560788816967


  data = torch.load(checkpoint, map_location=self.device)


epoch 33/40 ; loss = 2.271359370304988; train_scores = 0.9102134421260287; valid_scores = 0.8623254051777638


  data = torch.load(checkpoint, map_location=self.device)


epoch 34/40 ; loss = 2.23126277923584; train_scores = 0.9137537161908621; valid_scores = 0.8626819814771939


  data = torch.load(checkpoint, map_location=self.device)


epoch 35/40 ; loss = 2.247070857456752; train_scores = 0.9125365367136867; valid_scores = 0.8615759971192855


  data = torch.load(checkpoint, map_location=self.device)


epoch 36/40 ; loss = 2.158216953277588; train_scores = 0.9179655239835488; valid_scores = 0.8619955926035753


  data = torch.load(checkpoint, map_location=self.device)


epoch 37/40 ; loss = 1.6928961277008057; train_scores = 0.9136740794580505; valid_scores = 0.8614429507229995


  data = torch.load(checkpoint, map_location=self.device)


epoch 38/40 ; loss = 2.269094467163086; train_scores = 0.9193677780366782; valid_scores = 0.8640925070397791


  data = torch.load(checkpoint, map_location=self.device)


epoch 39/40 ; loss = 2.28258601101962; train_scores = 0.9176144207082821; valid_scores = 0.86330174036341


  data = torch.load(checkpoint, map_location=self.device)


epoch 40/40 ; loss = 2.2318592071533203; train_scores = 0.9183468169740971; valid_scores = 0.863525506039406


In [None]:
model.restore(f"./DMPNN+QM_epoch_40/experiments/checkpoint2.pt")

test_scores = model.evaluate(test_dataset, [metric])['roc_auc_score']

print("test_score: ", test_scores)

  data = torch.load(checkpoint, map_location=self.device)


test_score:  0.8591745991307085
