In [2]:
import base64
import glob
import os.path as osp
import pickle
import random

import numpy as np

import networkx as nx

from sklearn.utils import shuffle

import torch
import torch.jit as jit
import torch.nn.functional as F
from torch_scatter import scatter

import torch_geometric as tg
from torch_geometric.data import Dataset, Data, DataLoader, InMemoryDataset, Batch
from torch_geometric.data.collate import collate
from torch_geometric.data.data import BaseData
from torch_geometric.data.separate import separate
from torch_geometric.utils import degree

from tqdm import tqdm

from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

from rdkit import Chem
from rdkit.Chem.rdchem import BondType as BT
from rdkit.Chem.rdchem import ChiralType
from rdkit.Chem.rdchem import HybridizationType
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device("cpu")
dihedral_pattern = Chem.MolFromSmarts('[*]~[*]~[*]~[*]')
chirality = {ChiralType.CHI_TETRAHEDRAL_CW: -1.,
             ChiralType.CHI_TETRAHEDRAL_CCW: 1.,
             ChiralType.CHI_UNSPECIFIED: 0,
             ChiralType.CHI_OTHER: 0}






class qm9_data(Dataset):
    def __init__(self, root: str, transform: Optional[Callable] = None,
                    pre_transform: Optional[Callable] = None,
                    pre_filter: Optional[Callable] = None):
        self.all_files = sorted(glob.glob(osp.join(root+'/qm9/', '*.pickle')))
        self.pickle_files = [f for i, f in enumerate(self.all_files)]
        super().__init__(root, transform, pre_transform, pre_filter)
        #self.process()
        #print(torch.load(self.processed_paths[0]))
        self.data, self.slices = torch.load(self.processed_paths[0])
    
    def process(self):
        data_list = []
        for i in tqdm(self.pickle_files):
            data = featurization(i)
            #print(data)  # Add this line to check the returned data from featurization
            if self.pre_filter is not None and not self.pre_filter(data):
                continue
            if self.pre_transform is not None:
                data = self.pre_transform(data)
            if data is not None:  # Updated the condition to use "is not None"
                if isinstance(data, list):
                    for i in data:
                        data_list.append(i)
                else:
                    data_list.append(data)
            

        processed_data, slices = self.collate(data_list)
        #d, slices = self.collate([data_list[0]])
        #print("slices is:", slices)
        torch.save((processed_data, slices), self.processed_paths[0])
        self.data, self.slices = processed_data, slices
        #print("DataList length is: " + str(len(data_list)))

        #print("Processed data file path: ", self.processed_paths[0])
        loaded_data, loaded_slices = torch.load(self.processed_paths[0])
        #print("Loaded data: ", loaded_data)
        #print("Loaded slices: ", loaded_slices)
        
    def get(self, idx: int) -> Data:
        """Gets the data object at index :obj:`idx`."""
        data = {}
        for key in self.slices.keys():
            start = self.slices[key][idx]
            end = self.slices[key][idx + 1] if idx < len(self.slices[key]) - 1 else None
            if key == 'edge_index':
                data[key] = self.data[key][:, start:end]
            elif isinstance(self.data[key], Tensor):
                data[key] = self.data[key][start:end]
            else:
                data[key] = self.data[key][idx]
        return Data(**data)
    
    def collate(self, data_list: List[Data]) -> Tuple[Data, Optional[Dict[str, torch.Tensor]]]:
        keys = ['x', 'z', 'pos', 'edge_index', 'edge_attr', 'chiral_tag', 'name', 'boltzmann_weight', 'degeneracy', 'mol', 'pos_mask']
        data_dict = {key: [] for key in keys}
        slices = {key: [] for key in keys}
        print('********************************************')
        for item in data_list:
            for key in keys:
                data = item[key]
                #print(key)
                #print(data)
                if isinstance(data, torch.Tensor):
                    if data.dim() == 0:  # Handle tensors with no dimensions
                        data = data.unsqueeze(0)
                    data_dict[key].append(data)
                elif isinstance(data, str):
                    data_dict[key].append(data)
                elif key == 'mol':
                    # Serialize RDKit molecule as bytes
                    #mol_bytes = pickle.dumps(data)
                    # Encode bytes as base64 string
                    #mol_str = base64.b64encode(mol_bytes).decode('utf-8')
                    data_dict[key].append(data)
                elif isinstance(data, float) or isinstance(data, int):
                    data_dict[key].append(torch.tensor(data).unsqueeze(0))
                    slices[key].append(1)
                else:
                    data_dict[key].append([data])

                if key == 'edge_index':
                    slices[key].append(data.size(1))
                else:
                    slices[key].append(data.size(0) if isinstance(data, torch.Tensor) else 1)

        data = Data(x=torch.cat(data_dict['x'], dim=0),
                    z=torch.cat(data_dict['z'], dim=0),
                    pos=torch.cat(data_dict['pos'], dim=0),
                    edge_index=torch.cat(data_dict['edge_index'], dim=1),
                    edge_attr=torch.cat(data_dict['edge_attr'], dim=0),
                    chiral_tag=torch.cat(data_dict['chiral_tag'], dim=0),
                    name=data_dict['name'],
                    boltzmann_weight=torch.cat(data_dict['boltzmann_weight'], dim=0),
                    degeneracy=torch.cat(data_dict['degeneracy'], dim=0),
                    mol=data_dict['mol'],
                    pos_mask=torch.cat(data_dict['pos_mask'], dim=0), dim=0)

        slices = {key: torch.tensor(slices[key]).cumsum(0) for key in slices.keys()}

        return data, slices

    def len(self) -> int:
        """Returns the number of examples in the dataset."""
        return len(self.slices['x']) - 1
    def open_pickle(self, mol_path):
        # reads the mol pickle file and return the dictionary containing the data
        with open(mol_path, "rb") as f:
            dic = pickle.load(f)
        return dic

    @property
    def raw_file_names(self) -> List[str]:
        return self.all_files

    @property
    def processed_file_names(self):
        return ['data_v3.pt']

    def get_idx_split(self, data_size, train_size, valid_size, seed):
        ids = shuffle(range(data_size), random_state=seed)
        train_idx, val_idx, test_idx = torch.tensor(ids[:train_size]), torch.tensor(ids[train_size:train_size + valid_size]), torch.tensor(ids[train_size + valid_size:])
        split_dict = {'train': train_idx, 'valid': val_idx, 'test': test_idx}
        return split_dict

