# Convertion to Batch, and Save as HDF5 Dataset

In [None]:
%load_ext autoreload
%autoreload 2
import h5py
import os
import torch
import numpy as np
from tqdm import tqdm, trange
import ase
from e3nn import o3
from e3nn.o3 import Irreps
from e3_layers.data import Batch

## Hamiltonian From ASE DB to Batch

In [None]:
#!/usr/bin/env python3
import numpy as np
from base64 import b64decode
from ase.db import connect

db = connect('../wavefunc/schnorb_hamiltonian_water.db')

lst = []
for i, row in enumerate(db.select()):
    if i % 100 == 0:
        print(i)
    Z = row['numbers']
    R = row['positions']#*1.8897261258369282 donot convert angstrom to bohr
    E = row.data['energy']
    F = row.data['forces']
    H = row.data['hamiltonian'].reshape(-1)
    S = row.data['overlap']
    species = np.array([8, 1, 1], dtype=np.int32).reshape(-1, 1)
    lst += [ dict(pos=R, energy=E, forces=F, hamiltonian=H, species=species)]

In [None]:
import numpy as np
from e3_layers.data import Batch
path = 'h2o.hdf5'
attrs = {}
attrs['pos'] = ('node', '1x1o')
attrs['species'] = ('node', '1x0e')
attrs['energy'] = ('graph', '1x0e')
attrs['forces'] = ('node', '1x1o')
attrs['hamiltonian'] = ('graph', 24*24)

batch = Batch.from_data_list(lst, attrs)
batch.dumpHDF5(path)

## Multipole from HDF5 to Batch

In [None]:
from functools import lru_cache
import pickle


@lru_cache(maxsize=None)
def get_clebsch_gordon(i: int, j: int, k: int, device):
    return o3.wigner_3j(i, j, k, dtype=torch.float64, device=device).numpy()

def irreps2matrix(a, b, c):
    device = 'cpu'
    irreps = [a, b, c] # degree 0, 1, 2
    result = 0
    for i in range(3):
        basis = o3.wigner_3j(1, 1, i, device=device)
        result += basis@irreps[i]
    return result

def matrix2irreps(m):
    device = 'cpu'
    irreps = [] # degree 0, 1, 2
    for i in range(3):
        basis = o3.wigner_3j(1, 1, i, device=device)
        irreps += [np.einsum("ijc,ijk->kc", m, basis)]
    return irreps

tmp_path = "multipole.pickle"
if os.path.isfile(tmp_path):
    with open(tmp_path, "rb") as file:
        coord, species, dipoles, quadrupoles = pickle.load(file)
else:
    paths = ['multipole_gdb.hdf5', 'multipole_chembl.hdf5']
    coord = []
    species = []
    dipoles = []
    quadrupoles = []
    for path in paths:
        with h5py.File(path, "r") as f:
            for i, key in enumerate(tqdm(f.keys())):
                mol = f[key]
                tmp = mol['dipoles'][:]
                if tmp.shape[0]==0:
                    continue
                coord += [mol['coordinates'][:]]
                species += [mol['elements'][:]]
                dipoles += [tmp]
                quadrupoles += [mol['quadrupoles'][:]]

    with open(tmp_path, "wb") as f:
        pickle.dump([coord, species, dipoles, quadrupoles], f)
        
def reflect(x):
    size, _ = x.shape
    x = x.transpose(1, 0)
    y = np.zeros((3, 3, size))
    y[0,0] = x[0]
    y[0,1], y[1, 0]=x[1], x[1]
    y[0,2],y[2,0]=x[2], x[2]
    y[1,1]=x[3]
    y[1,2], y[2,1]=x[4], x[4]
    y[2, 2]=x[5]
    return y #[3, 3, n]

In [None]:
cnt = len(coord)
print(cnt)
table = ['H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne']
table += ['Na', 'Mg', 'Al', 'Si', 'P', 'S', 'CL', 'Ar']
symbol2idx = {table[i]:i+1 for i in range(len(table))}
lst = []
for i in tqdm(range(cnt)):
    cur_size = coord[i].shape[0]
    pos = coord[i]
    symbols = np.zeros((cur_size,), dtype=int)
    for j,item in enumerate(species[i]):
        symbols[j] = symbol2idx[item.decode("utf-8")]
    dipole = dipoles[i]
    result = matrix2irreps(reflect(quadrupoles[i]))
    quadrupole_0 = result[0].transpose(1, 0)
    quadrupole_2 = result[2].transpose(1, 0)
    lst.append({'pos': pos, 'atom_types': symbols, 'dipole': dipole, 'quadrupole_0': quadrupole_0, 'quadrupole_2':quadrupole_2})

In [None]:
from e3_layers.data import Batch
path = 'multipole.hdf5'
attrs = {}
attrs['pos'] = ('node', '1x1o')
attrs['atom_types'] = ('node', '1x0e')
attrs['dipole'] = ('node', '1x1o')
attrs['quadrupole_0'] = ('node', '1x0e')
attrs['quadrupole_2'] = ('node', '1x2e')

