In [1]:
from ase.db import connect
import numpy as np
import xarray as xr

In [2]:
db = connect('/nc/SHNITSEL-data/CH2NH2.db')

In [3]:
row0 = next(db.select())

In [4]:
dir(row0)

['__annotations__',
 '__class__',
 '__contains__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setitem__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_constrained_forces',
 '_constraints',
 '_data',
 '_keys',
 'cell',
 'charge',
 'constrained_forces',
 'constraints',
 'count_atoms',
 'ctime',
 'data',
 'fmax',
 'formula',
 'get',
 'id',
 'key_value_pairs',
 'mass',
 'mtime',
 'natoms',
 'numbers',
 'pbc',
 'positions',
 'smax',
 'symbols',
 'toatoms',
 'unique_id',
 'user',
 'volume']

In [5]:
row0.data

{'energy': array([-94.66686484, -94.3702908 , -94.32309568]),
 'socs': array([0.]),
 'forces': array([[[-1.41413e-01, -4.19326e-03,  3.65943e-02],
         [ 3.57556e-02,  4.55518e-03,  9.76048e-02],
         [ 1.17359e-01, -7.12822e-04, -5.82614e-02],
         [-3.50216e-02, -7.77292e-04, -5.01839e-02],
         [-4.49923e-03, -2.02663e-03, -5.71260e-02],
         [ 2.78191e-02,  3.15483e-03,  3.13721e-02]],
 
        [[-1.28948e-01,  5.71755e-03,  5.44669e-02],
         [-4.55383e-03, -2.03714e-03,  8.83774e-02],
         [ 1.38570e-01, -1.74265e-03, -9.04250e-02],
         [-1.93396e-02, -1.43049e-03, -1.91034e-02],
         [ 8.62517e-03,  1.67778e-03, -2.54392e-02],
         [ 5.64659e-03, -2.18504e-03, -7.87679e-03]],
 
        [[-1.37716e-01,  2.72603e-03, -2.35016e-01],
         [ 3.17298e-02, -7.68109e-04,  3.72636e-01],
         [ 1.15724e-01, -2.09941e-03, -6.29040e-02],
         [-3.79241e-02, -3.39649e-04, -4.74694e-02],
         [-1.55435e-03, -1.09021e-03, -5.67115e-02],

In [6]:
keys = list(row0.data)
keys

['energy', 'socs', 'forces', 'nacs', 'dipoles']

In [7]:
# Assuming dimensions follow SchNarc order
shapes = {
    'energy': ['frame', 'state'],
    'socs': ['frame', 'soc'],
    'forces': ['frame', 'state', 'atom', 'direction'],
    'nacs': ['frame', 'statecomb', 'atom', 'direction'],
    'dipoles': ['frame', 'not_sure', 'direction'],
    # TODO Not sure what the second dipole dimension means.
    # 3 permanent dipoles + 3 transition dipoles?
}

In [8]:
def stack_rows(name):
    global db
    return np.stack([row.data[name] for row in db.select()])

energy = stack_rows('energy')
energy

array([[-94.66686484, -94.3702908 , -94.32309568],
       [-94.69368015, -94.38693814, -94.3592036 ],
       [-94.69062259, -94.3711848 , -94.34549879],
       ...,
       [-94.40452945, -94.37105986, -94.23179496],
       [-94.42711427, -94.31497935, -94.17673991],
       [-94.49998622, -94.34293763, -94.26442948]])

In [9]:
data_vars = {
    name: (dims, stack_rows(name))
    for name, dims in shapes.items()
}
data_vars

{'energy': (['frame', 'state'],
  array([[-94.66686484, -94.3702908 , -94.32309568],
         [-94.69368015, -94.38693814, -94.3592036 ],
         [-94.69062259, -94.3711848 , -94.34549879],
         ...,
         [-94.40452945, -94.37105986, -94.23179496],
         [-94.42711427, -94.31497935, -94.17673991],
         [-94.49998622, -94.34293763, -94.26442948]])),
 'socs': (['frame', 'soc'],
  array([[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]])),
 'forces': (['frame', 'state', 'atom', 'direction'],
  array([[[[-1.41413e-01, -4.19326e-03,  3.65943e-02],
           [ 3.57556e-02,  4.55518e-03,  9.76048e-02],
           [ 1.17359e-01, -7.12822e-04, -5.82614e-02],
           [-3.50216e-02, -7.77292e-04, -5.01839e-02],
           [-4.49923e-03, -2.02663e-03, -5.71260e-02],
           [ 2.78191e-02,  3.15483e-03,  3.13721e-02]],
  
          [[-1.28948e-01,  5.71755e-03,  5.44669e-02],
           [-4.55383e-03, -2.03714e-03,  8.83774e-02],
 

In [10]:
frames = xr.Dataset(data_vars)
frames

In [11]:
# In case I was right about the dipoles
dipoles = data_vars['dipoles'][1]
dip_perm = dipoles[:,:3,:]
dip_trans = dipoles[:,3:,:]
del(data_vars['dipoles'])

data_vars['dip_perm'] = (['frame', 'state', 'direction'], dip_perm)
data_vars['dip_trans'] = (['frame', 'statecomb', 'direction'], dip_trans)

In [12]:
frames = xr.Dataset(data_vars)
frames

In [13]:
frames.to_netcdf('/tmp/output.nc', engine='h5netcdf')