In [23]:
from rdkit import Chem
from rdkit.Chem.rdchem import HybridizationType, ChiralType
import torch
from torch_geometric.data import Dataset, Data, DataLoader
import numpy as np
import os
import networkx as nx
import torch.nn as nn
from torch_geometric.nn import SAGEConv
from torch.utils.data import random_split
import pickle
from torch_geometric.utils import from_networkx
from sklearn.metrics import confusion_matrix, roc_auc_score, accuracy_score, precision_score, f1_score, recall_score, jaccard_score

In [24]:
from torch.utils.tensorboard import SummaryWriter
tr_writer = SummaryWriter("./tensorboard/atombased/train")
val_writer = SummaryWriter("./tensorboard/atombased/val")

In [25]:
class Atombased(Dataset):

    def __init__(self, root, filename, test=False,transform=None, pre_transform=None, pre_filter=None):
        self.filename = filename
        self.test = test
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return self.filename

    @property
    def processed_file_names(self):
        self.mols = Chem.SDMolSupplier(self.raw_paths[0])
        if self.test:
            return [f'data_test_{i}' for i in range(len(self.mols))]
        else:
            return [f'data_{i}.pt' for i in range(len(self.mols))]

    def download(self):
        pass

    def process(self):
        self.mols = Chem.SDMolSupplier(self.raw_paths[0])
        for idx, mol in enumerate(self.mols):
            # Get node features
            node_feats = self._get_node_features(mol)
            # Get edge features
            edge_feats = self._get_edge_features(mol)
            # Get adjacency info
            edge_index = self._get_adjacency_info(mol)
            # Get labels info
            label = self._get_labels(mol)
            # create data object
            data = Data(x=node_feats, 
                        edge_index=edge_index,
                        edge_attr=edge_feats,
                        y=label,
                        )
            if self.test:
                torch.save(data, os.path.join(self.processed_dir, \
                f'data_test_{idx}.pt'))
            else:
                torch.save(data, os.path.join(self.processed_dir, \
                f'data_{idx}.pt'))
        
    def _get_node_features(self, mol):
        all_node_feats = []

        identity = {
            'C':[1,0,0,0,0,0,0,0,0,0],
            'N':[0,1,0,0,0,0,0,0,0,0],
            'O':[0,0,1,0,0,0,0,0,0,0],
            'F':[0,0,0,1,0,0,0,0,0,0],
            'P':[0,0,0,0,1,0,0,0,0,0],
            'S':[0,0,0,0,0,1,0,0,0,0],
            'Cl':[0,0,0,0,0,0,1,0,0,0],
            'Br':[0,0,0,0,0,0,0,1,0,0],
            'I':[0,0,0,0,0,0,0,0,1,0],
            'other':[0,0,0,0,0,0,0,0,0,1],
        }
        for atom in mol.GetAtoms():
            node_feats = []
            # atom number
            idx = atom.GetIdx()
            # atom type one-hot 10
            node_feats.extend(identity.get(atom.GetSymbol(),[0,0,0,0,0,0,0,0,0,1]))
            # implicit valence
            node_feats.append(atom.GetImplicitValence())
            # formal charge
            node_feats.append(atom.GetFormalCharge())
            # radical electrons
            node_feats.append(atom.GetNumRadicalElectrons())
            
            # aromatic 0 or 1
            if atom.GetIsAromatic():
                node_feats.append(1)
            else:
                node_feats.append(0)

            # chirality
            chirality = atom.GetChiralTag()
            if chirality == ChiralType.CHI_TETRAHEDRAL_CCW: temp = [1, 0, 0, 0]
            if chirality == ChiralType.CHI_TETRAHEDRAL_CW: temp = [0, 1, 0, 0]
            if chirality == ChiralType.CHI_OTHER: temp = [0, 0, 1, 0]
            if chirality == ChiralType.CHI_UNSPECIFIED: temp = [0, 0, 0, 1]
            node_feats.extend(temp)
            # hybridization
            hybridization = atom.GetHybridization()
            if hybridization == HybridizationType.S: tmp = [1, 0, 0, 0, 0, 0, 0, 0]
            if hybridization == HybridizationType.SP: tmp = [0, 1, 0, 0, 0, 0, 0, 0]
            if hybridization == HybridizationType.SP2: tmp = [0, 0, 1, 0, 0, 0, 0, 0]
            if hybridization == HybridizationType.SP3: tmp = [0, 0, 0, 1, 0, 0, 0, 0]
            if hybridization == HybridizationType.SP3D: tmp = [0, 0, 0, 0, 1, 0, 0, 0]
            if hybridization == HybridizationType.SP3D2: tmp = [0, 0, 0, 0, 0, 1, 0, 0]
            if hybridization == HybridizationType.OTHER: tmp = [0, 0, 0, 0, 0, 0, 1, 0]
            if hybridization == HybridizationType.UNSPECIFIED: tmp = [0, 0, 0, 0, 0, 0, 0, 1]
            node_feats.extend(tmp)
            # Append node features to matrix
            all_node_feats.append(node_feats)

        all_node_feats = np.asarray(all_node_feats)
        return torch.tensor(all_node_feats, dtype=torch.float)

    def _get_edge_features(self, mol):
        """ 
        This will return a matrix / 2d array of the shape
        [Number of edges, Edge Feature size]
        """
        all_edge_feats = []

        for bond in mol.GetBonds():
            edge_feats = []
            # Feature 1: Bond type (as double)
            edge_feats.append(bond.GetBondTypeAsDouble())
            # Feature 2: Rings
            edge_feats.append(bond.IsInRing())
            # Append node features to matrix (twice, per direction)
            all_edge_feats += [edge_feats, edge_feats]

        all_edge_feats = np.asarray(all_edge_feats)
        return torch.tensor(all_edge_feats, dtype=torch.float)

    def _get_adjacency_info(self, mol):
        """
        We could also use rdmolops.GetAdjacencyMatrix(mol)
        but we want to be sure that the order of the indices
        matches the order of the edge features
        """
        edge_indices = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_indices += [[i, j], [j, i]]

        edge_indices = torch.tensor(edge_indices)
        edge_indices = edge_indices.t().to(torch.long).view(2, -1)
        return edge_indices

    def _get_labels(self, mol):
        _y = []
        som = ['PRIMARY_SOM_1A2', 'PRIMARY_SOM_2A6','PRIMARY_SOM_2B6','PRIMARY_SOM_2C8','PRIMARY_SOM_2C9','PRIMARY_SOM_2C19','PRIMARY_SOM_2D6','PRIMARY_SOM_2E1','PRIMARY_SOM_3A4',
            'SECONDARY_SOM_1A2', 'SECONDARY_SOM_2A6','SECONDARY_SOM_2B6','SECONDARY_SOM_2C8','SECONDARY_SOM_2C9','SECONDARY_SOM_2C19','SECONDARY_SOM_2D6','SECONDARY_SOM_2E1','SECONDARY_SOM_3A4',
            'TERTIARY_SOM_1A2', 'TERTIARY_SOM_2A6','TERTIARY_SOM_2B6','TERTIARY_SOM_2C8','TERTIARY_SOM_2C9','TERTIARY_SOM_2C19','TERTIARY_SOM_2D6','TERTIARY_SOM_2E1','TERTIARY_SOM_3A4'
            ]
        result = []
        for k in som:
            try:
                _res = mol.GetProp(k)
                if ' ' in _res:
                    res = _res.split(' ')
                    for s in res:
                        result.append(int(s))
                else:
                    result.append(int(_res))
            except:
                pass

        for data in result:
            _y.append(data)
        _y = list(set(_y))

        y = np.zeros(len(mol.GetAtoms()))
        for i in _y:
            y[i-1] = 1
        # int64 or float32? past is float32
        return torch.tensor(y, dtype=torch.int64)

    def len(self):
        return len(self.mols)

    def get(self, idx):
        if self.test:
            data = torch.load(os.path.join(self.processed_dir, f'data_test_{idx}.pt'))
        else:
            data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt'))
        return data