def featurization(mol_path: str):
    chirality = {ChiralType.CHI_TETRAHEDRAL_CW: -1.,
             ChiralType.CHI_TETRAHEDRAL_CCW: 1.,
             ChiralType.CHI_UNSPECIFIED: 0,
             ChiralType.CHI_OTHER: 0}
    types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}

    dihedral_pattern = Chem.MolFromSmarts('[*]~[*]~[*]~[*]')
    types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
    bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
    file = open(mol_path, 'rb')
    mol_dic = pickle.load(file)

    file.close()
    #save conformers of the moelcule
    confs = mol_dic['conformers']
    name = mol_dic["smiles"]
    confs = mol_dic['conformers']
    random.shuffle(confs)  # shuffle confs
    name = mol_dic["smiles"]
    max_confs=10
    # filter mols rdkit can't intrinsically handle
    mol_ = Chem.MolFromSmiles(name)
    if mol_:
        canonical_smi = Chem.MolToSmiles(mol_)
    else:
        return None

    # skip conformers with fragments
    if '.' in name:
        return None

    # skip conformers without dihedrals
    N = confs[0]['rd_mol'].GetNumAtoms()
    if N < 4:
        return None
    if confs[0]['rd_mol'].GetNumBonds() < 4:
        return None
    if not confs[0]['rd_mol'].HasSubstructMatch(dihedral_pattern):
        return None

    pos = torch.zeros([max_confs, N, 3])
    pos_mask = torch.zeros(max_confs, dtype=torch.int64)
    k = 0
    for conf in confs:
        mol = conf['rd_mol']

        # skip mols with atoms with more than 4 neighbors for now
        n_neighbors = [len(a.GetNeighbors()) for a in mol.GetAtoms()]
        if np.max(n_neighbors) > 4:
            continue

        # filter for conformers that may have reacted
        try:
            conf_canonical_smi = Chem.MolToSmiles(Chem.RemoveHs(mol))
        except Exception as e:
            continue

        if conf_canonical_smi != canonical_smi:
            continue

        pos[k] = torch.tensor(mol.GetConformer().GetPositions(), dtype=torch.float)
        pos_mask[k] = 1
        k += 1
        correct_mol = mol
        if k == 10:
            break

    # return None if no non-reactive conformers were found
    if k == 0:
        return None

    type_idx = []
    atomic_number = []
    atom_features = []
    chiral_tag = []
    neighbor_dict = {}
    ring = correct_mol.GetRingInfo()
    for i, atom in enumerate(correct_mol.GetAtoms()):
        type_idx.append(types[atom.GetSymbol()])
        n_ids = [n.GetIdx() for n in atom.GetNeighbors()]
        if len(n_ids) > 1:
            neighbor_dict[i] = torch.tensor(n_ids)
        chiral_tag.append(chirality[atom.GetChiralTag()])
        atomic_number.append(atom.GetAtomicNum())
        atom_features.extend([atom.GetAtomicNum(),
                                1 if atom.GetIsAromatic() else 0])
        atom_features.extend(one_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6]))
        atom_features.extend(one_k_encoding(atom.GetHybridization(), [
                                Chem.rdchem.HybridizationType.SP,
                                Chem.rdchem.HybridizationType.SP2,
                                Chem.rdchem.HybridizationType.SP3,
                                Chem.rdchem.HybridizationType.SP3D,
                                Chem.rdchem.HybridizationType.SP3D2]))
        atom_features.extend(one_k_encoding(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]))
        atom_features.extend(one_k_encoding(atom.GetFormalCharge(), [-1, 0, 1]))
        atom_features.extend([int(ring.IsAtomInRingOfSize(i, 3)),
                                int(ring.IsAtomInRingOfSize(i, 4)),
                                int(ring.IsAtomInRingOfSize(i, 5)),
                                int(ring.IsAtomInRingOfSize(i, 6)),
                                int(ring.IsAtomInRingOfSize(i, 7)),
                                int(ring.IsAtomInRingOfSize(i, 8))])
        atom_features.extend(one_k_encoding(int(ring.NumAtomRings(i)), [0, 1, 2, 3]))

    z = torch.tensor(atomic_number, dtype=torch.long)
    chiral_tag = torch.tensor(chiral_tag, dtype=torch.float)

    row, col, edge_type, bond_features = [], [], [], []
    for bond in correct_mol.GetBonds():
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        row += [start, end]
        col += [end, start]
        edge_type += 2 * [bonds[bond.GetBondType()]]
        bt = tuple(sorted([bond.GetBeginAtom().GetAtomicNum(), bond.GetEndAtom().GetAtomicNum()])), bond.GetBondTypeAsDouble()
        bond_features += 2 * [int(bond.IsInRing()),
                                int(bond.GetIsConjugated()),
                                int(bond.GetIsAromatic())]

    edge_index = torch.tensor([row, col], dtype=torch.long)
    edge_type = torch.tensor(edge_type, dtype=torch.long)
    edge_attr = one_hot(edge_type, num_classes=len(bonds)).to(torch.float)
    # bond_features = torch.tensor(bond_features, dtype=torch.float).view(len(bond_type), -1)

    perm = (edge_index[0] * N + edge_index[1]).argsort()
    edge_index = edge_index[:, perm]
    edge_type = edge_type[perm]
    # edge_attr = torch.cat([edge_attr[perm], bond_features], dim=-1)
    edge_attr = edge_attr[perm]

    row, col = edge_index
    hs = (z == 1).to(torch.float)
    #num_hs = scatter(hs[row], col, dim_size=N).tolist()

    x1 = one_hot(torch.tensor(type_idx), num_classes=len(types))
    x2 = torch.tensor(atom_features).view(N, -1)
    x = torch.cat([x1.to(torch.float), x2], dim=-1)
    if len(pos)>1:
        data_list = []
        for i in pos:
            if not(torch.all(torch.eq(i, 0))):
                data = Data(x=x, z=z, pos=i, edge_index=edge_index, edge_attr=edge_attr,  
                    chiral_tag=chiral_tag, name=name, boltzmann_weight=conf['boltzmannweight'],
                    degeneracy=conf['degeneracy'], mol=correct_mol, pos_mask=pos_mask)
                data.edge_index_dihedral_pairs = get_dihedral_pairs(data.edge_index, data=data)
                data_list.append(data)
        return data_list
    else:
        if not(torch.all(torch.eq(pos[0], 0))):
            data = Data(x=x, z=z, pos=pos[0], edge_index=edge_index, edge_attr=edge_attr, 
                        chiral_tag=chiral_tag, name=name, boltzmann_weight=conf['boltzmannweight'],
                        degeneracy=conf['degeneracy'], mol=correct_mol, pos_mask=pos_mask)
        
        data.edge_index_dihedral_pairs = get_dihedral_pairs(data.edge_index, data=data)
        return data