batch = Batch.from_data_list(lst, attrs)
batch.dumpHDF5(path)

## QM9 from npz to Batch

In [None]:
npz = np.load('qm9_edge.npz')

In [None]:
from e3_layers.data import Batch
path = 'qm9.hdf5'
attrs = {}

attrs['R'] = ('node', '1x1o')
attrs['Z'] = ('node', '1x0e')
attrs['U0'] = ('graph', '1x0e')
attrs['U'] = ('graph', '1x0e')
attrs['_n_nodes'] = ('graph', '1x0e')

dic = {}
dic['Z'] = npz['node_attr'][:, 5].astype(np.int64)
dic['U0'] = npz['targets'][:, 7]
dic['U'] = npz['targets'][:, 8]
dic['R'] = npz['node_pos']
dic['_n_nodes'] = npz['n_node']

batch = Batch(attrs, **dic)
batch.dumpHDF5(path)

## From Padded HDF5 to Batch

In [None]:
path  = 'proteintopo.hdf5'
import h5py
f = h5py.File(path, "r")

In [None]:
cnt = f['coord'].shape[0]
lst = []
for i in range(cnt):
    dic = {}
    n_nodes = sum(f['species'][i]>0)
    dic['coord'] = f['coord'][i, :n_nodes]
    dic['dipole'] = f['dipole'][i, :n_nodes]
    dic['species'] = f['species'][i, :n_nodes]
    dic['quadrupole_2'] = f['quadrupole_2'][i, :n_nodes]
    n_edges = sum(f['edge_indexs'][i, 0]>-1)
    dic['bond_orders'] = f['bond_orders'][i, :n_edges]
    dic['edge_index'] = f['edge_indexs'][i, :, :n_edges]
    lst.append(dic)

In [None]:
attrs = {}
attrs['coord'] = ('node', '1x1o')
attrs['dipole'] = ('node', '1x1o')
attrs['species'] = ('node', '1x0e')
attrs['quadrupole_2'] = ('node', '1x2e')
attrs['bond_orders'] = ('edge', '1x0e')

In [None]:
dic['edge_index'].shape

In [None]:
batch = Batch.from_data_list(lst, attrs)
batch.dumpHDF5('protein_topo.hdf5')

## From PDB to Batch 
As a minimal approximation, only considers C-alpha atoms.

In [None]:
aa_names = {'CYS': 'C', 'ASP': 'D', 'SER': 'S', 'GLN': 'Q', 'LYS': 'K',
     'ILE': 'I', 'PRO': 'P', 'THR': 'T', 'PHE': 'F', 'ASN': 'N', 
     'GLY': 'G', 'HIS': 'H', 'LEU': 'L', 'ARG': 'R', 'TRP': 'W', 
     'ALA': 'A', 'VAL':'V', 'GLU': 'E', 'TYR': 'Y', 'MET': 'M'}
aa_ids = {key:i for i, key in enumerate(aa_names.keys())}
def name2id(x):
    return aa_ids[x] + 1

In [None]:
path = 'sampling_result'
proteins = []
for root, dirs, files in tqdm(os.walk(path)):
    for file in files:
        if not file.split('.')[-1] == 'pdb':
            continue
        with open(os.path.join(root, file)) as file:
            lines = file.readlines()
            cnt = 0
            aa_types = []
            coords = []
            for line in lines:
                if len(line)>=20 and line[13:15] == 'CA':
                    aa_type = line[17:20]
                    x, y, z = line[30:38], line[38:46], line[46:54]
                    aa_types.append(name2id(aa_type))
                    coords.append([float(x), float(y), float(z)])
                    cnt += 1
        file = {'_n_nodes': cnt, 'aa_type': np.array(aa_types), 'pos': np.array(coords)}
        proteins += [file]

In [None]:
path = 'antibody.hdf5'
attrs = {}
attrs['pos'] = ('node', '1x1o')
attrs['aa_type'] = ('node', '1x0e')
attrs['_n_nodes'] = ('graph', '1x0e')

batch = Batch.from_data_list(proteins, attrs)
batch.dumpHDF5(path)

In [None]:
len(proteins)

# Reading HDF5

In [None]:
path  = 'antibody.hdf5'
import h5py
f = h5py.File(path, "r")

In [None]:
f.keys()

In [None]:
f.close()

# Compute Statistics

In [None]:
from e3_layers.data import CondensedDataset

In [None]:
type_names = list(ase.atom.atomic_numbers.keys())[: 20]
ds = CondensedDataset('qm9.hdf5', type_names=type_names)

## Atom Reference Energy

In [None]:
ds.statistics(['energy-per-species-mean_std'], stride=10)

## Position std

In [None]:
from torch_runstats.scatter import scatter

In [None]:
def std(x):
    return ((x*x).sum(dim=1).mean(dim=0))**0.5

In [None]:
node_segment = ds.nodeSegment()
center = scatter(ds['R'], node_segment, dim=0, reduce='sum')
center = center/ds['_n_nodes']
pos = ds['R'] - center[node_segment]
std(pos)