In [1]:
from pymatgen.entries.computed_entries import ComputedStructureEntry
from ocpmodels.datasets import LmdbDataset
import sys,os
import json
from ase import Atoms
import ase
import numpy as np
from pymatgen.io.ase import AseAtomsAdaptor
from ase.visualize import view
from copy import deepcopy
import torch
import lmdb
import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sys.path.append('C:/Users/lhuang37/Desktop/Shell_repo/')

In [3]:
from structure_generation.lmdb_generator import generate_lmdb, convert_atoms_data

In [4]:
with open('Data/bulk_oxides_20220621.json') as f:
    bulks=json.load(f)
bulk_oxides_dict = {entry['entry_id']: ComputedStructureEntry.from_dict(entry) \
                    for entry in bulks}

In [5]:
mplist=[bulks[i]['entry_id'] for i in range(len(bulks))]

# Classes

In [6]:

class Edgedataset():
    # This calss is used to create a new dataset with the same input as OC20, based on the information from OC22 dataset
    # The input should be the data object and the output is still data object.
    # Add the edge_index and atoms' distances
    
    def __init__(
        self,
        max_neigh=200,
        radius=6,
        r_energy=False,
        r_forces=False,
        r_distances=False,
        r_edges=True,
        r_fixed=True,
    ):
        self.max_neigh = max_neigh
        self.radius = radius
        self.r_energy = r_energy
        self.r_forces = r_forces
        self.r_distances = r_distances
        self.r_fixed = r_fixed
        self.r_edges = r_edges
    
    def get_neighbors_pymatgen(self, atoms):
        """Preforms nearest neighbor search and returns edge index, distances,
        and cell offsets"""
        struct = AseAtomsAdaptor.get_structure(atoms)
        _c_index, _n_index, _offsets, n_distance = struct.get_neighbor_list(
            r=self.radius, numerical_tol=0, exclude_self=True
        )

        _nonmax_idx = []
        for i in range(len(atoms)):
            idx_i = (_c_index == i).nonzero()[0]
            # sort neighbors by distance, remove edges larger than max_neighbors
            idx_sorted = np.argsort(n_distance[idx_i])[: self.max_neigh]
            _nonmax_idx.append(idx_i[idx_sorted])
        _nonmax_idx = np.concatenate(_nonmax_idx)

        _c_index = _c_index[_nonmax_idx]
        _n_index = _n_index[_nonmax_idx]
        n_distance = n_distance[_nonmax_idx]
        _offsets = _offsets[_nonmax_idx]

        return _c_index, _n_index, n_distance, _offsets
    def reshape_features(self, c_index, n_index, n_distance, offsets):
        """Stack center and neighbor index and reshapes distances,
        takes in np.arrays and returns torch tensors"""
        edge_index = torch.LongTensor(np.vstack((n_index, c_index)))
        edge_distances = torch.FloatTensor(n_distance)
        cell_offsets = torch.LongTensor(offsets)

        # remove distances smaller than a tolerance ~ 0. The small tolerance is
        # needed to correct for pymatgen's neighbor_list returning self atoms
        # in a few edge cases.
        nonzero = torch.where(edge_distances >= 1e-8)[0]
        edge_index = edge_index[:, nonzero]
        edge_distances = edge_distances[nonzero]
        cell_offsets = cell_offsets[nonzero]

        return edge_index, edge_distances, cell_offsets
    
    def convert(
            self,
            obj,
        ):
            """Convert a single atomic stucture to a graph.
            Args:
                atoms (ase.atoms.Atoms): An ASE atoms object.
            Returns:
                data (torch_geometric.data.Data): A torch geometic data object with edge_index, positions, atomic_numbers,
                and optionally, energy, forces, and distances.
                Optional properties can included by setting r_property=True when constructing the class.
            """

            # put the minimum data in torch geometric data object
            atoms=Atoms(obj.atomic_numbers,positions=obj.pos,tags=obj.tags,cell=obj.cell.squeeze(),pbc=True)

            # optionally include other properties
            if self.r_edges:
                # run internal functions to get padded indices and distances
                split_idx_dist = self.get_neighbors_pymatgen(atoms)
                edge_index, edge_distances, cell_offsets = self.reshape_features(
                    *split_idx_dist
                )

                obj.edge_index = edge_index
                obj.cell_offsets = cell_offsets
            if self.r_energy:
                energy = atoms.get_potential_energy(apply_constraint=False)
                atoms.y = energy
            if self.r_forces:
                forces = torch.Tensor(atoms.get_forces(apply_constraint=False))
                atoms.force = forces
            if self.r_distances and self.r_edges:
                obj.distances = edge_distances
            if self.r_fixed:
                fixed_idx = torch.zeros(natoms)
                if hasattr(atoms, "constraints"):
                    from ase.constraints import FixAtoms

                    for constraint in atoms.constraints:
                        if isinstance(constraint, FixAtoms):
                            fixed_idx[constraint.index] = 1
                atoms.fixed = fixed_idx

            return obj
        
    def convert_all(self,dataobj):
        data_objects=[]
        for idx,system in enumerate(dataobj):
            data=deepcopy(system)
            data_objects.append(data)

        for i in range(len(data_objects)):
            data_iter=data_objects[i]
            data_objects[i]=self.convert(data_iter)
        return data_objects

# next

In [7]:
# pp=AseAtomsAdaptor.get_atoms(bulk_oxides_dict['mp-1219547'].structure)
# data=convert_atoms_data(pp)
# data.y=bulk_oxides_dict['mp-1219547'].energy

In [8]:
a2g = Edgedataset(
    max_neigh=50,
    radius=6,
    r_energy=False,    # False for test data
    r_forces=False,
    r_distances=True,
    r_fixed=False,
)

In [9]:
datas=[]
pathname='./Data/bulks_22.lmdb'
for i in mplist:
    pp=AseAtomsAdaptor.get_atoms(bulk_oxides_dict[i].structure)
    data=convert_atoms_data(pp)
    data.y=bulk_oxides_dict[i].energy
    #new_data=a2g.convert(data)
    datas.append(data)
    if len(datas)>50:
        generate_lmdb(datas,pathname)
        datas=[]
generate_lmdb(datas,pathname)

In [10]:
idx=np.random.randint(0,4732,int(4732*0.15))

In [11]:
seeit=LmdbDataset({"src":"./Data/bulks_22.lmdb"})

In [12]:
def generate_lmdb_single(data: object, pathname: str):
    """
    atoms_list:: Can be either a list of atoms objects or list of Data objects
    """
    
    pathname = pathname + '.lmdb' if '.lmdb' not in pathname else pathname
    db = lmdb.open(
        pathname,
        map_size=1099511627 * 4,
        subdir=False,
        meminit=False,
        map_async=True,
    )
    
    for i in range(len(data)):
        
        txn = db.begin(write=True)        
        length=txn.stat()['entries']        
        txn.put(f"{length}".encode('ascii'), pickle.dumps(data, protocol=0))
        txn.commit()
        db.sync()
    db.close()

In [13]:
for i in range(len(seeit)):
    if i in idx:
        generate_lmdb_single(seeit[i],'./Data/bulk_val.lmdb')
    else:
        generate_lmdb_single(seeit[i],'./Data/bulk_train.lmdb')