### Set Path (Won't be needed once `setup.py` is finished)

In [1]:
import sys
sys.path.append(sys.path[0][:-8])

In [2]:
import torch
from tqdm import tqdm

from rdkit import Chem
from rdkit.Chem import Draw
from matplotlib import pyplot as plt

### Auglichem imports

In [3]:
from auglichem.molecule import Compose, RandomAtomMask, RandomBondDelete
from auglichem.molecule.data import MoleculeDatasetWrapper
from auglichem.molecule.models import GCN

### Set up dataset

In [None]:
# Create transformation
transform = Compose([
    RandomAtomMask([0.1, 0.3]),
    RandomBondDelete([0.1, 0.3])
])

# Initialize dataset object
dataset = MoleculeDatasetWrapper("MUV", transform=transform, batch_size=512)

# Get train/valid/test splits as loaders
train_loader, valid_loader, test_loader = dataset.get_data_loaders()

Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/muv.csv.gz to ./data_download/muv.csv.gz


1756160it [00:00, 3188462.85it/s]                                                                    


DATASET: MUV


74117it [01:48, 671.60it/s]

### Initialize model with task from data

In [5]:
# Get model
model = GCN(task=dataset.task)

### Initialize traning loop

In [6]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

### Train the model

In [7]:
for epoch in tqdm(range(10)):
    for bn, data in enumerate(train_loader):
        optimizer.zero_grad()
        
        _, pred = model(data)
        loss = criterion(pred, data.y.flatten())
        
        loss.backward()
        optimizer.step()

100%|████████████████████████████████████████████████████████████████| 10/10 [01:44<00:00, 10.44s/it]


### Test the model

In [24]:
def evaluate(model, test_loader):
    with torch.no_grad():
        model.eval()
        data = next(iter(test_loader))
        _, pred = model(next(iter(test_loader)))
        loss = criterion(pred, data.y.flatten())

    print("TEST LOSS: {0:.3f}".format(loss.detach()))

In [25]:
evaluate(model, test_loader)

TEST LOSS: 2.098


### Model saving/loading example

In [27]:
# Save model
torch.save(model.state_dict(), "./saved_models/example_gcn")

In [41]:
# Instantiate new model and evaluate
model = GCN(task=dataset.task)
evaluate(model, test_loader)
data = next(iter(test_loader))
_, pred = model(next(iter(test_loader)))
ras(data.y, pred.detach()[:,1])

TEST LOSS: 4.111


0.36557971014492757

In [42]:
# Load saved model and evaluate
model.load_state_dict(torch.load("./saved_models/example_gcn"))
evaluate(model, test_loader)

TEST LOSS: 2.098


In [43]:
from sklearn.metrics import roc_auc_score as ras

In [44]:
data = next(iter(test_loader))
_, pred = model(next(iter(test_loader)))
ras(data.y, pred.detach()[:,1])

0.39221014492753625

In [45]:
import deepchem

ModuleNotFoundError: No module named 'deepchem'