In [1]:
import h5py
import numpy as np

In [2]:
f1 = h5py.File('postera_protease2_pos_neg_train.hdf5', 'r')
f2 = h5py.File('postera_protease2_pos_neg_test.hdf5', 'r')
f3 = h5py.File('postera_protease2_pos_neg_val.hdf5', 'r')
ligand_names = list(f1.keys())

In [3]:
print(len(ligand_names))

19533


In [4]:
print(ligand_names[:31])

['104679481_protease2_1', '104679481_protease2_10', '104679481_protease2_2', '104679481_protease2_3', '104679481_protease2_4', '104679481_protease2_5', '104679481_protease2_6', '104679481_protease2_7', '104679481_protease2_8', '104679481_protease2_9', '104680060_protease2_1', '104680060_protease2_10', '104680060_protease2_2', '104680060_protease2_3', '104680060_protease2_4', '104680060_protease2_5', '104680060_protease2_6', '104680060_protease2_7', '104680060_protease2_8', '104680060_protease2_9', '1161709_protease2_1', '1161709_protease2_10', '1161709_protease2_2', '1161709_protease2_3', '1161709_protease2_4', '1161709_protease2_5', '1161709_protease2_6', '1161709_protease2_7', '1161709_protease2_8', '1161709_protease2_9', '11748279_protease2_1']


In [6]:
for name in ligand_names:
    ligand = f1[name]['ligand'] 
    label = f1[name].attrs['label'] 

    atom_coordinates = ligand[:, :3]
    atom_coordinates.shape

    atom_features = ligand[:, 3:]
    atom_features.shape

    print(f'{name}: bind={label}({type(label)}) coordinates={atom_coordinates.shape} features={atom_features.shape}')

    break

104679481_protease2_1: bind=0(<class 'numpy.int64'>) coordinates=(100, 3) features=(100, 19)


In [13]:
atoms = []
labels = {'pos': 0, 'neg': 0}
mean_position = []
for name in f1.keys():
    if name[-1] != '1':
        continue
    ligand = f1[name]['ligand']
    
    label = f1[name].attrs['label']
    if label > 0:
        labels['pos'] += 1
    else:
        labels['neg'] += 1

    atom_coordinates = ligand[:, :3]
    mean_position.append(atom_coordinates.mean(axis=0))
    
    atom_features = ligand[:, 3:]
    atoms.append(atom_features)

print(labels)
    
mean_position = np.row_stack(mean_position)
assert mean_position.shape[-1] == 3, mean_position.shape
print(f'mean position of atoms: {mean_position.mean(axis=0)}')

atoms = np.vstack(atoms)
print(f'total atoms in dataset: {atoms.shape}')
unique_atoms, counts = np.unique(atoms, return_counts=True, axis=0)

asc = np.argsort(counts)
print(f'number of distinct atoms: {counts.shape}')
print(f'greatest_freq: {counts.max()}, next_greatest_freq: {counts[asc[-2]]}, mean: {counts.mean()}, least: {counts.min()}')
print(np.hstack([counts[asc[-3:]].reshape([-1, 1]), unique_atoms[asc[-3:]]]))

for i, feature in enumerate(atoms.T):
    print(feature.shape)
    unique_values, counts = np.unique(feature, return_counts=True)
    print(f'feature_{i} has {unique_values.size} unique values {unique_values} frequency {counts}')

