In [3]:
import numpy as np
import pandas as pd
from pymatgen.io.cif import CifParser
from tqdm import tqdm
from joblib import Parallel, delayed

In [19]:
df = pd.read_pickle('pilot/data.pickle.gz')
structures = pd.Series(data=df['initial_structure'].values, index=df['descriptor_id'].values, name='structures')

In [20]:
descriptors = pd.read_csv('descriptors.csv', index_col='_id')
descriptors = descriptors[['base', 'cell']]

In [21]:
prepared = pd.merge(structures, descriptors, left_index=True, right_index=True)
prepared

Unnamed: 0,structures,base,cell
6141cf10b842c2e72e2f2d42,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,MoS2,"[8, 8, 1]"
6141cf10cc0e69a0cf28ab33,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,MoS2,"[8, 8, 1]"
6141cf10cc0e69a0cf28ab33,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,MoS2,"[8, 8, 1]"
6141cf123ac25c70a5c6c835,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,MoS2,"[8, 8, 1]"
6141cf123ac25c70a5c6c835,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,MoS2,"[8, 8, 1]"
...,...,...,...
6141cf184e27a1844a5efff8,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,MoS2,"[8, 8, 1]"
6141cf184e27a1844a5efff8,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,MoS2,"[8, 8, 1]"
6141cf184e27a1844a5efff8,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,MoS2,"[8, 8, 1]"
6141cf184e27a1844a5efff8,[[1.27612629e-07 1.84192955e+00 3.71975100e+00...,MoS2,"[8, 8, 1]"


In [22]:
unit_cells = {
    'MoS2': CifParser("MoS2.cif").get_structures(primitive=False)[0],
    'WSe2': CifParser("WSe2.cif").get_structures(primitive=False)[0]
}



In [23]:
prepared = prepared.values.tolist()
prepared = [[p[0], p[1], eval(p[2])] for p in prepared]

In [24]:
from MEGNetSparse.dense2sparse import convert_to_sparse_representation

dataset = Parallel(n_jobs=-1)(
            delayed(convert_to_sparse_representation)(p[0], unit_cells[p[1]], p[2], True) for p in tqdm(prepared))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:26<00:00,  7.44it/s]


In [32]:
targets = pd.read_csv('pilot/targets.csv.gz')
targets = targets['formation_energy_per_site'].values

In [34]:
train_data, test_data = dataset[:100], dataset[100:]
train_targets, test_targets = targets[:100], targets[100:]

In [35]:
config = {
    'model': {
        'train_batch_size': 50,
        'test_batch_size': 50,
        'add_z_bond_coord': True,
        'atom_features': 'werespecies',
        'state_input_shape': 2,
        'cutoff': 10,
        'edge_embed_size': 10,
        'vertex_aggregation': 'mean',
        'global_aggregation': 'mean',
        'embedding_size': 32,
        'nblocks': 3,
    },
    'optim': {
        'factor': 0.5,
        'patience': 30,
        'threshold': 5e-2,
        'min_lr': 1e-5,
        'lr_initial': 1e-3,
        'scheduler': 'ReduceLROnPlateau',
    }
}

In [36]:
from MEGNetSparse.trainer import MEGNetTrainer

trainer = MEGNetTrainer(config, 'cpu')

In [37]:
trainer.prepare_data(train_data, train_targets, test_data, test_targets, 'formation_energy')

adding targets to data
converting data


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:03<00:00,  1.57it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 694.72it/s]


In [38]:
trainer.train_one_epoch()

target: formation_energy device: cpu


RuntimeError: Found dtype Double but expected Float