# Manual data processing if issues

### Note: used this before in top folder. 
- May have issues with imports.
- Can move back to top folder if using.

In [1]:
from numpy.core.fromnumeric import product
from scipy.sparse import data
import torch
import torch.nn.functional as F
from torch_scatter import scatter
from torch_geometric.data import InMemoryDataset, DataLoader # , Data
from torch_geometric.data.data import Data
from rdkit import Chem
from rdkit.Chem.rdchem import HybridizationType
from rdkit.Chem.rdchem import BondType as BT
from tqdm import tqdm

def process_geometry_file(geometry_file, list = None):
    """ Code mostly lifted from QM9 dataset creation https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/qm9.html 
        Transforms molecules to their atom features and adjacency lists.
    """
    types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
    bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
    limit = 100

    data_list = list if list else []
    full_path = r'data' + geometry_file
    geometries = Chem.SDMolSupplier(full_path, removeHs=False, sanitize=False)

    # get atom and edge features for each geometry
    for i, mol in enumerate(tqdm(geometries)):

        # temp soln cos of split edge memory issues
        if i == limit:
            break
        
        N = mol.GetNumAtoms()
        # get atom positions as matrix w shape [num_nodes, num_dimensions] = [num_atoms, 3]
        atom_data = geometries.GetItemText(i).split('\n')[4:4 + N] 
        atom_positions = [[float(x) for x in line.split()[:3]] for line in atom_data]
        atom_positions = torch.tensor(atom_positions, dtype=torch.float)
        # all the features
        type_idx = []
        atomic_number = []
        aromatic = []
        sp = []
        sp2 = []
        sp3 = []
        num_hs = []

        # atom/node features
        for atom in mol.GetAtoms():
            type_idx.append(types[atom.GetSymbol()])
            atomic_number.append(atom.GetAtomicNum())
            aromatic.append(1 if atom.GetIsAromatic() else 0)
            hybridisation = atom.GetHybridization()
            sp.append(1 if hybridisation == HybridizationType.SP else 0)
            sp2.append(1 if hybridisation == HybridizationType.SP2 else 0)
            sp3.append(1 if hybridisation == HybridizationType.SP3 else 0)
            # !!! should do the features that lucky does: whether bonded, 3d_rbf

        # bond/edge features
        row, col, edge_type = [], [], []
        for bond in mol.GetBonds(): 
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            row += [start, end]
            col += [end, start]
            # edge type for each bond type; *2 because both ways
            edge_type += 2 * [bonds[bond.GetBondType()]]
        # edge_index is graph connectivity in COO format with shape [2, num_edges]
        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_type = torch.tensor(edge_type, dtype=torch.long)
        # edge_attr is edge feature matrix with shape [num_edges, num_edge_features]
        edge_attr = F.one_hot(edge_type, num_classes=len(bonds)).to(torch.float) 

        # order edges based on combined ascending order
        perm = (edge_index[0] * N + edge_index[1]).argsort() # TODO
        edge_index = edge_index[:, perm]
        edge_type = edge_type[perm]
        edge_attr = edge_attr[perm]

        row, col = edge_index
        z = torch.tensor(atomic_number, dtype=torch.long)
        hs = (z == 1).to(torch.float) # hydrogens
        num_hs = scatter(hs[row], col, dim_size=N).tolist() # scatter helps with one-hot
        
        x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(types))
        x2 = torch.tensor([atomic_number, aromatic, sp, sp2, sp3, num_hs], dtype=torch.float).t().contiguous()
        x = torch.cat([x1.to(torch.float), x2], dim=-1)

        data = Data(x=x, z=z, pos=atom_positions, edge_index=edge_index, edge_attr=edge_attr, idx=i)
        
        data_list.append(data)

    return data_list

In [2]:
# concat train r and test r
reactants = []
reactants = process_geometry_file('/raw/train_reactants.sdf', reactants)
reactants = process_geometry_file('/raw/test_reactants.sdf', reactants)

# concat train ts and test ts
ts = []
ts = process_geometry_file('/raw/train_ts.sdf', ts)
ts = process_geometry_file('/raw/test_ts.sdf', ts) 

# concat train p and test p
products = []
products = process_geometry_file('/raw/train_products.sdf', products)
products = process_geometry_file('/raw/test_products.sdf', products) 

assert len(reactants) == len(ts) == len(products)

