In [23]:
from rdkit import Chem
import torch
from torch_geometric.data import Dataset, Data, DataLoader
import numpy as np
import os
import pickle
from rdkit.Chem.rdchem import HybridizationType

In [24]:
class QM9GraphDataset(Dataset):

    def __init__(self, root, filename, test=False,transform=None, pre_transform=None, pre_filter=None):
        self.filename = filename
        self.test = test
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return self.filename

    @property
    def processed_file_names(self):
        self.mols = Chem.SDMolSupplier(self.raw_paths[0])
        if self.test:
            return [f'data_test_{i}' for i in range(len(self.mols))]
        else:
            return [f'data_{i}.pt' for i in range(len(self.mols))]

    def download(self):
        # Download to `self.raw_dir`.
        pass

    def process(self):
        self.mols = Chem.SDMolSupplier(self.raw_paths[0])
        for idx, mol in enumerate(self.mols):
            # Get node features
            node_feats = self._get_node_features(mol)
            # Get edge features
            edge_feats = self._get_edge_features(mol)
            # Get adjacency info
            edge_index = self._get_adjacency_info(mol)
            # Get labels info
            label = self._get_labels(mol)
            # create data object
            data = Data(x=node_feats, 
                        edge_index=edge_index,
                        edge_attr=edge_feats,
                        y=label,
                        )
            if self.test:
                torch.save(data, os.path.join(self.processed_dir, \
                f'data_test_{idx}.pt'))
            else:
                torch.save(data, os.path.join(self.processed_dir, \
                f'data_{idx}.pt'))
        
    def _get_node_features(self, mol):
        all_node_feats = []
        # 我们包含更多类型，需要改这里，只留下H C N O， F 改为other
        types = {'H': [1,0,0,0,0], 'C': [0,1,0,0,0], 'N': [0,0,1,0,0], 'O': [0,0,0,1,0]}
        for atom in mol.GetAtoms():
            node_feats = []
            node_feats.extend(types.get(atom.GetSymbol(), [0,0,0,0,1]))
            node_feats.append(atom.GetAtomicNum())
            node_feats.append(atom.GetIsAromatic())
            sp = []
            sp2 = []
            sp3 = []
            hybridization = atom.GetHybridization()
            sp.append(1 if hybridization == HybridizationType.SP else 0)
            sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
            sp3.append(1 if hybridization == HybridizationType.SP3 else 0)
            node_feats.extend(sp)
            node_feats.extend(sp2)
            node_feats.extend(sp3)
            node_feats.append(atom.GetTotalNumHs())
            # Append node features to matrix
            all_node_feats.append(node_feats)

        all_node_feats = np.asarray(all_node_feats)
        return torch.tensor(all_node_feats, dtype=torch.float)

    def _get_edge_features(self, mol):
        """ 
        This will return a matrix / 2d array of the shape
        [Number of edges, Edge Feature size]
        """
        all_edge_feats = []

        for bond in mol.GetBonds():
            edge_feats = []
            # Feature 1: Bond type (as double)
            edge_feats.append(bond.GetBondTypeAsDouble())
            # Feature 2: Rings
            edge_feats.append(bond.IsInRing())
            # Append node features to matrix (twice, per direction)
            all_edge_feats += [edge_feats, edge_feats]

        all_edge_feats = np.asarray(all_edge_feats)
        return torch.tensor(all_edge_feats, dtype=torch.float)

    def _get_adjacency_info(self, mol):
        """
        We could also use rdmolops.GetAdjacencyMatrix(mol)
        but we want to be sure that the order of the indices
        matches the order of the edge features
        """
        edge_indices = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_indices += [[i, j], [j, i]]

        edge_indices = torch.tensor(edge_indices)
        edge_indices = edge_indices.t().to(torch.long).view(2, -1)
        return edge_indices

    def _get_labels(self, mol):
        _y = []
        som = ['PRIMARY_SOM_1A2', 'PRIMARY_SOM_2A6','PRIMARY_SOM_2B6','PRIMARY_SOM_2C8','PRIMARY_SOM_2C9','PRIMARY_SOM_2C19','PRIMARY_SOM_2D6','PRIMARY_SOM_2E1','PRIMARY_SOM_3A4',
            'SECONDARY_SOM_1A2', 'SECONDARY_SOM_2A6','SECONDARY_SOM_2B6','SECONDARY_SOM_2C8','SECONDARY_SOM_2C9','SECONDARY_SOM_2C19','SECONDARY_SOM_2D6','SECONDARY_SOM_2E1','SECONDARY_SOM_3A4',
            'TERTIARY_SOM_1A2', 'TERTIARY_SOM_2A6','TERTIARY_SOM_2B6','TERTIARY_SOM_2C8','TERTIARY_SOM_2C9','TERTIARY_SOM_2C19','TERTIARY_SOM_2D6','TERTIARY_SOM_2E1','TERTIARY_SOM_3A4'
            ]
        result = []
        for k in som:
            try:
                _res = mol.GetProp(k)
                if ' ' in _res:
                    res = _res.split(' ')
                    for s in res:
                        result.append(int(s))
                else:
                    result.append(int(_res))
            except:
                pass

        for data in result:
            _y.append(data)
        _y = list(set(_y))

        y = np.zeros(len(mol.GetAtoms()))
        for i in _y:
            y[i-1] = 1
        # int64 or float32? past is float32
        return torch.tensor(y, dtype=torch.int64)

    def len(self):
        return len(self.mols)

    def get(self, idx):
        if self.test:
            data = torch.load(os.path.join(self.processed_dir, f'data_test_{idx}.pt'))
        else:
            data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt'))
        return data

