### Set Path (Won't be needed once `setup.py` is finished)

In [1]:
import sys
sys.path.append(sys.path[0][:-8])

In [2]:
import torch
from tqdm import tqdm

### Auglichem imports

In [3]:
from auglichem.molecule import Compose, RandomAtomMask, RandomBondDelete
from auglichem.molecule.data import MoleculeDataset
from auglichem.models import GCN

### Set up dataset

In [4]:
# Create transformation
transform = Compose([
    RandomAtomMask([0.1, 0.3]),
    RandomBondDelete([0.1, 0.3])
])

# Initialize dataset object
data = MoleculeDataset("BACE", transform=transform, batch_size=512)

# Get train/valid/test splits as loaders
train_loader, valid_loader, test_loader = data.get_data_loaders()

TRANSFORM IN MOLECULE DATASET: <auglichem.molecule._compositions.Compose object at 0x1433e09a0>
Using: ./data_download/bace.csv
1512
About to generate scaffolds
Generating scaffold 0/1512
Generating scaffold 1000/1512
About to sort in scaffold sets


### Initialize model with task from data

In [5]:
# Get model
model = GCN(task=data.task)

### Initialize traning loop

In [6]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

### Train the model

In [7]:
for epoch in tqdm(range(10)):
    for bn, data in enumerate(train_loader):
        optimizer.zero_grad()
        
        _, pred = model(data)
        loss = criterion(pred, data.y.flatten())
        
        loss.backward()
        optimizer.step()

100%|████████████████████████████████████████████████████████████████| 10/10 [01:35<00:00,  9.50s/it]


### Test the model

In [8]:
with torch.no_grad():
    model.eval()
    data = next(iter(test_loader))
    _, pred = model(next(iter(test_loader)))
    loss = criterion(pred, data.y.flatten())

print("TEST LOSS: {0:.3f}".format(loss.detach()))

TEST LOSS: 0.730