print(type(reactants[0]), type(ts[0]), type(products[0]))

  1%|▏         | 100/6739 [00:00<00:20, 322.49it/s]
 12%|█▏        | 100/842 [00:00<00:01, 630.53it/s]
  1%|▏         | 100/6739 [00:00<00:04, 1348.46it/s]
 12%|█▏        | 100/842 [00:00<00:01, 580.31it/s]
  1%|▏         | 100/6739 [00:00<00:04, 1388.79it/s]
 12%|█▏        | 100/842 [00:00<00:00, 773.73it/s]

<class 'torch_geometric.data.data.Data'> <class 'torch_geometric.data.data.Data'> <class 'torch_geometric.data.data.Data'>





In [58]:
class ReactionTriple(Data):
    def __init__(self, r = None, ts = None, p = None):
        super(ReactionTriple, self).__init__()
        self.r = r
        self.ts = ts
        self.p = p

    def __inc__(self, key, value):
        if key == 'r':
            return self.r.edge_index.size(0)
        elif key == 'ts':
            return self.ts.edge_index.size(0)
        elif key == 'p':
            return self.p.edge_index.size(0)
        else:
            return super().__inc__(key, value)

class OtherReactionTriple(Data):
    # seeing if this works

    def __init__(self, r, ts, p):
        super(OtherReactionTriple, self).__init__()

        # initial checks
        if r and ts and p:
            assert r.idx == ts.idx == p.idx, \
                "The IDs of each mol don't match. Are you sure your data processing is correct?"
            assert len(r.z) == len(ts.z) == len(p.z), \
                "The mols have different number of atoms."
            self.idx = r.idx
            self.num_atoms = len(r.z)

            # reactant
            self.edge_attr_r = r.edge_attr
            self.edge_index_r = r.edge_index
            self.pos_r = r.pos
            self.x_r = r.x

            # ts
            self.edge_attr_ts = ts.edge_attr
            self.edge_index_ts = ts.edge_index
            self.pos_ts = ts.pos
            self.x_ts = ts.x

            # product
            self.edge_attr_p = p.edge_attr
            self.edge_index_p = p.edge_index
            self.pos_p = p.pos
            self.x_p = p.x
        else:
            NameError("Reactant, TS, or Product not defined for this reaction.")

    def __inc__(self, key, value):
        if key == 'edge_index_r' or key == 'edge_attr_r':
            return self.x_r.size(0)
        if key == 'edge_index_ts' or key == 'edge_attr_ts':
            return self.x_ts.size(0)
        if key == 'edge_index_p' or key == 'edge_attr_p':
            return self.x_p.size(0)
        else:
            return super().__inc__(key, value)
    
    def __cat_dim__(self, key, item):
        # NOTE: automatically figures out .x and .pos
        if key == 'edge_attr_r' or key == 'edge_attr_ts' or key == 'edge_attr_p':
            return 0
        if key == 'edge_index_r' or key == 'edge_index_ts' or key == 'edge_index_p':
            return 1
        else:
            return super().__cat_dim__(key, item)



In [72]:
rxns = []
for rxn_id in range(len(reactants)):
    rxn = OtherReactionTriple(reactants[rxn_id], ts[rxn_id], products[rxn_id])
    rxns.append(rxn)

to_follow = ['edge_index_r', 'edge_index_ts', 'edge_index_p', 'edge_attr_r', 'edge_attr_ts', 'edge_attr_p'
             'pos_r', 'pos_ts', 'pos_p', 'x_r', 'x_ts', 'x_p']

loader = DataLoader(rxns, batch_size = 2, follow_batch = to_follow)
batch = next(iter(loader))

TypeError: __init__() missing 3 required positional arguments: 'r', 'ts', and 'p'

## Data functions

In [51]:
def edge2adj(z, edge_index, sigmoid = True):
    value = (z[edge_index[0]] * z[edge_index[1]]).sum(dim = 1)
    return torch.sigmoid(value) if sigmoid else value

In [92]:
# their model 
# so they take their nodes, edges, edge_attr and actual adj
# adj_pred, z = model(nodes, edges, edge_attr)
# bce, kl = loss(adj_pred, adj_gt)

26

In [153]:
from torch_geometric.utils import to_dense_adj

node_fs = mol_graph.x
edge_index = mol_graph.edge_index
edge_attr = mol_graph.edge_attr
num_nodes = len(mol_graph.z)
latent_dim = 3
max_num_nodes = 21