In [26]:
train_dataset = Atombased('../Dataset/atom_based/', 'train.sdf')
test_dataset = Atombased('../Dataset/atom_based/', 'test.sdf', test=True)

Processing...
Done!


In [27]:
training_set, validation_set  = random_split(train_dataset, [435, 109], generator=torch.Generator().manual_seed(42))
batch_size = 1
train_loader = DataLoader(training_set, batch_size, shuffle=True)
val_loader = DataLoader(validation_set, batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size, shuffle=False)



In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool
from torch.nn.modules.batchnorm import _BatchNorm
import torch_geometric.nn as gnn
from torch import Tensor
from collections import OrderedDict

In [29]:
class NodeLevelBatchNorm(_BatchNorm):
    r"""
    Applies Batch Normalization over a batch of graph data.
    Shape:
        - Input: [batch_nodes_dim, node_feature_dim]
        - Output: [batch_nodes_dim, node_feature_dim]
    batch_nodes_dim: all nodes of a batch graph
    """

    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super(NodeLevelBatchNorm, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

    def _check_input_dim(self, input):
        if input.dim() != 2:
            raise ValueError('expected 2D input (got {}D input)'
                             .format(input.dim()))

    def forward(self, input):
        self._check_input_dim(input)
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum
        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked = self.num_batches_tracked + 1
                if self.momentum is None:
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:
                    exponential_average_factor = self.momentum

        return torch.functional.F.batch_norm(
            input, self.running_mean, self.running_var, self.weight, self.bias,
            self.training or not self.track_running_stats,
            exponential_average_factor, self.eps)

    def extra_repr(self):
        return 'num_features={num_features}, eps={eps}, ' \
               'affine={affine}'.format(**self.__dict__)

In [30]:
class GraphConvBn(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = gnn.GraphConv(in_channels, out_channels)
        self.norm = NodeLevelBatchNorm(out_channels)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        data.x = F.relu(self.norm(self.conv(x, edge_index)))

        return data

In [31]:
class DenseLayer(nn.Module):
    def __init__(self, num_input_features, growth_rate=128, bn_size=4):
        super().__init__()
        self.conv1 = GraphConvBn(num_input_features, int(growth_rate * bn_size))
        self.conv2 = GraphConvBn(int(growth_rate * bn_size), growth_rate)

    def bn_function(self, data):
        concated_features = torch.cat(data.x, 1)
        data.x = concated_features

        data = self.conv1(data)

        return data
    
    def forward(self, data):
        if isinstance(data.x, Tensor):
            data.x = [data.x]

        data = self.bn_function(data)
        data = self.conv2(data)

        return data

In [32]:
class DenseBlock(nn.ModuleDict):
    def __init__(self, num_layers, num_input_features, growth_rate=128, bn_size=4):
        super().__init__()
        for i in range(num_layers):
            layer = DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size)
            self.add_module('layer%d' % (i + 1), layer)


    def forward(self, data):
        features = [data.x]
        for name, layer in self.items():
            data = layer(data)
            features.append(data.x)
            data.x = features

        data.x = torch.cat(data.x, 1)

        return data

In [33]:
class GraphDenseNet(nn.Module):
    def __init__(self, num_input_features, out_dim, growth_rate=128, block_config = (3, 3, 3, 3), bn_sizes=[2, 3, 4, 4]):
        super().__init__()
        self.features = nn.Sequential(OrderedDict([('conv0', GraphConvBn(num_input_features, 32))]))
        num_input_features = 32

        for i, num_layers in enumerate(block_config):
            block = DenseBlock(
                num_layers, num_input_features, growth_rate=growth_rate, bn_size=bn_sizes[i]
            )
            self.features.add_module('block%d' % (i+1), block)
            num_input_features += int(num_layers * growth_rate)

            trans = GraphConvBn(num_input_features, num_input_features // 2)
            self.features.add_module("transition%d" % (i+1), trans)
            num_input_features = num_input_features // 2

        self.classifer = nn.Linear(num_input_features, out_dim)

    def forward(self, data):
        data = self.features(data)
        # x = gnn.global_mean_pool(data.x, data.batch)
        x = self.classifer(data.x)

        return x


In [56]:
class MGraphDTA(nn.Module):
    def __init__(self, filter_num=128, out_dim=2):
        super().__init__()
        self.ligand_encoder = GraphDenseNet(num_input_features=26, out_dim=filter_num*3, block_config=[1,1,1,1,1,1], bn_sizes=[1,1,1,1,1,1])

        self.classifier = nn.Sequential(
            nn.Linear(filter_num * 3, 1024),
            nn.ReLU(),
            nn.Dropout(0.4),
            # nn.Linear(1024, 1024),
            # nn.ReLU(),
            # nn.Dropout(0.1),
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, out_dim)
        )

    def forward(self, data):
        target = data.y
        ligand_x = self.ligand_encoder(data)
        # print(f'model output after gnn {ligand_x}\n')
        x = self.classifier(ligand_x)
        # print(f'model output after linear {x}\n')
        return x

In [57]:
import random
import os
import numpy as np
np.set_printoptions(threshold=np.inf)
def seed_torch(seed=42):
	random.seed(seed)
	os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化，使得实验可复现
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed(seed)
	torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.

In [58]:
# evaluation
def top2(output, label):
    sf = nn.Softmax(dim=1)
    preds = sf(output)
    preds = preds[:, 1]
    _, indices = torch.topk(preds, 2)
    pos_index = []
    for i in range(label.shape[0]):
        if label[i] == 1:
            pos_index.append(i)  
    for li in pos_index:
        if li in indices:
            return True
    return False

def MCC(output, label):
    tn,fp,fn,tp=confusion_matrix(label, output).ravel()
    up = (tp * tn) - (fp * fn)
    down = ((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) ** 0.5
    return up / down

def metrics(output, label):
    tn,fp,fn,tp=confusion_matrix(label, output).ravel()
    up = (tp * tn) - (fp * fn)
    down = ((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) ** 0.5
    mcc = up / down
    selectivity = tn / (tn + fp)
    recall = tp / (tp + fn)
    g_mean = (selectivity * recall) ** 0.5
    balancedAccuracy = (recall + selectivity) / 2
    return mcc, selectivity, recall, g_mean, balancedAccuracy

In [59]:
def train(args, model, device, training_set, optimizer, criterion, epoch):
    model.train()
    sf = nn.Softmax(dim=1)
    total_loss = 0
    all_pred = []
    all_pred_raw = []
    all_labels = []
    top2n = 0
    for mol in training_set:
        mol = mol.to(device)
        mol.x = mol.x.to(torch.float32)
        target = mol.y
        optimizer.zero_grad()
        output= model(mol)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        # tracking
        top2n += top2(output, target)
        all_pred.append(np.argmax(output.cpu().detach().numpy(), axis=1))
        all_pred_raw.append(sf(output)[:, 1].cpu().detach().numpy())
        all_labels.append(target.cpu().detach().numpy())

    all_pred = np.concatenate(all_pred).ravel()
    all_pred_raw = np.concatenate(all_pred_raw).ravel()
    all_labels = np.concatenate(all_labels).ravel()

    mcc = MCC(all_pred, all_labels)
    tr_writer.add_scalar('Ave Loss', total_loss / len(training_set), epoch)
    tr_writer.add_scalar('ACC', accuracy_score(all_labels, all_pred), epoch)
    tr_writer.add_scalar('Top2', top2n / len(training_set), epoch)
    try:
        auc = roc_auc_score(all_labels, all_pred_raw)
        tr_writer.add_scalar('AUC', auc, epoch)
    except:
        auc = -1
    tr_writer.add_scalar('MCC', mcc, epoch)
    print(f'Train Epoch: {epoch}, Ave Loss: {total_loss / len(training_set)} ACC: {accuracy_score(all_labels, all_pred)} Top2: {top2n / len(training_set)} AUC: {auc} MCC: {mcc}')

In [60]:
def val(args, model, device, val_set, optimizer, criterion, epoch):
    model.eval()
    sf = nn.Softmax(dim=1)
    total_loss = 0
    all_pred = []
    all_pred_raw = []
    all_labels = []
    top2n = 0
    for mol in val_set:
        mol = mol.to(device)
        mol.x = mol.x.to(torch.float32)
        target = mol.y
        optimizer.zero_grad()
        output = model(mol)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        # tracking
        top2n += top2(output, target)
        all_pred.append(np.argmax(output.cpu().detach().numpy(), axis=1))
        all_pred_raw.append(sf(output)[:, 1].cpu().detach().numpy())
        all_labels.append(target.cpu().detach().numpy())
    all_pred = np.concatenate(all_pred).ravel()
    all_pred_raw = np.concatenate(all_pred_raw).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    mcc = MCC(all_pred, all_labels)
    val_writer.add_scalar('Ave Loss', total_loss / len(val_set), epoch)
    val_writer.add_scalar('ACC', accuracy_score(all_labels, all_pred), epoch)
    val_writer.add_scalar('Top2', top2n / len(val_set), epoch)
    try:
        auc = roc_auc_score(all_labels, all_pred_raw)
        tr_writer.add_scalar('AUC', auc, epoch)
    except:
        auc = -1
    val_writer.add_scalar('MCC', mcc, epoch)
    print(f'Val Epoch: {epoch}, Ave Loss: {total_loss / len(val_set)} ACC: {accuracy_score(all_labels, all_pred)} Top2: {top2n / len(val_set)} AUC: {auc} MCC: {mcc}')
    return top2n / len(val_set)

In [61]:
def main(args):
    seed_torch(args['seed'])
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.manual_seed(args['seed'])

    model = MGraphDTA(filter_num=32, out_dim=2).to(device)
    print(model)
    weights = torch.tensor([1, args['pos_weight']], dtype=torch.float32).to(device)
    loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
    optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])
    # optimizer = torch.optim.SGD(model.parameters(), lr=args['lr'])
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    max_top2 = 0
    for epoch in range(1, args['epoch'] + 1):
        train(args, model, device, train_loader, optimizer, loss_fn, epoch)
        top2acc = val(args, model, device, val_loader, optimizer, loss_fn, epoch)
        scheduler.step()
        if top2acc > max_top2:
            max_top2 = top2acc
            print('Saving model (epoch = {:4d}, top2acc = {:.4f})'
                .format(epoch, max_top2))
            torch.save(model.state_dict(), args['save_path'])

In [62]:
args = {
    'lr': 0.0005,
    'epoch': 35,
    'seed': 42,
    'save_path': './model/atom_based_block6_bn1',
    'pos_weight': 3,
    'conv_hidden': 1024, 
    'cls_hidden': 1024,
    'n_layers': 3
}

In [64]:
# 每个block 333 就会梯度消失
main(args)

MGraphDTA(
  (ligand_encoder): GraphDenseNet(
    (features): Sequential(
      (conv0): GraphConvBn(
        (conv): GraphConv(26, 32)
        (norm): NodeLevelBatchNorm(num_features=32, eps=1e-05, affine=True)
      )
      (block1): DenseBlock(
        (layer1): DenseLayer(
          (conv1): GraphConvBn(
            (conv): GraphConv(32, 128)
            (norm): NodeLevelBatchNorm(num_features=128, eps=1e-05, affine=True)
          )
          (conv2): GraphConvBn(
            (conv): GraphConv(128, 128)
            (norm): NodeLevelBatchNorm(num_features=128, eps=1e-05, affine=True)
          )
        )
      )
      (transition1): GraphConvBn(
        (conv): GraphConv(160, 80)
        (norm): NodeLevelBatchNorm(num_features=80, eps=1e-05, affine=True)
      )
      (block2): DenseBlock(
        (layer1): DenseLayer(
          (conv1): GraphConvBn(
            (conv): GraphConv(80, 128)
            (norm): NodeLevelBatchNorm(num_features=128, eps=1e-05, affine=True)
          )


  return up / down


Val Epoch: 1, Ave Loss: nan ACC: 0.8689862542955327 Top2: 0.1651376146788991 AUC: -1 MCC: nan
Saving model (epoch =    1, top2acc = 0.1651)


  return up / down


Train Epoch: 2, Ave Loss: nan ACC: 0.8925841768229704 Top2: 0.15632183908045977 AUC: -1 MCC: nan


  return up / down


Val Epoch: 2, Ave Loss: nan ACC: 0.8689862542955327 Top2: 0.1651376146788991 AUC: -1 MCC: nan


  return up / down


Train Epoch: 3, Ave Loss: nan ACC: 0.8925841768229704 Top2: 0.15632183908045977 AUC: -1 MCC: nan


  return up / down


Val Epoch: 3, Ave Loss: nan ACC: 0.8689862542955327 Top2: 0.1651376146788991 AUC: -1 MCC: nan


  return up / down


Train Epoch: 4, Ave Loss: nan ACC: 0.8925841768229704 Top2: 0.15632183908045977 AUC: -1 MCC: nan


  return up / down


Val Epoch: 4, Ave Loss: nan ACC: 0.8689862542955327 Top2: 0.1651376146788991 AUC: -1 MCC: nan


  return up / down


Train Epoch: 5, Ave Loss: nan ACC: 0.8925841768229704 Top2: 0.15632183908045977 AUC: -1 MCC: nan


  return up / down


Val Epoch: 5, Ave Loss: nan ACC: 0.8689862542955327 Top2: 0.1651376146788991 AUC: -1 MCC: nan


  return up / down


Train Epoch: 6, Ave Loss: nan ACC: 0.8925841768229704 Top2: 0.15632183908045977 AUC: -1 MCC: nan


  return up / down


Val Epoch: 6, Ave Loss: nan ACC: 0.8689862542955327 Top2: 0.1651376146788991 AUC: -1 MCC: nan


  return up / down


Train Epoch: 7, Ave Loss: nan ACC: 0.8925841768229704 Top2: 0.15632183908045977 AUC: -1 MCC: nan


  return up / down


Val Epoch: 7, Ave Loss: nan ACC: 0.8689862542955327 Top2: 0.1651376146788991 AUC: -1 MCC: nan


  return up / down


Train Epoch: 8, Ave Loss: nan ACC: 0.8925841768229704 Top2: 0.15632183908045977 AUC: -1 MCC: nan


  return up / down


Val Epoch: 8, Ave Loss: nan ACC: 0.8689862542955327 Top2: 0.1651376146788991 AUC: -1 MCC: nan


  return up / down


Train Epoch: 9, Ave Loss: nan ACC: 0.8925841768229704 Top2: 0.15632183908045977 AUC: -1 MCC: nan


  return up / down


Val Epoch: 9, Ave Loss: nan ACC: 0.8689862542955327 Top2: 0.1651376146788991 AUC: -1 MCC: nan


  return up / down


Train Epoch: 10, Ave Loss: nan ACC: 0.8925841768229704 Top2: 0.15632183908045977 AUC: -1 MCC: nan


  return up / down


Val Epoch: 10, Ave Loss: nan ACC: 0.8689862542955327 Top2: 0.1651376146788991 AUC: -1 MCC: nan


  return up / down


Train Epoch: 11, Ave Loss: nan ACC: 0.8925841768229704 Top2: 0.15632183908045977 AUC: -1 MCC: nan


KeyboardInterrupt: 

In [53]:
def test(model, device, test_set):
    model.eval()
    sf = nn.Softmax(dim=1)
    all_pred = []
    all_pred_raw = []
    all_labels = []
    top2n = 0
    with torch.no_grad():
        for mol in test_set:
            mol = mol.to(device)
            mol.x = mol.x.to(torch.float32)
            mol.edge_attr = mol.edge_attr.to(torch.float32)
            target = mol.y
            output = model(mol)
            # squeeze
            output = torch.squeeze(output)
            # tracking
            top2n += top2(output, target)
            all_pred.append(np.argmax(output.cpu().detach().numpy(), axis=1))
            all_pred_raw.append(sf(output)[:, 1].cpu().detach().numpy())
            all_labels.append(target.cpu().detach().numpy())
    all_pred = np.concatenate(all_pred).ravel()
    all_pred_raw = np.concatenate(all_pred_raw).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    mcc, selectivity, recall, g_mean, balancedAcc = metrics(all_pred, all_labels)
    print(f'ACC: {accuracy_score(all_labels, all_pred)} \
        Top2: {top2n / len(test_set)} \
        AUC: {roc_auc_score(all_labels, all_pred_raw)}\
        MCC: {mcc} selectivity {selectivity} recall {recall} \
        g_mean {g_mean} balanced acc {balancedAcc} f1score {f1_score(all_labels, all_pred)} \
        precision score {precision_score(all_labels, all_pred)} jaccard score {jaccard_score(all_labels, all_pred)}')

In [54]:
model = MGraphDTA(filter_num=32, out_dim=2).to("cuda")
model.load_state_dict(torch.load(args['save_path']))

<All keys matched successfully>

In [55]:
test(model, "cuda", test_loader)

ACC: 0.8824792881251918         Top2: 0.6911764705882353         AUC: 0.7910855938255269        MCC: 0.3663908148910571 selectivity 0.9282800815771584 recall 0.45741324921135645         g_mean 0.6516192203214934 balanced acc 0.6928466653942574 f1score 0.43090638930163444         precision score 0.40730337078651685 jaccard score 0.2746212121212121
