### Imports

In [1]:
import torch
from tqdm import tqdm
import numpy as np

from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')                                                                                                                                                       

from rdkit.Chem import Draw
from matplotlib import pyplot as plt

from sklearn.metrics import roc_auc_score as ras
from sklearn.metrics import mean_squared_error

### Auglichem imports

In [2]:
from auglichem.molecule import Compose, RandomAtomMask, RandomBondDelete, MotifRemoval
from auglichem.molecule.data import MoleculeDatasetWrapper
from auglichem.molecule.models import GCN, AttentiveFP, GINE, DeepGCN

### Set up dataset

In [3]:
# Create transformation
transform = Compose([
    RandomAtomMask([0.1, 0.3]),
    RandomBondDelete([0.1, 0.3]),
    MotifRemoval()
])
transform = RandomAtomMask(0.1)

# Initialize dataset object
dataset = MoleculeDatasetWrapper("ClinTox", data_path="./data_download", transform=transform, batch_size=128)

# Get train/valid/test splits as loaders
train_loader, valid_loader, test_loader = dataset.get_data_loaders("all")

Using: ./data_download/clintox.csv.gz
DATASET: ClinTox


1484it [00:00, 8222.46it/s]

Generating scaffolds...
Generating scaffold 0/1477





Generating scaffold 1000/1477
About to sort in scaffold sets




### Initialize model with task from data

In [4]:
# Get model
num_outputs = len(dataset.labels.keys())
model = AttentiveFP(task=dataset.task, output_dim=num_outputs)

# Uncomment the following line to use GPU
#model.cuda()

### Initialize traning loop

In [5]:
if(dataset.task == 'classification'):
    criterion = torch.nn.CrossEntropyLoss()
elif(dataset.task == 'regression'):
    criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

### Train the model

In [6]:
for epoch in range(2):
    for bn, data in tqdm(enumerate(train_loader)):
        optimizer.zero_grad()
        loss = 0.
        
        # Get prediction for all data
        _, pred = model(data)
        
        # To use GPU, data must be cast to cuda
        #_, pred = model(data.cuda())

        for idx, t in enumerate(train_loader.dataset.target):
            # Get indices where target has a value
            good_idx = np.where(data.y[:,idx]!=-999999999)
            
            # When the data is placed on GPU, target must come back to CPU
            #good_idx = np.where(data.y.cpu()[:,idx]!=-999999999)

            # Prediction is handled differently for classification and regression
            if(train_loader.dataset.task == 'classification'):
                current_preds = pred[:,2*(idx):2*(idx+1)][good_idx]
                current_labels = data.y[:,idx][good_idx]
            elif(train_loader.dataset.task == 'regression'):
                current_preds = pred[:,idx][good_idx]
                current_labels = data.y[:,idx][good_idx]
            
            loss += criterion(current_preds, current_labels)
        
        loss.backward()
        optimizer.step()

10it [00:36,  3.65s/it]
10it [00:30,  3.06s/it]


### Test the model

In [7]:
def evaluate(model, test_loader, validation=False):
    set_str = "VALIDATION" if validation else "TEST"
    with torch.no_grad():
        
        # All targets we're evaluating
        target_list = test_loader.dataset.target
        
        # Dictionaries to keep track of predictions and labels for all targets
        all_preds = {target: [] for target in target_list}
        all_labels = {target: [] for target in target_list}
        
        model.eval()
        for data in test_loader:
            # Get prediction for all data
            _, pred = model(data)

            # To use GPU, data must be cast to cuda
            #_, pred = model(data.cuda())
            
            for idx, target in enumerate(target_list):
                # Get indices where target has a value
                good_idx = np.where(data.y[:,idx]!=-999999999)
                
                # When the data is placed on GPU, target must come back to CPU
                #good_idx = np.where(data.y.cpu()[:,idx]!=-999999999)
                
                # Prediction is handled differently for classification and regression
                if(train_loader.dataset.task == 'classification'):
                    current_preds = pred[:,2*(idx):2*(idx+1)][good_idx][:,1]
                    current_labels = data.y[:,idx][good_idx]
                elif(train_loader.dataset.task == 'regression'):
                    current_preds = pred[:,idx][good_idx]
                    current_labels = data.y[:,idx][good_idx]
                
                # Save predictions and targets
                all_preds[target].extend(list(current_preds.detach().cpu().numpy()))
                all_labels[target].extend(list(current_labels.detach().cpu().numpy()))
            
        scores = {target: None for target in target_list}
        for target in target_list:
            if(test_loader.dataset.task == 'classification'):
                scores[target] = ras(all_labels[target], all_preds[target])
                print("{0} {1} ROC: {2:.5f}".format(target, set_str, scores[target]))
            elif(test_loader.dataset.task == 'regression'):
                scores[target] = mean_squared_error(all_labels[target], all_preds[target],
                                                    squared=False)
                print("{0} {1} RMSE: {2:.5f}".format(target, set_str, scores[target]))

In [8]:
evaluate(model, valid_loader, validation=True)
evaluate(model, test_loader)

CT_TOX VALIDATION ROC: 0.47011
FDA_APPROVED VALIDATION ROC: 0.30634
CT_TOX TEST ROC: 0.50000
FDA_APPROVED TEST ROC: 0.47002


### Model saving/loading example

In [9]:
# Save model
torch.save(model.state_dict(), "./saved_models/example_gcn")

In [10]:
# Instantiate new model and evaluate
model = AttentiveFP(task=dataset.task, output_dim=num_outputs)
evaluate(model, test_loader)

CT_TOX TEST ROC: 0.51159
FDA_APPROVED TEST ROC: 0.47002


In [11]:
# Load saved model and evaluate
model.load_state_dict(torch.load("./saved_models/example_gcn"))
evaluate(model, test_loader)

CT_TOX TEST ROC: 0.50000
FDA_APPROVED TEST ROC: 0.47002