def sparse_to_dense_adj(num_nodes, edge_index):
    # edge_index is sparse_adj matrix (given in coo format for graph connectivity)
    sparse_adj = torch.cat([edge_index[0].unsqueeze(0), edge_index[1].unsqueeze(0)])
    # the values we put in at each tuple; that's why length of sparse_adj
    ones = torch.ones(sparse_adj.size(1)) 
    # FloatTensor() creates sparse coo tensor in torch format, then to_dense()
    dense_adj = torch.sparse.FloatTensor(sparse_adj, ones, torch.Size([num_nodes, num_nodes])).to_dense() # to_dense adds the zeroes needed
    return dense_adj


adj_egnn = sparse_to_dense_adj(num_nodes, edge_index)
# with edge_attr, we get a [1, num_nodes, num_nodes] for each edge_type
adj_pyg = to_dense_adj(edge_index, edge_attr = edge_attr, max_num_nodes = num_nodes)

# get_dense_graph(): returns self.nodes, self.edges_dense, self.edge_attr_dense, self.adj
# adj = sparse2dense(n_nodes, self.edges); adjust for loops
# compare sparse2dense (egnn) vs to_dense_adj (pyg)

# adj_egnn.shape
# (adj_pyg == adj_egnn).all()

# gcn = GCNConv(num_nodes, latent_dim)
# z = gcn(node_fs, edge_index)

# adj_pred = adj_pred * (1 - torch.eye(num_nodes).to(self.device)) # removes self_loops
# * is hadamard product

In [None]:
# coords always same, maybe node and edge features too? need to pad adj matrix

# dataset dims
elements = "HCNO"
num_elements = len(elements)
max_n_atoms = max([r.GetNumAtoms() for r,ts,p in data])
num_coords = 3
num_bond_fs

# want to pad exist features

def prepare_batch(batch_mols):

    # initialise batch
    batch_size = len(batch_mols)
    atom_fs = torch.zeros((batch_size, max_n_atoms, num_elements + 1), dtype = torch.float32) # num_atoms, max_num_atoms, 
    bond_fs = torch.zeros((batch_size, max_n_atoms, max_n_atoms, num_bond_fs), dtype = torch.float32)
    sizes = torch.zeros(batch_size, dtype = torch.float32)
    coords = torch.zeros((batch_size, max_size, num_coords), dtype = torch.float32)
    
    pass

def pad_sequence(sequences: List[torch.Tensor], max_length: int, padding_value=0) -> torch.Tensor:
    # assuming trailing dimensions and type of all the Tensors
    # in sequences are same and fetching those from sequences[0]
    max_size = sequences[0].size()
    trailing_dims = max_size[1:]
    out_dims = (len(sequences), max_length) + trailing_dims

    out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)  # type: ignore
    for i, tensor in enumerate(sequences):
        length = tensor.size(0)
        # use index notation to prevent duplicate references to the tensor
        out_tensor[i, :length, ...] = tensor

    return out_tensor

# ts_gen processing

## Testing

In [1]:
from rdkit import Chem
import numpy as np
import torch
from torch_geometric.data import DataLoader
from torch_geometric.data.data import Data
import tqdm

In [20]:
class TSGenData(Data):
    # seeing if this works

    def __init__(self, x = None, pos = None, edge_attr = None, idx = None):
        super(TSGenData, self).__init__(x = x, pos = pos, edge_attr = edge_attr)
        self.idx = idx

    def __inc__(self, key, value):
        if key == 'edge_attr':
            return self.x.size(0)
        else:
            return super().__inc__(key, value)
    
    def __cat_dim__(self, key, item):
        # NOTE: automatically figures out .x and .pos
        if key == 'edge_attr':
            return (0, 1) # since N x N x edge_attr
        else:
            return super().__cat_dim__(key, item)

In [21]:
# constants
MAX_D = 10.
COORD_DIM = 3
ELEM_TYPES = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
NUM_EDGE_ATTR = 3
TEMP_MOLS_LIMIT = 10

