In [None]:
from rdkit import Chem
from rdkit.Chem.rdchem import HybridizationType, ChiralType
import torch
from torch_geometric.data import Dataset, Data, DataLoader
import numpy as np
import os
import networkx as nx
import torch.nn as nn
from torch_geometric.nn import SAGEConv
from torch.utils.data import random_split
import pickle
from torch_geometric.utils import from_networkx
from sklearn.metrics import confusion_matrix, roc_auc_score, accuracy_score, precision_score, f1_score, recall_score, jaccard_score

In [2]:
from torch.utils.tensorboard import SummaryWriter
tr_writer = SummaryWriter("./tensorboard/atombased/train")
val_writer = SummaryWriter("./tensorboard/atombased/val")

In [3]:
class Atombased(Dataset):

    def __init__(self, root, filename, test=False,transform=None, pre_transform=None, pre_filter=None):
        self.filename = filename
        self.test = test
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return self.filename

    @property
    def processed_file_names(self):
        self.mols = Chem.SDMolSupplier(self.raw_paths[0])
        if self.test:
            return [f'data_test_{i}' for i in range(len(self.mols))]
        else:
            return [f'data_{i}.pt' for i in range(len(self.mols))]

    def download(self):
        pass

    def process(self):
        self.mols = Chem.SDMolSupplier(self.raw_paths[0])
        for idx, mol in enumerate(self.mols):
            # Get node features
            node_feats = self._get_node_features(mol)
            # Get edge features
            edge_feats = self._get_edge_features(mol)
            # Get adjacency info
            edge_index = self._get_adjacency_info(mol)
            # Get labels info
            label = self._get_labels(mol)
            # create data object
            data = Data(x=node_feats, 
                        edge_index=edge_index,
                        edge_attr=edge_feats,
                        y=label,
                        )
            if self.test:
                torch.save(data, os.path.join(self.processed_dir, \
                f'data_test_{idx}.pt'))
            else:
                torch.save(data, os.path.join(self.processed_dir, \
                f'data_{idx}.pt'))
        
    def _get_node_features(self, mol):
        all_node_feats = []

        identity = {
            'C':[1,0,0,0,0,0,0,0,0,0],
            'N':[0,1,0,0,0,0,0,0,0,0],
            'O':[0,0,1,0,0,0,0,0,0,0],
            'F':[0,0,0,1,0,0,0,0,0,0],
            'P':[0,0,0,0,1,0,0,0,0,0],
            'S':[0,0,0,0,0,1,0,0,0,0],
            'Cl':[0,0,0,0,0,0,1,0,0,0],
            'Br':[0,0,0,0,0,0,0,1,0,0],
            'I':[0,0,0,0,0,0,0,0,1,0],
            'other':[0,0,0,0,0,0,0,0,0,1],
        }
        for atom in mol.GetAtoms():
            node_feats = []
            # atom number
            idx = atom.GetIdx()
            # atom type one-hot 10
            node_feats.extend(identity.get(atom.GetSymbol(),[0,0,0,0,0,0,0,0,0,1]))
            # implicit valence
            node_feats.append(atom.GetImplicitValence())
            # formal charge
            node_feats.append(atom.GetFormalCharge())
            # radical electrons
            node_feats.append(atom.GetNumRadicalElectrons())
            
            # aromatic 0 or 1
            if atom.GetIsAromatic():
                node_feats.append(1)
            else:
                node_feats.append(0)

            # chirality
            chirality = atom.GetChiralTag()
            if chirality == ChiralType.CHI_TETRAHEDRAL_CCW: temp = [1, 0, 0, 0]
            if chirality == ChiralType.CHI_TETRAHEDRAL_CW: temp = [0, 1, 0, 0]
            if chirality == ChiralType.CHI_OTHER: temp = [0, 0, 1, 0]
            if chirality == ChiralType.CHI_UNSPECIFIED: temp = [0, 0, 0, 1]
            node_feats.extend(temp)
            # hybridization
            hybridization = atom.GetHybridization()
            if hybridization == HybridizationType.S: tmp = [1, 0, 0, 0, 0, 0, 0, 0]
            if hybridization == HybridizationType.SP: tmp = [0, 1, 0, 0, 0, 0, 0, 0]
            if hybridization == HybridizationType.SP2: tmp = [0, 0, 1, 0, 0, 0, 0, 0]
            if hybridization == HybridizationType.SP3: tmp = [0, 0, 0, 1, 0, 0, 0, 0]
            if hybridization == HybridizationType.SP3D: tmp = [0, 0, 0, 0, 1, 0, 0, 0]
            if hybridization == HybridizationType.SP3D2: tmp = [0, 0, 0, 0, 0, 1, 0, 0]
            if hybridization == HybridizationType.OTHER: tmp = [0, 0, 0, 0, 0, 0, 1, 0]
            if hybridization == HybridizationType.UNSPECIFIED: tmp = [0, 0, 0, 0, 0, 0, 0, 1]
            node_feats.extend(tmp)
            # Append node features to matrix
            all_node_feats.append(node_feats)

        all_node_feats = np.asarray(all_node_feats)
        return torch.tensor(all_node_feats, dtype=torch.float)

    def _get_edge_features(self, mol):
        """ 
        This will return a matrix / 2d array of the shape
        [Number of edges, Edge Feature size]
        """
        all_edge_feats = []

        for bond in mol.GetBonds():
            edge_feats = []
            # Feature 1: Bond type (as double)
            edge_feats.append(bond.GetBondTypeAsDouble())
            # Feature 2: Rings
            edge_feats.append(bond.IsInRing())
            # Append node features to matrix (twice, per direction)
            all_edge_feats += [edge_feats, edge_feats]

        all_edge_feats = np.asarray(all_edge_feats)
        return torch.tensor(all_edge_feats, dtype=torch.float)

    def _get_adjacency_info(self, mol):
        """
        We could also use rdmolops.GetAdjacencyMatrix(mol)
        but we want to be sure that the order of the indices
        matches the order of the edge features
        """
        edge_indices = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_indices += [[i, j], [j, i]]

        edge_indices = torch.tensor(edge_indices)
        edge_indices = edge_indices.t().to(torch.long).view(2, -1)
        return edge_indices

    def _get_labels(self, mol):
        _y = []
        som = ['PRIMARY_SOM_1A2', 'PRIMARY_SOM_2A6','PRIMARY_SOM_2B6','PRIMARY_SOM_2C8','PRIMARY_SOM_2C9','PRIMARY_SOM_2C19','PRIMARY_SOM_2D6','PRIMARY_SOM_2E1','PRIMARY_SOM_3A4',
            'SECONDARY_SOM_1A2', 'SECONDARY_SOM_2A6','SECONDARY_SOM_2B6','SECONDARY_SOM_2C8','SECONDARY_SOM_2C9','SECONDARY_SOM_2C19','SECONDARY_SOM_2D6','SECONDARY_SOM_2E1','SECONDARY_SOM_3A4',
            'TERTIARY_SOM_1A2', 'TERTIARY_SOM_2A6','TERTIARY_SOM_2B6','TERTIARY_SOM_2C8','TERTIARY_SOM_2C9','TERTIARY_SOM_2C19','TERTIARY_SOM_2D6','TERTIARY_SOM_2E1','TERTIARY_SOM_3A4'
            ]
        result = []
        for k in som:
            try:
                _res = mol.GetProp(k)
                if ' ' in _res:
                    res = _res.split(' ')
                    for s in res:
                        result.append(int(s))
                else:
                    result.append(int(_res))
            except:
                pass

        for data in result:
            _y.append(data)
        _y = list(set(_y))

        y = np.zeros(len(mol.GetAtoms()))
        for i in _y:
            y[i-1] = 1
        # int64 or float32? past is float32
        return torch.tensor(y, dtype=torch.int64)

    def len(self):
        return len(self.mols)

    def get(self, idx):
        if self.test:
            data = torch.load(os.path.join(self.processed_dir, f'data_test_{idx}.pt'))
        else:
            data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt'))
        return data

