### Set Path (Won't be needed once `setup.py` is finished)

In [1]:
import sys
sys.path.insert(0, sys.path[0][:-8])

In [2]:
import torch
import warnings
from tqdm import tqdm
from torch.autograd import Variable
from sklearn.metrics import mean_absolute_error

### Auglichem imports

In [3]:
from auglichem.crystal import Compose, RotationTransformation, SupercellTransformation
from auglichem.crystal.data import CrystalDatasetWrapper
from auglichem.crystal.models import SchNet, GINet

### Set up dataset

In [4]:
#help(CrystalDatasetWrapper)

In [5]:
#help(CrystalDatasetWrapper.__init__)

In [6]:
# Create transformation
transform = [
    SupercellTransformation(),
]

# Initialize dataset object
dataset = CrystalDatasetWrapper("lanthanides", batch_size=128,
                                valid_size=0.1, test_size=0.1)

# Get train/valid/test splits as loaders
train_loader, valid_loader, test_loader = dataset.get_data_loaders(transform=transform)

100%|████████████████████████████████████████████████████████| 3332/3332 [00:00<00:00, 183732.40it/s]


### Initialize model with task from data

In [17]:
# Get model
model = SchNet() # Note: SchNet and GINet are interchangeable in use cases

# Uncomment the following line to use cuda
#model.cuda()

SchNet(hidden_channels=128, num_filters=128, num_interactions=6, num_gaussians=50, cutoff=10.0)

### Initialize traning loop

In [18]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

### Train the model

In [25]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for epoch in range(1):
        for bn, data in tqdm(enumerate(train_loader)):        
            optimizer.zero_grad()

            # Comment out the following line and uncomment the line after for cuda
            pred = model(data)
            #pred = model(data.cuda())
            
            loss = criterion(pred, data.y)

            loss.backward()
            optimizer.step()

53it [00:02, 23.30it/s]


### Test the model

In [26]:
def evaluate(model, test_loader, validation=False):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        with torch.no_grad():
            model.eval()
            preds = torch.Tensor([])
            targets = torch.Tensor([])
            for data in test_loader:
                #pred = model(data)
                pred = model(data.cuda())
                preds = torch.cat((preds, pred.cpu()))
                targets = torch.cat((targets, data.y.cpu()))

            mae = mean_absolute_error(preds, targets)   
        
        set_str = "VALIDATION" if(validation) else "TEST"
        print("{0} MAE: {1:.3f}".format(set_str, mae))

In [27]:
evaluate(model, valid_loader, validation=True)
evaluate(model, test_loader)

VALIDATION MAE: 4.307
TEST MAE: 4.708


### Model saving/loading example

In [28]:
# Save model
torch.save(model.state_dict(), "./saved_models/example_gin")

In [29]:
# Instantiate new model and evaluate
model = GINet()

evaluate(model, valid_loader, validation=True)
evaluate(model, test_loader)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper__index_select)

In [30]:
# Load saved model and evaluate
model.load_state_dict(torch.load("./saved_models/example_gin"))
evaluate(model, valid_loader, validation=True)
evaluate(model, test_loader)