def process():

    # reactants
    r_train = Chem.SDMolSupplier('data/raw/train_reactants.sdf', removeHs = False, sanitize = False)
    r_test = Chem.SDMolSupplier('data/raw/test_reactants.sdf', removeHs = False, sanitize = False)
    rs = []
    for mol in r_train:
        rs.append(mol)
    for mol in r_test:
        rs.append(mol)
    
    # transition states
    ts_train = Chem.SDMolSupplier('data/raw/train_ts.sdf', removeHs = False, sanitize = False)
    ts_test = Chem.SDMolSupplier('data/raw/test_ts.sdf', removeHs = False, sanitize = False)
    tss = []
    for mol in ts_train:
        tss.append(mol)
    for mol in ts_test:
        tss.append(mol)
    
    # products
    p_train = Chem.SDMolSupplier('data/raw/train_products.sdf', removeHs = False, sanitize = False)
    p_test = Chem.SDMolSupplier('data/raw/test_products.sdf', removeHs = False, sanitize = False)
    ps = []
    for mol in p_train:
        ps.append(mol)
    for mol in p_test:
        ps.append(mol)
    
    assert len(rs) == len(tss) == len(ps), f"Lengths of reactants ({len(rs)}), transition states \
                                            ({len(tss)}), products ({len(ps)}) don't match."

    geometries = list(zip(rs, tss, ps))
    data_list = process_geometries(geometries)
    return data_list
    # torch.save(self.collate(data_list), self.processed_paths[0])

def process_geometries(geometries):
    """Process all geometries in same manner as ts_gen."""
    
    data_list = []
    
    for rxn_id, rxn in enumerate(geometries):

        if rxn_id == TEMP_MOLS_LIMIT:
            break

        r, ts, p = rxn
        num_atoms = r.GetNumAtoms()

        # dist matrices
        D = (Chem.GetDistanceMatrix(r) + Chem.GetDistanceMatrix(p)) / 2
        D[D > MAX_D] = MAX_D
        D_3D_rbf = np.exp(-((Chem.Get3DDistanceMatrix(r) + Chem.Get3DDistanceMatrix(p)) / 2))  

        # node feats, edge attr init
        type_ids, atomic_ns = [], [] # TODO: init of vec N
        edge_attr = torch.zeros(num_atoms, num_atoms, NUM_EDGE_ATTR)
        
        # ts ground truth coords
        ts_gt_pos = torch.zeros((num_atoms, COORD_DIM))
        ts_conf = ts.GetConformer()
        for i in range(num_atoms):

            # node feats
            atom = r.GetAtomWithIdx(i)
            type_ids.append(ELEM_TYPES[atom.GetSymbol()])
            atomic_ns.append(atom.GetAtomicNum() / 10.)

            # ts coordinates: atom positions as matrix w shape [num_atoms, 3]
            pos = ts_conf.GetAtomPosition(i)
            ts_gt_pos[i] = torch.tensor([pos.x, pos.y, pos.z])
            
            # edge attrs
            for j in range(num_atoms):
                if D[i][j] == 1: # if stays bonded
                    edge_attr[i][j][0] = 1 # bonded?
                    if r.GetBondBetweenAtoms(i, j).GetIsAromatic():
                        edge_attr[i][j][1] = 1 # aromatic?
                edge_attr[i][j][2] = D_3D_rbf[i][j] # 3d rbf
        
        node_feats = torch.tensor([type_ids, atomic_ns], dtype = torch.float).t().contiguous()
        atomic_ns = torch.tensor(atomic_ns, dtype = torch.long)
        # edge_attr = torch.tensor([bonded, aromatic, rbf], dtype = torch.float).t().contiguous()

        data = TSGenData(x = node_feats, pos = ts_gt_pos, edge_attr = edge_attr, idx = rxn_id)
        data_list.append(data) 

    return data_list

In [22]:
data_list = process()

In [7]:
from torch import Tensor
from itertools import product