def one_k_encoding(value, choices):
    """
    Creates a one-hot encoding with an extra category for uncommon values.
    :param value: The value for which the encoding should be one.
    :param choices: A list of possible values.
    :return: A one-hot encoding of the :code:`value` in a list of length :code:`len(choices) + 1`.
             If :code:`value` is not in :code:`choices`, then the final element in the encoding is 1.
    """
    encoding = [0] * (len(choices) + 1)
    index = choices.index(value) if value in choices else -1
    encoding[index] = 1

    return encoding

def get_cycle_values(cycle_list, start_at=None):
    start_at = 0 if start_at is None else cycle_list.index(start_at)
    while True:
        yield cycle_list[start_at]
        start_at = (start_at + 1) % len(cycle_list)


def get_cycle_indices(cycle, start_idx):
    cycle_it = get_cycle_values(cycle, start_idx)
    indices = []

    end = 9e99
    start = next(cycle_it)
    a = start
    while start != end:
        b = next(cycle_it)
        indices.append(torch.tensor([a, b]))
        a = b
        end = b

    return indices
def get_current_cycle_indices(cycles, cycle_check, idx):
    c_idx = [i for i, c in enumerate(cycle_check) if c][0]
    current_cycle = cycles.pop(c_idx)
    current_idx = current_cycle[(np.array(current_cycle) == idx.item()).nonzero()[0][0]]
    return get_cycle_indices(current_cycle, current_idx)

