In [3]:
import torch

In [4]:
charge_dict ={
    
    'Ag': 47, 'Eu': 63,'Al': 13, 'As': 33, 'Au': 79, 'B': 5, 'Ba': 56, 'Be': 4, 'Br': 35, 'C': 6, 
    'Ca': 20, 'Cd': 48, 'Ce': 58, 'Cl': 17, 'Co': 27, 'Cr': 24, 'Cs': 55, 'Cu': 29, 'F': 9, 
    'Fe': 26, 'Ga': 31, 'Gd': 64, 'Ge': 32, 'H': 1, 'He': 2, 'Hg': 80, 'I': 53, 'Ir': 77, 
    'K': 19, 'La': 57, 'Li': 3, 'Mg': 12, 'Mn': 25, 'Mo': 42, 'N': 7, 'Na': 11, 'Nb': 41, 
    'Ni': 28, 'O': 8, 'Os': 76, 'P': 15, 'Pb': 82, 'Pd': 46, 'Po': 84, 'Pr': 59, 'Pt': 78, 
    'Rb': 37, 'Ru': 44, 'S': 16, 'Sb': 51, 'Sc': 21, 'Se': 34, 'Si': 14, 'Sn': 50, 'Te': 52, 
    'Ti': 22, 'Tl': 81, 'V': 23, 'W': 74, 'Xe': 54, 'Y': 39, 'Yb': 70, 'Zn': 30, 'Zr': 40
}


In [5]:


def process_xyz(datafile):
    """
    Read xyz file and return a molecular dict with number of atoms, energy, forces, coordinates and atom-type for the gdb9 dataset.

    Parameters
    ----------
    datafile : python file object
        File object containing the molecular data in the MD17 dataset.

    Returns
    -------
    molecule : dict
        Dictionary containing the molecular properties of the associated file object.

    Notes
    -----
    TODO : Replace breakpoint with a more informative failure?
    """
    xyz_lines = [line.decode('UTF-8') for line in datafile.readlines()]

    num_atoms = int(xyz_lines[0])
    #mol_props = xyz_lines[1].split()
    mol_xyz = xyz_lines[2:num_atoms+2]
   

    atom_charges, atom_positions = [], []
    for line in mol_xyz:
        atom, posx, posy, posz = line.replace('*^', 'e').split()
        atom_charges.append(charge_dict[atom])
        atom_positions.append([float(posx), float(posy), float(posz)])

    #prop_strings = ['tag', 'index', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv']
    #prop_strings = prop_strings[1:]
    #mol_props = [int(mol_props[1])] + [float(x) for x in mol_props[2:]]
    #mol_props = dict(zip(prop_strings, mol_props))
    #mol_props['omega1'] = max(float(omega) for omega in mol_freq.split())

    molecule = {'num_atoms': num_atoms, 'charges': atom_charges, 'positions': atom_positions}
    #molecule.update(mol_props)
    molecule = {key: torch.tensor(val) for key, val in molecule.items()}

    return molecule

In [6]:
filename = '/Users/shangchao/Desktop/ChEBI_data/train/5280437.xyz'
# Open the file and process it
with open(filename, 'rb') as datafile:
    molecule_data = process_xyz(datafile)

print(molecule_data)