In [4]:
train_dataset = Atombased('../Dataset/atom_based/', 'train.sdf')
test_dataset = Atombased('../Dataset/atom_based/', 'test.sdf', test=True)

Processing...
Done!


In [5]:
training_set, validation_set  = random_split(train_dataset, [435, 109], generator=torch.Generator().manual_seed(42))
batch_size = 1
train_loader = DataLoader(training_set, batch_size, shuffle=True)
val_loader = DataLoader(validation_set, batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size, shuffle=False)



In [8]:
class Model(nn.Module):
    def __init__(self, args):
        super(Model, self).__init__()
        num_classses = 2

        conv_hidden = args['conv_hidden']
        cls_hidden = args['cls_hidden']
        self.n_layers = args['n_layers']
        # cls_drop = ['cls_drop']

        self.conv_layers = nn.ModuleList([])

        self.conv1 = SAGEConv(26, conv_hidden)

        for i in range(self.n_layers):
            self.conv_layers.append(
                SAGEConv(conv_hidden, conv_hidden)
            )

        self.linear1 = nn.Linear(conv_hidden, cls_hidden)
        self.linear2 = nn.Linear(cls_hidden, num_classses)
        self.relu = nn.ReLU()
        self.drop1 = nn.Dropout(p=0.5)

    
    def forward(self, mol):

        res = self.conv1(mol.x, mol.edge_index)
        for i in range(self.n_layers):
            res = self.conv_layers[i](res, mol.edge_index)

        res = self.linear1(res)
        res = self.relu(res)
        res = self.drop1(res)
        res = self.linear2(res)

        return res

