In [1]:
import numpy as np
import pandas as pd
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from dataset import MyDataset
from model import MyModel
import metrics

def dataset_collate(batch):
    label = []
    dp = []
    maccs = []
    ecfp4 = []
    for i in batch:
        a,e,f,g = i
        label.append(a)
        dp.append(e)
        maccs.append(f)
        ecfp4.append(g)
    
    return (np.array(label,dtype=np.int32),
            np.array(dp),
            np.array(maccs),
            np.array(ecfp4))


def Pre_test(flag, model: nn.Module, test_loader, loss_function, device, show):
    model.eval()
    test_loss = 0
    outputs = []
    targets = []
    with torch.no_grad():
        for idx,  (labels, DP, MACCS, ECFP4) in tqdm(enumerate(test_loader), disable=not show, total=len(test_loader)):
            
            y = torch.tensor(labels).to(device)
            DP = torch.tensor(DP).to(device)
            MACCS = torch.tensor(MACCS).to(device)
            ECFP4 = torch.tensor(ECFP4).to(device)

            y_hat = model(DP, MACCS, ECFP4)
            y_hat_temp = y_hat
            test_loss += loss_function(y_hat_temp.view(-1), y.view(-1).float()).item()
            outputs.append(y_hat.cpu().numpy().reshape(-1))
            targets.append(y.cpu().numpy().reshape(-1))

    targets = np.concatenate(targets).reshape(-1)
    outputs = np.concatenate(outputs).reshape(-1)
    b = torch.sigmoid(torch.tensor(outputs))
    pre_label = []
    for i in b:
        if i <= 0.4:
            pre_label.append(0)
        else:
            pre_label.append(1)

    outputs = np.array(pre_label)

    test_loss /= len(test_loader.dataset)

    evaluation = {
        'loss': test_loss,
        'ACC': metrics.get_ACC(targets, outputs),
        'get_F1': metrics.get_F1(targets, outputs),
        'AUROC': metrics.get_ROC(targets, np.array(b)),
    }

    return evaluation,outputs

SHOW_PROCESS_BAR = True
data_path = '../data_cleaned.xlsx'
seed = 31861

device = torch.device("cuda:0")

dp_len = 194
maccs_len = 167
ecfp4_len = 2048
num_out_dim = 32

embedding_dim = 100
batch_size = 128
n_epoch = 20
interrupt = None
save_best_epoch = 6

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

torch.manual_seed(seed)
np.random.seed(seed)


assert 0<save_best_epoch<n_epoch

model = MyModel(hidden_dim=128, dp_len=dp_len, maccs_len=maccs_len, ecfp4_len=ecfp4_len, num_out_dim=num_out_dim)
model = model.to(device)

data_loaders = {phase_name:
                    DataLoader(MyDataset(phase_name, data_path, dp_len, maccs_len, ecfp4_len),
                            batch_size=batch_size,
                            pin_memory=True,
                            shuffle=False,
                            collate_fn=dataset_collate)
                for phase_name in ['test']}
optimizer = optim.AdamW(model.parameters(), weight_decay=1e-2)

scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=4e-4, epochs=n_epoch,
                                        steps_per_epoch=len(data_loaders['test']))

loss_function = nn.BCEWithLogitsLoss(reduction='sum')

model.load_state_dict(torch.load(r'.\best_model\best_model.pt'))
flag = 1
for _p in ['test']:
    performance,output = Pre_test(flag, model, data_loaders[_p], loss_function, device, SHOW_PROCESS_BAR)
    print(f'{_p}:')
    for k, v in performance.items():
        print(f'{k}: {v}\n')
    print()

pre_lables = pd.DataFrame(output)
pre_lables.columns = ['Pre_lables']
pre_lables.to_csv(r'./pre_labels.csv', index=False)

100%|██████████| 5/5 [00:00<00:00,  6.02it/s]

test:
loss: 0.37197920337684187

ACC: 0.896421845574388

get_F1: 0.9026548672566371

AUROC: 0.9619673551928356







* In this study, as described in the paper, we deleted the feature of all zeros in physical-chemical descriptors. In addition to removing these features, you can also modify the operations that dataset.py perform when reading the data if you replace them with other datasets.
* The deleted features are: SMR_VSA8，SlogP_VSA9，fr_SH，fr_amidine，fr_azo，fr_barbitur，fr_benzodiazepine，fr_diazo，fr_dihydropyridine，fr_isocyan，fr_lactam，fr_phos_acid，fr_phos_ester，fr_prisulfonamd