In [7]:
import json
import numpy as np
import matplotlib.pyplot as plt
from ase import Atoms

with open('supporting-information.json', 'rb') as f:
     d = json.loads(f.read())
        

oxides = list(d.keys())
polymorphs = list(d['TiO2'].keys())
oxides, polymorphs

(['SnO2', 'IrO2', 'RuO2', 'TiO2', 'VO2'],
 ['rutile', 'pyrite', 'columbite', 'brookite', 'fluorite', 'anatase'])

In [8]:
c = d['TiO2']['rutile']['PBE']['EOS']['calculations'][0]
atoms = Atoms(symbols=c['atoms']['symbols'],
                 positions=c['atoms']['positions'],
             cell=c['atoms']['cell'],
             pbc=c['atoms']['pbc'])
atoms.set_tags(np.ones(len(atoms)))
atoms, c['data']['total_energy'], c['data']['forces']

(Atoms(symbols='Ti2O4', pbc=True, cell=[4.3789762519649225, 4.3789762519649225, 2.864091775985314], tags=...),
 -56.230672,
 [[0.0, 0.0, 0.0],
  [0.0, 0.0, 0.0],
  [-0.001264, -0.001264, 0.0],
  [0.001264, 0.001264, 0.0],
  [-0.001264, 0.001264, 0.0],
  [0.001264, -0.001264, 0.0]])

In [9]:
from ase.db import connect
from ase.calculators.singlepoint import SinglePointCalculator

! rm -fr oxides.db
db = connect('oxides.db')

In [10]:
for oxide in oxides:
    for polymorph in polymorphs:
        for c in d[oxide][polymorph]['PBE']['EOS']['calculations']:
            atoms = Atoms(symbols=c['atoms']['symbols'],
                          positions=c['atoms']['positions'],
                          cell=c['atoms']['cell'],
                          pbc=c['atoms']['pbc'])
            atoms.set_tags(np.ones(len(atoms)))
            calc = SinglePointCalculator(atoms, 
                                         energy=c['data']['total_energy'],
                                         forces=c['data']['forces'])
            atoms.set_calculator(calc)
            db.write(atoms)

Let's see what we made.

In [11]:
! ase db oxides.db

id|age|formula|calculator| energy|natoms| fmax|pbc| volume|charge|   mass
 1| 4s|Sn2O4  |unknown   |-41.359|     6|0.045|TTT| 64.258| 0.000|301.416
 2| 4s|Sn2O4  |unknown   |-41.853|     6|0.025|TTT| 66.526| 0.000|301.416
 3| 4s|Sn2O4  |unknown   |-42.199|     6|0.010|TTT| 68.794| 0.000|301.416
 4| 4s|Sn2O4  |unknown   |-42.419|     6|0.006|TTT| 71.062| 0.000|301.416
 5| 4s|Sn2O4  |unknown   |-42.534|     6|0.011|TTT| 73.330| 0.000|301.416
 6| 4s|Sn2O4  |unknown   |-42.562|     6|0.029|TTT| 75.598| 0.000|301.416
 7| 4s|Sn2O4  |unknown   |-42.518|     6|0.033|TTT| 77.866| 0.000|301.416
 8| 4s|Sn2O4  |unknown   |-42.415|     6|0.010|TTT| 80.134| 0.000|301.416
 9| 4s|Sn2O4  |unknown   |-42.266|     6|0.006|TTT| 82.402| 0.000|301.416
10| 4s|Sn2O4  |unknown   |-42.083|     6|0.017|TTT| 84.670| 0.000|301.416
11| 4s|Sn4O8  |unknown   |-81.424|    12|0.012|TTT|117.473| 0.000|602.832
12| 4s|Sn4O8  |unknown   |-82.437|    12|0.005|TTT|121.620| 0.000|602.832
13| 4s|Sn4O8  |unknown   

In [6]:
! wget -q -nc https://dl.fbaipublicfiles.com/opencatalystproject/models/2022_09/oc22/s2ef/gnoc_oc22_oc20_all_s2ef.pt

We need to split the ase-db into three separate databases, one for training (80%), one for testing (10%) and one for validation. We generate a list of ids and then shuffle them. Then we write the first 80% into train.db, the next 10% into test.db, and the remaining into val.db.

In [47]:
import numpy as np
rng = np.random.default_rng(seed=42)

ids = np.arange(1, 296)
rng.shuffle(ids)
ids 

array([271,  97, 101,  27,  20, 261, 194, 264,  71, 133, 267, 162, 160,
        76, 214, 246,   1, 243, 119, 293,  51,  90, 185, 136, 281, 189,
       280,  69, 218,   4, 292,  13,  93, 253, 270,   3,  81,  65,  47,
       169, 198,  52,  39,  62, 269, 222,  25, 144, 237, 226,  57, 166,
       286, 154, 254, 170, 181, 139,  99, 228, 145, 244, 279, 223, 186,
       248, 192, 132,  83, 233, 277,  55, 282,  63, 146, 240, 105, 147,
       262, 221,  31, 256, 199, 116,  48,  59,   8,  28, 245, 149,  61,
        85, 150, 135, 274, 128, 294, 272, 165,  38, 102, 263, 257, 172,
       182, 206,  21, 258, 127, 220, 151,  34,   9,  89, 203, 295, 204,
       113, 208,  26, 167,  30,  78,  80, 287,  96, 108, 193,  67,  10,
       249, 173,  60, 250, 229,   5, 212,  49,  68, 190,  95, 153, 180,
       125,  73, 134, 291, 159,  33, 156, 157, 231, 174, 122, 288, 131,
       109,  82, 163, 210,  56, 175,  53, 236,  87,  29,  12, 107, 255,
       121, 141, 290, 217, 129, 260, 289,  92, 179,  40, 202, 14

In [49]:
train_end = int(len(ids) * 0.8)
test_end = train_end + int(len(ids) * 0.1)

! rm -fr train.db test.db val.db

train = connect('train.db')
test = connect('test.db')
val = connect('val.db')

for _id in ids[0:train_end]:
    row = db.get(id=int(_id))
    train.write(row.toatoms())
    
for _id in ids[train_end:test_end]:
    row = db.get(id=int(_id))
    test.write(row.toatoms())
    
for _id in ids[test_end:]:
    row = db.get(id=int(_id))
    val.write(row.toatoms())
    