# Turning some mochas into periodic graphs!
Adopted from https://docs.e3nn.org/en/latest/guide/periodic_boundary_conditions.html

## Importing libraries and shihh

In [17]:
import torch
from ase.io import read
import ase.neighborlist
import torch_geometric
import torch_geometric.data
import os

default_dtype = torch.float64
torch.set_default_dtype(default_dtype)

## Getting the MOChas from the directory thingy

In [19]:
# folder where MOChas are stored
root = "/Users/adrianaladera/Desktop/MIT/research/mochas/ALL_STRUCTURES/mochas_opendac/"
type_encoding = {'Ag': 0, 'S': 1, 'C': 2, 'H': 3, 'O': 4, 'Cu': 5, 'N': 6, 'Se': 7, 'Na': 8, 'Te': 9}
atomic_symbol = 0
radial_cutoff = 3.5  # Only include edges for neighboring atoms within a radius of 3.5 Angstroms
type_onehot = torch.eye(len(type_encoding))
dataset = []

for mocha in os.listdir(root):
    if mocha.endswith(".cif") or mocha.endswith(".vasp"):
        structure = read(f"{root}/{mocha}") # getting structure from cif file

        # for species in structure.get_chemical_symbols():
        #     if species not in type_encoding.keys():
        #         # print(species)
        #         type_encoding[species] = atomic_symbol
        #         atomic_symbol += 1

        edge_src, edge_dst, edge_shift = ase.neighborlist.neighbor_list("ijS", a=structure, cutoff=radial_cutoff, self_interaction=True)

        data = torch_geometric.data.Data(
            pos=torch.tensor(structure.get_positions()),
            lattice=torch.tensor(structure.cell.array).unsqueeze(0),  # We add a dimension for batching
            x=type_onehot[[type_encoding[atom] for atom in structure.symbols]],
            edge_index=torch.stack([torch.LongTensor(edge_src), torch.LongTensor(edge_dst)], dim=0),
            edge_shift=torch.tensor(edge_shift, dtype=default_dtype)
        )
        dataset.append(data)

print(dataset)

[Data(x=[120, 10], edge_index=[2, 1936], pos=[120, 3], lattice=[1, 3, 3], edge_shift=[1936, 3]), Data(x=[105, 10], edge_index=[2, 2129], pos=[105, 3], lattice=[1, 3, 3], edge_shift=[2129, 3]), Data(x=[136, 10], edge_index=[2, 2000], pos=[136, 3], lattice=[1, 3, 3], edge_shift=[2000, 3]), Data(x=[204, 10], edge_index=[2, 4224], pos=[204, 3], lattice=[1, 3, 3], edge_shift=[4224, 3]), Data(x=[68, 10], edge_index=[2, 946], pos=[68, 3], lattice=[1, 3, 3], edge_shift=[946, 3]), Data(x=[152, 10], edge_index=[2, 2276], pos=[152, 3], lattice=[1, 3, 3], edge_shift=[2276, 3]), Data(x=[210, 10], edge_index=[2, 4326], pos=[210, 3], lattice=[1, 3, 3], edge_shift=[4326, 3]), Data(x=[180, 10], edge_index=[2, 3114], pos=[180, 3], lattice=[1, 3, 3], edge_shift=[3114, 3]), Data(x=[240, 10], edge_index=[2, 4008], pos=[240, 3], lattice=[1, 3, 3], edge_shift=[4008, 3]), Data(x=[104, 10], edge_index=[2, 1568], pos=[104, 3], lattice=[1, 3, 3], edge_shift=[1568, 3]), Data(x=[272, 10], edge_index=[2, 5364], pos

## Processing relative distance vectors of edges with periodic boundaries

In [20]:
def get_relative_distance_vectors(data):
    '''data - a torch_geometric.data.Data object'''
    edge_src, edge_dst = data['edge_index'][0], data['edge_index'][1]
    edge_vec = (data['pos'][edge_dst] - data['pos'][edge_src]
            + torch.einsum('ni,nij->nj', data['edge_shift'], data['lattice']))
    
    return edge_src, edge_dst, edge_vec