### Imports

In [1]:
import torch
import warnings
from tqdm import tqdm
from torch.autograd import Variable
from sklearn.metrics import mean_absolute_error

### Auglichem imports

In [2]:
from auglichem.crystal import (PerturbStructureTransformation,
                               RotationTransformation,
                               SwapAxesTransformation,
                               TranslateSitesTransformation,
                               SupercellTransformation,
)
from auglichem.crystal.data import CrystalDatasetWrapper
from auglichem.crystal.models import CrystalGraphConvNet as CGCNN

### Set up dataset

In [3]:
# Create transformation
transforms = [
        PerturbStructureTransformation(distance=0.1, min_distance=0.01),
        RotationTransformation(axis=[0,0,1], angle=90),
        SwapAxesTransformation(),
        TranslateSitesTransformation(indices_to_move=[0], translation_vector=[1,0,0],
                                     vector_in_frac_coords=True),
        SupercellTransformation(scaling_matrix=[[1,0,0],[0,1,0],[0,0,1]]),
]

# Initialize dataset object
dataset = CrystalDatasetWrapper("lanthanides", batch_size=256, folds=3,
                                valid_size=0.1, test_size=0.1, cgcnn=True)


Data found at: ./data_download/lanths


### Train the model

In [4]:
def train(model, train_loader):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        for epoch in range(1):
            for bn, (data, target, _) in tqdm(enumerate(train_loader)):
                optimizer.zero_grad()
                input_var = (Variable(data[0]),
                             Variable(data[1]),
                             data[2],
                             data[3])

                pred = model(*input_var)
                loss = criterion(pred, target)

                loss.backward()
                optimizer.step()
    return model

### Test the model

In [5]:
def evaluate(model, test_loader, validation=False):
    with torch.no_grad():
        model.eval()
        preds = torch.Tensor([])
        targets = torch.Tensor([])
        for data, target, _ in test_loader:
            input_var = (Variable(data[0]),
                         Variable(data[1]),
                         data[2],
                         data[3])

            pred = model(*input_var)
            
            preds = torch.cat((preds, pred))
            targets = torch.cat((targets, target))
            
        mae = mean_absolute_error(preds, targets)   
    set_str = "VALIDATION" if(validation) else "TEST"
    print("{0} MAE: {1:.3f}".format(set_str, mae))

### Initialize model with task from data

In [6]:
# Get model
structures, _, _ = dataset[0]
orig_atom_fea_len = structures[0].shape[-1]
nbr_fea_len = structures[1].shape[-1]

model = CGCNN(orig_atom_fea_len, nbr_fea_len)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

# Get train/valid/test splits as loaders
train_loader, valid_loader, test_loader = dataset.get_data_loaders(transform=transforms, fold=0)

evaluate(model, valid_loader, validation=True)
evaluate(model, test_loader)

  model = CGCNN(orig_atom_fea_len, nbr_fea_len)
100%|█████████████████████████████████████████████████████████| 3332/3332 [00:00<00:00, 53804.60it/s]


VALIDATION MAE: 14.597




TEST MAE: 14.514


In [7]:
# Get model
structures, _, _ = dataset[0]
orig_atom_fea_len = structures[0].shape[-1]
nbr_fea_len = structures[1].shape[-1]

model = CGCNN(orig_atom_fea_len, nbr_fea_len)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

# Get train/valid/test splits as loaders
train_loader, valid_loader, test_loader = dataset.get_data_loaders(transform=transforms, fold=1)

evaluate(model, valid_loader, validation=True)
evaluate(model, test_loader)

  model = CGCNN(orig_atom_fea_len, nbr_fea_len)
100%|█████████████████████████████████████████████████████████| 3332/3332 [00:00<00:00, 25026.41it/s]


VALIDATION MAE: 1.366




TEST MAE: 1.286


In [8]:
# Get model
structures, _, _ = dataset[0]
orig_atom_fea_len = structures[0].shape[-1]
nbr_fea_len = structures[1].shape[-1]

model = CGCNN(orig_atom_fea_len, nbr_fea_len)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

# Get train/valid/test splits as loaders
train_loader, valid_loader, test_loader = dataset.get_data_loaders(transform=transforms, fold=2)

evaluate(model, valid_loader, validation=True)
evaluate(model, test_loader)

  model = CGCNN(orig_atom_fea_len, nbr_fea_len)
100%|█████████████████████████████████████████████████████████| 3332/3332 [00:00<00:00, 56328.20it/s]


VALIDATION MAE: 1.819




TEST MAE: 1.965
