In [None]:
import pandas as pd
import numpy as np
import pickle
import os
import torch
from molSimplify.Classes.mol3D import mol3D
from molSimplify.Classes.ligand import ligand_breakdown,get_lig_symmetry,ligand
from torch_geometric.data import Data

In [None]:
GEOM_ATOM2IDX = {'H':0,'C': 1, 'N': 2, 'O': 3, 'S': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'F': 8,'metal':9}

GEOM_CHARGES = {'H':1,'C': 6, 'O': 8, 'N': 7, 'S': 16, 'Cl': 17, 'P': 15, 'Br': 35, 'F': 9}
metals={'Ru':44,'Pt':78,'Pd':46}
data=[]
l_num=0
for filename in os.listdir('../xyz'):
    with open(f'../xyz/{filename}','r+') as f:
        label=filename[:-4]
        ele=[]
        pos=[]
        nuclear_charges=[]
        lines=f.readlines()
        num_atoms=int(lines[0])
        #total_charge=float(lines[1].split()[0])
        one_hot=torch.zeros(num_atoms,10)
        metal=lines[2].split()[0]
        ele.append(GEOM_ATOM2IDX['metal'])
        nuclear_charges.append(metals[metal])
        pos.append([float(j) for j in lines[2].split()[1:]])
        for i in lines[3:]:
            ele.append(GEOM_ATOM2IDX[i.split()[0]])
            nuclear_charges.append(GEOM_CHARGES[i.split()[0]])
            pos.append([float(j) for j in i.split()[1:]])
        one_hot[range(len(ele)),ele]=1
        one_hot=one_hot[:,:-1]
        pos=torch.tensor(pos)
        nuclear_charges=torch.tensor(nuclear_charges)
    my_mol=mol3D()
    my_mol.readfromxyz(f'../xyz/{filename}')
    liglist,ligdents,ligcon=ligand_breakdown(my_mol,silent=True,BondedOct=True)
    f_group=torch.zeros(num_atoms)
    for i in range(len(liglist)):
        f_group[liglist[i]]=i+1   

    ligand_group=torch.zeros((num_atoms,7) )
    ligand_group[range(len(f_group.long())),f_group.long()]=1

    anchor_group=torch.zeros(num_atoms)
    for i in range(len(ligcon)):
        anchor_group[ligcon[i]]=i+1
    anchors_group=torch.zeros((num_atoms,7) )
    anchors_group[range(len(anchor_group.long())),anchor_group.long()]=1
    for k in range(len(liglist)):
        anchors=torch.zeros(num_atoms)
        ligand=torch.zeros(num_atoms)
        for i in ligcon[k]:
            anchors[i]=1
        for i in liglist[k]:
            ligand[i]=1
    
        dicts={'label':label,'natoms':num_atoms,'one_hot':one_hot,'pos':pos,'nuclear_charges':nuclear_charges,'anchors':anchors,'ligand_diff':ligand,'ligand_group':ligand_group[:,1:],'anchor_group':anchors_group[:,1:]}
        data.append(dicts)

In [None]:
data_list=[]
for i in data:
    positions = i['pos']
    label = i['label']
    one_hot = i['one_hot']
    context = 1-i['ligand_diff']
    nuclear_charges =i['charges']
    coord_site = i['anchors']
    ligand_diff = i['ligand_diff']
    num_atoms = i['num_atoms']
    ligand_group= i['ligand_group']
    single_TMC = Data(pos=positions,label=label,  context=context,  nuclear_charges=nuclear_charges, coord_site=coord_site, ligand_diff=ligand_diff, num_atoms=num_atoms, one_hot=one_hot, ligand_group=ligand_group)
    data_list.append(single_TMC)

torch.save(data_list,'ppr.pt')