def collate(data_list):
    keys = data_list[0].keys
    data = data_list[0].__class__()

    for key in keys:
        data[key] = []
    slices = {key: [0] for key in keys}

    for item, key in product(data_list, keys):
        data[key].append(item[key])
        if isinstance(item[key], Tensor) and (item[key].dim() == 1 or item[key].dim() == 2):
            cat_dim = item.__cat_dim__(key, item[key])
            cat_dim = 0 if cat_dim is None else cat_dim
            s = slices[key][-1] + item[key].size(cat_dim)
        elif isinstance(item[key], Tensor) and (item[key].dim() > 2):
            cat_dims = item.__cat_dim__(key, item[key])
            # print(cat_dims)
            s = slices[key][-1]
            for cat_dim in cat_dims:
                s += item[key].size(cat_dim)
        else:
            s = slices[key][-1] + 1
        slices[key].append(s)
    # print(slices)

    if hasattr(data_list[0], '__num_nodes__'):
        data.__num_nodes__ = []
        for item in data_list:
            data.__num_nodes__.append(item.num_nodes)

    for key in keys:
        item = data_list[0][key]
        if isinstance(item, Tensor) and len(data_list) > 1:
            if item.dim() == 1 or item.dim() == 2:
                cat_dim = data.__cat_dim__(key, item)
                cat_dim = 0 if cat_dim is None else cat_dim
                data[key] = torch.cat(data[key], dim=cat_dim)
            elif item.dim() > 2:
                print(item.dim())
                cat_dim = data.__cat_dim__(key, item)
                # size = torch.tensor(item.sizes())[torch.tensor(cat_dim)]
                # print(len(data[key]))
                data[key] = torch.stack(data[key])
                # data[key] = torch.cat(data[key], dim = 0)
                # data[key] = torch.cat(data[key], dim = 1)
                continue
            else:
                data[key] = torch.stack(data[key])
        elif isinstance(item, Tensor):  # Don't duplicate attributes...
            data[key] = data[key][0]
        elif isinstance(item, int) or isinstance(item, float):
            data[key] = torch.tensor(data[key])

        slices[key] = torch.tensor(slices[key], dtype=torch.long)

    return data, slices

collate(data_list)

NameError: name 'data_list' is not defined

In [9]:
import torch
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
a = torch.ones(2, 2, 10)
b = torch.ones(4, 4, 10)
c = torch.ones(6, 6, 10)
pack_padded_sequence(torch.cat([a, b, c]), [2, 4, 6])

RuntimeError: Sizes of tensors must match except in dimension 0. Got 2 and 4 in dimension 1 (The offending index is 1)

In [13]:
loader = DataLoader(data_list, batch_size = 5)

## Masks

In [15]:
# you would need to do this for each batch
import torch
def sequence_mask(sizes, max_size = None, dtype = torch.bool):
    if max_size is None:
        max_size = sizes.max()
    row_vector = torch.arange(0, max_size, 1)
    matrix = torch.unsqueeze(sizes, dim = -1)
    mask = row_vector < matrix

    mask.type(dtype)
    return mask

sizes = torch.tensor([10, 12, 20, 18]) # num_atoms in each graph
max_size = 21
mask = sequence_mask(sizes, max_size)

mask_n = torch.unsqueeze(mask, 2)
mask_v = torch.unsqueeze(mask_n, 1) * torch.unsqueeze(mask_n, 2)
mask_n.shape, mask_v.shape

(torch.Size([4, 21, 1]), torch.Size([4, 21, 21, 1]))

## Final ts_gen code from .py file

In [10]:
from ts_vae.data_processors.ts_gen_processor import TSGenDataset
from torch_geometric.data import DataLoader
import numpy as np
from ts_vae.utils import remove_files
remove_files() 
rxns = TSGenDataset(r'data')

Files removed.
Processing...


  0%|          | 10/7581 [00:00<01:07, 112.66it/s]

Done!





In [11]:
tt_split = 0.8
num_rxns = len(rxns)
num_train = int(np.floor(tt_split * num_rxns))
batch_size = 5
train_loader = DataLoader(rxns[: num_train], batch_size = batch_size)
test_loader = DataLoader(rxns[num_train: ], batch_size = batch_size)

In [14]:
train_loader.dataset[0].edge_attr.size(2)

3

In [6]:
rxns[0]
batch = next(iter(train_loader))
batch

Batch(batch=[105], edge_attr=[21, 21, 15], idx=[5], pos=[105, 3], ptr=[6], x=[105, 5])

In [1]:
from experiments.building_on_mit.meta_eval.meta_eval import ablation_experiment
# from ts_vae.utils import remove_files
# remove_files()
# have to use batch_size = 1 right now
train_log, test_log = ablation_experiment(0.8, 1, 2, 2)

Processing...


  0%|          | 10/7581 [00:00<00:42, 178.77it/s]

torch.Size([15, 15, 3])

torch.Size([13, 13, 3])

torch.Size([10, 10, 3])

torch.Size([9, 9, 3])

torch.Size([11, 11, 3])

torch.Size([14, 14, 3])

torch.Size([15, 15, 3])

torch.Size([15, 15, 3])

torch.Size([15, 15, 3])

torch.Size([14, 14, 3])






TypeError: size() received an invalid combination of arguments - got (tuple), but expected one of:
 * (int dim)
      didn't match because some of the arguments have invalid types: (!tuple!)
 * ()
      didn't match because some of the arguments have invalid types: (!tuple!)
 * (name dim)
      didn't match because some of the arguments have invalid types: (!tuple!)


