### Set Path (Won't be needed once `setup.py` is finished)

In [None]:
import sys
sys.path.append(sys.path[0][:-8])

In [None]:
import torch
from tqdm import tqdm

from torch.autograd import Variable

### Auglichem imports

In [None]:
from auglichem.crystal import Compose, RandomRotationTransformation, SupercellTransformation
from auglichem.crystal.data import CrystalDatasetWrapper
from auglichem.crystal.models import CrystalGraphConvNet as CGCNN

### Set up dataset

In [None]:
# Create transformation
transform = [
    SupercellTransformation()
]

# Initialize dataset object
dataset = CrystalDatasetWrapper("Lanthanides", batch_size=1024)
dataset.data_augmentation(transform)

# Get train/valid/test splits as loaders
train_loader, valid_loader, test_loader = dataset.get_data_loaders()

### Initialize model with task from data

In [None]:
# Get model
structures, _, _ = dataset[0]
orig_atom_fea_len = structures[0].shape[-1]
nbr_fea_len = structures[1].shape[-1]

model = CGCNN(orig_atom_fea_len, nbr_fea_len)

### Initialize traning loop

In [None]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

### Train the model

In [None]:
for epoch in tqdm(range(10)):
    for bn, (data, target, _) in enumerate(train_loader):
        optimizer.zero_grad()
        input_var = (Variable(data[0]),
                     Variable(data[1]),
                     data[2],
                     data[3])
        
        pred = model(*input_var)
        loss = criterion(pred, target)
        
        loss.backward()
        optimizer.step()

### Test the model

In [None]:
with torch.no_grad():
    model.eval()
    data, target, _ = next(iter(test_loader))
    input_var = (Variable(data[0]),
                 Variable(data[1]),
                 data[2],
                 data[3])
    
    pred = model(*input_var)
    loss = criterion(pred, target)

print("TEST LOSS: {0:.3f}".format(loss.detach()))