def get_dihedral_pairs(edge_index, data):
    """
    Given edge indices, return pairs of indices that we must calculate dihedrals for
    """
    start, end = edge_index
    degrees = degree(end)
    dihedral_pairs_true = torch.nonzero(torch.logical_and(degrees[start] > 1, degrees[end] > 1))
    dihedral_pairs = edge_index[:, dihedral_pairs_true].squeeze(-1)
    
    # # first method which removes one (pseudo) random edge from a cycle
    dihedral_idxs = torch.nonzero(dihedral_pairs.sort(dim=0).indices[0, :] == 0).squeeze().detach().cpu().numpy()

    # prioritize rings for assigning dihedrals
    dihedral_pairs = dihedral_pairs.t().cpu().detach()[dihedral_idxs]
    G = nx.to_undirected(tg.utils.to_networkx(data))
    cycles = nx.cycle_basis(G)
    keep, sorted_keep = [], []

    if len(dihedral_pairs.shape) == 1:
        dihedral_pairs = dihedral_pairs.unsqueeze(0)

    for pair in dihedral_pairs:
        x, y = pair

        if sorted(pair) in sorted_keep:
            continue

        y_cycle_check = [y in cycle for cycle in cycles]
        x_cycle_check = [x in cycle for cycle in cycles]

        if any(x_cycle_check) and any(y_cycle_check):  # both in new cycle
            cycle_indices = get_current_cycle_indices(cycles, x_cycle_check, x)
            keep.extend(cycle_indices)

            sorted_keep.extend([sorted(c.cpu()) for c in cycle_indices])
            continue

        if any(y_cycle_check):
            cycle_indices = get_current_cycle_indices(cycles, y_cycle_check, y)
            keep.append(pair)
            keep.extend(cycle_indices)

            sorted_keep.append(sorted(pair))
            sorted_keep.extend([sorted(c.cpu()) for c in cycle_indices])
            continue

        keep.append(pair)

    #keep = torch.tensor(keep).to(device) 
    keep = [torch.tensor(t).to(device) for t in keep]
    
    return torch.stack(keep).t()

In [59]:
import torch_geometric
import pickle

class QM9Batch(torch_geometric.data.batch.Batch):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._data = None

    def __getstate__(self):
        state = self.__dict__.copy()
        # Include the _data attribute in the state.
        state['_data'] = self._data
        return state

    def __setstate__(self, state):
        self.__dict__.update(state)
        self._data = None
    def __repr__(self):
        return f"{self.__class__.__name__}(batch_size={self.batch_size})"

In [3]:
path_to_data = '../../../others_approaches/conformation_generation/GeoMol/data/QM9/'
qm9_set = qm9_data(root= path_to_data)

