# Collate some site graphs into features

In [1]:
from metalsitenn.dataloading import MetalSiteDataset
from metalsitenn.featurizer import MetalSiteFeaturizer
from metalsitenn.utils import visualize_chain_3d, visualize_protein_data_3d
import numpy as np
import torch

from metalsitenn.placer_modules.cifutils import CIFParser

In [2]:
parser = CIFParser()

In [3]:
ds = MetalSiteDataset(
    cache_folder='../../bonnanzio_metal_site_modeling/data/1/1.1_parse_sites_metadata',
)

In [4]:
md = ds.get_all_metadata()
md

Unnamed: 0,pdb_code,site_name,site_idx,n_entities,n_atoms,n_bonds,metal,n_metals,n_waters,n_organic_ligands,...,n_amino_acids,n_coordinating_amino_acids,n_nucleotides,non_residue_non_metal_names,n_non_residue_non_metal,coordination_distance,n_unresolved_removed,coordinating_residues,resolution,max_rczd
0,6fpw,6fpw_0,0,19,158,150,Fe,4,0,0,...,18,4,0,,0,2.9,0,29123,1.35,
1,6fpw,6fpw_1,1,22,138,124,Fe,3,3,0,...,18,3,0,,0,2.9,0,71613,1.35,
2,6fpw,6fpw_2,2,29,203,184,Fe,4,2,0,...,26,6,0,,0,2.9,0,135153196,1.35,
3,6fpw,6fpw_3,3,20,141,123,"Fe,Ni",2,0,0,...,19,4,0,,0,2.9,0,619317,1.35,
4,6fpw,6fpw_4,4,19,101,83,Mg,1,6,0,...,12,3,0,,0,2.9,0,1261,1.35,1.6
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
126231,1fyu,1fyu_0,0,15,103,93,Mn,1,1,0,...,12,4,0,,0,2.9,0,2175,2.60,
126232,1fyu,1fyu_1,1,18,130,117,Ca,1,1,1,...,14,4,0,GAL,1,2.9,0,57911,2.60,
126233,1fyu,1fyu_2,2,16,104,93,Mn,1,2,0,...,12,4,0,,0,2.9,0,2157,2.60,
126234,1fyu,1fyu_3,3,18,123,110,Ca,1,2,1,...,13,4,0,GAL,1,2.9,0,57119,2.60,


In [5]:
site_chain = ds[1][1]

In [6]:
visualize_chain_3d(site_chain)

<py3Dmol.view at 0x7fe90ed386a0>

In [7]:
clean_chain = parser.clean_metal_bonding_patterns(site_chain)

In [8]:
visualize_chain_3d(clean_chain)

<py3Dmol.view at 0x7fe90ed383d0>

In [9]:
featurizer = MetalSiteFeaturizer(
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic',]
)

### Call featurizer with no special effects, just convert

In [10]:
# first atom stuff
features = featurizer._init_atoms_into_protein_data(clean_chain, metal_unknown=True)
# now graph and bond construction
# these are seperatre because in pracice if we are doing any noising or collapsing, it must happen
# between these two steps such that there is no leakage in the form of graph structure
features = featurizer._make_graph_and_tokenize_edges(clean_chain, features)


In [11]:
features