In [25]:
train_dataset = QM9GraphDataset('../../Dataset/qm9featuredataset/', 'train.sdf')
test_set = QM9GraphDataset('../../Dataset/qm9featuredataset/', 'test.sdf', test=True)

Processing...
Done!


In [26]:
train_dataset[0]

Data(x=[35, 11], edge_index=[2, 74], edge_attr=[74, 2], y=[35])

In [27]:
import random
import os
import numpy as np
np.set_printoptions(threshold=np.inf)
def seed_torch(seed=42):
	random.seed(seed)
	os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化，使得实验可复现
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed(seed)
	torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.

In [28]:
import torch
from torch_geometric.data import Dataset, Data, DataLoader
import numpy as np
import os
from sklearn.metrics import confusion_matrix, roc_auc_score, accuracy_score, precision_score, f1_score, recall_score,jaccard_score
import networkx as nx
import torch.nn as nn
from torch_geometric.nn import SAGEConv, global_mean_pool
from torch.utils.data import random_split
import pickle
from torch_geometric.utils import from_networkx

In [29]:
class Model(nn.Module):
    def __init__(self, args):
        super(Model, self).__init__()
        num_classses = 1

        conv_hidden = args['conv_hidden']
        cls_hidden = args['cls_hidden']
        self.n_layers = args['n_layers']

        self.conv_layers = nn.ModuleList([])

        self.conv1 = SAGEConv(11, conv_hidden)

        for i in range(self.n_layers):
            self.conv_layers.append(
                SAGEConv(conv_hidden, conv_hidden)
            )

        self.linear1 = nn.Linear(conv_hidden, cls_hidden)
        self.linear2 = nn.Linear(cls_hidden, num_classses)
        self.relu = nn.ReLU()
        self.drop1 = nn.Dropout(p=0.5)

    def forward(self, mol):

        res = self.conv1(mol.x, mol.edge_index)
        res = self.relu(res)
        for i in range(self.n_layers):
            res = self.relu(self.conv_layers[i](res, mol.edge_index))

        # res = global_mean_pool(res, mol.batch)
        res = self.linear1(res)
        res = self.relu(res)
        res = self.drop1(res)
        res = self.linear2(res)

        return res

In [30]:
training_set, validation_set  = random_split(train_dataset, [int(len(train_dataset) * 0.8), len(train_dataset) - int(len(train_dataset) * 0.8)], generator=torch.Generator().manual_seed(42))
batch_size = 1
tune_train_loader = DataLoader(training_set, batch_size, shuffle=True)
tune_val_loader = DataLoader(validation_set, batch_size, shuffle=True)
tune_test_loader = DataLoader(test_set, batch_size, shuffle=False)



