### Set Path (Won't be needed once `setup.py` is finished)

In [1]:
import sys
sys.path.insert(0, sys.path[0][:-8])

In [2]:
import torch
import warnings
from tqdm import tqdm
from torch.autograd import Variable
from sklearn.metrics import mean_absolute_error

### Auglichem imports

In [3]:
from auglichem.crystal import Compose, RotationTransformation, SupercellTransformation
from auglichem.crystal.data import CrystalDatasetWrapper
from auglichem.crystal.models import SchNet, GINet
from auglichem.crystal.models import GINet

### Set up dataset

In [4]:
#help(CrystalDatasetWrapper)

In [5]:
#help(CrystalDatasetWrapper.__init__)

In [6]:
# Create transformation
transform = [
    SupercellTransformation(),
]

# Initialize dataset object
dataset = CrystalDatasetWrapper("lanthanides", batch_size=128, kfolds=4,
                                valid_size=0.1, test_size=0.1)


### Train the model

In [7]:
def train(model, train_loader):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        for epoch in range(1):
            for bn, data in tqdm(enumerate(train_loader)):        
                optimizer.zero_grad()

                pred = model(data)
                loss = criterion(pred, data.y)

                loss.backward()
                optimizer.step()
    return model

### Test the model

In [8]:
def evaluate(model, test_loader, validation=False):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        with torch.no_grad():
            model.eval()
            preds = torch.Tensor([])
            targets = torch.Tensor([])
            for data in test_loader:
                pred = model(data)
                preds = torch.cat((preds, pred))
                targets = torch.cat((targets, data.y))

            mae = mean_absolute_error(preds, targets)   
        
        set_str = "VALIDATION" if(validation) else "TEST"
        print("{0} MAE: {1:.3f}".format(set_str, mae))

### Initialize model, train, test for first fold

In [9]:
# Get model
model = GINet() # Note: SchNet and GINet are interchangeable in use cases

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

# Get train/valid/test splits as loaders
train_loader, valid_loader, test_loader = dataset.get_data_loaders(transform=transform, fold=0)

model = train(model, train_loader)

evaluate(model, valid_loader, validation=True)
evaluate(model, test_loader)

Ignoring splitting. Using pre-split k folds.


100%|████████████████████████████████████████████████████████| 4166/4166 [00:00<00:00, 185886.01it/s]
44it [01:45,  2.40s/it]


VALIDATION MAE: 0.016
TEST MAE: 0.018


In [10]:
# Get model
model = GINet() # Note: SchNet and GINet are interchangeable in use cases

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

# Get train/valid/test splits as loaders
train_loader, valid_loader, test_loader = dataset.get_data_loaders(transform=transform, fold=1)

model = train(model, train_loader)

evaluate(model, valid_loader, validation=True)
evaluate(model, test_loader)

Ignoring splitting. Using pre-split k folds.
Augmentation has already been done.


44it [01:19,  1.81s/it]


VALIDATION MAE: 0.441
TEST MAE: 0.451


In [11]:
# Get model
model = GINet() # Note: SchNet and GINet are interchangeable in use cases

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

# Get train/valid/test splits as loaders
train_loader, valid_loader, test_loader = dataset.get_data_loaders(transform=transform, fold=2)

model = train(model, train_loader)

evaluate(model, valid_loader, validation=True)
evaluate(model, test_loader)

Ignoring splitting. Using pre-split k folds.
Augmentation has already been done.


44it [01:05,  1.48s/it]


VALIDATION MAE: 0.144
TEST MAE: 0.134


In [12]:
# Get model
model = GINet() # Note: SchNet and GINet are interchangeable in use cases

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

# Get train/valid/test splits as loaders
train_loader, valid_loader, test_loader = dataset.get_data_loaders(transform=transform, fold=3)

model = train(model, train_loader)

evaluate(model, valid_loader, validation=True)
evaluate(model, test_loader)

Ignoring splitting. Using pre-split k folds.
Augmentation has already been done.


44it [00:56,  1.27s/it]


VALIDATION MAE: 0.047
TEST MAE: 0.043
