### Set Path (Won't be needed once `setup.py` is finished)

In [1]:
import sys
sys.path.insert(0, sys.path[0][:-8])

In [2]:
import torch
from tqdm import tqdm
from torch.autograd import Variable
from sklearn.metrics import mean_absolute_error

### Auglichem imports

In [3]:
from auglichem.crystal import Compose, RandomRotationTransformation, SupercellTransformation
from auglichem.crystal.data import CrystalDatasetWrapper
from auglichem.crystal.models import CrystalGraphConvNet as CGCNN

### Set up dataset

In [4]:
help(CrystalDatasetWrapper)

Help on class CrystalDatasetWrapper in module auglichem.crystal.data._crystal_dataset:

class CrystalDatasetWrapper(CrystalDataset)
 |  CrystalDatasetWrapper(*args, **kwds)
 |  
 |  The CIFData dataset is a wrapper for a dataset where the crystal structures
 |  are stored in the form of CIF files. The dataset should have the following
 |  directory structure:
 |  
 |  root_dir
 |  ├── id_prop.csv
 |  ├── atom_init.json
 |  ├── 0.cif
 |  ├── 1.cif
 |  ├── ...
 |  
 |  id_prop.csv: a CSV file with two columns. The first column recodes a
 |  unique ID for each crystal, and the second column recodes the value of
 |  target property.
 |  
 |  atom_init.json: a JSON file that stores the initialization vector for each
 |  element.
 |  
 |  ID.cif: a CIF file that recodes the crystal structure, where ID is the
 |  unique ID for the crystal.
 |  
 |  Parameters
 |  ----------
 |  
 |  root_dir: str
 |      The path to the root directory of the dataset
 |  max_num_nbr: int
 |      The maximum nu

In [5]:
help(CrystalDatasetWrapper.__init__)

Help on function __init__ in module auglichem.crystal.data._crystal_dataset:

__init__(self, dataset, transform=None, split='random', batch_size=64, num_workers=0, valid_size=0.1, test_size=0.1, data_path=None, target=None, **kwargs)
    Wrapper Class to handle splitting dataset into train, validation, and test sets
    
    inputs:
    -------------------------
    dataset (str): One of our dataset: lanthanides, perovskites, band_gap, fermi_energy,
                                       or formation_energy
    transform (AbstractTransformation, optional): A crystal transformation
    split (str, default=random): Method of splitting data into train, validation, and
                                 test
    batch_size (int, default=64): Data batch size for train_loader
    num_workers (int, default=0): Number of worker processes for parallel data loading
    valid_size (float, optional, between [0, 1]): Fraction of data used for validation
    test_size (float, optional, between [0, 1])

In [6]:
# Create transformation
transform = [
    SupercellTransformation()
]

# Initialize dataset object
dataset = CrystalDatasetWrapper("lanthanides", batch_size=256, valid_size=0.1, test_size=0.1)

# Get train/valid/test splits as loaders
train_loader, valid_loader, test_loader = dataset.get_data_loaders(transform=transform)

100%|███████████████████████████████████████████████████████████| 3332/3332 [00:33<00:00, 100.25it/s]


### Initialize model with task from data

In [7]:
# Get model
structures, _, _ = dataset[0]
orig_atom_fea_len = structures[0].shape[-1]
nbr_fea_len = structures[1].shape[-1]

model = CGCNN(orig_atom_fea_len, nbr_fea_len)

### Initialize traning loop

In [8]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

### Train the model

In [9]:
for epoch in range(1):
    for bn, (data, target, _) in tqdm(enumerate(train_loader)):
        optimizer.zero_grad()
        input_var = (Variable(data[0]),
                     Variable(data[1]),
                     data[2],
                     data[3])
        
        pred = model(*input_var)
        loss = criterion(pred, target)
        
        loss.backward()
        optimizer.step()

26it [02:01,  4.68s/it]


### Test the model

In [10]:
def evaluate(model, test_loader):
    with torch.no_grad():
        model.eval()
        data, target, _ = next(iter(test_loader))
        input_var = (Variable(data[0]),
                     Variable(data[1]),
                     data[2],
                     data[3])

        pred = model(*input_var)
        mae = mean_absolute_error(pred, target)   
        
    print("TEST MAE: {0:.3f}".format(loss.detach()))

In [11]:
evaluate(model, test_loader)

TEST MAE: 0.168


### Model saving/loading example

In [12]:
# Save model
torch.save(model.state_dict(), "./saved_models/example_cgcnn")

In [13]:
# Instantiate new model and evaluate
structures, _, _ = dataset[0]
orig_atom_fea_len = structures[0].shape[-1]
nbr_fea_len = structures[1].shape[-1]

model = CGCNN(orig_atom_fea_len, nbr_fea_len)

evaluate(model, test_loader)

TEST MAE: 0.168


In [14]:
# Load saved model and evaluate
model.load_state_dict(torch.load("./saved_models/example_cgcnn"))
evaluate(model, test_loader)

TEST MAE: 0.168