In [6]:
qm9_set.len()

522124

In [4]:
from torch import Tensor
#from torch.utils.data import DataLoader
from torch_geometric.loader import DataLoader
# Define the indices of the subset you want to load
subset_indices = [0, 1, 2]  # Example indices, replace with your desired subset

# Create the DataLoader
dataloader = DataLoader(dataset=qm9_set, batch_size=4, shuffle=False)



In [8]:
for i, data in tqdm(enumerate(dataloader), total=len(dataloader)):
    print(data)

  0%|          | 17/130531 [00:00<22:41, 95.85it/s] 

Batch(x=[40, 44], edge_index=[2, 72], edge_attr=[72, 4], pos=[40, 3], z=[40], chiral_tag=[40], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[40], ptr=[5])
Batch(x=[51, 44], edge_index=[2, 96], edge_attr=[96, 4], pos=[51, 3], z=[51], chiral_tag=[51], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[51], ptr=[5])
Batch(x=[49, 44], edge_index=[2, 90], edge_attr=[90, 4], pos=[49, 3], z=[49], chiral_tag=[49], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[49], ptr=[5])
Batch(x=[53, 44], edge_index=[2, 98], edge_attr=[98, 4], pos=[53, 3], z=[53], chiral_tag=[53], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[53], ptr=[5])
Batch(x=[68, 44], edge_index=[2, 128], edge_attr=[128, 4], pos=[68, 3], z=[68], chiral_tag=[68], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[68], ptr=[5])
Batch(x=[60, 44], edge_index=[2, 112], edge_attr=[112, 4],

  0%|          | 92/130531 [00:00<09:12, 236.09it/s]

Batch(x=[67, 44], edge_index=[2, 138], edge_attr=[138, 4], pos=[67, 3], z=[67], chiral_tag=[67], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[67], ptr=[5])
Batch(x=[57, 44], edge_index=[2, 114], edge_attr=[114, 4], pos=[57, 3], z=[57], chiral_tag=[57], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[57], ptr=[5])
Batch(x=[71, 44], edge_index=[2, 142], edge_attr=[142, 4], pos=[71, 3], z=[71], chiral_tag=[71], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[71], ptr=[5])
Batch(x=[65, 44], edge_index=[2, 130], edge_attr=[130, 4], pos=[65, 3], z=[65], chiral_tag=[65], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[65], ptr=[5])
Batch(x=[44, 44], edge_index=[2, 80], edge_attr=[80, 4], pos=[44, 3], z=[44], chiral_tag=[44], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[44], ptr=[5])
Batch(x=[53, 44], edge_index=[2, 98], edge_attr=[98,

  0%|          | 139/130531 [00:00<06:56, 313.16it/s]

Batch(x=[64, 44], edge_index=[2, 120], edge_attr=[120, 4], pos=[64, 3], z=[64], chiral_tag=[64], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[64], ptr=[5])
Batch(x=[63, 44], edge_index=[2, 118], edge_attr=[118, 4], pos=[63, 3], z=[63], chiral_tag=[63], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[63], ptr=[5])
Batch(x=[64, 44], edge_index=[2, 120], edge_attr=[120, 4], pos=[64, 3], z=[64], chiral_tag=[64], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[64], ptr=[5])
Batch(x=[76, 44], edge_index=[2, 144], edge_attr=[144, 4], pos=[76, 3], z=[76], chiral_tag=[76], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[76], ptr=[5])
Batch(x=[76, 44], edge_index=[2, 144], edge_attr=[144, 4], pos=[76, 3], z=[76], chiral_tag=[76], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[76], ptr=[5])
Batch(x=[76, 44], edge_index=[2, 144], edge_attr=[

  0%|          | 199/130531 [00:00<08:15, 263.20it/s]

Batch(x=[60, 44], edge_index=[2, 112], edge_attr=[112, 4], pos=[60, 3], z=[60], chiral_tag=[60], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[60], ptr=[5])
Batch(x=[64, 44], edge_index=[2, 120], edge_attr=[120, 4], pos=[64, 3], z=[64], chiral_tag=[64], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[64], ptr=[5])
Batch(x=[76, 44], edge_index=[2, 144], edge_attr=[144, 4], pos=[76, 3], z=[76], chiral_tag=[76], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[76], ptr=[5])
Batch(x=[76, 44], edge_index=[2, 144], edge_attr=[144, 4], pos=[76, 3], z=[76], chiral_tag=[76], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[76], ptr=[5])
Batch(x=[70, 44], edge_index=[2, 132], edge_attr=[132, 4], pos=[70, 3], z=[70], chiral_tag=[70], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[70], ptr=[5])
Batch(x=[68, 44], edge_index=[2, 128], edge_attr=[

  0%|          | 254/130531 [00:01<08:34, 253.33it/s]

Batch(x=[68, 44], edge_index=[2, 136], edge_attr=[136, 4], pos=[68, 3], z=[68], chiral_tag=[68], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[68], ptr=[5])
Batch(x=[61, 44], edge_index=[2, 122], edge_attr=[122, 4], pos=[61, 3], z=[61], chiral_tag=[61], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[61], ptr=[5])
Batch(x=[64, 44], edge_index=[2, 128], edge_attr=[128, 4], pos=[64, 3], z=[64], chiral_tag=[64], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[64], ptr=[5])
Batch(x=[61, 44], edge_index=[2, 122], edge_attr=[122, 4], pos=[61, 3], z=[61], chiral_tag=[61], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[61], ptr=[5])
Batch(x=[68, 44], edge_index=[2, 136], edge_attr=[136, 4], pos=[68, 3], z=[68], chiral_tag=[68], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[68], ptr=[5])
Batch(x=[64, 44], edge_index=[2, 128], edge_attr=[

  0%|          | 280/130531 [00:01<09:05, 238.83it/s]

Batch(x=[60, 44], edge_index=[2, 112], edge_attr=[112, 4], pos=[60, 3], z=[60], chiral_tag=[60], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[60], ptr=[5])
Batch(x=[60, 44], edge_index=[2, 112], edge_attr=[112, 4], pos=[60, 3], z=[60], chiral_tag=[60], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[60], ptr=[5])
Batch(x=[62, 44], edge_index=[2, 116], edge_attr=[116, 4], pos=[62, 3], z=[62], chiral_tag=[62], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[62], ptr=[5])
Batch(x=[64, 44], edge_index=[2, 120], edge_attr=[120, 4], pos=[64, 3], z=[64], chiral_tag=[64], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[64], ptr=[5])
Batch(x=[64, 44], edge_index=[2, 120], edge_attr=[120, 4], pos=[64, 3], z=[64], chiral_tag=[64], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[64], ptr=[5])
Batch(x=[60, 44], edge_index=[2, 112], edge_attr=[

  0%|          | 331/130531 [00:01<09:19, 232.63it/s]

Batch(x=[55, 44], edge_index=[2, 112], edge_attr=[112, 4], pos=[55, 3], z=[55], chiral_tag=[55], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[55], ptr=[5])
Batch(x=[63, 44], edge_index=[2, 128], edge_attr=[128, 4], pos=[63, 3], z=[63], chiral_tag=[63], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[63], ptr=[5])
Batch(x=[65, 44], edge_index=[2, 138], edge_attr=[138, 4], pos=[65, 3], z=[65], chiral_tag=[65], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[65], ptr=[5])
Batch(x=[60, 44], edge_index=[2, 112], edge_attr=[112, 4], pos=[60, 3], z=[60], chiral_tag=[60], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[60], ptr=[5])
Batch(x=[58, 44], edge_index=[2, 108], edge_attr=[108, 4], pos=[58, 3], z=[58], chiral_tag=[58], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[58], ptr=[5])
Batch(x=[59, 44], edge_index=[2, 110], edge_attr=[

  0%|          | 381/130531 [00:01<09:11, 236.04it/s]

Batch(x=[44, 44], edge_index=[2, 80], edge_attr=[80, 4], pos=[44, 3], z=[44], chiral_tag=[44], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[44], ptr=[5])
Batch(x=[56, 44], edge_index=[2, 104], edge_attr=[104, 4], pos=[56, 3], z=[56], chiral_tag=[56], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[56], ptr=[5])
Batch(x=[59, 44], edge_index=[2, 110], edge_attr=[110, 4], pos=[59, 3], z=[59], chiral_tag=[59], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[59], ptr=[5])
Batch(x=[68, 44], edge_index=[2, 128], edge_attr=[128, 4], pos=[68, 3], z=[68], chiral_tag=[68], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[68], ptr=[5])
Batch(x=[68, 44], edge_index=[2, 128], edge_attr=[128, 4], pos=[68, 3], z=[68], chiral_tag=[68], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[68], ptr=[5])
Batch(x=[62, 44], edge_index=[2, 116], edge_attr=[11

  0%|          | 429/130531 [00:01<09:15, 234.35it/s]

Batch(x=[56, 44], edge_index=[2, 104], edge_attr=[104, 4], pos=[56, 3], z=[56], chiral_tag=[56], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[56], ptr=[5])
Batch(x=[53, 44], edge_index=[2, 98], edge_attr=[98, 4], pos=[53, 3], z=[53], chiral_tag=[53], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[53], ptr=[5])
Batch(x=[60, 44], edge_index=[2, 112], edge_attr=[112, 4], pos=[60, 3], z=[60], chiral_tag=[60], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[60], ptr=[5])
Batch(x=[68, 44], edge_index=[2, 128], edge_attr=[128, 4], pos=[68, 3], z=[68], chiral_tag=[68], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[68], ptr=[5])
Batch(x=[68, 44], edge_index=[2, 128], edge_attr=[128, 4], pos=[68, 3], z=[68], chiral_tag=[68], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[68], ptr=[5])
Batch(x=[60, 44], edge_index=[2, 112], edge_attr=[11

  0%|          | 481/130531 [00:02<09:07, 237.62it/s]

Batch(x=[84, 44], edge_index=[2, 160], edge_attr=[160, 4], pos=[84, 3], z=[84], chiral_tag=[84], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[84], ptr=[5])
Batch(x=[76, 44], edge_index=[2, 144], edge_attr=[144, 4], pos=[76, 3], z=[76], chiral_tag=[76], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[76], ptr=[5])
Batch(x=[76, 44], edge_index=[2, 144], edge_attr=[144, 4], pos=[76, 3], z=[76], chiral_tag=[76], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[76], ptr=[5])
Batch(x=[70, 44], edge_index=[2, 132], edge_attr=[132, 4], pos=[70, 3], z=[70], chiral_tag=[70], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[70], ptr=[5])
Batch(x=[64, 44], edge_index=[2, 120], edge_attr=[120, 4], pos=[64, 3], z=[64], chiral_tag=[64], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[64], ptr=[5])
Batch(x=[64, 44], edge_index=[2, 120], edge_attr=[

  0%|          | 527/130531 [00:02<11:02, 196.20it/s]

Batch(x=[63, 44], edge_index=[2, 118], edge_attr=[118, 4], pos=[63, 3], z=[63], chiral_tag=[63], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[63], ptr=[5])
Batch(x=[64, 44], edge_index=[2, 120], edge_attr=[120, 4], pos=[64, 3], z=[64], chiral_tag=[64], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[64], ptr=[5])
Batch(x=[66, 44], edge_index=[2, 124], edge_attr=[124, 4], pos=[66, 3], z=[66], chiral_tag=[66], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[66], ptr=[5])
Batch(x=[72, 44], edge_index=[2, 136], edge_attr=[136, 4], pos=[72, 3], z=[72], chiral_tag=[72], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[72], ptr=[5])
Batch(x=[72, 44], edge_index=[2, 136], edge_attr=[136, 4], pos=[72, 3], z=[72], chiral_tag=[72], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[72], ptr=[5])
Batch(x=[48, 44], edge_index=[2, 88], edge_attr=[8

  0%|          | 531/130531 [00:02<09:49, 220.61it/s]


Batch(x=[74, 44], edge_index=[2, 148], edge_attr=[148, 4], pos=[74, 3], z=[74], chiral_tag=[74], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[74], ptr=[5])
Batch(x=[68, 44], edge_index=[2, 136], edge_attr=[136, 4], pos=[68, 3], z=[68], chiral_tag=[68], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[68], ptr=[5])
Batch(x=[68, 44], edge_index=[2, 136], edge_attr=[136, 4], pos=[68, 3], z=[68], chiral_tag=[68], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[68], ptr=[5])
Batch(x=[71, 44], edge_index=[2, 142], edge_attr=[142, 4], pos=[71, 3], z=[71], chiral_tag=[71], name=[4], boltzmann_weight=[4], degeneracy=[4], mol=[4], pos_mask=[40], batch=[71], ptr=[5])


KeyboardInterrupt: 