# Turning some mochas into periodic graphs!
Adopted from https://docs.e3nn.org/en/latest/guide/periodic_boundary_conditions.html

## Importing libraries and shihh

In [2]:
import torch
from ase.io import read
import ase.neighborlist
import torch_geometric
import torch_geometric.data
import os

default_dtype = torch.float64
torch.set_default_dtype(default_dtype)

## Getting the MOChas from the directory thingy

In [7]:
# folder where MOChas are stored
root = "/Users/adrianaladera/Desktop/MIT/research/mochas/ALL_STRUCTURES/mochas_opendac/mochas_with_displacements"
type_encoding = {'Ag': 0, 'S': 1, 'C': 2, 'H': 3, 'O': 4, 'Cu': 5, 'N': 6, 'Se': 7, 'Na': 8, 'Te': 9}
atomic_symbol = 0
radial_cutoff = 3.5  # Only include edges for neighboring atoms within a radius of 3.5 Angstroms
type_onehot = torch.eye(len(type_encoding))
dataset = []

for mocha in os.listdir(root):
    for file in os.listdir(f"{root}/{mocha}"):
        if file.endswith(".cif") or file.endswith(".vasp"):
            print(mocha, file)
            structure = read(f"{root}/{mocha}/{file}") # getting structure from cif file

                # for species in structure.get_chemical_symbols():
                #     if species not in type_encoding.keys():
                #         # print(species)
                #         type_encoding[species] = atomic_symbol
                #         atomic_symbol += 1

            edge_src, edge_dst, edge_shift = ase.neighborlist.neighbor_list("ijS", a=structure, cutoff=radial_cutoff, self_interaction=True)

            data = torch_geometric.data.Data(
                    pos=torch.tensor(structure.get_positions()),
                    lattice=torch.tensor(structure.cell.array).unsqueeze(0),  # We add a dimension for batching
                    x=type_onehot[[type_encoding[atom] for atom in structure.symbols]],
                    edge_index=torch.stack([torch.LongTensor(edge_src), torch.LongTensor(edge_dst)], dim=0),
                    edge_shift=torch.tensor(edge_shift, dtype=default_dtype)
            )
            dataset.append(data)

print(dataset)

tethrene_2D tethrene_2D-003.vasp
tethrene_2D tethrene_2D-054.vasp
tethrene_2D tethrene_2D-042.vasp
tethrene_2D tethrene_2D-015.vasp
tethrene_2D tethrene_2D-039.vasp
tethrene_2D tethrene_2D-078.vasp
tethrene_2D tethrene_2D-058.vasp
tethrene_2D tethrene_2D-019.vasp
tethrene_2D tethrene_2D-062.vasp
tethrene_2D tethrene_2D-035.vasp
tethrene_2D tethrene_2D-023.vasp
tethrene_2D tethrene_2D-074.vasp
tethrene_2D tethrene_2D-075.vasp
tethrene_2D tethrene_2D-022.vasp
tethrene_2D tethrene_2D-034.vasp
tethrene_2D tethrene_2D-063.vasp
tethrene_2D tethrene_2D-018.vasp
tethrene_2D tethrene_2D-059.vasp
tethrene_2D tethrene_2D-038.vasp
tethrene_2D tethrene_2D-014.vasp
tethrene_2D tethrene_2D-043.vasp
tethrene_2D tethrene_2D-055.vasp
tethrene_2D tethrene_2D-002.vasp
tethrene_2D tethrene_2D-009.vasp
tethrene_2D tethrene_2D-048.vasp
tethrene_2D tethrene_2D-064.vasp
tethrene_2D tethrene_2D-033.vasp
tethrene_2D tethrene_2D-025.vasp
tethrene_2D tethrene_2D-072.vasp
tethrene_2D tethrene_2D-005.vasp
tethrene_2

: 

: 

## Processing relative distance vectors of edges with periodic boundaries

In [20]:
def get_relative_distance_vectors(data):
    '''data - a torch_geometric.data.Data object'''
    edge_src, edge_dst = data['edge_index'][0], data['edge_index'][1]
    edge_vec = (data['pos'][edge_dst] - data['pos'][edge_src]
            + torch.einsum('ni,nij->nj', data['edge_shift'], data['lattice']))
    
    return edge_src, edge_dst, edge_vec