### Set Path (Won't be needed once `setup.py` is finished)

In [1]:
import sys
sys.path.insert(0, sys.path[0][:-8])

In [2]:
import torch
import warnings
from tqdm import tqdm
from torch.autograd import Variable
from sklearn.metrics import mean_absolute_error

### Auglichem imports

In [3]:
from auglichem.crystal import Compose, RotationTransformation, SupercellTransformation
from auglichem.crystal.data import CrystalDatasetWrapper
from auglichem.crystal.models import CrystalGraphConvNet as CGCNN

### Set up dataset

In [4]:
#help(CrystalDatasetWrapper)

In [5]:
#help(CrystalDatasetWrapper.__init__)

In [6]:
# Create transformation
transform = [
    SupercellTransformation(),
]

# Initialize dataset object
dataset = CrystalDatasetWrapper("lanthanides", batch_size=256,
                                valid_size=0.1, test_size=0.1, cgcnn=True)

# Get train/valid/test splits as loaders
train_loader, valid_loader, test_loader = dataset.get_data_loaders(transform=transform)

100%|████████████████████████████████████████████████████████| 3332/3332 [00:00<00:00, 236959.90it/s]


### Initialize model with task from data

In [7]:
# Get model
structures, _, _ = dataset[0]
orig_atom_fea_len = structures[0].shape[-1]
nbr_fea_len = structures[1].shape[-1]

model = CGCNN(orig_atom_fea_len, nbr_fea_len)

#model.cuda()

  model = CGCNN(orig_atom_fea_len, nbr_fea_len)


### Initialize traning loop

In [8]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

### Train the model

In [9]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for epoch in range(1):
        for bn, (data, target, _) in tqdm(enumerate(train_loader)):
            optimizer.zero_grad()
            
            input_var = (Variable(data[0]),
                         Variable(data[1]),
                         data[2],
                         data[3])
            #input_var = (Variable(data[0].cuda()),
            #             Variable(data[1].cuda()),
            #             data[2].cuda(),
            #             data[3])

            pred = model(*input_var)
            loss = criterion(pred, target)
            #loss = criterion(pred, target.cuda())

            loss.backward()
            optimizer.step()

27it [02:15,  5.02s/it]


### Test the model

In [10]:
def evaluate(model, test_loader, validation=False):
    with torch.no_grad():
        model.eval()
        preds = torch.Tensor([])
        targets = torch.Tensor([])
        for data, target, _ in test_loader:
            input_var = (Variable(data[0]),
                         Variable(data[1]),
                         data[2],
                         data[3])
            
            #input_var = (Variable(data[0].cuda()),
            #             Variable(data[1].cuda()),
            #             data[2].cuda(),
            #             data[3])

            pred = model(*input_var)
            
            preds = torch.cat((preds, pred.cpu().detach()))
            targets = torch.cat((targets, target))
            
        mae = mean_absolute_error(preds, targets)   
    set_str = "VALIDATION" if(validation) else "TEST"
    print("{0} MAE: {1:.3f}".format(set_str, mae))

In [11]:
evaluate(model, valid_loader, validation=True)
evaluate(model, test_loader)



VALIDATION MAE: 0.878




TEST MAE: 0.897


### Model saving/loading example

In [12]:
# Save model
torch.save(model.state_dict(), "./saved_models/example_cgcnn")

In [13]:
# Instantiate new model and evaluate
structures, _, _ = dataset[0]
orig_atom_fea_len = structures[0].shape[-1]
nbr_fea_len = structures[1].shape[-1]

model = CGCNN(orig_atom_fea_len, nbr_fea_len)

evaluate(model, valid_loader, validation=True)
evaluate(model, test_loader)

VALIDATION MAE: 20.384


  model = CGCNN(orig_atom_fea_len, nbr_fea_len)


TEST MAE: 20.446


In [14]:
# Load saved model and evaluate
model.load_state_dict(torch.load("./saved_models/example_cgcnn"))
evaluate(model, valid_loader, validation=True)
evaluate(model, test_loader)

VALIDATION MAE: 0.878
TEST MAE: 0.897