{'num_atoms': tensor(52), 'charges': tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1]), 'positions': tensor([[ 4.4977e+00, -1.1068e+00, -1.4608e+00],
        [ 3.3103e+00, -1.0918e+00, -5.3256e-01],
        [ 2.2205e+00, -1.8143e+00, -8.7572e-01],
        [ 9.7637e-01, -1.9381e+00, -2.9157e-02],
        [-1.9750e-01, -2.4663e+00, -8.5794e-01],
        [-1.4260e+00, -2.6507e+00,  7.3070e-03],
        [-2.3940e+00, -1.7156e+00,  1.1597e-01],
        [-2.3962e+00, -4.0498e-01, -6.3813e-01],
        [-1.5804e+00,  7.3119e-01,  6.1400e-04],
        [-3.1180e+00,  7.6852e-01,  4.5537e-02],
        [-3.7193e+00,  5.9215e-01,  1.4497e+00],
        [-3.8862e+00,  1.6986e+00, -9.0721e-01],
        [-9.1120e-01,  1.7459e+00, -9.5403e-01],
        [ 6.1113e-01,  1.8254e+00, -7.5425e-01],
        [ 1.0624e+00,  2.3093e+00,  6.2192e-01],
        [ 2.3104e+00,  2.0538e+00

In [19]:
import os

def process_xyz_gdb9(datafile):
    xyz_lines = [line.decode('UTF-8') for line in datafile.readlines()]
    num_atoms = int(xyz_lines[0])
    mol_xyz = xyz_lines[2:num_atoms+2]

    atom_types = set()
    for line in mol_xyz:
        atom, _, _, _ = line.replace('*^', 'e').split()
        atom_types.add(atom)
    return atom_types

def process_all_xyz_in_directory(directory_path):
    all_atom_types = set()
    for filename in os.listdir(directory_path):
        if filename.endswith(".xyz"):
            with open(os.path.join(directory_path, filename), 'rb') as datafile:
                atom_types = process_xyz_gdb9(datafile)
                all_atom_types.update(atom_types)
    return all_atom_types

# Example usage
directory_path = '/Users/shangchao/Desktop/ChEBI_data/train'
unique_atom_types = process_all_xyz_in_directory(directory_path)
print("Unique atom types in the directory:", unique_atom_types)


Unique atom types in the directory: {'As', 'Sc', 'P', 'Mg', 'Cr', 'Al', 'Zn', 'Na', 'Sn', 'Ga', 'Nb', 'Ca', 'W', 'V', 'Si', 'Ti', 'B', 'Ir', 'Ru', 'Os', 'Ge', 'O', 'Pt', 'Pd', 'Tl', 'Cs', 'S', 'Be', 'N', 'K', 'C', 'Sb', 'Cl', 'Pr', 'Y', 'Ce', 'Te', 'Li', 'Co', 'Se', 'He', 'Cd', 'Fe', 'Ba', 'I', 'H', 'Mn', 'Gd', 'Xe', 'Mo', 'La', 'Rb', 'Ag', 'Yb', 'Au', 'Hg', 'Ni', 'Eu', 'Po', 'Zr', 'F', 'Cu', 'Br', 'Pb'}


In [6]:
import os

# Assuming process_xyz and PreprocessQM9 are defined elsewhere
# ...

source_directory = '/Users/shangchao/Desktop/ChEBI_data/train'
output_file = '/Users/shangchao/Desktop/ChEBI_data/train/processed_xyz2tensor.txt'

# Initialize the list to hold all molecule dictionaries
molecule_data_list = []

with open(output_file, 'w') as out_file:
    for filename in os.listdir(source_directory):
        if filename.endswith('.xyz'):
            file_path = os.path.join(source_directory, filename)
            with open(file_path, 'rb') as datafile:
                molecule_data = process_xyz(datafile)
                molecule_data_list.append(molecule_data)

                # Write individual molecule data to file (as you did before)
                out_file.write(f"Filename: {filename}\n")
                out_file.write(f"Number of Atoms: {molecule_data['num_atoms'].item()}\n")
                charges = molecule_data['charges'].tolist()
                positions = molecule_data['positions'].tolist()
                out_file.write("Charges: " + ', '.join(map(str, charges)) + '\n')
                out_file.write("Positions:\n")
                for pos in positions:
                    out_file.write("    " + ', '.join(map(str, pos)) + '\n')
                out_file.write("\n")

# Now process the list of molecule dictionaries as a batch
preprocessor = PreprocessQM9(load_charges=True)
batch_processed_data = preprocessor.collate_fn(molecule_data_list)

print(f"Processed data saved to {output_file}")


NameError: name 'PreprocessQM9' is not defined

In [7]:
import torch


def batch_stack(props):
    """
    Stack a list of torch.tensors so they are padded to the size of the
    largest tensor along each axis.

    Parameters
    ----------
    props : list of Pytorch Tensors
        Pytorch tensors to stack

    Returns
    -------
    props : Pytorch tensor
        Stacked pytorch tensor.

    Notes
    -----
    TODO : Review whether the behavior when elements are not tensors is safe.
    """
    if not torch.is_tensor(props[0]):
        return torch.tensor(props)
    elif props[0].dim() == 0:
        return torch.stack(props)
    else:
        return torch.nn.utils.rnn.pad_sequence(props, batch_first=True, padding_value=0)


def drop_zeros(props, to_keep):
    """
    Function to drop zeros from batches when the entire dataset is padded to the largest molecule size.

    Parameters
    ----------
    props : Pytorch tensor
        Full Dataset


    Returns
    -------
    props : Pytorch tensor
        The dataset with  only the retained information.

    Notes
    -----
    TODO : Review whether the behavior when elements are not tensors is safe.
    """
    if not torch.is_tensor(props[0]):
        return props
    elif props[0].dim() == 0:
        return props
    else:
        return props[:, to_keep, ...]


class PreprocessQM9:
    def __init__(self, load_charges=True):
        self.load_charges = load_charges

    def add_trick(self, trick):
        self.tricks.append(trick)

    def collate_fn(self, batch):
        """
        Collation function that collates datapoints into the batch format for cormorant

        Parameters
        ----------
        batch : list of datapoints
            The data to be collated.

        Returns
        -------
        batch : dict of Pytorch tensors
            The collated data.
        """
        batch = {prop: batch_stack([mol[prop] for mol in batch]) for prop in batch[0].keys()}

        to_keep = (batch['charges'].sum(0) > 0)

        batch = {key: drop_zeros(prop, to_keep) for key, prop in batch.items()}

        atom_mask = batch['charges'] > 0
        batch['atom_mask'] = atom_mask

        #Obtain edges
        batch_size, n_nodes = atom_mask.size()
        edge_mask = atom_mask.unsqueeze(1) * atom_mask.unsqueeze(2)

        #mask diagonal
        diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0)
        edge_mask *= diag_mask

        #edge_mask = atom_mask.unsqueeze(1) * atom_mask.unsqueeze(2)
        batch['edge_mask'] = edge_mask.view(batch_size * n_nodes * n_nodes, 1)

        if self.load_charges:
            batch['charges'] = batch['charges'].unsqueeze(2)
        else:
            batch['charges'] = torch.zeros(0)
        return batch

In [8]:
preprocessor = PreprocessQM9(load_charges=True)
processed_data = preprocessor.collate_fn(molecule_data_list)
print(processed_data)

{'num_atoms': tensor([ 45,  67, 124,  ...,   3,  46,  67]), 'charges': tensor([[[ 6],
         [ 7],
         [ 6],
         ...,
         [ 0],
         [ 0],
         [ 0]],

        [[ 6],
         [ 6],
         [ 6],
         ...,
         [ 0],
         [ 0],
         [ 0]],

        [[ 6],
         [ 6],
         [ 6],
         ...,
         [ 0],
         [ 0],
         [ 0]],

        ...,

        [[16],
         [ 1],
         [ 1],
         ...,
         [ 0],
         [ 0],
         [ 0]],

        [[ 6],
         [ 6],
         [ 6],
         ...,
         [ 0],
         [ 0],
         [ 0]],

        [[ 6],
         [ 6],
         [ 6],
         ...,
         [ 0],
         [ 0],
         [ 0]]]), 'positions': tensor([[[-4.0599e+00,  4.1449e-01,  1.2710e-01],
         [-2.7309e+00,  1.0504e+00,  2.6010e-01],
         [-2.1243e+00,  1.3078e+00, -1.0726e+00],
         ...,
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
 

In [14]:
# Define the path for the output file
output_file_path = '/Users/shangchao/Desktop/ChEBI_data/train/preprocessed_tensorexample.txt'

# Write the first ten items of each tensor to the file
with open(output_file_path, 'w') as file:
    for key, value in processed_data.items():
        if torch.is_tensor(value):
            # Write the first ten items of the tensor
            file.write(f"First ten items of '{key}': {value[:10].tolist()}\n")
        else:
            # Handle non-tensor data, if any
            file.write(f"First ten items of '{key}': {value[:10]} (Non-tensor data)\n")

print(f"Data written to {output_file_path}")



Data written to /Users/shangchao/Desktop/ChEBI_data/train/preprocessed_tensorexample.txt


In [11]:
for key, tensor in processed_data.items():
    if torch.is_tensor(tensor):
        print(f"Size of tensor '{key}': {tensor.size()}")

        # If you want to print the length of a specific dimension, for example, the first dimension:
        print(f"Length of the first dimension of tensor '{key}': {tensor.size(0)}")


Size of tensor 'num_atoms': torch.Size([24708])
Length of the first dimension of tensor 'num_atoms': 24708
Size of tensor 'charges': torch.Size([24708, 806, 1])
Length of the first dimension of tensor 'charges': 24708
Size of tensor 'positions': torch.Size([24708, 806, 3])
Length of the first dimension of tensor 'positions': 24708
Size of tensor 'atom_mask': torch.Size([24708, 806])
Length of the first dimension of tensor 'atom_mask': 24708
Size of tensor 'edge_mask': torch.Size([16051206288, 1])
Length of the first dimension of tensor 'edge_mask': 16051206288
