In [1]:
import tempfile
import shutil
import os
import unittest
import pytest
import logging
import warnings
import torch
from deeprankcore.trainer import Trainer
from deeprankcore.dataset import HDF5DataSet
from deeprankcore.ginet import GINet
from deeprankcore.foutnet import FoutNet
from deeprankcore.naive_gnn import NaiveNetwork
from deeprankcore.sGAT import sGAT
from deeprankcore.models.metrics import (
    OutputExporter,
    TensorboardBinaryClassificationExporter,
    ScatterPlotExporter
)
from deeprankcore.domain.features import groups, edgefeats
from deeprankcore.domain.features import nodefeats as Nfeat
from deeprankcore.domain import targettypes as targets

_log = logging.getLogger(__name__)

default_node_features = [Nfeat.RESTYPE, Nfeat.POLARITY, Nfeat.BSA, Nfeat.RESDEPTH, Nfeat.HSE, Nfeat.INFOCONTENT, Nfeat.PSSM]


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# new

def _model_base_test( # pylint: disable=too-many-arguments, too-many-locals
    train_hdf5_path,
    val_hdf5_path,
    test_hdf5_path,
    model_class,
    node_features,
    edge_features,
    task,
    target,
    transform_sigmoid,
    clustering_method,
    use_cuda = False
):

    dataset_train = HDF5DataSet(
        train_hdf5_path,
        clustering_method=clustering_method)

    if val_hdf5_path is not None:
        dataset_val = HDF5DataSet(
            val_hdf5_path,
            clustering_method=clustering_method)
    else:
        dataset_val = None

    if test_hdf5_path is not None:
        dataset_test = HDF5DataSet(
            test_hdf5_path,
            clustering_method=clustering_method)
    else:
        dataset_test = None

    trainer = Trainer(
        model_class,
        dataset_train,
        dataset_val,
        dataset_test,
        node_features=node_features,
        edge_features=edge_features,
        target=target,
        task=task,
        batch_size=64,
        transform_sigmoid=transform_sigmoid,
    )

    trainer.train(nepoch=10, validate=True)

    trainer.save_model("test.pth.tar")

    Trainer(
        dataset_train,
        dataset_val,
        dataset_test,
        model_class,
        pretrained_model="test.pth.tar")


_model_base_test(
    "./data/hdf5/1ATN_ppi.hdf5",
    "./data/hdf5/1ATN_ppi.hdf5",
    "./data/hdf5/1ATN_ppi.hdf5",
    GINet,
    default_node_features,
    [edgefeats.DISTANCE],
    targets.REGRESS,
    targets.IRMSD,
    True,
    "mcl",
)
    

   ['./data/hdf5/1ATN_ppi.hdf5'] dataset                 : 100%|██████████| 1/1 [00:00<00:00, 244.41it/s, mol=1ATN_ppi.hdf5]
   ['./data/hdf5/1ATN_ppi.hdf5'] dataset                 : 100%|██████████| 1/1 [00:00<00:00, 321.16it/s, mol=1ATN_ppi.hdf5]
   ['./data/hdf5/1ATN_ppi.hdf5'] dataset                 : 100%|██████████| 1/1 [00:00<00:00, 296.63it/s, mol=1ATN_ppi.hdf5]


100%|██████████| 4/4 [00:00<00:00, 10.52it/s]
100%|██████████| 4/4 [00:00<00:00, 13.20it/s]
100%|██████████| 4/4 [00:00<00:00,  9.83it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (17554x5 and 1x1)