# Redoing with GraphDataLoader

- GraphDataLoader: takes collate_fn given by GraphCollater()
- GraphCollater -> Collater for ABC
- GraphBatch
<br/><br/>
- CustomDataLoader, CustomBatch, CustomCollater
- Then create my own collate() and Batch.from_data_list() funcs
- CustomDataLoader is super simple, the main logic would be in CustomCollater which defines the collate() func for the DataLoader
<br/><br/>
- All I need to do is create a DataLoader (which I have), then overwrite the collate() and 



## Initial

In [14]:
data_list[0].__dict__.keys()

dict_keys(['x', 'edge_index', 'edge_attr', 'y', 'pos', 'normal', 'face', 'idx'])

In [10]:
from torch_geometric.data import Data, Batch
from collections.abc import Mapping, Sequence


class TSGenBatch(TSGenData): # Data

    def __init__(self, batch = None, ptr = None, **kwargs):
        super(Batch, self).__init__(**kwargs)

        for key, item in kwargs.items():
            if key == 'num_nodes':
                self.__num_nodes__ = item
            else:
                self[key] = item
        
        self.batch = batch
        self.ptr = ptr
        self.__data_class__ = TSGenData # Data
        self.__slices__ = None
        self.__cumsum__ = None
        self.__cat_dims__ = None
        self.__num_nodes_list__ = None
        self.__num_graphs__ = None
    
    @classmethod
    def from_data_list(cls, data_list, follow_batch = [], exclude_keys = []):
        # construct batch from TSGenData objects
        
        # get relevant graph keys
        keys = list(set(data_list[0].keys) - set(exclude_keys))
        assert 'batch' not in keys and 'ptr' not in keys

        batch = cls()
        for key in data_list[0].__dict__.keys():
            # no batch for those intrinsic class fs
            if key[:2] != '__' and key[-2:] != '__':
                batch[key] = None
            
        batch.__num_graphs__ = len(data_list)
        batch.__data_class__ = data_list[0].__class__
        # init all keys for the batch
        for key in keys + ['batch']:
            batch[key] = []
        batch['ptr'] = [0] # pointer to this batch

        device = None
        slices = {key: [0] for key in keys}
        cumsum = {key: [0] for key in keys}
        cat_dims = {}
        num_nodes_list = []
        for i, data in enumerate(data_list):
            for key in keys:
                item = data[key]

                # increase values by cumsum value
                cum = cumsum[key][-1]
                



class TSGenCollater(object):

    def __init__(self, follow_batch, exclude_keys):
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys
    
    def collate(self, batch):
        # dgl: collate(self, items): items is list of data points or tuples; elems in list same length
        # pyg: collate(self, batch)
        elem = batch[0]
        if isinstance(elem, TSGenData):
            return Batch.from_data_list(batch, self.follow_batch, self.exclude_keys)
        if isinstance(elem, Data):
            return Batch.from_data_list(batch, self.follow_batch, self.exclude_keys)
        elif isinstance(elem, torch.Tensor):
            return default_collate(batch)
        elif isinstance(elem, float):
            return torch.tensor(batch, dtype=torch.float)
        elif isinstance(elem, int):
            return torch.tensor(batch)
        elif isinstance(elem, str):
            return batch
        elif isinstance(elem, Mapping):
            return {key: self.collate([d[key] for d in batch]) for key in elem}
        elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
            return type(elem)(*(self.collate(samples) for samples in zip(*batch)))
        elif isinstance(elem, Sequence) and not isinstance(elem, str):
            return [self.collate(samples) for ssamples in zip(*batch)]

        raise TypeError('DataLoader found invalid type: {}'.format(type(elem)))

    def __call__(self, batch):
        return self.collate(batch)


class TSGenDataLoader(torch.utils.data.DataLoader):

    def __init__(self, dataset, batch_size = 1, shuffle = False, \
        follow_batch = [], exclude_keys = [], **kwargs):

        if "collate_fn" in kwargs:
            del kwargs["collate_fn"]
        
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

        super(TSGenDataLoader, self).__init__(dataset, batch_size, shuffle, \
            collate_fn = TSGenCollater(follow_batch, exclude_keys), **kwargs)




## Collate function

In [None]:
import torch

# specific collate_fn in DataLoader

def collate_fn(batch):

    batch = {key: batch_stack([graph[key] for graph in batch]) for key in batch[0].keys()}
    batch = {key: drop_z}