{'pos': 10470, 'neg': 9063}
mean position of atoms: [ 2.6895690e-09 -1.0401038e-10  1.3006661e-09]
total atoms in dataset: (1953300, 19)
number of distinct atoms: (4552,)
greatest_freq: 1409342, next_greatest_freq: 12730, mean: 429.10808435852374, least: 1
[[9.00000000e+03 0.00000000e+00 1.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 2.00000000e+00 2.00000000e+00
  0.00000000e+00 9.99999978e-03 1.00000000e+00 1.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 1.00000000e+00]
 [1.27300000e+04 0.00000000e+00 1.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 2.00000000e+00 2.00000000e+00
  0.00000000e+00 1.00000005e-03 1.00000000e+00 1.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 1.00000000e+00]
 [1.40934200e+06 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+

feature_14 has 2 unique values [0. 1.] frequency [1688349  264951]
(1953300,)
feature_15 has 2 unique values [0. 1.] frequency [1936710   16590]
(1953300,)
feature_16 has 2 unique values [0. 1.] frequency [1840361  112939]
(1953300,)
feature_17 has 2 unique values [0. 1.] frequency [1925607   27693]
(1953300,)
feature_18 has 2 unique values [0. 1.] frequency [1586680  366620]


In [9]:
atoms = []
labels = {'pos': 0, 'neg': 0}
mean_position = []
for name in f2.keys():
    if name[-1] != '1':
        continue
    ligand = f2[name]['ligand']
    
    label = f2[name].attrs['label']
    if label > 0:
        labels['pos'] += 1
    else:
        labels['neg'] += 1

    atom_coordinates = ligand[:, :3]
    mean_position.append(atom_coordinates.mean(axis=0))
    
    atom_features = ligand[:, 3:]
    atoms.append(atom_features)

print(labels)
    
mean_position = np.row_stack(mean_position)
assert mean_position.shape[-1] == 3, mean_position.shape
print(f'mean position of atoms: {mean_position.mean(axis=0)}')

atoms = np.vstack(atoms)
print(f'total atoms in dataset: {atoms.shape}')
unique_atoms, counts = np.unique(atoms, return_counts=True, axis=0)

asc = np.argsort(counts)
print(f'number of distinct atoms: {counts.shape}')
print(f'greatest_freq: {counts.max()}, next_greatest_freq: {counts[asc[-2]]}, mean: {counts.mean()}, least: {counts.min()}')
print(np.hstack([counts[asc[-3:]].reshape([-1, 1]), unique_atoms[asc[-3:]]]))

for i, feature in enumerate(atoms.T):
    print(feature.shape)
    unique_values, counts = np.unique(feature, return_counts=True)
    print(f'feature_{i} has {unique_values.size} unique values {unique_values} frequency {counts}')

{'pos': 68, 'neg': 60}
mean position of atoms: [ 1.3820829e-08  8.2119360e-09 -4.2170281e-08]
total atoms in dataset: (12800, 19)
number of distinct atoms: (1414,)
greatest_freq: 9320, next_greatest_freq: 87, mean: 9.052333804809052, least: 1
[[4.60000000e+01 0.00000000e+00 1.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 2.00000000e+00 2.00000000e+00
  0.00000000e+00 1.00000005e-03 1.00000000e+00 1.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 1.00000000e+00]
 [8.70000000e+01 0.00000000e+00 1.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 2.00000000e+00 2.00000000e+00
  0.00000000e+00 0.00000000e+00 1.00000000e+00 1.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 1.00000000e+00]
 [9.32000000e+03 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.0000000

In [10]:
atoms = []
labels = {'pos': 0, 'neg': 0}
mean_position = []
for name in f3.keys():
    if name[-1] != '1':
        continue
    ligand = f3[name]['ligand']
    
    label = f3[name].attrs['label']
    if label > 0:
        labels['pos'] += 1
    else:
        labels['neg'] += 1

    atom_coordinates = ligand[:, :3]
    mean_position.append(atom_coordinates.mean(axis=0))
    
    atom_features = ligand[:, 3:]
    atoms.append(atom_features)

print(labels)
    
mean_position = np.row_stack(mean_position)
assert mean_position.shape[-1] == 3, mean_position.shape
print(f'mean position of atoms: {mean_position.mean(axis=0)}')

atoms = np.vstack(atoms)
print(f'total atoms in dataset: {atoms.shape}')
unique_atoms, counts = np.unique(atoms, return_counts=True, axis=0)

asc = np.argsort(counts)
print(f'number of distinct atoms: {counts.shape}')
print(f'greatest_freq: {counts.max()}, next_greatest_freq: {counts[asc[-2]]}, mean: {counts.mean()}, least: {counts.min()}')
print(np.hstack([counts[asc[-3:]].reshape([-1, 1]), unique_atoms[asc[-3:]]]))

for i, feature in enumerate(atoms.T):
    print(feature.shape)
    unique_values, counts = np.unique(feature, return_counts=True)
    print(f'feature_{i} has {unique_values.size} unique values {unique_values} frequency {counts}')

{'pos': 62, 'neg': 51}
mean position of atoms: [ 1.5127979e-08 -1.1185102e-08  5.2325517e-09]
total atoms in dataset: (11300, 19)
number of distinct atoms: (1168,)
greatest_freq: 8223, next_greatest_freq: 67, mean: 9.674657534246576, least: 1
[[5.60000000e+01 0.00000000e+00 1.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 2.00000000e+00 2.00000000e+00
  0.00000000e+00 1.00000005e-03 1.00000000e+00 1.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 1.00000000e+00]
 [6.70000000e+01 0.00000000e+00 1.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 2.00000000e+00 2.00000000e+00
  0.00000000e+00 0.00000000e+00 1.00000000e+00 1.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 1.00000000e+00]
 [8.22300000e+03 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.0000000

In [6]:
name = ligand_names[0]
ligand = f[name]['ligand']
label = f[name].attrs['label']

atom_coordinates = ligand[:, :3]
atom_coordinates.shape

atom_features = ligand[:, 3:]
atom_features.shape

for i, features in enumerate(atom_features):
    if features.sum() == 0:
        print(f'{i}:\t coords={atom_coordinates[i]} features={features}')

37:	 coords=[0. 0. 0.] features=[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
38:	 coords=[0. 0. 0.] features=[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
39:	 coords=[0. 0. 0.] features=[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
40:	 coords=[0. 0. 0.] features=[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
41:	 coords=[0. 0. 0.] features=[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
42:	 coords=[0. 0. 0.] features=[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
43:	 coords=[0. 0. 0.] features=[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
44:	 coords=[0. 0. 0.] features=[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
45:	 coords=[0. 0. 0.] features=[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
46:	 coords=[0. 0. 0.] features=[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
47:	 coords=[0. 0. 0.] features=[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

After analyzing the data, there are only 9 different atoms (one-hot encoding features 0-8) in this dataset. However there may not be enough data to make good embeddings for each atom e.g., (atoms 4, and 8 rarely appear and atoms 0, 6 never do). 

It may be necessary to make separate embeddings based on the other 19 descriptors (features 9-18). However, feature 12 has an infinite range of values, so it cannot be used to train embeddings.

After having embeddings for both the atom type and their descriptors, they can be concatenated. Additionally, feature 12 can be concatenated with the embeddings too.

# Example Dataloader

In [14]:
import torch
from torch import tensor
from torch.utils.data import Dataset, DataLoader

class LigandsDataset(Dataset):
    def __init__(self, hdf5_path):
        super(AtomsDataset, self).__init__()
        self.File = h5py.File(hdf5_path, 'r')
        self.ligand_names = list(self.File.keys())
        
    def __len__(self):
        return len(self.ligand_names)
    
    def __getitem__(self, x):
        name = self.ligand_names[x]
        ligand = self.File[name]['ligand']
        label = tensor(self.File[name].attrs['label'])
        
        coordinates = tensor(ligand[:, :3])
        features = tensor(ligand[:, 3:])
        
        return coordinates, features, label 

In [15]:
dataset = LigandsDataset('postera_protease2_pos_neg_train.hdf5')

In [18]:
coords, features, label = dataset[0]

print(coords.shape, features.shape, label.shape)

torch.Size([100, 3]) torch.Size([100, 19]) torch.Size([])


# Embedding Dicts and IDs

Creates 2 embedding dictionary matrices (saves to atoms.pt and descriptors.pt) for each unique atom and set of descriptors

Then creates a new HDF5 file with all ligand graphs from the postera dataset with the embedding ID for each atoms and their descriptors (as their own datasets).

###  Dictionaries

In [1]:
import h5py
import numpy as np

In [11]:
f1 = h5py.File('postera_protease2_pos_neg_train.hdf5', 'r')
f2 = h5py.File('postera_protease2_pos_neg_test.hdf5', 'r')
f3 = h5py.File('postera_protease2_pos_neg_val.hdf5', 'r')

Just in case, check the number of unique atoms in the datasets.

In [19]:
atoms = []
atom_descriptors = []
for name in f1.keys():
    if name[-1] != '1':
        continue
    ligand = f1[name]['ligand'][:, 3:]
    atom_onehot = ligand[:, :9]
    atoms.append(atom_onehot)
    descriptors = (ligand[:, 9:12], ligand[:, 13:]) # skip feature 12 because it has an infinite domain
    atom_descriptors.append(np.hstack(descriptors))
    
for name in f2.keys():
    if name[-1] != '1':
        continue
    ligand = f2[name]['ligand'][:, 3:]
    atom_onehot = ligand[:, :9]
    atoms.append(atom_onehot)
    descriptors = (ligand[:, 9:12], ligand[:, 13:])
    atom_descriptors.append(np.hstack(descriptors))
    
for name in f3.keys():
    if name[-1] != '1':
        continue
    ligand = f3[name]['ligand'][:, 3:]
    atom_onehot = ligand[:, :9]
    atoms.append(atom_onehot)
    descriptors = (ligand[:, 9:12], ligand[:, 13:])
    atom_descriptors.append(np.hstack(descriptors))

atoms = np.vstack(atoms)
print(f'total atoms in dataset: {atoms.shape}')
unique_atoms, counts = np.unique(atoms, return_counts=True, axis=0)

asc = np.argsort(counts)
print(f'number of distinct atoms: {counts.shape}')
print(f'greatest_freq: {counts.max()}, next_greatest_freq: {counts[asc[-2]]}, mean: {counts.mean()}, least: {counts.min()}')

atoms_sorted = unique_atoms[asc[:-1]]
atoms_sorted = np.flipud(atoms_sorted)
print(atoms_sorted)

total atoms in dataset: (219600, 9)
number of distinct atoms: (8,)
greatest_freq: 158581, next_greatest_freq: 45384, mean: 27450.0, least: 1
[[0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1.]]


Also, check the number of unqiue descriptors in the dataset.

In [20]:
atom_descriptors = np.vstack(atom_descriptors)
print(f'total descriptors in dataset: {atom_descriptors.shape}')
unique_descriptors, counts = np.unique(atom_descriptors, return_counts=True, axis=0)

asc = np.argsort(counts)
print(f'number of distinct descriptors: {counts.shape}')
print(f'greatest_freq: {counts.max()}, next_greatest_freq: {counts[asc[-2]]}, mean: {counts.mean()}, least: {counts.min()}')

descriptors_sorted = unique_descriptors[asc[:-1]]
descriptors_sorted = np.flipud(descriptors_sorted)
print(descriptors_sorted)

total descriptors in dataset: (219600, 9)
number of distinct descriptors: (139,)
greatest_freq: 158581, next_greatest_freq: 15284, mean: 1579.8561151079136, least: 1
[[2. 2. 0. ... 0. 0. 1.]
 [2. 3. 0. ... 0. 0. 1.]
 [2. 3. 1. ... 0. 0. 1.]
 ...
 [2. 3. 2. ... 1. 0. 1.]
 [2. 3. 1. ... 1. 1. 0.]
 [6. 6. 0. ... 0. 0. 0.]]


In [None]:
import torch
from torch import tensor

Save both matrices (which are embedding dictionaries) for later use.

In [None]:
cpy = atoms_sorted.copy()
torch.save(tensor(cpy), 'atoms.pt')

cpy = descriptors_sorted.copy()
torch.save(tensor(cpy), 'descriptors.pt')

Write a helper function that finds the index in our embed dict that contains the given vector. Otherwise returns -1

In [23]:
def get_id(embed, vec):
    if np.all(vec == 0):
        return -1
    assert vec.shape[-1] == embed.shape[-1], (vec.shape, embed.shape)
    element_comp = np.equal(embed, vec)
    row_comp = np.all(element_comp, axis=1)
    indices = np.argwhere(row_comp)
    assert indices.size == 1, f'onehot {onehot}\nindices {indices}\n'
    return indices.item()

Create the HDF5 file to write in

In [None]:
outfile = h5py.File('postera_protease2_id_graph.hdf5','w')

Translate the ligand graphs from the training, testing, and validation HDF5 files into our graphs of our embedding IDs

In [29]:
for i, name in enumerate(f1.keys()):
    if name[-1] != '1':
        continue
    grp = outfile.create_group(name)
    
    old_ligand = f1[name]['ligand']
    old_label = f1[name].attrs['label']
    new_label = grp.attrs.create('label', old_label)

    atom_coordinates = old_ligand[:, :3] # (100, 3)
    atom_onehot = old_ligand[:, 3:12] # (100, 9)
    atom_descriptors = np.hstack((old_ligand[:, 12:15], old_ligand[:, 16:]))
    
    copy_coordinates = grp.create_dataset('coordinates', (100, 3))
    copy_coordinates[()] = atom_coordinates
    
    print('.', end='') # initialized ligand group and created atom coordinates dataset
    
    ligand_atoms = grp.create_dataset('atoms', (100, 1), dtype='i8')
    
    atoms = []
    for vec in atom_onehot:
        atoms.append(get_id(atoms_sorted, vec))
        
    atoms = np.array(atoms)
    atoms = np.expand_dims(atoms, -1)
    assert atoms.shape == (100, 1), atoms.shape
    
    ligand_atoms[()] = atoms
    
    print('.', end='') # created ligand's atom dataset with atom ids
    
    ligand_descriptors = grp.create_dataset('descriptors', (100, 1), dtype='i8')
    
    descriptors = []
    for vec in atom_descriptors:
        descriptors.append(get_id(descriptors_sorted, vec))
        
    descriptors = np.array(descriptors)
    descriptors = np.expand_dims(descriptors, -1)
    assert descriptors.shape == (100, 1), descriptors.shape
    
    ligand_descriptors[()] = descriptors
    
    print('.', end='') # created ligand's descriptors dataset with descriptor ids
    
    print(i)

...0
...10
...20
...30
...40
...50
...60
...70
...80
...90
...100
...110
...120
...130
...140
...150
...160
...170
...180
...190
...200
...210
...220
...230
...240
...250
...260
...270
...280
...290
...300
...310
...320
...330
...340
...350
...360
...370
...380
...390
...400
...410
...420
...430
...440
...450
...460
...470
...480
...490
...500
...510
...520
...530
...540
...550
...560
...570
...580
...590
...600
...610
...620
...630
...640
...650
...660
...670
...680
...690
...700
...710
...720
...730
...740
...750
...760
...770
...780
...790
...800
...810
...820
...830
...840
...850
...860
...870
...880
...890
...900
...910
...920
...930
...940
...950
...960
...970
...980
...990
...1000
...1010
...1020
...1030
...1040
...1050
...1060
...1070
...1080
...1090
...1100
...1110
...1120
...1130
...1140
...1150
...1160
...1170
...1180
...1190
...1200
...1210
...1220
...1230
...1240
...1250
...1260
...1270
...1280
...1290
...1300
...1310
...1320
...1330
...1340
...1350
...1360
...1370
...1380

...10492
...10502
...10512
...10522
...10532
...10542
...10552
...10562
...10572
...10582
...10592
...10602
...10612
...10622
...10632
...10642
...10652
...10662
...10672
...10682
...10692
...10702
...10712
...10722
...10732
...10742
...10752
...10762
...10772
...10782
...10792
...10802
...10812
...10822
...10832
...10842
...10852
...10862
...10872
...10882
...10892
...10902
...10912
...10922
...10932
...10942
...10952
...10962
...10972
...10982
...10992
...11002
...11012
...11022
...11032
...11042
...11052
...11062
...11072
...11082
...11092
...11102
...11112
...11122
...11132
...11142
...11152
...11162
...11172
...11182
...11192
...11202
...11212
...11222
...11232
...11242
...11252
...11262
...11272
...11282
...11292
...11302
...11312
...11322
...11332
...11342
...11352
...11362
...11372
...11382
...11392
...11402
...11412
...11422
...11432
...11442
...11452
...11462
...11472
...11482
...11492
...11502
...11512
...11522
...11532
...11542
...11552
...11562
...11572
...11582
...11592
.

In [31]:
for i, name in enumerate(f2.keys()):
    if name[-1] != '1':
        continue
    grp = outfile.create_group(name)
    
    old_ligand = f2[name]['ligand']
    old_label = f2[name].attrs['label']
    new_label = grp.attrs.create('label', old_label)

    atom_coordinates = old_ligand[:, :3] # (100, 3)
    atom_onehot = old_ligand[:, 3:12] # (100, 9)
    atom_descriptors = np.hstack((old_ligand[:, 12:15], old_ligand[:, 16:]))
    
    copy_coordinates = grp.create_dataset('coordinates', (100, 3))
    copy_coordinates[()] = atom_coordinates
    
    print('.', end='') # initialized ligand group and created atom coordinates dataset
    
    ligand_atoms = grp.create_dataset('atoms', (100, 1), dtype='i8')
    
    atoms = []
    for vec in atom_onehot:
        atoms.append(get_id(atoms_sorted, vec))
        
    atoms = np.array(atoms)
    atoms = np.expand_dims(atoms, -1)
    assert atoms.shape == (100, 1), atoms.shape
    
    ligand_atoms[()] = atoms
    
    print('.', end='') # created ligand's atom dataset with atom ids
    
    ligand_descriptors = grp.create_dataset('descriptors', (100, 1), dtype='i8')
    
    descriptors = []
    for vec in atom_descriptors:
        descriptors.append(get_id(descriptors_sorted, vec))
        
    descriptors = np.array(descriptors)
    descriptors = np.expand_dims(descriptors, -1)
    assert descriptors.shape == (100, 1), descriptors.shape
    
    ligand_descriptors[()] = descriptors
    
    print('.', end='') # created ligand's descriptors dataset with descriptor ids
    
    print(i)

...0
...10
...20
...30
...40
...50
...60
...70
...80
...90
...100
...110
...120
...130
...140
...150
...160
...170
...180
...190
...200
...210
...220
...230
...240
...250
...260
...270
...280
...290
...300
...310
...320
...330
...340
...350
...360
...370
...380
...390
...400
...410
...420
...430
...440
...450
...460
...470
...480
...490
...500
...510
...520
...530
...540
...550
...560
...570
...580
...590
...600
...610
...620
...630
...640
...650
...660
...670
...680
...690
...700
...710
...720
...730
...740
...750
...760
...770
...780
...790
...800
...810
...820
...830
...840
...850
...860
...870
...880
...890
...900
...910
...920
...930
...940
...950
...960
...970
...980
...990
...1000
...1010
...1020
...1030
...1040
...1050
...1060
...1070
...1080
...1090
...1100
...1110
...1120
...1130
...1140
...1150
...1160
...1170
...1180
...1190
...1200
...1210
...1220
...1230
...1240
...1250
...1260
...1270


In [32]:
for i, name in enumerate(f3.keys()):
    if name[-1] != '1':
        continue
    grp = outfile.create_group(name)
    
    old_ligand = f3[name]['ligand']
    old_label = f3[name].attrs['label']
    new_label = grp.attrs.create('label', old_label)

    atom_coordinates = old_ligand[:, :3] # (100, 3)
    atom_onehot = old_ligand[:, 3:12] # (100, 9)
    atom_descriptors = np.hstack((old_ligand[:, 12:15], old_ligand[:, 16:]))
    
    copy_coordinates = grp.create_dataset('coordinates', (100, 3))
    copy_coordinates[()] = atom_coordinates
    
    print('.', end='') # initialized ligand group and created atom coordinates dataset
    
    ligand_atoms = grp.create_dataset('atoms', (100, 1), dtype='i8')
    
    atoms = []
    for vec in atom_onehot:
        atoms.append(get_id(atoms_sorted, vec))
        
    atoms = np.array(atoms)
    atoms = np.expand_dims(atoms, -1)
    assert atoms.shape == (100, 1), atoms.shape
    
    ligand_atoms[()] = atoms
    
    print('.', end='') # created ligand's atom dataset with atom ids
    
    ligand_descriptors = grp.create_dataset('descriptors', (100, 1), dtype='i8')
    
    descriptors = []
    for vec in atom_descriptors:
        descriptors.append(get_id(descriptors_sorted, vec))
        
    descriptors = np.array(descriptors)
    descriptors = np.expand_dims(descriptors, -1)
    assert descriptors.shape == (100, 1), descriptors.shape
    
    ligand_descriptors[()] = descriptors
    
    print('.', end='') # created ligand's descriptors dataset with descriptor ids
    
    print(i)

...0
...10
...20
...30
...40
...50
...60
...70
...80
...90
...100
...110
...120
...130
...140
...150
...160
...170
...180
...190
...200
...210
...220
...230
...240
...250
...260
...270
...280
...290
...300
...310
...320
...330
...340
...350
...360
...370
...380
...390
...400
...410
...420
...430
...440
...450
...460
...470
...480
...490
...500
...510
...520
...530
...540
...550
...560
...570
...580
...590
...600
...610
...620
...630
...640
...650
...660
...670
...680
...690
...700
...710
...720
...730
...740
...750
...760
...770
...780
...790
...800
...810
...820
...830
...840
...850
...860
...870
...880
...890
...900
...910
...920
...930
...940
...950
...960
...970
...980
...990
...1000
...1010
...1020
...1030
...1040
...1050
...1060
...1070
...1080
...1090
...1100
...1110
...1120


Last sanity checks

In [33]:
ligand_names = list(outfile.keys())
print(ligand_names)

['104679481_protease2_1', '104680060_protease2_1', '1161709_protease2_1', '11748279_protease2_1', '11798390_protease2_1', '11808502_protease2_1', '12604832_protease2_1', '1321130_protease2_1', '13987632_protease2_1', '16016872_protease2_1', '16135771_protease2_1', '16292231_protease2_1', '1679888_protease2_1', '17179139_protease2_1', '171822134_protease2_1', '171835375_protease2_1', '171966708_protease2_1', '1812803_protease2_1', '18857566_protease2_1', '24322185_protease2_1', '24346622_protease2_1', '24424670_protease2_1', '24425568_protease2_1', '24437634_protease2_1', '25392386_protease2_1', '25545621_protease2_1', '26174112_protease2_1', '2943672_protease2_1', '300146207_protease2_1', '30536096_protease2_1', '3172642_protease2_1', '32494230_protease2_1', '32782830_protease2_1', '35936444_protease2_1', '4204459_protease2_1', '43679490_protease2_1', '43680058_protease2_1', '43756189_protease2_1', '43838718_protease2_1', '43872339_protease2_1', '43926720_protease2_1', '43966586_protea

In [30]:
for name in outfile.keys():
    print(name)
    coordinates = outfile[name]['coordinates'][()]
    atoms = outfile[name]['atoms'][()]
    descriptors = outfile[name]['descriptors'][()]
    label = outfile[name].attrs['label']
    print(coordinates.shape)
    print(atoms.shape, atoms)
    print(descriptors.shape, descriptors)
    print(label.shape, label)
    break

104679481_protease2_1
(100, 3)
(100, 1) [[ 0]
 [ 0]
 [ 0]
 [ 1]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 1]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 1]
 [ 0]
 [ 1]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 2]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]]
(100, 1) [[23]
 [11]
 [18]
 [ 7]
 [ 2]
 [ 1]
 [ 2]
 [ 0]
 [ 0]
 [ 0]
 [17]
 [ 4]
 [ 1]
 [ 1]
 [ 0]
 [ 1]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 2]
 [ 7]
 [18]
 [17]
 [ 6]
 [11]
 [11]
 [11]
 [ 6]
 [26]
 [10]
 [ 1]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [ 0]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [-1]
 [

In [35]:
outfile.close()
f1.close()
f2.close()
f3.close()

### Sequence Generation

Now to write a dataloader that generates sequences from our new embedding dataset, which can be used to learn graph embeddings.