RuntimeError: Error(s) in loading state_dict for GINet:
	Missing key(s) in state_dict: "x_embedding1.weight", "x_embedding2.weight", "x_embedding2.bias", "gnns.0.mlp.0.weight", "gnns.0.mlp.0.bias", "gnns.0.mlp.2.weight", "gnns.0.mlp.2.bias", "gnns.0.edge_embedding1.weight", "gnns.1.mlp.0.weight", "gnns.1.mlp.0.bias", "gnns.1.mlp.2.weight", "gnns.1.mlp.2.bias", "gnns.1.edge_embedding1.weight", "gnns.2.mlp.0.weight", "gnns.2.mlp.0.bias", "gnns.2.mlp.2.weight", "gnns.2.mlp.2.bias", "gnns.2.edge_embedding1.weight", "gnns.3.mlp.0.weight", "gnns.3.mlp.0.bias", "gnns.3.mlp.2.weight", "gnns.3.mlp.2.bias", "gnns.3.edge_embedding1.weight", "gnns.4.mlp.0.weight", "gnns.4.mlp.0.bias", "gnns.4.mlp.2.weight", "gnns.4.mlp.2.bias", "gnns.4.edge_embedding1.weight", "batch_norms.0.weight", "batch_norms.0.bias", "batch_norms.0.running_mean", "batch_norms.0.running_var", "batch_norms.1.weight", "batch_norms.1.bias", "batch_norms.1.running_mean", "batch_norms.1.running_var", "batch_norms.2.weight", "batch_norms.2.bias", "batch_norms.2.running_mean", "batch_norms.2.running_var", "batch_norms.3.weight", "batch_norms.3.bias", "batch_norms.3.running_mean", "batch_norms.3.running_var", "batch_norms.4.weight", "batch_norms.4.bias", "batch_norms.4.running_mean", "batch_norms.4.running_var", "feat_lin.weight", "feat_lin.bias", "head.0.weight", "head.0.bias", "head.2.weight", "head.2.bias". 
	Unexpected key(s) in state_dict: "atomic_mass", "embedding.weight", "distance_expansion.offset", "interactions.0.mlp.0.weight", "interactions.0.mlp.0.bias", "interactions.0.mlp.2.weight", "interactions.0.mlp.2.bias", "interactions.0.conv.lin1.weight", "interactions.0.conv.lin2.weight", "interactions.0.conv.lin2.bias", "interactions.0.conv.nn.0.weight", "interactions.0.conv.nn.0.bias", "interactions.0.conv.nn.2.weight", "interactions.0.conv.nn.2.bias", "interactions.0.lin.weight", "interactions.0.lin.bias", "interactions.1.mlp.0.weight", "interactions.1.mlp.0.bias", "interactions.1.mlp.2.weight", "interactions.1.mlp.2.bias", "interactions.1.conv.lin1.weight", "interactions.1.conv.lin2.weight", "interactions.1.conv.lin2.bias", "interactions.1.conv.nn.0.weight", "interactions.1.conv.nn.0.bias", "interactions.1.conv.nn.2.weight", "interactions.1.conv.nn.2.bias", "interactions.1.lin.weight", "interactions.1.lin.bias", "interactions.2.mlp.0.weight", "interactions.2.mlp.0.bias", "interactions.2.mlp.2.weight", "interactions.2.mlp.2.bias", "interactions.2.conv.lin1.weight", "interactions.2.conv.lin2.weight", "interactions.2.conv.lin2.bias", "interactions.2.conv.nn.0.weight", "interactions.2.conv.nn.0.bias", "interactions.2.conv.nn.2.weight", "interactions.2.conv.nn.2.bias", "interactions.2.lin.weight", "interactions.2.lin.bias", "interactions.3.mlp.0.weight", "interactions.3.mlp.0.bias", "interactions.3.mlp.2.weight", "interactions.3.mlp.2.bias", "interactions.3.conv.lin1.weight", "interactions.3.conv.lin2.weight", "interactions.3.conv.lin2.bias", "interactions.3.conv.nn.0.weight", "interactions.3.conv.nn.0.bias", "interactions.3.conv.nn.2.weight", "interactions.3.conv.nn.2.bias", "interactions.3.lin.weight", "interactions.3.lin.bias", "interactions.4.mlp.0.weight", "interactions.4.mlp.0.bias", "interactions.4.mlp.2.weight", "interactions.4.mlp.2.bias", "interactions.4.conv.lin1.weight", "interactions.4.conv.lin2.weight", "interactions.4.conv.lin2.bias", "interactions.4.conv.nn.0.weight", "interactions.4.conv.nn.0.bias", "interactions.4.conv.nn.2.weight", "interactions.4.conv.nn.2.bias", "interactions.4.lin.weight", "interactions.4.lin.bias", "interactions.5.mlp.0.weight", "interactions.5.mlp.0.bias", "interactions.5.mlp.2.weight", "interactions.5.mlp.2.bias", "interactions.5.conv.lin1.weight", "interactions.5.conv.lin2.weight", "interactions.5.conv.lin2.bias", "interactions.5.conv.nn.0.weight", "interactions.5.conv.nn.0.bias", "interactions.5.conv.nn.2.weight", "interactions.5.conv.nn.2.bias", "interactions.5.lin.weight", "interactions.5.lin.bias", "lin1.weight", "lin1.bias", "lin2.weight", "lin2.bias". 