### Imports

In [1]:
import torch
import os
from tqdm import tqdm

from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')                                                                                                                                                       

from rdkit.Chem import Draw
from matplotlib import pyplot as plt

from sklearn.metrics import roc_auc_score as ras
from sklearn.metrics import mean_squared_error as mse

### Auglichem imports

In [2]:
from auglichem.molecule import Compose, RandomAtomMask, RandomBondDelete, MotifRemoval
from auglichem.molecule.data import MoleculeDatasetWrapper
from auglichem.molecule.models import GCN, AttentiveFP, GINE, DeepGCN

### Set up dataset

In [3]:
# Create transformation
transform = Compose([
    RandomAtomMask([0.1, 0.3]),
    RandomBondDelete([0.1, 0.3]),
    MotifRemoval()
])

# Initialize dataset object
dataset = MoleculeDatasetWrapper("FreeSolv", data_path="./data_download", transform=transform, batch_size=128)

# Get train/valid/test splits as loaders
train_loader, valid_loader, test_loader = dataset.get_data_loaders()

Using: ./data_download/FreeSolv/SAMPL.csv
DATASET: FreeSolv


642it [00:00, 27590.43it/s]

Generating scaffolds...
Generating scaffold 0/641
About to sort in scaffold sets



  train_loader, valid_loader, test_loader = dataset.get_data_loaders()


### Initialize model with task from data

In [4]:
# Get model
model = GCN(task=dataset.task)

# Uncomment the following line to use GPU
#model.cuda()

### Initialize traning loop

In [5]:
if(dataset.task == 'classification'):
    criterion = torch.nn.CrossEntropyLoss()
elif(dataset.task == 'regression'):
    criterion = torch.nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

### Train the model

In [6]:
for epoch in range(2):
    for bn, data in tqdm(enumerate(train_loader)):

        optimizer.zero_grad()
        
        _, pred = model(data)
        
        # data -> GPU
        #_, pred = model(data.cuda())
        
        if(train_loader.dataset.task == "classification"):
            loss = criterion(pred, data.y.flatten())
        if(train_loader.dataset.task == "regression"):
            loss = criterion(pred[:,0], data.y.flatten())

        loss.backward()
        optimizer.step()

4it [00:00,  5.47it/s]
4it [00:00,  5.73it/s]


### Test the model

In [7]:
def evaluate(model, test_loader, validation=False):
    task = test_loader.dataset.task
    set_str = "VALIDATION" if validation else "TEST"
    with torch.no_grad():
        model.eval()
        
        all_preds = torch.Tensor()
        all_labels = torch.Tensor()
        for data in test_loader:
            _, pred = model(data)

            # data -> GPU
            #_, pred = model(data.cuda())
            
            # Hold on to all predictions and labels
            if(task == 'classification'):
                #all_preds.extend(pred[:,1])
                all_preds = torch.cat([all_preds, pred[:,1]])
            elif(task == 'regression'):
                #all_preds.extend(pred)
                all_preds = torch.cat([all_preds, pred])
                
            #all_labels.extend(data.y)
            all_labels = torch.cat([all_labels, data.y])
        
        if(task == 'classification'):
            metric = ras(all_labels.cpu(), all_preds.cpu().detach())
            print("{0} ROC: {1:.3f}".format(set_str, metric))
        elif(task == 'regression'):
            metric = mse(all_labels.cpu(), all_preds.cpu().detach(), squared=False)
            print("{0} RMSE: {1:.3f}".format(set_str, metric))


In [8]:
evaluate(model, valid_loader, True)
evaluate(model, test_loader)

VALIDATION RMSE: 33.159
TEST RMSE: 37.131


### Model saving/loading example

In [9]:
# Save model
os.makedirs("./saved_models/", exist_ok=True)
torch.save(model.state_dict(), "./saved_models/example_gcn")

In [10]:
# Instantiate new model and evaluate
model = GCN(task=dataset.task)

# For GPU, uncomment the following line
#model.cuda()

evaluate(model, test_loader)

TEST RMSE: 6.590


In [11]:
# Load saved model and evaluate
model.load_state_dict(torch.load("./saved_models/example_gcn"))
evaluate(model, test_loader)

TEST RMSE: 37.131