ProteinData(
  element: shape=(138, 1),
  charge: shape=(138, 1),
  nhyd: shape=(138, 1),
  hyb: shape=(138, 1),
  positions: shape=(138, 3),
  atom_movable_mask=None,
  atom_name: shape=(138, 1),
  atom_resname: shape=(138, 1),
  atom_resid: shape=(138, 1),
  atom_ishetero: shape=(138, 1),
  distances: shape=(2760, 1),
  bond_order: shape=(2760, 1),
  is_aromatic: shape=(2760, 1),
  is_in_ring: shape=(2760, 1),
  edge_index: shape=(2760, 2),
  topology={'bonds': tensor([[ 29,  30],
        [110, 111],
        [ 21,  23],
        [ 37,  38],
        [ 15,  16],
        [ 51,  52],
        [ 92,  93],
        [101, 103],
        [ 10,  16],
        [ 36,  37],
        [ 61,  64],
        [ 27,  28],
        [ 97, 100],
        [105, 106],
        [100, 101],
        [ 78,  79],
        [ 12,  13],
        [116, 117],
        [ 17,  18],
        [108, 109],
        [ 64,  65],
        [  5,   8],
        [ 72,  75],
        [ 32,  33],
        [ 46,  47],
        [ 11,  12],
        [ 71

### Observe vocab

In [12]:
featurizer.tokenizers['element'].i2d

{0: '<UNK>',
 1: '<MASK>',
 2: 'Ag',
 3: 'Al',
 4: 'Au',
 5: 'Br',
 6: 'C',
 7: 'Ca',
 8: 'Cd',
 9: 'Cl',
 10: 'Co',
 11: 'Cr',
 12: 'Cu',
 13: 'Dy',
 14: 'F',
 15: 'Fe',
 16: 'Ga',
 17: 'Hg',
 18: 'I',
 19: 'Ir',
 20: 'K',
 21: 'La',
 22: 'Li',
 23: 'Mg',
 24: 'Mn',
 25: 'Mo',
 26: 'N',
 27: 'Na',
 28: 'Nd',
 29: 'Ni',
 30: 'O',
 31: 'P',
 32: 'Pb',
 33: 'Pr',
 34: 'Pt',
 35: 'S',
 36: 'Se',
 37: 'Si',
 38: 'Tb',
 39: 'Te',
 40: 'Ti',
 41: 'U',
 42: 'V',
 43: 'W',
 44: 'Zn',
 45: '<METAL>'}

In [13]:
atom_id = np.where(np.array(features.atom_name) == 'CZ2')[0][0]
atom_id

57

In [14]:
features.element

tensor([[26],
        [ 6],
        [ 6],
        [30],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [35],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [ 6],
        [ 6],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [30],
        [ 6],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [ 6],
        [30],
        [26],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [35],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [26],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
      

In [15]:
features.element[atom_id]

tensor([6], dtype=torch.int32)

Yay it is carbon

In [16]:
featurizer.tokenizers['element'].metal_token_id

45

In [17]:
# with unknown metal
47 in features.element

False

In [18]:
featurizer.tokenizers['element'].get_vocab()

{'<UNK>': 0,
 '<MASK>': 1,
 'Ag': 2,
 'Al': 3,
 'Au': 4,
 'Br': 5,
 'C': 6,
 'Ca': 7,
 'Cd': 8,
 'Cl': 9,
 'Co': 10,
 'Cr': 11,
 'Cu': 12,
 'Dy': 13,
 'F': 14,
 'Fe': 15,
 'Ga': 16,
 'Hg': 17,
 'I': 18,
 'Ir': 19,
 'K': 20,
 'La': 21,
 'Li': 22,
 'Mg': 23,
 'Mn': 24,
 'Mo': 25,
 'N': 26,
 'Na': 27,
 'Nd': 28,
 'Ni': 29,
 'O': 30,
 'P': 31,
 'Pb': 32,
 'Pr': 33,
 'Pt': 34,
 'S': 35,
 'Se': 36,
 'Si': 37,
 'Tb': 38,
 'Te': 39,
 'Ti': 40,
 'U': 41,
 'V': 42,
 'W': 43,
 'Zn': 44,
 '<METAL>': 45}

In [19]:
assert 16 not in features.element

Indeed no iron remains, converted to unknown metal

In [20]:
# check other metals are properly converted to metal token
featurizer.tokenizers['element'].encode('Ho')

45

In [21]:
features = featurizer._init_atoms_into_protein_data(clean_chain, metal_unknown=False)
# now graph and bond construction
features = featurizer._make_graph_and_tokenize_edges(clean_chain, features)

In [22]:
assert 15 in features.element

### Charge

In [23]:
featurizer.tokenizers['charge'].i2d

{0: '<MASK>', 1: -3, 2: -2, 3: -1, 4: 0, 5: 1, 6: 2, 7: 3}

In [24]:
np.array(features.atom_name).reshape(-1,1)[features.charge !=4]

array(['NZ'], dtype='<U3')

Indeed that is a charged atom on a Lysine

### Num hydr

In [25]:
featurizer.tokenizers['nhyd'].i2d

{0: '<UNK>', 1: '<MASK>', 2: 0, 3: 1, 4: 2, 5: 3, 6: 4}

In [26]:
features.nhyd[atom_id]

tensor([3], dtype=torch.int32)

atom id is from a CZ2 - an aromatic carbon, which should have 1 hydrogen, indeed it does

Chek metal has been set to known or mask

In [27]:
metal_mask = features.element == featurizer.tokenizers['element'].get_vocab()['Fe']
metal_mask.sum()

tensor(3)

In [28]:
featurizer.tokenizers['nhyd'].get_vocab()

{'<UNK>': 0, '<MASK>': 1, 0: 2, 1: 3, 2: 4, 3: 5, 4: 6}

In [29]:
features.nhyd[metal_mask]

tensor([0, 0, 0], dtype=torch.int32)

### hybridization

Note in the og code hyb goes i.e., 1 = sp, 2 = sp2, 3 = sp3 ...).

In [30]:
featurizer.tokenizers['hyb'].i2d

{0: '<UNK>', 1: '<MASK>', 2: 0, 3: 1, 4: 2, 5: 3, 6: 4, 7: 5}

In [31]:
features.hyb[atom_id]

tensor([4], dtype=torch.int32)

CZ2 in benzene ring is indeed sp2

In [32]:
features.hyb[metal_mask]

tensor([0, 0, 0], dtype=torch.int32)

### Bond order

In [33]:
featurizer.tokenizers['bond_order'].i2d

{0: '<MASK>', 1: 0, 2: 1, 3: 2, 4: 3, 5: 4}

In [34]:
src, dst = features.edge_index.T

In [35]:
site_chain.residues

{'1': Residue(name='GLY', atoms={'N': Atom(name=('A', '1', 'GLY', 'N'), xyz=[-19.936, 7.408, 17.572], occ=1.0, bfac=8.54, leaving=False, leaving_group=['H2'], parent='CA', element=7, metal=False, charge=0, hyb=2, nhyd=1, hvydeg=1, align=1, hetero=False), 'CA': Atom(name=('A', '1', 'GLY', 'CA'), xyz=[-19.452, 6.122, 17.104], occ=1.0, bfac=9.08, leaving=False, leaving_group=[], parent='C', element=6, metal=False, charge=0, hyb=3, nhyd=2, hvydeg=2, align=1, hetero=False), 'C': Atom(name=('A', '1', 'GLY', 'C'), xyz=[-18.415, 5.592, 18.061], occ=1.0, bfac=7.63, leaving=False, leaving_group=['OXT', 'HXT'], parent='OXT', element=6, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=3, align=1, hetero=False), 'O': Atom(name=('A', '1', 'GLY', 'O'), xyz=[-17.736, 6.358, 18.754], occ=1.0, bfac=8.77, leaving=False, leaving_group=[], parent='C', element=8, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=1, align=1, hetero=False)}, bonds=[Bond(a=('A', '1', 'GLY', 'N'), b=('A', '1', 'GLY', 'CA'), aromatic=Fal

In [36]:
bond_order = features.bond_order

In [37]:
bond_order.shape

torch.Size([2760, 1])

In [38]:
features.atom_resname

array([['GLY'],
       ['GLY'],
       ['GLY'],
       ['GLY'],
       ['CYS'],
       ['CYS'],
       ['CYS'],
       ['CYS'],
       ['CYS'],
       ['CYS'],
       ['PRO'],
       ['PRO'],
       ['PRO'],
       ['PRO'],
       ['PRO'],
       ['PRO'],
       ['PRO'],
       ['ILE'],
       ['ILE'],
       ['ILE'],
       ['ILE'],
       ['ILE'],
       ['ILE'],
       ['ILE'],
       ['ILE'],
       ['THR'],
       ['THR'],
       ['THR'],
       ['THR'],
       ['THR'],
       ['THR'],
       ['THR'],
       ['ASN'],
       ['ASN'],
       ['ASN'],
       ['ASN'],
       ['ASN'],
       ['ASN'],
       ['ASN'],
       ['ASN'],
       ['CYS'],
       ['CYS'],
       ['CYS'],
       ['CYS'],
       ['CYS'],
       ['CYS'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['PHE'],
       ['PHE'],
       [

In [39]:
list(zip(bond_order, np.array(features.atom_resid)[src], np.array(features.atom_name)[src], np.array(features.atom_resid)[dst], np.array(features.atom_name)[dst]))

[(tensor([2], dtype=torch.int32),
  array([1], dtype=int32),
  array(['N'], dtype='<U3'),
  array([1], dtype=int32),
  array(['CA'], dtype='<U3')),
 (tensor([1], dtype=torch.int32),
  array([1], dtype=int32),
  array(['N'], dtype='<U3'),
  array([1], dtype=int32),
  array(['C'], dtype='<U3')),
 (tensor([1], dtype=torch.int32),
  array([1], dtype=int32),
  array(['N'], dtype='<U3'),
  array([1], dtype=int32),
  array(['O'], dtype='<U3')),
 (tensor([1], dtype=torch.int32),
  array([1], dtype=int32),
  array(['N'], dtype='<U3'),
  array([2], dtype=int32),
  array(['N'], dtype='<U3')),
 (tensor([1], dtype=torch.int32),
  array([1], dtype=int32),
  array(['N'], dtype='<U3'),
  array([2], dtype=int32),
  array(['CA'], dtype='<U3')),
 (tensor([1], dtype=torch.int32),
  array([1], dtype=int32),
  array(['N'], dtype='<U3'),
  array([2], dtype=int32),
  array(['CB'], dtype='<U3')),
 (tensor([1], dtype=torch.int32),
  array([1], dtype=int32),
  array(['N'], dtype='<U3'),
  array([3], dtype=int32)

bonds (>2) only occur between atoms in the same resid. Good.

Make sure not bonds for metals

In [40]:
metal_edges = np.where(metal_mask[src] | metal_mask[dst])[0]
metal_edges

array([ 438,  456,  457,  478,  497,  598,  618,  638,  737,  756,  776,
        795,  796,  818,  838,  858,  896,  913,  914,  915, 1098, 1534,
       1535, 1557, 1817, 1837, 1858, 1878, 1897, 1916, 1938, 2056, 2098,
       2118, 2176, 2177, 2197, 2218, 2239, 2256, 2274, 2298, 2560, 2561,
       2562, 2563, 2564, 2565, 2566, 2567, 2568, 2569, 2570, 2571, 2572,
       2573, 2574, 2575, 2576, 2577, 2578, 2579, 2580, 2581, 2582, 2583,
       2584, 2585, 2586, 2587, 2588, 2589, 2590, 2591, 2592, 2593, 2594,
       2595, 2596, 2597, 2598, 2599, 2600, 2601, 2602, 2603, 2604, 2605,
       2606, 2607, 2608, 2609, 2610, 2611, 2612, 2613, 2614, 2615, 2616,
       2617, 2618, 2619, 2631, 2632, 2633, 2654, 2655, 2656, 2673, 2674,
       2675, 2692, 2693, 2694, 2716, 2717, 2736, 2754, 2755])

In [41]:
featurizer.tokenizers['bond_order'].i2d

{0: '<MASK>', 1: 0, 2: 1, 3: 2, 4: 3, 5: 4}

In [42]:
features.bond_order[metal_edges]

tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],

### Is aromatic

In [43]:
featurizer.tokenizers['is_aromatic'].i2d

{0: '<MASK>', 1: False, 2: True}

In [44]:
(features.is_aromatic ==2).sum()

tensor(32)

In [45]:
is_aromatic = features.is_aromatic

In [46]:
is_aromatic.shape

torch.Size([2760, 1])

In [47]:
aromatic_edges = np.where(is_aromatic == 2)[0]
aromatic_edges

array([1025, 1026, 1047, 1049, 1065, 1068, 1069, 1087, 1089, 1108, 1109,
       1111, 1127, 1131, 1150, 1153, 1171, 1173, 1194, 1195, 1309, 1310,
       1328, 1330, 1349, 1352, 1368, 1371, 1391, 1393, 1411, 1412])

In [48]:
site_chain.residues

{'1': Residue(name='GLY', atoms={'N': Atom(name=('A', '1', 'GLY', 'N'), xyz=[-19.936, 7.408, 17.572], occ=1.0, bfac=8.54, leaving=False, leaving_group=['H2'], parent='CA', element=7, metal=False, charge=0, hyb=2, nhyd=1, hvydeg=1, align=1, hetero=False), 'CA': Atom(name=('A', '1', 'GLY', 'CA'), xyz=[-19.452, 6.122, 17.104], occ=1.0, bfac=9.08, leaving=False, leaving_group=[], parent='C', element=6, metal=False, charge=0, hyb=3, nhyd=2, hvydeg=2, align=1, hetero=False), 'C': Atom(name=('A', '1', 'GLY', 'C'), xyz=[-18.415, 5.592, 18.061], occ=1.0, bfac=7.63, leaving=False, leaving_group=['OXT', 'HXT'], parent='OXT', element=6, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=3, align=1, hetero=False), 'O': Atom(name=('A', '1', 'GLY', 'O'), xyz=[-17.736, 6.358, 18.754], occ=1.0, bfac=8.77, leaving=False, leaving_group=[], parent='C', element=8, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=1, align=1, hetero=False)}, bonds=[Bond(a=('A', '1', 'GLY', 'N'), b=('A', '1', 'GLY', 'CA'), aromatic=Fal

In [49]:
for bond_id in aromatic_edges:
    src, dst = features.edge_index[bond_id, :]
    print(is_aromatic[bond_id], features.atom_resname[src], features.atom_name[src], features.atom_resname[dst], features.atom_name[dst])

tensor([2], dtype=torch.int32) ['TRP'] ['CG'] ['TRP'] ['CD1']
tensor([2], dtype=torch.int32) ['TRP'] ['CG'] ['TRP'] ['CD2']
tensor([2], dtype=torch.int32) ['TRP'] ['CD1'] ['TRP'] ['CG']
tensor([2], dtype=torch.int32) ['TRP'] ['CD1'] ['TRP'] ['NE1']
tensor([2], dtype=torch.int32) ['TRP'] ['CD2'] ['TRP'] ['CG']
tensor([2], dtype=torch.int32) ['TRP'] ['CD2'] ['TRP'] ['CE2']
tensor([2], dtype=torch.int32) ['TRP'] ['CD2'] ['TRP'] ['CE3']
tensor([2], dtype=torch.int32) ['TRP'] ['NE1'] ['TRP'] ['CD1']
tensor([2], dtype=torch.int32) ['TRP'] ['NE1'] ['TRP'] ['CE2']
tensor([2], dtype=torch.int32) ['TRP'] ['CE2'] ['TRP'] ['CD2']
tensor([2], dtype=torch.int32) ['TRP'] ['CE2'] ['TRP'] ['NE1']
tensor([2], dtype=torch.int32) ['TRP'] ['CE2'] ['TRP'] ['CZ2']
tensor([2], dtype=torch.int32) ['TRP'] ['CE3'] ['TRP'] ['CD2']
tensor([2], dtype=torch.int32) ['TRP'] ['CE3'] ['TRP'] ['CZ3']
tensor([2], dtype=torch.int32) ['TRP'] ['CZ2'] ['TRP'] ['CE2']
tensor([2], dtype=torch.int32) ['TRP'] ['CZ2'] ['TRP'] ['CH

The tryptophan and phenolalinine in here has the proper aromatic labels.

### Is in ring

In [50]:
featurizer.tokenizers['is_in_ring'].i2d

{0: '<MASK>', 1: False, 2: True}

In [51]:
is_in_ring = features.is_in_ring

In [52]:
ring_edges = np.where(is_in_ring == 2)[0]

In [53]:
for bondid in ring_edges:
    src, dst = features.edge_index[bondid, :]
    print(is_in_ring[bondid], features.atom_resname[src], features.atom_name[src], features.atom_resname[dst], features.atom_name[dst])

tensor([2], dtype=torch.int32) ['PRO'] ['N'] ['PRO'] ['CA']
tensor([2], dtype=torch.int32) ['PRO'] ['N'] ['PRO'] ['CD']
tensor([2], dtype=torch.int32) ['PRO'] ['CA'] ['PRO'] ['N']
tensor([2], dtype=torch.int32) ['PRO'] ['CA'] ['PRO'] ['CB']
tensor([2], dtype=torch.int32) ['PRO'] ['CB'] ['PRO'] ['CA']
tensor([2], dtype=torch.int32) ['PRO'] ['CB'] ['PRO'] ['CG']
tensor([2], dtype=torch.int32) ['PRO'] ['CG'] ['PRO'] ['CB']
tensor([2], dtype=torch.int32) ['PRO'] ['CG'] ['PRO'] ['CD']
tensor([2], dtype=torch.int32) ['PRO'] ['CD'] ['PRO'] ['N']
tensor([2], dtype=torch.int32) ['PRO'] ['CD'] ['PRO'] ['CG']
tensor([2], dtype=torch.int32) ['TRP'] ['CG'] ['TRP'] ['CD1']
tensor([2], dtype=torch.int32) ['TRP'] ['CG'] ['TRP'] ['CD2']
tensor([2], dtype=torch.int32) ['TRP'] ['CD1'] ['TRP'] ['CG']
tensor([2], dtype=torch.int32) ['TRP'] ['CD1'] ['TRP'] ['NE1']
tensor([2], dtype=torch.int32) ['TRP'] ['CD2'] ['TRP'] ['CG']
tensor([2], dtype=torch.int32) ['TRP'] ['CD2'] ['TRP'] ['CE2']
tensor([2], dtype=to

Awesome. I see some Ca and Cb of only proline. And all of the tryptophan rings are in rings.

#### Visualize the model post featurization, with atoms in a graph noted

In [54]:
visualize_protein_data_3d(features)

<py3Dmol.view at 0x7fe8980903a0>

In [55]:
# get the atoms with edges to the iron
iron_id = np.where(features.element == featurizer.tokenizers['element'].d2i['Fe'])[0][0]
iron_id

128

In [56]:
visualize_protein_data_3d(features, focus_atom=iron_id)

<py3Dmol.view at 0x7fe898090940>

## Let's check on the tokenizer methods we will use to help construct global prediction task

In [57]:
features = featurizer._init_atoms_into_protein_data(clean_chain, metal_unknown=False)
# now graph and bond construction
features = featurizer._make_graph_and_tokenize_edges(clean_chain, features)

In [58]:
tokenizer = featurizer.tokenizers['element']

In [59]:
#tokenizer.get_metal_vocab()

In [60]:
#tokenizer.get_metal_label_indices()

In [61]:
#tokenizer.get_metal_labels_mapping()

In [62]:
labels = tokenizer.encode_metal_composition_counts(['Fe', 'Fe', 'Zn'])

In [63]:
labels

[0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 2.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0]

In [64]:
tokenizer.decode_metal_composition_counts(labels)

['Fe', 'Fe', 'Zn']

In [65]:
tokenizer.get_vocab()['Fe']

15

In [66]:
15 in features.element

True

In [67]:
# try with the actual featurized tensor
labels = tokenizer.encode_metal_composition_counts_from_tokens(features.element)
labels

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 3., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [68]:
# check that atom features and bonds are completely masked
is_metal = features.element == tokenizer.get_vocab()['Fe']
is_metal.sum()

tensor(3)

In [69]:
featurizer.tokenizers['nhyd'].get_vocab()

{'<UNK>': 0, '<MASK>': 1, 0: 2, 1: 3, 2: 4, 3: 5, 4: 6}

In [70]:
features.nhyd[is_metal]

tensor([0, 0, 0], dtype=torch.int32)

In [71]:
featurizer.tokenizers['hyb'].get_vocab()

{'<UNK>': 0, '<MASK>': 1, 0: 2, 1: 3, 2: 4, 3: 5, 4: 6, 5: 7}

In [72]:
features.hyb[is_metal]

tensor([0, 0, 0], dtype=torch.int32)

In [73]:
features = featurizer._anonymize_metals_for_classification(features)

In [74]:
features.global_labels

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 3., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [75]:
visualize_protein_data_3d(features)

<py3Dmol.view at 0x7fe898093970>

In [76]:
features.element[is_metal]

tensor([45, 45, 45], dtype=torch.int32)

In [77]:
featurizer.tokenizers['nhyd'].get_vocab()

{'<UNK>': 0, '<MASK>': 1, 0: 2, 1: 3, 2: 4, 3: 5, 4: 6}

In [78]:
features.nhyd[is_metal]

tensor([0, 0, 0], dtype=torch.int32)

In [79]:
featurizer.tokenizers['hyb'].get_vocab()    

{'<UNK>': 0, '<MASK>': 1, 0: 2, 1: 3, 2: 4, 3: 5, 4: 6, 5: 7}

In [80]:
features.hyb[is_metal]

tensor([0, 0, 0], dtype=torch.int32)

## Try noising atoms for a resid

In [81]:
features = featurizer._init_atoms_into_protein_data(clean_chain, metal_unknown=False)
# noise before graph construction
features_masked = featurizer._collapse_and_noise_residues(features, resid=8, ca_fixed=True, limb_atom_noise_sigma=.2, center_atom_noise_sigma=.5)
# now graph and bond construction
features = featurizer._make_graph_and_tokenize_edges(clean_chain, features)

In [82]:
features_masked.atom_resname[features_masked.atom_noised_mask]

array([['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP']], dtype='<U3')

In [83]:
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0])

<py3Dmol.view at 0x7fdbd7130c10>

In [84]:
visualize_protein_data_3d(features_masked, focus_atom=np.where(features_masked.atom_noised_mask)[0][0]) 

<py3Dmol.view at 0x7fe898090a30>

In [85]:
features_masked.positions = features_masked.position_labels.clone()
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0])

<py3Dmol.view at 0x7fe898090af0>

The limb atoms moved and the center atom is not part of the noise mask

Move it the ca somewhere else to prove that the edges change based on post and not pre noise positions
and check that the ca is now part of the noise mas

In [86]:
features = featurizer._init_atoms_into_protein_data(clean_chain, metal_unknown=False)
# noise before graph construction
features_masked = featurizer._collapse_and_noise_residues(features, resid=8, ca_fixed=False, limb_atom_noise_sigma=.2, center_atom_noise_sigma=5.0)
# now graph and bond construction
features = featurizer._make_graph_and_tokenize_edges(clean_chain, features)
visualize_protein_data_3d(features_masked, focus_atom=np.where(features_masked.atom_noised_mask)[0][0])

<py3Dmol.view at 0x7fe89809b910>

In [87]:
# check that distances for the noised atoms are real close to eachother eg. the distances were computed for the 
edges_mask_with_noised_atoms = features_masked.atom_noised_mask[features_masked.edge_index].any(dim=1)
edges_mask_with_noised_atoms

tensor([False, False, False,  ..., False, False, False])

In [88]:
features_masked.distances[edges_mask_with_noised_atoms]

tensor([[6.2219],
        [6.5059],
        [6.5445],
        [6.4711],
        [6.4251],
        [6.5552],
        [6.4261],
        [5.1410],
        [5.5022],
        [5.4094],
        [5.4592],
        [5.4013],
        [5.2998],
        [5.4181],
        [5.3163],
        [5.6136],
        [5.6864],
        [5.8142],
        [5.8158],
        [6.0415],
        [3.7961],
        [4.1809],
        [4.1158],
        [4.1176],
        [4.0650],
        [3.9939],
        [4.1329],
        [4.0296],
        [4.3095],
        [4.3121],
        [4.3001],
        [3.9534],
        [4.2962],
        [4.1718],
        [4.2625],
        [4.1216],
        [4.2411],
        [4.3522],
        [4.1280],
        [4.4000],
        [2.6600],
        [3.0129],
        [2.9530],
        [2.9663],
        [2.9456],
        [2.7739],
        [2.8920],
        [3.1153],
        [2.8437],
        [3.3685],
        [3.1693],
        [3.3606],
        [3.2077],
        [3.1281],
        [5.2170],
        [4

In [89]:
features_masked.positions = features_masked.position_labels.clone()
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0])

<py3Dmol.view at 0x7fe89809b1f0>

Ca now moves, the edges change, and the ca is part of the noise mask

In [90]:
list(zip(features.atom_resid, features.atom_resname))

[(tensor([1], dtype=torch.int32), array(['GLY'], dtype='<U3')),
 (tensor([1], dtype=torch.int32), array(['GLY'], dtype='<U3')),
 (tensor([1], dtype=torch.int32), array(['GLY'], dtype='<U3')),
 (tensor([1], dtype=torch.int32), array(['GLY'], dtype='<U3')),
 (tensor([2], dtype=torch.int32), array(['CYS'], dtype='<U3')),
 (tensor([2], dtype=torch.int32), array(['CYS'], dtype='<U3')),
 (tensor([2], dtype=torch.int32), array(['CYS'], dtype='<U3')),
 (tensor([2], dtype=torch.int32), array(['CYS'], dtype='<U3')),
 (tensor([2], dtype=torch.int32), array(['CYS'], dtype='<U3')),
 (tensor([2], dtype=torch.int32), array(['CYS'], dtype='<U3')),
 (tensor([3], dtype=torch.int32), array(['PRO'], dtype='<U3')),
 (tensor([3], dtype=torch.int32), array(['PRO'], dtype='<U3')),
 (tensor([3], dtype=torch.int32), array(['PRO'], dtype='<U3')),
 (tensor([3], dtype=torch.int32), array(['PRO'], dtype='<U3')),
 (tensor([3], dtype=torch.int32), array(['PRO'], dtype='<U3')),
 (tensor([3], dtype=torch.int32), array(

In [91]:
features = featurizer._init_atoms_into_protein_data(clean_chain, metal_unknown=False)
# noise before graph construction
features_masked = featurizer._collapse_and_noise_residues(features, resid=19, ca_fixed=False, limb_atom_noise_sigma=.2, center_atom_noise_sigma=1.0)
# now graph and bond construction
features = featurizer._make_graph_and_tokenize_edges(clean_chain, features)
visualize_protein_data_3d(features_masked, focus_atom=np.where(features_masked.atom_noised_mask)[0][0])

<py3Dmol.view at 0x7fdbd7130eb0>

The metal residues is successfully noised. Try with flow matching eg. at an interpolation time

In [92]:
features = featurizer._init_atoms_into_protein_data(clean_chain, metal_unknown=False)
features_masked = featurizer._collapse_and_noise_residues(features, resid=8, ca_fixed=True, limb_atom_noise_sigma=.5, center_atom_noise_sigma=1.0, time=0.0)
features = featurizer._make_graph_and_tokenize_edges(clean_chain, features)
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0])

<py3Dmol.view at 0x7fe89807d3a0>

In [93]:
features_masked.time

tensor([[0.]])

In [94]:
flow = features_masked.position_flow_labels
features_masked.positions = features_masked.positions + flow *(1 - features_masked.time.item())
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0])

<py3Dmol.view at 0x7fe89809bc40>

The flow field recovers the final structure

In [95]:
features = featurizer._init_atoms_into_protein_data(clean_chain, metal_unknown=False)
features_masked = featurizer._collapse_and_noise_residues(features, resid=8, ca_fixed=True, limb_atom_noise_sigma=.5, center_atom_noise_sigma=1.0, time=1.0)
features = featurizer._make_graph_and_tokenize_edges(clean_chain, features)
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0])

<py3Dmol.view at 0x7fe89807d4c0>

time=1 also prodices the final structure

In [96]:
features = featurizer._init_atoms_into_protein_data(clean_chain, metal_unknown=False)
features_masked = featurizer._collapse_and_noise_residues(features, resid=8, ca_fixed=True, limb_atom_noise_sigma=.5, center_atom_noise_sigma=1.0, time=0.7)
features = featurizer._make_graph_and_tokenize_edges(clean_chain, features)
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0])

<py3Dmol.view at 0x7fe89809ba90>

In [97]:
features_masked.positions = features_masked.positions + features_masked.position_flow_labels * (1 - features_masked.time.item())
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0])

<py3Dmol.view at 0x7fe8980933d0>

Show the flow field

In [98]:
features = featurizer._init_atoms_into_protein_data(clean_chain, metal_unknown=False)
features_masked = featurizer._collapse_and_noise_residues(features, resid=8, ca_fixed=True, limb_atom_noise_sigma=.5, center_atom_noise_sigma=1.0, time=0.5)
features = featurizer._make_graph_and_tokenize_edges(clean_chain, features)
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0], velocities= 'flow')

<py3Dmol.view at 0x7fe89809b4c0>

Check that the noising of all other atoms works

In [99]:
features = featurizer._init_atoms_into_protein_data(clean_chain, metal_unknown=False)
features_masked = featurizer._collapse_and_noise_residues(features, resid=8, ca_fixed=True, limb_atom_noise_sigma=.5, center_atom_noise_sigma=1.0, time=0.0, other_atom_noise_sigma=.6)
features = featurizer._make_graph_and_tokenize_edges(clean_chain, features)
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0], velocities= 'flow')

<py3Dmol.view at 0x7fdbd7130d60>

In [100]:
features.atom_name[~features.atom_noised_mask]

array([['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA']], dtype='<U3')

In [101]:
features = featurizer._init_atoms_into_protein_data(clean_chain, metal_unknown=False)
features_masked = featurizer._collapse_and_noise_residues(features, resid=8, ca_fixed=False, limb_atom_noise_sigma=.5, center_atom_noise_sigma=1.0, time=0.0, other_atom_noise_sigma=.6)
features = featurizer._make_graph_and_tokenize_edges(clean_chain, features)
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0], velocities= 'flow')

<py3Dmol.view at 0x7fe89809bfa0>

In [102]:
features.atom_name[~features.atom_noised_mask]

array([], shape=(0, 1), dtype='<U3')

In [103]:
features = featurizer._init_atoms_into_protein_data(clean_chain, metal_unknown=False)
features_masked = featurizer._collapse_and_noise_residues(features, resid=[8], ca_fixed=True, limb_atom_noise_sigma=.5, center_atom_noise_sigma=1.0, time=0.0, other_atom_noise_sigma=1)
features = featurizer._make_graph_and_tokenize_edges(clean_chain, features)
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0], velocities= 'flow')

<py3Dmol.view at 0x7fdbd7130790>

## Node masking for eg. pretraining

In [104]:
features = featurizer._init_atoms_into_protein_data(clean_chain, metal_unknown=True)
features = featurizer._make_graph_and_tokenize_edges(clean_chain, features)

In [105]:
visualize_protein_data_3d(features, highlight_atoms=[0,1,2,4,5,6,8,9,10])

<py3Dmol.view at 0x7fe89809be50>

In [106]:
features_masked = featurizer._mask_atoms(
    features, indices_to_mask=[0,], indices_to_tweak= [4,5,6], indices_to_keep=[8,9,10]
)
visualize_protein_data_3d(features, highlight_atoms=list(np.where(features_masked.atom_masked_mask)[0]))

<py3Dmol.view at 0x7fe89809b400>

Indeed the one masked atom lost its bonds, the others (eg random or keep) kept their bonds and are included in the loss mask. Last just confirm that the tweaked and same atoms have the appropriate tokens

In [107]:
tokenizer.i2d[features.element[0].item()], tokenizer.i2d[features.element_labels[0].item()]

('<MASK>', 'N')

In [108]:
for id in features.atom_masked_mask.nonzero(as_tuple=True)[0]:
    print(tokenizer.i2d[features.element[id].item()], tokenizer.i2d[features.element_labels[id].item()])

<MASK> N
Fe N
Hg C
Te C
C C
S S
N N


## Topology

In [109]:
features = featurizer._init_atoms_into_protein_data(clean_chain, metal_unknown=True)
features = featurizer._make_graph_and_tokenize_edges(clean_chain, features)
topology_data = features.topology

In [110]:
topology_data.keys()

dict_keys(['bonds', 'bond_lengths', 'angles', 'torsions', 'chirals', 'planars', 'permuts', 'frames'])

In [111]:
topology_data['frames'].size()

torch.Size([115, 3])

In [112]:
featurizer.get_feature_vocab_sizes()

{'element': 46,
 'charge': 8,
 'nhyd': 7,
 'hyb': 8,
 'bond_order': 6,
 'is_in_ring': 3,
 'is_aromatic': 3}

In [113]:
features.distances

tensor([[1.4516],
        [2.4188],
        [2.7092],
        ...,
        [5.3327],
        [3.1899],
        [2.7116]])

# Try out the final wrapped call method for each of the above and combinations thereof

## Nothing at all

In [114]:
features_base = featurizer(clean_chain, return_batched=False)[0]

In [115]:
visualize_protein_data_3d(features_base)


<py3Dmol.view at 0x7fdbd7093a30>

## Just metal classification

In [116]:
features = featurizer(
    clean_chain,
    metal_classification=True,
    metal_unknown=False,
    return_batched=False,
)[0]

In [117]:
visualize_protein_data_3d(features)

<py3Dmol.view at 0x7fe89807d070>

In [118]:
features

ProteinData(
  element: shape=(138, 1),
  charge: shape=(138, 1),
  nhyd: shape=(138, 1),
  hyb: shape=(138, 1),
  positions: shape=(138, 3),
  atom_movable_mask: shape=(138,),
  atom_name: shape=(138, 1),
  atom_resname: shape=(138, 1),
  atom_resid: shape=(138, 1),
  atom_ishetero: shape=(138, 1),
  distances: shape=(2760, 1),
  bond_order: shape=(2760, 1),
  is_aromatic: shape=(2760, 1),
  is_in_ring: shape=(2760, 1),
  edge_index: shape=(2760, 2),
  topology={'bonds': tensor([[ 29,  30],
        [110, 111],
        [ 21,  23],
        [ 37,  38],
        [ 15,  16],
        [ 51,  52],
        [ 92,  93],
        [101, 103],
        [ 10,  16],
        [ 36,  37],
        [ 61,  64],
        [ 27,  28],
        [ 97, 100],
        [105, 106],
        [100, 101],
        [ 78,  79],
        [ 12,  13],
        [116, 117],
        [ 17,  18],
        [108, 109],
        [ 64,  65],
        [  5,   8],
        [ 72,  75],
        [ 32,  33],
        [ 46,  47],
        [ 11,  12],
   

In [119]:
(features.element == featurizer.tokenizers['element'].get_vocab()['Fe']).sum()

tensor(0)

In [120]:
(features.element == featurizer.tokenizers['element'].get_vocab()['<METAL>']).sum()

tensor(3)

In [121]:
# repeat with incoming precomuted features object
features = featurizer(
    features_base,
    metal_classification=True,
    return_batched=False,
)[0]

In [122]:
features.global_labels

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 3., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [123]:
features_base = featurizer(clean_chain, metal_unknown=True, return_batched=False)[0]
# repeat with incoming precomuted features object
features = featurizer(
    features_base,
    metal_classification=True,
    return_batched=False,
)[0]
features.global_labels

tensor([[0., 0., 3., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

## Just denoising

In [124]:
# just denoising task
features = featurizer(
    clean_chain,
    residue_collapse_do=True,
    residue_collapse_rate=0.3,
    residue_collapse_ca_fixed=False,
    residue_collapse_center_atom_noise_sigma=1,
    residue_collapse_limb_atom_noise_sigma=0.2,
    residue_collapse_other_atom_noise_sigma=0.5,
    residue_collapse_time=0.0,
    return_batched=False,
)[0]

In [125]:
visualize_protein_data_3d(features, velocities='flow')

<py3Dmol.view at 0x7fdbd3cdcc40>

## Just MLM

In [126]:
features = featurizer(
    clean_chain,
    node_mlm_do=True,
    return_batched=False,
)[0]

In [127]:
visualize_protein_data_3d(features)

<py3Dmol.view at 0x7fdbd3cdcd30>

In [128]:
visualize_protein_data_3d(features, highlight_atoms=np.where(features.atom_masked_mask)[0])

<py3Dmol.view at 0x7fdbd3cdc2b0>

In [129]:
for id in features.atom_masked_mask.nonzero(as_tuple=True)[0]:
    print(
        featurizer.tokenizers['element'].i2d[features.element[id].item()],
        featurizer.tokenizers['element'].i2d[features.element_labels[id].item()]
    )

<MASK> C
<MASK> C
Hg C
La N
<MASK> N
<MASK> C
<MASK> C
<MASK> C
<MASK> C
<MASK> O
<MASK> C
C C
<MASK> C
<MASK> C
<MASK> C
<MASK> N
<MASK> C
N N
<MASK> C
<MASK> Fe
<MASK> S


In [130]:
(features_base.element == featurizer.tokenizers['element'].get_vocab()['<METAL>']).sum()

tensor(3)

In [131]:
# repeat with the base features
features = featurizer(
    features_base,
    node_mlm_do=True,
)[0]
visualize_protein_data_3d(features, highlight_atoms=np.where(features.atom_masked_mask)[0])

<py3Dmol.view at 0x7fdbd70e3eb0>

In [132]:
(features.element == (featurizer.tokenizers['element'].get_vocab()['<METAL>'])).sum()

tensor(3)

## Denoising and MLM

In [133]:
# denoising and mlm tasks together
features = featurizer(
    clean_chain,
    residue_collapse_do=True,
    residue_collapse_rate=0.3,
    residue_collapse_ca_fixed=True,
    residue_collapse_center_atom_noise_sigma=1,
    residue_collapse_limb_atom_noise_sigma=0.2,
    residue_collapse_other_atom_noise_sigma=0.5,
    residue_collapse_time=0.0,
    node_mlm_do=True,
    node_mlm_subrate_tweak=0.2,
    node_mlm_subrate_keep=0.2,
    return_batched=False,
)[0]

In [134]:
visualize_protein_data_3d(features, velocities='flow', highlight_atoms=np.where(features.atom_masked_mask)[0])

<py3Dmol.view at 0x7fdbd7130400>

In [135]:
visualize_protein_data_3d(features, velocities='flow', highlight_atoms=np.where(features.atom_noised_mask)[0])

<py3Dmol.view at 0x7fdbd71308b0>

## Do denoising and classification

In [136]:
# Do denoising and classification
features = featurizer(
    clean_chain,
    residue_collapse_do=True,
    residue_collapse_rate=0.3,
    residue_collapse_ca_fixed=True,
    residue_collapse_center_atom_noise_sigma=1,
    residue_collapse_limb_atom_noise_sigma=0.2,
    residue_collapse_other_atom_noise_sigma=0.5,
    residue_collapse_time=0.0,
    metal_classification=True,
    return_batched=False,
)[0]

In [137]:
visualize_protein_data_3d(features, velocities='flow')

<py3Dmol.view at 0x7fe8980967c0>

## Simulate a later posibility - where we want to define the classification labels based on a noised structure

In [138]:
# Simulate a later posibility - where we want to define the classification labels based on a noised structure
noised_features = featurizer(
    clean_chain,
    residue_collapse_do=True,
    residue_collapse_rate=0.3,
    residue_collapse_ca_fixed=True,
    residue_collapse_center_atom_noise_sigma=1,
    residue_collapse_limb_atom_noise_sigma=0.2,
    residue_collapse_other_atom_noise_sigma=0.5,
    residue_collapse_time=0.5,
    metal_unknown=False,
    metal_classification=True,
    return_batched=False,
)[0]


In [139]:
noised_features.global_labels

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 3., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [140]:
noised_features.time

tensor([[0.5000]])

In [141]:
# It got the iron labels and the actual metal tokens should be "metal"
is_metal = noised_features.element == featurizer.tokenizers['element'].get_vocab()['<METAL>']
is_metal.sum()

tensor(3)

Nice. Then we would unroll the pretrained flow matching model up to some point with these unknown metal tokes, update positions, save the tensors to create a dataset for a classifier capable of making predictions along entire diffusion trajectory.

It may also be that we want to take exact positions along the target flow instead of sampling the model. Maybe both.

## Saving and loading data objects

In [142]:
noised_features.save('tmp_features_save.pt')

In [143]:
from metalsitenn.graph_data import ProteinData
loaded_features = ProteinData.load('tmp_features_save.pt')

  state_dict = torch.load(path, map_location=device)


In [144]:
visualize_protein_data_3d(loaded_features, velocities='flow')

<py3Dmol.view at 0x7fdbd3d15400>

In [145]:
loaded_features.time

tensor([[0.5000]])

In [146]:
loaded_features.topology.keys()

dict_keys(['bonds', 'bond_lengths', 'angles', 'torsions', 'chirals', 'planars', 'permuts', 'frames'])

In [147]:
loaded_features.topology['angles']

tensor([[  0,   1,   2],
        [  1,   2,   3],
        [  4,   5,   6],
        [  4,   5,   8],
        [  5,   6,   7],
        [  5,   8,   9],
        [  6,   5,   8],
        [ 10,  11,  12],
        [ 10,  11,  14],
        [ 10,  16,  15],
        [ 11,  10,  16],
        [ 11,  12,  13],
        [ 11,  14,  15],
        [ 12,  11,  14],
        [ 14,  15,  16],
        [ 17,  18,  19],
        [ 17,  18,  21],
        [ 18,  19,  20],
        [ 18,  21,  22],
        [ 18,  21,  23],
        [ 19,  18,  21],
        [ 21,  22,  24],
        [ 22,  21,  23],
        [ 25,  26,  27],
        [ 25,  26,  29],
        [ 26,  27,  28],
        [ 26,  29,  30],
        [ 26,  29,  31],
        [ 27,  26,  29],
        [ 30,  29,  31],
        [ 32,  33,  34],
        [ 32,  33,  36],
        [ 33,  34,  35],
        [ 33,  36,  37],
        [ 34,  33,  36],
        [ 36,  37,  38],
        [ 36,  37,  39],
        [ 38,  37,  39],
        [ 40,  41,  42],
        [ 40,  41,  44],


In [148]:
loaded_features.topology['torsions']

tensor([[  0,   1,   2,   3],
        [  4,   5,   6,   7],
        [  4,   5,   8,   9],
        [  6,   5,   8,   9],
        [  7,   6,   5,   8],
        [ 10,  11,  12,  13],
        [ 10,  11,  14,  15],
        [ 10,  16,  15,  14],
        [ 11,  10,  16,  15],
        [ 11,  14,  15,  16],
        [ 12,  11,  10,  16],
        [ 12,  11,  14,  15],
        [ 13,  12,  11,  14],
        [ 14,  11,  10,  16],
        [ 17,  18,  19,  20],
        [ 17,  18,  21,  22],
        [ 17,  18,  21,  23],
        [ 18,  21,  22,  24],
        [ 19,  18,  21,  22],
        [ 19,  18,  21,  23],
        [ 20,  19,  18,  21],
        [ 23,  21,  22,  24],
        [ 25,  26,  27,  28],
        [ 25,  26,  29,  30],
        [ 25,  26,  29,  31],
        [ 27,  26,  29,  30],
        [ 27,  26,  29,  31],
        [ 28,  27,  26,  29],
        [ 32,  33,  34,  35],
        [ 32,  33,  36,  37],
        [ 33,  36,  37,  38],
        [ 33,  36,  37,  39],
        [ 34,  33,  36,  37],
        [ 

In [149]:
loaded_features.topology['chirals']

tensor([[  5,   4,   8,   6],
        [ 11,  10,  14,  12],
        [ 18,  17,  21,  19],
        [ 21,  18,  23,  22],
        [ 26,  25,  29,  27],
        [ 29,  26,  31,  30],
        [ 33,  32,  36,  34],
        [ 41,  40,  44,  42],
        [ 47,  46,  50,  48],
        [ 61,  60,  64,  62],
        [ 72,  71,  75,  73],
        [ 79,  78,  82,  80],
        [ 82,  79,  84,  83],
        [ 91,  90,  94,  92],
        [ 97,  96, 100,  98],
        [109, 108, 112, 110],
        [115, 114, 118, 116],
        [120, 119, 123, 121]])

In [150]:
loaded_features.topology['permuts']

[tensor([[66, 67, 68, 69],
         [67, 66, 69, 68]]),
 tensor([[102, 103],
         [103, 102]])]

In [151]:
loaded_features.topology['frames'].shape

torch.Size([115, 3])

In [152]:
loaded_features.topology['bond_lengths']

tensor([1.4352, 1.2458, 1.5321, 1.2222, 1.5257, 1.3601, 1.2458, 1.5492, 1.4767,
        1.4874, 1.5348, 1.2565, 1.5391, 1.5151, 1.5262, 1.4668, 1.2550, 1.2322,
        1.4668, 1.4893, 1.5073, 1.5349, 1.5326, 1.4728, 1.4567, 1.5145, 1.4767,
        1.5150, 1.3875, 1.4893, 1.2224, 1.5151, 1.8186, 1.5689, 1.5349, 1.2460,
        1.4339, 1.5321, 1.5182, 1.4349, 1.3937, 1.2357, 1.3711, 1.5149, 1.5159,
        1.5242, 1.2550, 1.2458, 1.4591, 1.3908, 1.5297, 1.5463, 1.4349, 1.5349,
        1.4594, 1.5263, 1.3906, 1.5145, 1.2460, 1.5326, 1.5524, 1.2300, 1.5349,
        1.3948, 1.4119, 1.5689, 1.2300, 1.5362, 1.5211, 1.3753, 1.4349, 1.5376,
        1.4742, 1.4912, 1.5378, 1.4102, 1.5182, 1.5273, 1.5151, 1.5297, 1.5221,
        1.5451, 1.2306, 1.4594, 1.2496, 1.4000, 1.2458, 1.5263, 1.4813, 1.5263,
        1.8186, 1.4017, 1.3977, 1.5242, 1.5378, 1.8186, 1.3504, 1.8186, 1.5257,
        1.2300, 1.5213, 1.3721, 1.4617, 1.3898, 1.4893, 1.4893, 1.5263, 1.5247,
        1.4805, 1.2541, 1.5346, 1.5149, 

# Into a batch

In [153]:
chains = []
i = 0
for _, chain in ds:
    clean_chain = parser.clean_metal_bonding_patterns(chain)
    chains.append(clean_chain)
    i += 1
    if i == 4:
        break

features_batched = featurizer(
    chains,
    residue_collapse_do=True,
    residue_collapse_rate=0.3,
    residue_collapse_ca_fixed=True,
    residue_collapse_center_atom_noise_sigma=1,
    residue_collapse_limb_atom_noise_sigma=0.2,
    residue_collapse_other_atom_noise_sigma=0.2,
    residue_collapse_time=0.0,
    metal_classification=True,
)

In [154]:
visualize_protein_data_3d(features_batched[3], velocities='flow')

<py3Dmol.view at 0x7fe90ee5afd0>

In [155]:
for features in features_batched:
    print(len(features.atom_name))

158
138
203
141


In [156]:
from metalsitenn.graph_data import BatchProteinData

In [157]:
features_batched.save('tmp_batch_data_save.pt')

In [158]:
features_batched_loaded = BatchProteinData.load('tmp_batch_data_save.pt')

  state_dict = torch.load(path, map_location=device)


In [159]:
features_batched_loaded.element.shape

torch.Size([640, 1])

In [160]:
visualize_protein_data_3d(features_batched_loaded[3], velocities='flow')

<py3Dmol.view at 0x7fe89807d1f0>

In [161]:
def compare_protein_data(data1, data2, tolerance=1e-6, verbose=True):
    """
    Compare two ProteinData objects, handling different data types appropriately.
    
    Args:
        data1, data2: ProteinData objects to compare
        tolerance: Float tolerance for tensor comparisons
        verbose: Print detailed mismatches
    
    Returns:
        bool: True if all attributes match within tolerance
    """
    import torch
    import numpy as np
    from dataclasses import fields
    
    all_match = True
    
    for field in fields(data1):
        attr_name = field.name
        val1 = getattr(data1, attr_name)
        val2 = getattr(data2, attr_name)
        
        # Both None
        if val1 is None and val2 is None:
            continue
            
        # One None, other not
        if (val1 is None) != (val2 is None):
            if verbose:
                print(f"❌ {attr_name}: One is None, other is not")
                print(f"   data1: {type(val1)}")
                print(f"   data2: {type(val2)}")
            all_match = False
            continue
        
        # Compare torch tensors
        if isinstance(val1, torch.Tensor):
            if not isinstance(val2, torch.Tensor):
                if verbose:
                    print(f"❌ {attr_name}: Type mismatch - tensor vs {type(val2)}")
                all_match = False
                continue
                
            # Check shapes
            if val1.shape != val2.shape:
                if verbose:
                    print(f"❌ {attr_name}: Shape mismatch - {val1.shape} vs {val2.shape}")
                all_match = False
                continue
                
            # Check dtypes
            if val1.dtype != val2.dtype:
                if verbose:
                    print(f"❌ {attr_name}: Dtype mismatch - {val1.dtype} vs {val2.dtype}")
                all_match = False
                continue
                
            # Check values
            try:
                if torch.allclose(val1, val2, atol=tolerance, rtol=tolerance):
                    if verbose:
                        print(f"✅ {attr_name}: Tensors match (shape: {val1.shape})")
                else:
                    if verbose:
                        max_diff = torch.max(torch.abs(val1 - val2)).item()
                        print(f"❌ {attr_name}: Tensor values differ (max diff: {max_diff})")
                        # Show first few differences for debugging
                        diff_mask = ~torch.isclose(val1, val2, atol=tolerance, rtol=tolerance)
                        if diff_mask.any():
                            flat_val1 = val1.flatten()
                            flat_val2 = val2.flatten()
                            flat_mask = diff_mask.flatten()
                            first_diff_idx = torch.where(flat_mask)[0][:5]  # First 5 differences
                            print(f"   First differences at indices: {first_diff_idx.tolist()}")
                            for idx in first_diff_idx:
                                print(f"     [{idx}]: {flat_val1[idx].item()} vs {flat_val2[idx].item()}")
                    all_match = False
            except Exception as e:
                if verbose:
                    print(f"❌ {attr_name}: Error comparing tensors - {e}")
                all_match = False
                
        # Compare numpy arrays
        elif isinstance(val1, np.ndarray):
            if not isinstance(val2, np.ndarray):
                if verbose:
                    print(f"❌ {attr_name}: Type mismatch - numpy array vs {type(val2)}")
                all_match = False
                continue
                
            # Check shapes
            if val1.shape != val2.shape:
                if verbose:
                    print(f"❌ {attr_name}: Shape mismatch - {val1.shape} vs {val2.shape}")
                all_match = False
                continue
                
            # Check dtypes
            if val1.dtype != val2.dtype:
                if verbose:
                    print(f"❌ {attr_name}: Dtype mismatch - {val1.dtype} vs {val2.dtype}")
                all_match = False
                continue
                
            # Check values
            try:
                if np.array_equal(val1, val2):
                    if verbose:
                        print(f"✅ {attr_name}: Arrays match (shape: {val1.shape})")
                else:
                    if verbose:
                        print(f"❌ {attr_name}: Array values differ")
                        # Show first few differences for strings
                        if val1.dtype.kind in ['U', 'S']:  # Unicode or byte strings
                            diff_mask = val1 != val2
                            if diff_mask.any():
                                diff_indices = np.where(diff_mask)
                                for i in range(min(5, len(diff_indices[0]))):
                                    idx = tuple(d[i] for d in diff_indices)
                                    print(f"     {idx}: '{val1[idx]}' vs '{val2[idx]}'")
                    all_match = False
            except Exception as e:
                if verbose:
                    print(f"❌ {attr_name}: Error comparing arrays - {e}")
                all_match = False
                
        # Compare dictionaries (topology)
        elif isinstance(val1, dict):
            if not isinstance(val2, dict):
                if verbose:
                    print(f"❌ {attr_name}: Type mismatch - dict vs {type(val2)}")
                all_match = False
                continue
                
            # Check keys
            if set(val1.keys()) != set(val2.keys()):
                if verbose:
                    print(f"❌ {attr_name}: Dict keys differ")
                    print(f"   data1 keys: {set(val1.keys())}")
                    print(f"   data2 keys: {set(val2.keys())}")
                all_match = False
                continue
                
            # Check each key's value
            dict_match = True
            for key in val1.keys():
                subval1 = val1[key]
                subval2 = val2[key]
                
                # Handle permuts (list of tensors)
                if key == 'permuts':
                    if isinstance(subval1, list) and isinstance(subval2, list):
                        if len(subval1) != len(subval2):
                            if verbose:
                                print(f"❌ {attr_name}[{key}]: List length mismatch - {len(subval1)} vs {len(subval2)}")
                            dict_match = False
                        else:
                            for i, (t1, t2) in enumerate(zip(subval1, subval2)):
                                if not torch.equal(t1, t2):
                                    if verbose:
                                        print(f"❌ {attr_name}[{key}][{i}]: Tensor mismatch")
                                    dict_match = False
                    else:
                        if verbose:
                            print(f"❌ {attr_name}[{key}]: Expected lists, got {type(subval1)} vs {type(subval2)}")
                        dict_match = False
                        
                # Handle regular tensors in topology
                elif isinstance(subval1, torch.Tensor) and isinstance(subval2, torch.Tensor):
                    if not torch.allclose(subval1, subval2, atol=tolerance, rtol=tolerance):
                        if verbose:
                            print(f"❌ {attr_name}[{key}]: Tensor values differ")
                        dict_match = False
                else:
                    if subval1 != subval2:
                        if verbose:
                            print(f"❌ {attr_name}[{key}]: Values differ - {subval1} vs {subval2}")
                        dict_match = False
                        
            if dict_match and verbose:
                print(f"✅ {attr_name}: Dict matches (keys: {list(val1.keys())})")
            elif not dict_match:
                all_match = False
                
        # Compare other types (scalars, etc.)
        else:
            try:
                if val1 == val2:
                    if verbose:
                        print(f"✅ {attr_name}: Values match ({type(val1).__name__})")
                else:
                    if verbose:
                        print(f"❌ {attr_name}: Values differ - {val1} vs {val2}")
                    all_match = False
            except Exception as e:
                if verbose:
                    print(f"❌ {attr_name}: Error comparing values - {e}")
                all_match = False
    
    return all_match

# Usage example:
# result = compare_protein_data(features_list[3], features_list_loaded[3])
# print(f"\nOverall match: {result}")

In [162]:
compare_protein_data(features_batched[3], features_batched_loaded[3], verbose=True)

✅ element: Tensors match (shape: torch.Size([141, 1]))
✅ charge: Tensors match (shape: torch.Size([141, 1]))
✅ nhyd: Tensors match (shape: torch.Size([141, 1]))
✅ hyb: Tensors match (shape: torch.Size([141, 1]))
✅ positions: Tensors match (shape: torch.Size([141, 3]))
✅ atom_movable_mask: Tensors match (shape: torch.Size([141]))
✅ atom_name: Arrays match (shape: (141, 1))
✅ atom_resname: Arrays match (shape: (141, 1))
✅ atom_resid: Tensors match (shape: torch.Size([141, 1]))
✅ atom_ishetero: Tensors match (shape: torch.Size([141, 1]))
✅ distances: Tensors match (shape: torch.Size([2820, 1]))
✅ bond_order: Tensors match (shape: torch.Size([2820, 1]))
✅ is_aromatic: Tensors match (shape: torch.Size([2820, 1]))
✅ is_in_ring: Tensors match (shape: torch.Size([2820, 1]))
✅ edge_index: Tensors match (shape: torch.Size([2820, 2]))
✅ topology: Dict matches (keys: ['chirals', 'frames', 'bond_lengths', 'torsions', 'angles', 'permuts', 'bonds', 'planars'])
✅ time: Tensors match (shape: torch.Size

True

In [163]:
features_batched_loaded.global_labels

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 4., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 3., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 4., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [164]:
features_batched_loaded.time.shape

torch.Size([4, 1])

## Try a collation into batches directly from the dataset with 1 and multiple workers

In [165]:
# import dataloader
from torch.utils.data import DataLoader

In [166]:
def collate_fn(batch):
    """
    Custom collate function to handle ProteinData objects in a batch.
    """
    batch = [item[1] for item in batch]  # Extract the second element from each tuple
    batched = featurizer(
        batch,
        residue_collapse_do=True,
        residue_collapse_rate=0.3,
        residue_collapse_ca_fixed=True,
        residue_collapse_center_atom_noise_sigma=1,
        residue_collapse_limb_atom_noise_sigma=0.2,
        residue_collapse_other_atom_noise_sigma=0.2,
        residue_collapse_time=0.0,
        metal_classification=True,
    )
    return batched

In [167]:
loader = DataLoader(
    ds,
    batch_size=6,
    collate_fn=collate_fn,
    shuffle=False,
    num_workers=1)

In [168]:
for batch in loader:
    break

In [169]:
visualize_protein_data_3d(batch[0], velocities='flow')

<py3Dmol.view at 0x7fdbd70e3f70>

In [170]:
# try with multiple workers
loader = DataLoader(
    ds,
    batch_size=6,
    collate_fn=collate_fn,
    shuffle=False,
    num_workers=6)

In [171]:
for batch in loader:
    break

In [172]:
visualize_protein_data_3d(batch[0], velocities='flow')

<py3Dmol.view at 0x7fdbd3d17580>

Wow it worked first try that's wild. Though clearly we need a collator because the dataset and the featurizer have to be poked to work together

## Try collator

In [173]:
from metalsitenn.featurizer import MetalSiteCollator

In [174]:
collator = MetalSiteCollator(
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic'],
    metal_unknown=False,
    metal_classification=True,
    residue_collapse_do=True,
    residue_collapse_rate=0.3,
    residue_collapse_ca_fixed=True,
    residue_collapse_center_atom_noise_sigma=1,
    residue_collapse_limb_atom_noise_sigma=0.3,
    residue_collapse_other_atom_noise_sigma=0.2,
    residue_collapse_time=0.0,
)

In [175]:
loader = DataLoader(
    ds,
    batch_size=6,
    collate_fn=collator,
    shuffle=False,
    num_workers=1)

In [176]:
batch = next(iter(loader))

In [177]:
visualize_protein_data_3d(batch[0], velocities='flow')

<py3Dmol.view at 0x7fe90eec6640>

In [178]:
# we should also have pdb labels now
batch.pdb_id

array([['6fpw_0'],
       ['6fpw_1'],
       ['6fpw_2'],
       ['6fpw_3'],
       ['6fpw_4'],
       ['6fpw_5']], dtype=object)

In [179]:
batch[0].pdb_id

array([['6fpw_0']], dtype=object)

Can we conver to cuda properly?

In [180]:
batch.to('cuda')

BatchProteinData(batch_size=6, total_atoms=899)

In [183]:
batch.element

tensor([[26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [ 6],
        [26],
        [ 6],
        [ 6],
        [26],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [35],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [30],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [ 6],
        [ 6],
        [26],
        [ 6],
        [26],
        [26],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [ 6],
        [ 6],
        [26],
        [ 6],
        [26],
        [26],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [26],
        [ 6],
      