In [11]:
import ase.io
import numpy as np
import torch
import os
from pymatgen.io.ase import AseAtomsAdaptor
from ase.neighborlist import NeighborList, NewPrimitiveNeighborList, PrimitiveNeighborList

In [12]:
k = open("/checkpoint/mshuaibi/ocpdata_728k_raw.txt", "r")
ase_paths = k.read().splitlines()

In [13]:
images = ase.io.read("/checkpoint/sidgoyal/electro_done/random1623153.traj", ":")
image = images[768]

In [4]:
def get_neighbors_pymatgen(atoms, cutoff):
    """Preforms nearest neighbor search and returns split neighbors indices and distances"""
    struct = AseAtomsAdaptor.get_structure(atoms)
    # these return jagged arrays meaning certain atoms have more neighbors than others
    _c_index, _n_index, _offsets, n_distance = struct.get_neighbor_list(
        r=cutoff, exclude_self=True
    )
    
    return _c_index, _n_index, _offsets, n_distance

def _reshape_features(c_index, n_index, offsets, n_distance):
    edge_index = torch.LongTensor(np.vstack((c_index, n_index)))
    edge_distances = torch.FloatTensor(n_distance)
    cell_offsets = torch.LongTensor(offsets)
    
    # remove distances smaller than a tolerance~0. The small tolerance is introduced
    # to correct for pymatgen's neighbor list returning self atoms in a few edge cases
    nonzero = torch.nonzero(edge_distances >= 1e-8).flatten()
    edge_index = edge_index[:, nonzero]
    edge_distances = edge_distances[nonzero]
    cell_offsets = cell_offsets[nonzero]
    
    return edge_index, edge_distances, cell_offsets

def get_neighbors_ase(image, cutoff):
    n = NeighborList(cutoffs=[cutoff / 2.0] * len(image),
        self_interaction=False, skin=0, bothways=True,
        primitive=NewPrimitiveNeighborList
    )
    n.update(image)
    return [n.get_neighbors(index) for index in range(len(image))]

In [30]:
struct = AseAtomsAdaptor.get_structure(image)
_c_index, _n_index, _offsets, n_distance = struct.get_neighbor_list(
    r=6, exclude_self=True
)

_nonmax_idx = []
for i in range(len(image)):
    idx_i = (_c_index == i).nonzero()[0]
    # sort neighbors by distance, remove edges larger than max_neighbors
    idx_sorted = np.argsort(n_distance[idx_i])[: 5]
    _nonmax_idx.append(idx_i[idx_sorted])
_nonmax_idx = np.concatenate(_nonmax_idx)

_c_index1 = _c_index[_nonmax_idx]
_n_index1 = _n_index[_nonmax_idx]

In [31]:
_c_index[:17]

array([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int64)

In [32]:
_n_index[:17]

array([40,  8, 64, 46, 22, 72, 61, 41, 33,  9, 46, 22, 79, 65, 47, 23, 72],
      dtype=int64)

In [34]:
_c_index1[:17]

array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3], dtype=int64)

In [35]:
_n_index1[:17]

array([ 8, 64, 40, 61, 46,  9, 65, 41, 33, 79, 10, 39, 34, 53, 66, 11, 43],
      dtype=int64)

In [5]:
c, n, o, dist = get_neighbors_pymatgen(image, 6)
ed, dist, co = _reshape_features(c, n, o, dist)
ed = ed.numpy()
dist = dist.numpy()

	nonzero(Tensor input, *, Tensor out)
Consider using one of the following signatures instead:
	nonzero(Tensor input, *, bool as_tuple)


In [17]:
c

array([ 0,  0,  0, ..., 84, 84, 84], dtype=int64)

In [18]:
n

array([40,  8, 64, ..., 73, 29, 42], dtype=int64)

In [21]:
dist

array([3.9220517, 3.054488 , 3.857755 , ..., 5.830055 , 3.9825108,
       3.816854 ], dtype=float32)

In [7]:
max_neigh = 10
idx_to_keep = []
for i in range(len(image)):
    idx = (c==i).nonzero()[0][:max_neigh].tolist()
    idx_to_keep += idx

In [8]:
nonmax_idx = np.concatenate([(c==i).nonzero()[0][:max_neigh] for i in range(len(image))])

In [9]:
c[nonmax_idx].shape

(797,)

In [10]:
n[nonmax_idx].shape

(797,)

In [11]:
o[nonmax_idx].shape

(797, 3)

In [12]:
dist[nonmax_idx].shape

(797,)

In [None]:
source_idx = []
target_idx = []
ase_neighbors = get_neighbors_ase(image, 6)
neighbors = []
ase_offsets = []
for i,j in enumerate(ase_neighbors):
    source_idx += [i]*len(j[0])
    target_idx += j[0].tolist()
    neighbors.append(len(j[0]))
    ase_offsets.append(j[1])

source_idx = np.array(source_idx)
target_idx = np.array(target_idx)
ase_offsets = np.concatenate(np.array(ase_offsets))

In [None]:
np.dot(ase_neighbors[4][1], image.cell)

In [None]:
source_idx.shape

In [None]:
pos = torch.tensor(image.positions)
cell = torch.tensor(image.cell).view(1, 3, 3)
offset = torch.tensor(co)

In [None]:
source_idx = ed[0, :]
target_idx = ed[1, :]

In [None]:
distance_vectors = torch.tensor(pos[source_idx] - pos[target_idx])

# correct for pbc
cells = torch.repeat_interleave(cell, 1012, dim=0)
offsets = offset.float().view(-1, 1, 3).bmm(cells.float()).view(-1, 3)
distance_vectors -= offsets

# compute distances
distances = distance_vectors.norm(dim=-1)

In [1]:
from ocpmodels.common.utils import get_pbc_distances
from utils.dataset import get_ocpimage_dataset, data_list_collater
from torch.utils.data import DataLoader

dataset = get_ocpimage_dataset(src="/checkpoint/mshuaibi/ocpdata_reset_07_13_20/val/ocpdata_val_200kv4/")
loader = DataLoader(dataset, batch_size=200, shuffle=True, collate_fn=data_list_collater, num_workers=50)

data = next(iter(loader))

In [None]:
edge_index, edge_dist, edge_vec = get_pbc_distances(
    data.pos,
    data.edge_index,
    data.cell,
    data.cell_offsets,
    data.neighbors,
    6
)