In [31]:
# evaluation
def top2(output, label):
    preds = torch.nn.functional.softmax(output)
    # print('top2 after softmax preds', preds)
    preds = preds[:, 1]
    # print('get single col', preds)
    _, indices = torch.topk(preds, 2)
    # print('indices', indices)
    pos_index = []
    for i in range(label.shape[0]):
        if label[i] == 1:
            pos_index.append(i)
    # print(pos_index)      
    for li in pos_index:
        if li in indices:
            return True
    return False

def MCC(output, label):
    tn,fp,fn,tp=confusion_matrix(label, output).ravel()
    # print(f"TN: {tn}, FP: {fp}, FN: {fn}, TP: {tp}")
    up = (tp * tn) - (fp * fn)
    down = ((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) ** 0.5
    return up / down

def metrics(output, label):
    tn,fp,fn,tp=confusion_matrix(label, output).ravel()
    up = (tp * tn) - (fp * fn)
    down = ((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) ** 0.5
    mcc = up / down
    selectivity = tn / (tn + fp)
    recall = tp / (tp + fn)
    g_mean = (selectivity * recall) ** 0.5
    balancedAccuracy = (recall + selectivity) / 2
    return mcc, selectivity, recall, g_mean, balancedAccuracy

In [32]:
def train(args, model, device, training_set, optimizer, criterion, epoch):
    model.train()
    total_loss = 0
    all_pred = []
    all_pred_raw = []
    all_labels = []
    top2n = 0
    for mol in training_set:
        mol = mol.to(device)
        mol.edge_attr = mol.edge_attr.to(torch.float32)
        mol.x = mol.x.to(torch.float32)
        
        target = mol.y
        optimizer.zero_grad()
        output= model(mol)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        # tracking
        top2n += top2(output, target)
        all_pred.append(np.argmax(output.cpu().detach().numpy(), axis=1))
        all_pred_raw.append(torch.nn.functional.softmax(output)[:, 1].cpu().detach().numpy())
        all_labels.append(target.cpu().detach().numpy())

    all_pred = np.concatenate(all_pred).ravel()
    all_pred_raw = np.concatenate(all_pred_raw).ravel()
    all_labels = np.concatenate(all_labels).ravel()

    mcc = MCC(all_pred, all_labels)
    # tr_writer.add_scalar('Ave Loss', total_loss / len(training_set), epoch)
    # tr_writer.add_scalar('ACC', accuracy_score(all_labels, all_pred), epoch)
    # tr_writer.add_scalar('Top2', top2n / len(training_set), epoch)
    # tr_writer.add_scalar('AUC', roc_auc_score(all_labels, all_pred_raw), epoch)
    # tr_writer.add_scalar('MCC', mcc, epoch)
    print(f'Train Epoch: {epoch}, Ave Loss: {total_loss / len(training_set)} ACC: {accuracy_score(all_labels, all_pred)} Top2: {top2n / len(training_set)} AUC: {roc_auc_score(all_labels, all_pred_raw)} MCC: {mcc}')

In [33]:
def val(args, model, device, val_set, optimizer, criterion, epoch):
    model.eval()
    total_loss = 0
    all_pred = []
    all_pred_raw = []
    all_labels = []
    top2n = 0
    for mol in val_set:
        mol = mol.to(device)
        mol.edge_attr = mol.edge_attr.to(torch.float32)
        mol.x = mol.x.to(torch.float32)
        target = mol.y
        optimizer.zero_grad()
        output = model(mol)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        # tracking
        top2n += top2(output, target)
        all_pred.append(np.argmax(output.cpu().detach().numpy(), axis=1))
        all_pred_raw.append(torch.nn.functional.softmax(output)[:, 1].cpu().detach().numpy())
        all_labels.append(target.cpu().detach().numpy())
    all_pred = np.concatenate(all_pred).ravel()
    all_pred_raw = np.concatenate(all_pred_raw).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    mcc = MCC(all_pred, all_labels)
    # val_writer.add_scalar('Ave Loss', total_loss / len(val_set), epoch)
    # val_writer.add_scalar('ACC', accuracy_score(all_labels, all_pred), epoch)
    # val_writer.add_scalar('Top2', top2n / len(val_set), epoch)
    # val_writer.add_scalar('AUC', roc_auc_score(all_labels, all_pred_raw), epoch)
    # val_writer.add_scalar('MCC', mcc, epoch)
    print(f'Val Epoch: {epoch}, Ave Loss: {total_loss / len(val_set)} ACC: {accuracy_score(all_labels, all_pred)} Top2: {top2n / len(val_set)} AUC: {roc_auc_score(all_labels, all_pred_raw)} MCC: {mcc}')
    return top2n / len(val_set)

In [44]:
def test(model, device, test_set):
    model.eval()
    all_pred = []
    all_pred_raw = []
    all_labels = []
    top2n = 0
    with torch.no_grad():
        for mol in test_set:
            mol = mol.to(device)
            mol.x = mol.x.to(torch.float32)
            mol.edge_attr = mol.edge_attr.to(torch.float32)
            target = mol.y
            output = model(mol)
            # squeeze
            output = torch.squeeze(output)
            print(f'outpu is  {output}')
            # tracking
            top2n += top2(output, target)
            all_pred.append(np.argmax(output.cpu().detach().numpy(), axis=1))
            all_pred_raw.append(torch.nn.functional.softmax(output)[:, 1].cpu().detach().numpy())
            all_labels.append(target.cpu().detach().numpy())
    all_pred = np.concatenate(all_pred).ravel()
    all_pred_raw = np.concatenate(all_pred_raw).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    mcc, selectivity, recall, g_mean, balancedAcc = metrics(all_pred, all_labels)
    print(all_pred)
    print(all_pred_raw)
    print(f'ACC: {accuracy_score(all_labels, all_pred)} \
        Top2: {top2n / len(test_set)} \
        AUC: {roc_auc_score(all_labels, all_pred_raw)}\
        MCC: {mcc} selectivity {selectivity} recall {recall} \
        g_mean {g_mean} balanced acc {balancedAcc} f1score {f1_score(all_labels, all_pred)} \
        precision score {precision_score(all_labels, all_pred)} jaccard score {jaccard_score(all_labels, all_pred)}')

In [45]:
def finetune(args):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    seed_torch(args['seed'])
    # 创建模型，加载预训练参数
    tune_model = Model(args).to(device)
    tune_model.load_state_dict(torch.load('./model/model'))

    in_fea = tune_model.linear1.in_features
    out_fea = tune_model.linear1.out_features
    # 替换原来的卷积层，参数也变成新的，需要从新训练
    tune_model.linear1 = nn.Linear(in_fea, out_fea, bias=True).to(device)
    tune_model.linear2 = nn.Linear(out_fea, 2, bias=True).to(device)
    # freeze model 只冻结了所有的卷积层
    for name, para in tune_model.named_parameters():
        if "linear" not in name:
            para.requires_grad_(False)
    print(tune_model)

    weights = torch.tensor([1, 7], dtype=torch.float32).to(device)
    loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
    optimizer = torch.optim.SGD(filter(lambda p:p.requires_grad, tune_model.parameters()), lr=args['lr'])
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    max_top2 = 0
    for epoch in range(1, args['epoch']):
        train(args, tune_model, device, tune_train_loader, optimizer, loss_fn, epoch)
        top2acc = val(args, tune_model, device, tune_val_loader, optimizer, loss_fn, epoch)
        scheduler.step()
        if top2acc > max_top2:
            max_top2 = top2acc
            print('Saving model (epoch = {:4d}, top2acc = {:.4f})'
                .format(epoch, max_top2))
            torch.save(tune_model.state_dict(), args['save_path'])
    test(tune_model, "cuda", tune_test_loader)

In [37]:
args = {
    'lr': 0.0001,
    'epoch': 200,
    'seed': 42,
    'save_path': './model/tune',
    'conv_hidden':1024,
    'cls_hidden':512,
    'n_layers':3,
    'batch_size':128
}

In [None]:
finetune(args)

In [46]:
tune_model = Model(args).to('cuda')
in_fea = tune_model.linear1.in_features
out_fea = tune_model.linear1.out_features
tune_model.linear1 = nn.Linear(in_fea, out_fea, bias=True).to('cuda')
tune_model.linear2 = nn.Linear(out_fea, 2, bias=True).to('cuda')
tune_model.load_state_dict(torch.load('./model/tune'))

<All keys matched successfully>

In [47]:
torch.set_printoptions(threshold=np.inf)
test(tune_model, "cuda", tune_test_loader)

outpu is  tensor([[-1.5487e-02, -3.2846e-01],
        [-1.6872e-02, -2.8103e-01],
        [-1.6872e-02, -2.8103e-01],
        [-4.9071e-02, -3.1584e-01],
        [-1.3949e-02, -2.7962e-01],
        [-1.3949e-02, -2.7962e-01],
        [-1.3949e-02, -2.7962e-01],
        [-1.3949e-02, -2.7962e-01],
        [-5.1122e-05, -4.1075e-01],
        [-5.1122e-05, -4.1075e-01],
        [-5.1122e-05, -4.1075e-01],
        [-1.0292e-02, -2.8427e-01],
        [-1.0292e-02, -2.8427e-01],
        [-1.0292e-02, -2.8427e-01],
        [-1.0292e-02, -2.8427e-01],
        [-1.5162e-02, -2.7124e-01],
        [-1.5162e-02, -2.7124e-01],
        [-9.6260e-03, -2.7527e-01],
        [-9.6260e-03, -2.7527e-01],
        [-2.5218e-02, -2.4701e-01],
        [-2.5218e-02, -2.4701e-01]], device='cuda:0')
outpu is  tensor([[-0.0104, -0.3038],
        [-0.0134, -0.2894],
        [-0.0110, -0.2956],
        [-0.0158, -0.2880],
        [-0.0098, -0.2974],
        [-0.0154, -0.2863],
        [-0.0142, -0.3108],
        [-

  preds = torch.nn.functional.softmax(output)
  all_pred_raw.append(torch.nn.functional.softmax(output)[:, 1].cpu().detach().numpy())


outpu is  tensor([[-0.0292, -0.2255],
        [-0.0257, -0.2324],
        [-0.0199, -0.2577],
        [-0.0306, -0.2282],
        [-0.0289, -0.2329],
        [-0.0300, -0.2202],
        [-0.0214, -0.2680],
        [-0.0309, -0.2354],
        [-0.0334, -0.2396],
        [-0.0277, -0.2395],
        [-0.0202, -0.2541],
        [-0.0244, -0.2493],
        [-0.0280, -0.2330],
        [-0.0125, -0.2885],
        [-0.0184, -0.2835],
        [-0.0189, -0.2572],
        [-0.0286, -0.2320],
        [-0.0324, -0.2366],
        [-0.0278, -0.2337],
        [-0.0274, -0.2464],
        [-0.0150, -0.2914],
        [-0.0195, -0.2706],
        [-0.0265, -0.2472],
        [-0.0248, -0.2467],
        [-0.0115, -0.2863],
        [-0.0190, -0.2756]], device='cuda:0')
outpu is  tensor([[-0.0154, -0.2680],
        [-0.0152, -0.2800],
        [-0.0152, -0.2800],
        [-0.0292, -0.2343],
        [-0.0286, -0.2207],
        [-0.0116, -0.2822],
        [-0.0118, -0.2809],
        [-0.0116, -0.2822],
        [-

  mcc = up / down
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
ACC: 0.902730899048788         Top2: 0.4411764705882353         AUC: 0.6885801628540853  
      MCC: nan     selectivity 1.0       recall 0.0         g_mean 0.0     balanced acc 0.5 f1score 0.0  
             precision score 0.0     jaccard score 0.0