In [1]:
import numpy as np
import pandas as pd
import torch
from pymatgen.io.cif import CifParser
from tqdm.auto import tqdm, trange
from joblib import Parallel, delayed
import os
if "ROLOS_AVAILABLE_CPU" in os.environ:
    n_cpus = int(float(os.environ["ROLOS_AVAILABLE_CPU"]))
else:
    n_cpus = None
torch.set_num_threads(n_cpus)

In [2]:
def string_to_struct(s):
    return CifParser.from_str(s).get_structures(primitive=False)[0]
structures = pd.read_csv("pilot/data.csv.gz", index_col=0).squeeze("columns").apply(string_to_struct)



In [3]:
descriptors = pd.read_csv('descriptors.csv', index_col='_id', usecols=['_id', 'base', 'cell'])

In [4]:
prepared = pd.merge(structures, descriptors, left_index=True, right_index=True)
prepared

Unnamed: 0,structures,base,cell
6141cf10b842c2e72e2f2d42,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,MoS2,"[8, 8, 1]"
6141cf10cc0e69a0cf28ab33,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,MoS2,"[8, 8, 1]"
6141cf10cc0e69a0cf28ab33,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,MoS2,"[8, 8, 1]"
6141cf123ac25c70a5c6c835,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,MoS2,"[8, 8, 1]"
6141cf123ac25c70a5c6c835,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,MoS2,"[8, 8, 1]"
...,...,...,...
6141cf184e27a1844a5efff8,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,MoS2,"[8, 8, 1]"
6141cf184e27a1844a5efff8,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,MoS2,"[8, 8, 1]"
6141cf184e27a1844a5efff8,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,MoS2,"[8, 8, 1]"
6141cf184e27a1844a5efff8,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,MoS2,"[8, 8, 1]"


In [5]:
unit_cells = {
    'MoS2': CifParser("MoS2.cif").get_structures(primitive=False)[0],
    'WSe2': CifParser("WSe2.cif").get_structures(primitive=False)[0]
}



In [6]:
prepared = prepared.values.tolist()
prepared = [[p[0], p[1], eval(p[2])] for p in prepared]

In [7]:
from MEGNetSparse import convert_to_sparse_representation

dataset = Parallel(n_jobs=n_cpus)(
            delayed(convert_to_sparse_representation)(p[0], unit_cells[p[1]], p[2], True) for p in tqdm(prepared))

  0%|          | 0/200 [00:00<?, ?it/s]

In [8]:
targets = pd.read_csv('pilot/targets.csv.gz')
targets = torch.tensor(targets['formation_energy_per_site'].values).float()

In [9]:
train_data, test_data = dataset[:100], dataset[100:]
train_targets, test_targets = targets[:100], targets[100:]

In [10]:
config = {
    'model': {
        'train_batch_size': 50,
        'test_batch_size': 50,
        'add_z_bond_coord': True,
        'atom_features': 'werespecies',
        'state_input_shape': 2,
        'cutoff': 10,
        'edge_embed_size': 10,
        'vertex_aggregation': 'mean',
        'global_aggregation': 'mean',
        'embedding_size': 32,
        'nblocks': 3,
    },
    'optim': {
        'factor': 0.5,
        'patience': 30,
        'threshold': 5e-2,
        'min_lr': 1e-5,
        'lr_initial': 1e-3,
        'scheduler': 'ReduceLROnPlateau',
    }
}

In [11]:
from MEGNetSparse import MEGNetTrainer

trainer = MEGNetTrainer(config, 'cpu')

In [12]:
trainer.prepare_data(train_data, train_targets, test_data, test_targets, 'formation_energy')

adding targets to data
converting data


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

In [13]:
trainer.train_one_epoch()

target: formation_energy device: cpu


(1.1832267761230468, 1.0245155)

In [15]:
trainer.evaluate_on_test(return_predictions=True)

(1.0041792297363281,
 MyTensor([[2.2206],
           [2.2199],
           [2.2194],
           [2.2199],
           [2.2189],
           [2.2193],
           [2.2193],
           [2.2207],
           [2.2198],
           [2.2190],
           [2.2190],
           [2.2188],
           [2.2204],
           [2.2194],
           [2.2217],
           [2.2203],
           [2.2196],
           [2.2199],
           [2.2219],
           [2.2195],
           [2.2197],
           [2.2203],
           [2.2206],
           [2.2200],
           [2.2245],
           [2.2201],
           [2.2196],
           [2.2466],
           [2.2379],
           [2.2377],
           [2.2385],
           [2.2383],
           [2.2380],
           [2.2422],
           [2.2379],
           [2.2381],
           [2.2422],
           [2.2388],
           [2.2381],
           [2.2378],
           [2.2377],
           [2.2381],
           [2.2422],
           [2.2382],
           [2.2422],
           [2.2378],
           [2