In [12]:
import random
import os
import numpy as np
np.set_printoptions(threshold=np.inf)
def seed_torch(seed=42):
	random.seed(seed)
	os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化，使得实验可复现
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed(seed)
	torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.

In [13]:
# evaluation
def top2(output, label):
    sf = nn.Softmax(dim=1)
    preds = sf(output)
    preds = preds[:, 1]
    _, indices = torch.topk(preds, 2)
    pos_index = []
    for i in range(label.shape[0]):
        if label[i] == 1:
            pos_index.append(i)  
    for li in pos_index:
        if li in indices:
            return True
    return False

def MCC(output, label):
    tn,fp,fn,tp=confusion_matrix(label, output).ravel()
    up = (tp * tn) - (fp * fn)
    down = ((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) ** 0.5
    return up / down

def metrics(output, label):
    tn,fp,fn,tp=confusion_matrix(label, output).ravel()
    up = (tp * tn) - (fp * fn)
    down = ((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) ** 0.5
    mcc = up / down
    selectivity = tn / (tn + fp)
    recall = tp / (tp + fn)
    g_mean = (selectivity * recall) ** 0.5
    balancedAccuracy = (recall + selectivity) / 2
    return mcc, selectivity, recall, g_mean, balancedAccuracy

In [10]:
def train(args, model, device, training_set, optimizer, criterion, epoch):
    model.train()
    sf = nn.Softmax(dim=1)
    total_loss = 0
    all_pred = []
    all_pred_raw = []
    all_labels = []
    top2n = 0
    for mol in training_set:
        mol = mol.to(device)
        mol.x = mol.x.to(torch.float32)
        target = mol.y
        optimizer.zero_grad()
        output= model(mol)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        # tracking
        top2n += top2(output, target)
        all_pred.append(np.argmax(output.cpu().detach().numpy(), axis=1))
        all_pred_raw.append(sf(output)[:, 1].cpu().detach().numpy())
        all_labels.append(target.cpu().detach().numpy())

    all_pred = np.concatenate(all_pred).ravel()
    all_pred_raw = np.concatenate(all_pred_raw).ravel()
    all_labels = np.concatenate(all_labels).ravel()

    mcc = MCC(all_pred, all_labels)
    tr_writer.add_scalar('Ave Loss', total_loss / len(training_set), epoch)
    tr_writer.add_scalar('ACC', accuracy_score(all_labels, all_pred), epoch)
    tr_writer.add_scalar('Top2', top2n / len(training_set), epoch)
    tr_writer.add_scalar('AUC', roc_auc_score(all_labels, all_pred_raw), epoch)
    tr_writer.add_scalar('MCC', mcc, epoch)
    print(f'Train Epoch: {epoch}, Ave Loss: {total_loss / len(training_set)} ACC: {accuracy_score(all_labels, all_pred)} Top2: {top2n / len(training_set)} AUC: {roc_auc_score(all_labels, all_pred_raw)} MCC: {mcc}')

In [11]:
def val(args, model, device, val_set, optimizer, criterion, epoch):
    model.eval()
    sf = nn.Softmax(dim=1)
    total_loss = 0
    all_pred = []
    all_pred_raw = []
    all_labels = []
    top2n = 0
    for mol in val_set:
        mol = mol.to(device)
        mol.x = mol.x.to(torch.float32)
        target = mol.y
        optimizer.zero_grad()
        output = model(mol)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        # tracking
        top2n += top2(output, target)
        all_pred.append(np.argmax(output.cpu().detach().numpy(), axis=1))
        all_pred_raw.append(sf(output)[:, 1].cpu().detach().numpy())
        all_labels.append(target.cpu().detach().numpy())
    all_pred = np.concatenate(all_pred).ravel()
    all_pred_raw = np.concatenate(all_pred_raw).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    mcc = MCC(all_pred, all_labels)
    val_writer.add_scalar('Ave Loss', total_loss / len(val_set), epoch)
    val_writer.add_scalar('ACC', accuracy_score(all_labels, all_pred), epoch)
    val_writer.add_scalar('Top2', top2n / len(val_set), epoch)
    val_writer.add_scalar('AUC', roc_auc_score(all_labels, all_pred_raw), epoch)
    val_writer.add_scalar('MCC', mcc, epoch)
    print(f'Val Epoch: {epoch}, Ave Loss: {total_loss / len(val_set)} ACC: {accuracy_score(all_labels, all_pred)} Top2: {top2n / len(val_set)} AUC: {roc_auc_score(all_labels, all_pred_raw)} MCC: {mcc}')
    return top2n / len(val_set)

In [14]:
def main(args):
    seed_torch(args['seed'])
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.manual_seed(args['seed'])

    model = Model(args).to(device)
    print(model)
    weights = torch.tensor([1, args['pos_weight']], dtype=torch.float32).to(device)
    loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
    optimizer = torch.optim.SGD(model.parameters(), lr=args['lr'])
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    max_top2 = 0
    for epoch in range(1, args['epoch'] + 1):
        train(args, model, device, train_loader, optimizer, loss_fn, epoch)
        top2acc = val(args, model, device, val_loader, optimizer, loss_fn, epoch)
        scheduler.step()
        if top2acc > max_top2:
            max_top2 = top2acc
            print('Saving model (epoch = {:4d}, top2acc = {:.4f})'
                .format(epoch, max_top2))
            torch.save(model.state_dict(), args['save_path'])

In [15]:
args = {
    'lr': 0.02,
    'epoch': 400,
    'seed': 42,
    'save_path': './model/atom_based',
    'pos_weight': 3,
    'conv_hidden': 1024, 
    'cls_hidden': 1024,
    'n_layers': 3
}

In [None]:
main(args)

In [22]:
def test(model, device, test_set):
    model.eval()
    sf = nn.Softmax(dim=1)
    all_pred = []
    all_pred_raw = []
    all_labels = []
    top2n = 0
    with torch.no_grad():
        for mol in test_set:
            mol = mol.to(device)
            mol.x = mol.x.to(torch.float32)
            mol.edge_attr = mol.edge_attr.to(torch.float32)
            target = mol.y
            output = model(mol)
            # squeeze
            output = torch.squeeze(output)
            # tracking
            top2n += top2(output, target)
            all_pred.append(np.argmax(output.cpu().detach().numpy(), axis=1))
            all_pred_raw.append(sf(output)[:, 1].cpu().detach().numpy())
            all_labels.append(target.cpu().detach().numpy())
    all_pred = np.concatenate(all_pred).ravel()
    all_pred_raw = np.concatenate(all_pred_raw).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    mcc, selectivity, recall, g_mean, balancedAcc = metrics(all_pred, all_labels)
    print(f'ACC: {accuracy_score(all_labels, all_pred)} \
        Top2: {top2n / len(test_set)} \
        AUC: {roc_auc_score(all_labels, all_pred_raw)}\
        MCC: {mcc} selectivity {selectivity} recall {recall} \
        g_mean {g_mean} balanced acc {balancedAcc} f1score {f1_score(all_labels, all_pred)} \
        precision score {precision_score(all_labels, all_pred)} jaccard score {jaccard_score(all_labels, all_pred)}')

In [18]:
model = Model(args).to("cuda")
model.load_state_dict(torch.load(args['save_path']))

<All keys matched successfully>

In [23]:
test(model, "cuda", test_loader)

  preds = torch.nn.functional.softmax(output)
  all_pred_raw.append(torch.nn.functional.softmax(output)[:, 1].cpu().detach().numpy())


ACC: 0.8910708806382326         Top2: 0.75         AUC: 0.9029727196889603        MCC: 0.49595583773285007 selectivity 0.9150237933378654 recall 0.668769716088328         g_mean 0.7822660688567729 balanced acc 0.7918967547130967 f1score 0.5442875481386392         precision score 0.4588744588744589 jaccard score 0.37389770723104054
