### Imports

In [1]:
import torch
import warnings
from tqdm import tqdm
from torch.autograd import Variable
from sklearn.metrics import mean_absolute_error

### Auglichem imports

In [2]:
from auglichem.crystal import (PerturbStructureTransformation,
                               RotationTransformation,
                               SwapAxesTransformation,
                               TranslateSitesTransformation,
                               SupercellTransformation,
)
from auglichem.crystal.data import CrystalDatasetWrapper
from auglichem.crystal.models import GINet, GCN

### Set up dataset

In [3]:
# Create transformation
transforms = [
        PerturbStructureTransformation(distance=0.1, min_distance=0.01),
        RotationTransformation(axis=[0,0,1], angle=90),
        SwapAxesTransformation(),
        TranslateSitesTransformation(indices_to_move=[0], translation_vector=[1,0,0],
                                     vector_in_frac_coords=True),
        SupercellTransformation(scaling_matrix=[[1,0,0],[0,1,0],[0,0,1]]),
]

# Initialize dataset object
dataset = CrystalDatasetWrapper("lanthanides", batch_size=128,
                                valid_size=0.1, test_size=0.1)

# Get train/valid/test splits as loaders
train_loader, valid_loader, test_loader = dataset.get_data_loaders(transform=transforms)

Downloading data to: ./data_download/lanths...


99it [00:00, 253.36it/s]


Extracting zipfile...
Removing zipfile...


100%|████████████████████████████████████████████████████████████████████████████████████████████████| 3332/3332 [01:20<00:00, 41.35it/s]


### Initialize model with task from data

In [4]:
# Get model
model = GINet() # Note: GCN and GINet are interchangeable in use cases

# Uncomment the following line to use cuda
#model.cuda()

### Initialize traning loop

In [5]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

### Train the model

In [6]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for epoch in range(1):
        for bn, data in tqdm(enumerate(train_loader)):        
            optimizer.zero_grad()

            # Comment out the following line and uncomment the line after for cuda
            pred = model(data)
            #pred = model(data.cuda())
            
            loss = criterion(pred, data.y)

            loss.backward()
            optimizer.step()

56it [01:09,  1.24s/it]


ValueError: Invalid cif file with no structures!

### Test the model

In [None]:
def evaluate(model, test_loader, validation=False):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        with torch.no_grad():
            model.eval()
            preds = torch.Tensor([])
            targets = torch.Tensor([])
            for data in test_loader:
                pred = model(data)
                #pred = model(data.cuda())
                preds = torch.cat((preds, pred.cpu()))
                targets = torch.cat((targets, data.y.cpu()))

            mae = mean_absolute_error(preds, targets)   
        
        set_str = "VALIDATION" if(validation) else "TEST"
        print("{0} MAE: {1:.3f}".format(set_str, mae))

In [None]:
evaluate(model, valid_loader, validation=True)
evaluate(model, test_loader)

### Model saving/loading example

In [None]:
# Save model
torch.save(model.state_dict(), "./example_ginet")

In [None]:
# Instantiate new model and evaluate
model = GINet()

evaluate(model, valid_loader, validation=True)
evaluate(model, test_loader)

In [None]:
# Load saved model and evaluate
model.load_state_dict(torch.load("./example_ginet"))
evaluate(model, valid_loader, validation=True)
evaluate(model, test_loader)