# Collate some site graphs into features

In [1]:
from metalsitenn.dataloading import MetalSiteDataset
from metalsitenn.featurizer import MetalSiteFeaturizer
from metalsitenn.utils import visualize_chain_3d, visualize_protein_data_3d
import numpy as np
import torch

from metalsitenn.placer_modules.cifutils import CIFParser

In [2]:
parser = CIFParser()

In [3]:
ds = MetalSiteDataset(
    cache_folder='../../bonnanzio_metal_site_modeling/data/1/1.1_parse_sites_metadata',
)

In [4]:
md = ds.get_all_metadata()
md

Unnamed: 0,pdb_code,site_name,site_idx,n_entities,n_atoms,n_bonds,metal,n_metals,n_waters,n_organic_ligands,...,n_amino_acids,n_coordinating_amino_acids,n_nucleotides,non_residue_non_metal_names,n_non_residue_non_metal,coordination_distance,n_unresolved_removed,coordinating_residues,resolution,max_rczd
0,6fpw,6fpw_0,0,19,158,150,Fe,4,0,0,...,18,4,0,,0,2.9,0,29123,1.35,
1,6fpw,6fpw_1,1,22,138,124,Fe,3,3,0,...,18,3,0,,0,2.9,0,71613,1.35,
2,6fpw,6fpw_2,2,29,203,184,Fe,4,2,0,...,26,6,0,,0,2.9,0,135153196,1.35,
3,6fpw,6fpw_3,3,20,141,123,"Fe,Ni",2,0,0,...,19,4,0,,0,2.9,0,619317,1.35,
4,6fpw,6fpw_4,4,19,101,83,Mg,1,6,0,...,12,3,0,,0,2.9,0,1261,1.35,1.6
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
126231,1fyu,1fyu_0,0,15,103,93,Mn,1,1,0,...,12,4,0,,0,2.9,0,2175,2.60,
126232,1fyu,1fyu_1,1,18,130,117,Ca,1,1,1,...,14,4,0,GAL,1,2.9,0,57911,2.60,
126233,1fyu,1fyu_2,2,16,104,93,Mn,1,2,0,...,12,4,0,,0,2.9,0,2157,2.60,
126234,1fyu,1fyu_3,3,18,123,110,Ca,1,2,1,...,13,4,0,GAL,1,2.9,0,57119,2.60,


In [5]:
site_chain = ds[1][1]

In [6]:
visualize_chain_3d(site_chain)

<py3Dmol.view at 0x7f8118949160>

In [7]:
clean_chain = parser.clean_metal_bonding_patterns(site_chain)

In [8]:
visualize_chain_3d(clean_chain)

<py3Dmol.view at 0x7f8118949970>

In [9]:
featurizer = MetalSiteFeaturizer(
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic',]
)

### Call featurizer with no special effects, just convert

In [10]:

features = featurizer.featurize_one(clean_chain, metal_unknown=True)

In [11]:
features

ProteinData(
  element: shape=(138, 1),
  charge: shape=(138, 1),
  nhyd: shape=(138, 1),
  hyb: shape=(138, 1),
  positions: shape=(138, 3),
  atom_movable_mask=None,
  atom_name: shape=(138, 1),
  atom_resname: shape=(138, 1),
  atom_resid: shape=(138, 1),
  atom_ishetero: shape=(138, 1),
  distances: shape=(2760, 1),
  bond_order: shape=(2760, 1),
  is_aromatic: shape=(2760, 1),
  is_in_ring: shape=(2760, 1),
  edge_index: shape=(2760, 2),
  topology={'bonds': tensor([[ 29,  30],
        [110, 111],
        [ 21,  23],
        [ 37,  38],
        [ 15,  16],
        [ 51,  52],
        [ 92,  93],
        [101, 103],
        [ 10,  16],
        [ 36,  37],
        [ 61,  64],
        [ 27,  28],
        [ 97, 100],
        [105, 106],
        [100, 101],
        [ 78,  79],
        [ 12,  13],
        [116, 117],
        [ 17,  18],
        [108, 109],
        [ 64,  65],
        [  5,   8],
        [ 72,  75],
        [ 32,  33],
        [ 46,  47],
        [ 11,  12],
        [ 71

### Observe vocab

In [12]:
featurizer.tokenizers['element'].i2d

{0: '<UNK>',
 1: '<MASK>',
 2: 'Ag',
 3: 'Al',
 4: 'Au',
 5: 'Br',
 6: 'C',
 7: 'Ca',
 8: 'Cd',
 9: 'Cl',
 10: 'Co',
 11: 'Cr',
 12: 'Cu',
 13: 'Dy',
 14: 'F',
 15: 'Fe',
 16: 'Ga',
 17: 'Hg',
 18: 'I',
 19: 'Ir',
 20: 'K',
 21: 'La',
 22: 'Li',
 23: 'Mg',
 24: 'Mn',
 25: 'Mo',
 26: 'N',
 27: 'Na',
 28: 'Nd',
 29: 'Ni',
 30: 'O',
 31: 'P',
 32: 'Pb',
 33: 'Pr',
 34: 'Pt',
 35: 'S',
 36: 'Se',
 37: 'Si',
 38: 'Tb',
 39: 'Te',
 40: 'Ti',
 41: 'U',
 42: 'V',
 43: 'W',
 44: 'Zn',
 45: '<METAL>'}

In [13]:
atom_id = np.where(np.array(features.atom_name) == 'CZ2')[0][0]
atom_id

57

In [14]:
features.element

tensor([[26],
        [ 6],
        [ 6],
        [30],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [35],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [ 6],
        [ 6],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [30],
        [ 6],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [ 6],
        [30],
        [26],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [35],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [26],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [26],
        [ 6],
        [ 6],
        [30],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
      

In [15]:
features.element[atom_id]

tensor([6], dtype=torch.int32)

Yay it is carbon

In [16]:
featurizer.tokenizers['element'].metal_token_id

45

In [17]:
# with unknown metal
47 in features.element

False

In [18]:
featurizer.tokenizers['element'].get_vocab()

{'<UNK>': 0,
 '<MASK>': 1,
 'Ag': 2,
 'Al': 3,
 'Au': 4,
 'Br': 5,
 'C': 6,
 'Ca': 7,
 'Cd': 8,
 'Cl': 9,
 'Co': 10,
 'Cr': 11,
 'Cu': 12,
 'Dy': 13,
 'F': 14,
 'Fe': 15,
 'Ga': 16,
 'Hg': 17,
 'I': 18,
 'Ir': 19,
 'K': 20,
 'La': 21,
 'Li': 22,
 'Mg': 23,
 'Mn': 24,
 'Mo': 25,
 'N': 26,
 'Na': 27,
 'Nd': 28,
 'Ni': 29,
 'O': 30,
 'P': 31,
 'Pb': 32,
 'Pr': 33,
 'Pt': 34,
 'S': 35,
 'Se': 36,
 'Si': 37,
 'Tb': 38,
 'Te': 39,
 'Ti': 40,
 'U': 41,
 'V': 42,
 'W': 43,
 'Zn': 44,
 '<METAL>': 45}

In [19]:
assert 16 not in features.element

Indeed no iron remains, converted to unknown metal

In [20]:
# check other metals are properly converted to metal token
featurizer.tokenizers['element'].encode('Ho')

45

In [21]:
features = featurizer.featurize_one(clean_chain, metal_unknown=False)

In [22]:
assert 15 in features.element

### Charge

In [23]:
featurizer.tokenizers['charge'].i2d

{0: '<MASK>', 1: -3, 2: -2, 3: -1, 4: 0, 5: 1, 6: 2, 7: 3}

In [24]:
np.array(features.atom_name).reshape(-1,1)[features.charge !=4]

array(['NZ'], dtype='<U3')

Indeed that is a charged atom on a Lysine

### Num hydr

In [25]:
featurizer.tokenizers['nhyd'].i2d

{0: '<UNK>', 1: '<MASK>', 2: 0, 3: 1, 4: 2, 5: 3, 6: 4}

In [26]:
features.nhyd[atom_id]

tensor([3], dtype=torch.int32)

atom id is from a CZ2 - an aromatic carbon, which should have 1 hydrogen, indeed it does

Chek metal has been set to known or mask

In [27]:
metal_mask = features.element == featurizer.tokenizers['element'].get_vocab()['Fe']
metal_mask.sum()

tensor(3)

In [28]:
featurizer.tokenizers['nhyd'].get_vocab()

{'<UNK>': 0, '<MASK>': 1, 0: 2, 1: 3, 2: 4, 3: 5, 4: 6}

In [29]:
features.nhyd[metal_mask]

tensor([0, 0, 0], dtype=torch.int32)

### hybridization

Note in the og code hyb goes i.e., 1 = sp, 2 = sp2, 3 = sp3 ...).

In [30]:
featurizer.tokenizers['hyb'].i2d

{0: '<UNK>', 1: '<MASK>', 2: 0, 3: 1, 4: 2, 5: 3, 6: 4, 7: 5}

In [31]:
features.hyb[atom_id]

tensor([4], dtype=torch.int32)

CZ2 in benzene ring is indeed sp2

In [32]:
features.hyb[metal_mask]

tensor([0, 0, 0], dtype=torch.int32)

### Bond order

In [33]:
featurizer.tokenizers['bond_order'].i2d

{0: '<MASK>', 1: 0, 2: 1, 3: 2, 4: 3, 5: 4}

In [34]:
src, dst = features.edge_index.T

In [35]:
site_chain.residues

{'1': Residue(name='GLY', atoms={'N': Atom(name=('A', '1', 'GLY', 'N'), xyz=[-19.936, 7.408, 17.572], occ=1.0, bfac=8.54, leaving=False, leaving_group=['H2'], parent='CA', element=7, metal=False, charge=0, hyb=2, nhyd=1, hvydeg=1, align=1, hetero=False), 'CA': Atom(name=('A', '1', 'GLY', 'CA'), xyz=[-19.452, 6.122, 17.104], occ=1.0, bfac=9.08, leaving=False, leaving_group=[], parent='C', element=6, metal=False, charge=0, hyb=3, nhyd=2, hvydeg=2, align=1, hetero=False), 'C': Atom(name=('A', '1', 'GLY', 'C'), xyz=[-18.415, 5.592, 18.061], occ=1.0, bfac=7.63, leaving=False, leaving_group=['OXT', 'HXT'], parent='OXT', element=6, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=3, align=1, hetero=False), 'O': Atom(name=('A', '1', 'GLY', 'O'), xyz=[-17.736, 6.358, 18.754], occ=1.0, bfac=8.77, leaving=False, leaving_group=[], parent='C', element=8, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=1, align=1, hetero=False)}, bonds=[Bond(a=('A', '1', 'GLY', 'N'), b=('A', '1', 'GLY', 'CA'), aromatic=Fal

In [36]:
bond_order = features.bond_order

In [37]:
bond_order.shape

torch.Size([2760, 1])

In [38]:
features.atom_resname

array([['GLY'],
       ['GLY'],
       ['GLY'],
       ['GLY'],
       ['CYS'],
       ['CYS'],
       ['CYS'],
       ['CYS'],
       ['CYS'],
       ['CYS'],
       ['PRO'],
       ['PRO'],
       ['PRO'],
       ['PRO'],
       ['PRO'],
       ['PRO'],
       ['PRO'],
       ['ILE'],
       ['ILE'],
       ['ILE'],
       ['ILE'],
       ['ILE'],
       ['ILE'],
       ['ILE'],
       ['ILE'],
       ['THR'],
       ['THR'],
       ['THR'],
       ['THR'],
       ['THR'],
       ['THR'],
       ['THR'],
       ['ASN'],
       ['ASN'],
       ['ASN'],
       ['ASN'],
       ['ASN'],
       ['ASN'],
       ['ASN'],
       ['ASN'],
       ['CYS'],
       ['CYS'],
       ['CYS'],
       ['CYS'],
       ['CYS'],
       ['CYS'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['PHE'],
       ['PHE'],
       [

In [39]:
list(zip(bond_order, np.array(features.atom_resid)[src], np.array(features.atom_name)[src], np.array(features.atom_resid)[dst], np.array(features.atom_name)[dst]))

[(tensor([2], dtype=torch.int32),
  array([1], dtype=int32),
  array(['N'], dtype='<U3'),
  array([1], dtype=int32),
  array(['CA'], dtype='<U3')),
 (tensor([1], dtype=torch.int32),
  array([1], dtype=int32),
  array(['N'], dtype='<U3'),
  array([1], dtype=int32),
  array(['C'], dtype='<U3')),
 (tensor([1], dtype=torch.int32),
  array([1], dtype=int32),
  array(['N'], dtype='<U3'),
  array([1], dtype=int32),
  array(['O'], dtype='<U3')),
 (tensor([1], dtype=torch.int32),
  array([1], dtype=int32),
  array(['N'], dtype='<U3'),
  array([2], dtype=int32),
  array(['N'], dtype='<U3')),
 (tensor([1], dtype=torch.int32),
  array([1], dtype=int32),
  array(['N'], dtype='<U3'),
  array([2], dtype=int32),
  array(['CA'], dtype='<U3')),
 (tensor([1], dtype=torch.int32),
  array([1], dtype=int32),
  array(['N'], dtype='<U3'),
  array([2], dtype=int32),
  array(['CB'], dtype='<U3')),
 (tensor([1], dtype=torch.int32),
  array([1], dtype=int32),
  array(['N'], dtype='<U3'),
  array([3], dtype=int32)

bonds (>2) only occur between atoms in the same resid. Good.

Make sure not bonds for metals

In [40]:
metal_edges = np.where(metal_mask[src] | metal_mask[dst])[0]
metal_edges

array([ 438,  456,  457,  478,  497,  598,  618,  638,  737,  756,  776,
        795,  796,  818,  838,  858,  896,  913,  914,  915, 1098, 1534,
       1535, 1557, 1817, 1837, 1858, 1878, 1897, 1916, 1938, 2056, 2098,
       2118, 2176, 2177, 2197, 2218, 2239, 2256, 2274, 2298, 2560, 2561,
       2562, 2563, 2564, 2565, 2566, 2567, 2568, 2569, 2570, 2571, 2572,
       2573, 2574, 2575, 2576, 2577, 2578, 2579, 2580, 2581, 2582, 2583,
       2584, 2585, 2586, 2587, 2588, 2589, 2590, 2591, 2592, 2593, 2594,
       2595, 2596, 2597, 2598, 2599, 2600, 2601, 2602, 2603, 2604, 2605,
       2606, 2607, 2608, 2609, 2610, 2611, 2612, 2613, 2614, 2615, 2616,
       2617, 2618, 2619, 2631, 2632, 2633, 2654, 2655, 2656, 2673, 2674,
       2675, 2692, 2693, 2694, 2716, 2717, 2736, 2754, 2755])

In [41]:
features.bond_order[metal_edges]

tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],

### Is aromatic

In [42]:
featurizer.tokenizers['is_aromatic'].i2d

{0: '<MASK>', 1: False, 2: True}

In [43]:
(features.is_aromatic ==2).sum()

tensor(32)

In [44]:
is_aromatic = features.is_aromatic

In [45]:
is_aromatic.shape

torch.Size([2760, 1])

In [46]:
aromatic_edges = np.where(is_aromatic == 2)[0]
aromatic_edges

array([1025, 1026, 1047, 1049, 1065, 1068, 1069, 1087, 1089, 1108, 1109,
       1111, 1127, 1131, 1150, 1153, 1171, 1173, 1194, 1195, 1309, 1310,
       1328, 1330, 1349, 1352, 1368, 1371, 1391, 1393, 1411, 1412])

In [47]:
site_chain.residues

{'1': Residue(name='GLY', atoms={'N': Atom(name=('A', '1', 'GLY', 'N'), xyz=[-19.936, 7.408, 17.572], occ=1.0, bfac=8.54, leaving=False, leaving_group=['H2'], parent='CA', element=7, metal=False, charge=0, hyb=2, nhyd=1, hvydeg=1, align=1, hetero=False), 'CA': Atom(name=('A', '1', 'GLY', 'CA'), xyz=[-19.452, 6.122, 17.104], occ=1.0, bfac=9.08, leaving=False, leaving_group=[], parent='C', element=6, metal=False, charge=0, hyb=3, nhyd=2, hvydeg=2, align=1, hetero=False), 'C': Atom(name=('A', '1', 'GLY', 'C'), xyz=[-18.415, 5.592, 18.061], occ=1.0, bfac=7.63, leaving=False, leaving_group=['OXT', 'HXT'], parent='OXT', element=6, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=3, align=1, hetero=False), 'O': Atom(name=('A', '1', 'GLY', 'O'), xyz=[-17.736, 6.358, 18.754], occ=1.0, bfac=8.77, leaving=False, leaving_group=[], parent='C', element=8, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=1, align=1, hetero=False)}, bonds=[Bond(a=('A', '1', 'GLY', 'N'), b=('A', '1', 'GLY', 'CA'), aromatic=Fal

In [48]:
for bond_id in aromatic_edges:
    src, dst = features.edge_index[bond_id, :]
    print(is_aromatic[bond_id], features.atom_resname[src], features.atom_name[src], features.atom_resname[dst], features.atom_name[dst])

tensor([2], dtype=torch.int32) ['TRP'] ['CG'] ['TRP'] ['CD1']
tensor([2], dtype=torch.int32) ['TRP'] ['CG'] ['TRP'] ['CD2']
tensor([2], dtype=torch.int32) ['TRP'] ['CD1'] ['TRP'] ['CG']
tensor([2], dtype=torch.int32) ['TRP'] ['CD1'] ['TRP'] ['NE1']
tensor([2], dtype=torch.int32) ['TRP'] ['CD2'] ['TRP'] ['CG']
tensor([2], dtype=torch.int32) ['TRP'] ['CD2'] ['TRP'] ['CE2']
tensor([2], dtype=torch.int32) ['TRP'] ['CD2'] ['TRP'] ['CE3']
tensor([2], dtype=torch.int32) ['TRP'] ['NE1'] ['TRP'] ['CD1']
tensor([2], dtype=torch.int32) ['TRP'] ['NE1'] ['TRP'] ['CE2']
tensor([2], dtype=torch.int32) ['TRP'] ['CE2'] ['TRP'] ['CD2']
tensor([2], dtype=torch.int32) ['TRP'] ['CE2'] ['TRP'] ['NE1']
tensor([2], dtype=torch.int32) ['TRP'] ['CE2'] ['TRP'] ['CZ2']
tensor([2], dtype=torch.int32) ['TRP'] ['CE3'] ['TRP'] ['CD2']
tensor([2], dtype=torch.int32) ['TRP'] ['CE3'] ['TRP'] ['CZ3']
tensor([2], dtype=torch.int32) ['TRP'] ['CZ2'] ['TRP'] ['CE2']
tensor([2], dtype=torch.int32) ['TRP'] ['CZ2'] ['TRP'] ['CH

The tryptophan and phenolalinine in here has the proper aromatic labels.

### Is in ring

In [49]:
featurizer.tokenizers['is_in_ring'].i2d

{0: '<MASK>', 1: False, 2: True}

In [50]:
is_in_ring = features.is_in_ring

In [51]:
ring_edges = np.where(is_in_ring == 2)[0]

In [52]:
for bondid in ring_edges:
    src, dst = features.edge_index[bondid, :]
    print(is_in_ring[bondid], features.atom_resname[src], features.atom_name[src], features.atom_resname[dst], features.atom_name[dst])

tensor([2], dtype=torch.int32) ['PRO'] ['N'] ['PRO'] ['CA']
tensor([2], dtype=torch.int32) ['PRO'] ['N'] ['PRO'] ['CD']
tensor([2], dtype=torch.int32) ['PRO'] ['CA'] ['PRO'] ['N']
tensor([2], dtype=torch.int32) ['PRO'] ['CA'] ['PRO'] ['CB']
tensor([2], dtype=torch.int32) ['PRO'] ['CB'] ['PRO'] ['CA']
tensor([2], dtype=torch.int32) ['PRO'] ['CB'] ['PRO'] ['CG']
tensor([2], dtype=torch.int32) ['PRO'] ['CG'] ['PRO'] ['CB']
tensor([2], dtype=torch.int32) ['PRO'] ['CG'] ['PRO'] ['CD']
tensor([2], dtype=torch.int32) ['PRO'] ['CD'] ['PRO'] ['N']
tensor([2], dtype=torch.int32) ['PRO'] ['CD'] ['PRO'] ['CG']
tensor([2], dtype=torch.int32) ['TRP'] ['CG'] ['TRP'] ['CD1']
tensor([2], dtype=torch.int32) ['TRP'] ['CG'] ['TRP'] ['CD2']
tensor([2], dtype=torch.int32) ['TRP'] ['CD1'] ['TRP'] ['CG']
tensor([2], dtype=torch.int32) ['TRP'] ['CD1'] ['TRP'] ['NE1']
tensor([2], dtype=torch.int32) ['TRP'] ['CD2'] ['TRP'] ['CG']
tensor([2], dtype=torch.int32) ['TRP'] ['CD2'] ['TRP'] ['CE2']
tensor([2], dtype=to

Awesome. I see some Ca and Cb of only proline. And all of the tryptophan rings are in rings.

#### Visualize the model post featurization, with atoms in a graph noted

In [53]:
visualize_protein_data_3d(features)

<py3Dmol.view at 0x7f746fad6af0>

In [54]:
# get the atoms with edges to the iron
iron_id = np.where(features.element == featurizer.tokenizers['element'].d2i['Fe'])[0][0]
iron_id

128

In [55]:
visualize_protein_data_3d(features, focus_atom=iron_id)

<py3Dmol.view at 0x7f81a791a460>

## Let's check on the tokenizer methods we will use to help construct global prediction task

In [56]:
features = featurizer.featurize_one(clean_chain, metal_unknown=False)

In [57]:
tokenizer = featurizer.tokenizers['element']

In [58]:
tokenizer.get_metal_vocab()

{'Li': 22,
 'Cd': 8,
 'V': 42,
 'Ag': 2,
 'Pt': 34,
 'Si': 37,
 'Hg': 17,
 'Zn': 44,
 'Au': 4,
 'K': 20,
 'Cr': 11,
 'Na': 27,
 'U': 41,
 'Pb': 32,
 'Tb': 38,
 'Nd': 28,
 'Ir': 19,
 'Pr': 33,
 'W': 43,
 'Te': 39,
 'Al': 3,
 'Ga': 16,
 'Ca': 7,
 'Mn': 24,
 'Ti': 40,
 'Cu': 12,
 'Fe': 15,
 'Mo': 25,
 'Ni': 29,
 'Dy': 13,
 'Mg': 23,
 'La': 21,
 'Co': 10,
 '<MASK>': 1,
 '<UNK>': 0,
 '<METAL>': 45}

In [59]:
tokenizer.get_metal_label_indices()

[0,
 1,
 45,
 2,
 3,
 4,
 7,
 8,
 10,
 11,
 12,
 13,
 15,
 16,
 17,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 27,
 28,
 29,
 32,
 33,
 34,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44]

In [60]:
tokenizer.get_metal_labels_mapping()

{'<UNK>': 0,
 '<MASK>': 1,
 '<METAL>': 2,
 'Ag': 3,
 'Al': 4,
 'Au': 5,
 'Ca': 6,
 'Cd': 7,
 'Co': 8,
 'Cr': 9,
 'Cu': 10,
 'Dy': 11,
 'Fe': 12,
 'Ga': 13,
 'Hg': 14,
 'Ir': 15,
 'K': 16,
 'La': 17,
 'Li': 18,
 'Mg': 19,
 'Mn': 20,
 'Mo': 21,
 'Na': 22,
 'Nd': 23,
 'Ni': 24,
 'Pb': 25,
 'Pr': 26,
 'Pt': 27,
 'Si': 28,
 'Tb': 29,
 'Te': 30,
 'Ti': 31,
 'U': 32,
 'V': 33,
 'W': 34,
 'Zn': 35}

In [61]:
labels = tokenizer.encode_metal_composition_counts(['Fe', 'Fe', 'Zn'])

In [62]:
labels

[0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 2.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0]

In [63]:
tokenizer.decode_metal_composition_counts(labels)

['Fe', 'Fe', 'Zn']

In [64]:
tokenizer.get_vocab()['Fe']

15

In [65]:
15 in features.element

True

In [66]:
# try with the actual featurized tensor
labels = tokenizer.encode_metal_composition_counts_from_tokens(features.element)
labels

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 3., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [67]:
features = featurizer._anonymize_metals_for_classification(features)

In [68]:
features.global_labels

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 3., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [69]:
visualize_protein_data_3d(features)

<py3Dmol.view at 0x7f80e84f4460>

In [70]:
# check that atom features and bonds are completely masked
is_metal = features.element == tokenizer.metal_token_id
is_metal.sum()

tensor(3)

In [71]:
features.element[is_metal]

tensor([45, 45, 45], dtype=torch.int32)

In [72]:
featurizer.tokenizers['charge'].get_vocab()

{'<MASK>': 0, -3: 1, -2: 2, -1: 3, 0: 4, 1: 5, 2: 6, 3: 7}

In [73]:
features.charge[is_metal]

tensor([0, 0, 0], dtype=torch.int32)

In [74]:
featurizer.tokenizers['nhyd'].get_vocab()

{'<UNK>': 0, '<MASK>': 1, 0: 2, 1: 3, 2: 4, 3: 5, 4: 6}

In [75]:
features.nhyd[is_metal]

tensor([1, 1, 1], dtype=torch.int32)

In [76]:
featurizer.tokenizers['hyb'].get_vocab()    

{'<UNK>': 0, '<MASK>': 1, 0: 2, 1: 3, 2: 4, 3: 5, 4: 6, 5: 7}

In [77]:
features.hyb[is_metal]

tensor([1, 1, 1], dtype=torch.int32)

## Try noising atoms for a resid

In [119]:
features = featurizer.featurize_one(clean_chain, metal_unknown=True)
features_masked = featurizer._collapse_and_noise_residues(features, resid=8, ca_fixed=True, limb_atom_noise_sigma=.2, center_atom_noise_sigma=.5)

In [120]:
features_masked.atom_resname[features_masked.atom_noised_mask]

array([['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP'],
       ['TRP']], dtype='<U3')

In [121]:
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0])

<py3Dmol.view at 0x7f746c55da30>

In [122]:
visualize_protein_data_3d(features_masked, focus_atom=np.where(features_masked.atom_noised_mask)[0][0]) 

<py3Dmol.view at 0x7f746fad6b20>

In [109]:
# check that distances are removed
features_masked.distances

In [117]:
# and that we can reset them
features_masked.set_distances()
edges_mask_with_noised_atoms = features_masked.atom_noised_mask[features_masked.edge_index].any(dim=1)
edges_mask_with_noised_atoms

tensor([False, False, False,  ..., False, False, False])

In [118]:
features_masked.distances[edges_mask_with_noised_atoms]

tensor([[10.4142],
        [10.5385],
        [10.5606],
        [10.7168],
        [ 9.5724],
        [ 9.6620],
        [ 9.8186],
        [ 9.0184],
        [ 8.6950],
        [ 8.8368],
        [ 8.9853],
        [10.5169],
        [10.5307],
        [10.5641],
        [ 9.7478],
        [ 9.5215],
        [ 9.5430],
        [ 9.4964],
        [ 9.6410],
        [ 9.7430],
        [ 9.6611],
        [ 9.4908],
        [ 7.5833],
        [ 7.7543],
        [ 6.5273],
        [ 6.6384],
        [ 6.7011],
        [ 6.8682],
        [ 6.9384],
        [ 5.9078],
        [ 5.9053],
        [ 6.1350],
        [ 6.1691],
        [ 6.7555],
        [ 6.7909],
        [ 7.0583],
        [ 7.1325],
        [ 7.5398],
        [ 6.4639],
        [ 6.7011],
        [ 5.9078],
        [ 6.7803],
        [ 0.4044],
        [ 0.3348],
        [ 0.5566],
        [ 0.4581],
        [ 0.3862],
        [ 0.5358],
        [ 0.3144],
        [ 0.7391],
        [ 0.5523],
        [ 0.2342],
        [ 0.

In [81]:
features_masked.positions = features_masked.position_labels.clone()
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0])

<py3Dmol.view at 0x7f80e84f46d0>

The limb atoms moved and the center atom is not part of the noise mask

In [82]:
features = featurizer.featurize_one(clean_chain, metal_unknown=True)
features_masked = featurizer._collapse_and_noise_residues(features, resid=5, ca_fixed=False, limb_atom_noise_sigma=.1, center_atom_noise_sigma=5.0)
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0])

<py3Dmol.view at 0x7f811894bd90>

Ca now moves!

In [83]:
list(zip(features.atom_resid, features.atom_resname))

[(tensor([1], dtype=torch.int32), array(['GLY'], dtype='<U3')),
 (tensor([1], dtype=torch.int32), array(['GLY'], dtype='<U3')),
 (tensor([1], dtype=torch.int32), array(['GLY'], dtype='<U3')),
 (tensor([1], dtype=torch.int32), array(['GLY'], dtype='<U3')),
 (tensor([2], dtype=torch.int32), array(['CYS'], dtype='<U3')),
 (tensor([2], dtype=torch.int32), array(['CYS'], dtype='<U3')),
 (tensor([2], dtype=torch.int32), array(['CYS'], dtype='<U3')),
 (tensor([2], dtype=torch.int32), array(['CYS'], dtype='<U3')),
 (tensor([2], dtype=torch.int32), array(['CYS'], dtype='<U3')),
 (tensor([2], dtype=torch.int32), array(['CYS'], dtype='<U3')),
 (tensor([3], dtype=torch.int32), array(['PRO'], dtype='<U3')),
 (tensor([3], dtype=torch.int32), array(['PRO'], dtype='<U3')),
 (tensor([3], dtype=torch.int32), array(['PRO'], dtype='<U3')),
 (tensor([3], dtype=torch.int32), array(['PRO'], dtype='<U3')),
 (tensor([3], dtype=torch.int32), array(['PRO'], dtype='<U3')),
 (tensor([3], dtype=torch.int32), array(

In [84]:
features = featurizer.featurize_one(clean_chain, metal_unknown=True)
features_masked = featurizer._collapse_and_noise_residues(features, resid=19, ca_fixed=True, limb_atom_noise_sigma=.1, center_atom_noise_sigma=1.0)
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0])

<py3Dmol.view at 0x7f81a78a6880>

The metal residues is successfully noised. Try with flow matching eg. at an interpolation time

In [85]:
features = featurizer.featurize_one(clean_chain, metal_unknown=False)
features_masked = featurizer._collapse_and_noise_residues(features, resid=8, ca_fixed=True, limb_atom_noise_sigma=.5, center_atom_noise_sigma=1.0, time=0.0)
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0])

<py3Dmol.view at 0x7f811894b310>

In [86]:
features_masked.time

tensor([0.])

In [87]:
flow = features_masked.position_flow_labels
features_masked.positions = features_masked.positions + flow *(1 - features_masked.time.item())
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0])

<py3Dmol.view at 0x7f80e84f4580>

The flow field recovers the final structure

In [88]:
features = featurizer.featurize_one(clean_chain, metal_unknown=False)
features_masked = featurizer._collapse_and_noise_residues(features, resid=8, ca_fixed=True, limb_atom_noise_sigma=.5, center_atom_noise_sigma=1.0, time=1.0)
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0])

<py3Dmol.view at 0x7f746fb27040>

time=1 also prodices the final structure

In [89]:
features = featurizer.featurize_one(clean_chain, metal_unknown=False)
features_masked = featurizer._collapse_and_noise_residues(features, resid=8, ca_fixed=True, limb_atom_noise_sigma=.5, center_atom_noise_sigma=1.0, time=0.7)
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0])

<py3Dmol.view at 0x7f746fad6f40>

In [90]:
features_masked.positions = features_masked.positions + features_masked.position_flow_labels * (1 - features_masked.time.item())
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0])

<py3Dmol.view at 0x7f80e84f4070>

Show the flow field

In [91]:
features = featurizer.featurize_one(clean_chain, metal_unknown=False)
features_masked = featurizer._collapse_and_noise_residues(features, resid=8, ca_fixed=True, limb_atom_noise_sigma=.5, center_atom_noise_sigma=1.0, time=0.5)
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0], velocities= 'flow')

<py3Dmol.view at 0x7f81a78adca0>

Check that the noising of all other atoms works

In [92]:
features = featurizer.featurize_one(clean_chain, metal_unknown=False)
features_masked = featurizer._collapse_and_noise_residues(features, resid=8, ca_fixed=True, limb_atom_noise_sigma=.5, center_atom_noise_sigma=1.0, time=0.0, other_atom_noise_sigma=.6)
visualize_protein_data_3d(features_masked, highlight_atoms=np.where(features_masked.atom_noised_mask)[0], velocities= 'flow')

<py3Dmol.view at 0x7f8118949580>

In [93]:
features.atom_name[~features_masked.atom_noised_mask]

array([['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA'],
       ['CA']], dtype='<U3')

## Node masking for eg. pretraining

In [95]:
features = featurizer.featurize_one(clean_chain, metal_unknown=True)

In [96]:
visualize_protein_data_3d(features, highlight_atoms=[0,1,2,4,5,6,8,9,10])

<py3Dmol.view at 0x7f80e84f4220>

In [97]:
features_masked = featurizer._mask_atoms(
    features, indices_to_mask=[0,], indices_to_tweak= [4,5,6], indices_to_keep=[8,9,10]
)
visualize_protein_data_3d(features, highlight_atoms=list(np.where(features_masked.atom_masked_mask)[0]))

<py3Dmol.view at 0x7f80e84f4f10>

Indeed the one masked atom lost its bonds, the others (eg random or keep) kept their bonds and are included in the loss mask. Last just confirm that the tweaked and same atoms have the appropriate tokens

In [98]:
tokenizer.i2d[features.element[0].item()], tokenizer.i2d[features.element_labels[0].item()]

('<MASK>', 'N')

In [99]:
for id in features.atom_masked_mask.nonzero(as_tuple=True)[0]:
    print(tokenizer.i2d[features.element[id].item()], tokenizer.i2d[features.element_labels[id].item()])

<MASK> N
Au N
Tb C
Ca C
C C
S S
N N


## Try mutating

In [100]:
site_chain.residues

{'1': Residue(name='GLY', atoms={'N': Atom(name=('A', '1', 'GLY', 'N'), xyz=[-19.936, 7.408, 17.572], occ=1.0, bfac=8.54, leaving=False, leaving_group=['H2'], parent='CA', element=7, metal=False, charge=0, hyb=2, nhyd=1, hvydeg=1, align=1, hetero=False), 'CA': Atom(name=('A', '1', 'GLY', 'CA'), xyz=[-19.452, 6.122, 17.104], occ=1.0, bfac=9.08, leaving=False, leaving_group=[], parent='C', element=6, metal=False, charge=0, hyb=3, nhyd=2, hvydeg=2, align=1, hetero=False), 'C': Atom(name=('A', '1', 'GLY', 'C'), xyz=[-18.415, 5.592, 18.061], occ=1.0, bfac=7.63, leaving=False, leaving_group=['OXT', 'HXT'], parent='OXT', element=6, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=3, align=1, hetero=False), 'O': Atom(name=('A', '1', 'GLY', 'O'), xyz=[-17.736, 6.358, 18.754], occ=1.0, bfac=8.77, leaving=False, leaving_group=[], parent='C', element=8, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=1, align=1, hetero=False)}, bonds=[Bond(a=('A', '1', 'GLY', 'N'), b=('A', '1', 'GLY', 'CA'), aromatic=Fal

In [101]:
from metalsitenn.placer_modules.cifutils import mutate_chain
new_site = mutate_chain(
    site_chain, 
    target_res_num='8',
    target_res_name='TRP',
    new_res_name='GLY')

In [102]:
new_site.residues['8'].atoms

{'N': Atom(name=('A', '8', 'GLY', 'N'), xyz=[-14.36, 1.558, 7.279], occ=0.0, bfac=0.0, leaving=False, leaving_group=['H2'], parent='CA', element=7, metal=False, charge=1, hyb=3, nhyd=3, hvydeg=1, align=1, hetero=False),
 'CA': Atom(name=('A', '8', 'GLY', 'CA'), xyz=[-14.36, 1.558, 7.279], occ=0.0, bfac=0.0, leaving=False, leaving_group=[], parent='C', element=6, metal=False, charge=0, hyb=3, nhyd=2, hvydeg=2, align=1, hetero=False),
 'C': Atom(name=('A', '8', 'GLY', 'C'), xyz=[-14.36, 1.558, 7.279], occ=0.0, bfac=0.0, leaving=False, leaving_group=['OXT', 'HXT'], parent='OXT', element=6, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=3, align=1, hetero=False),
 'O': Atom(name=('A', '8', 'GLY', 'O'), xyz=[-14.36, 1.558, 7.279], occ=0.0, bfac=0.0, leaving=False, leaving_group=[], parent='C', element=8, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=1, align=1, hetero=False),
 'OXT': Atom(name=('A', '8', 'GLY', 'OXT'), xyz=[-14.36, 1.558, 7.279], occ=0.0, bfac=0.0, leaving=True, leaving_group=

In [103]:
new_site.atoms

{('A',
  '1',
  'GLY',
  'N'): Atom(name=('A', '1', 'GLY', 'N'), xyz=[-19.936, 7.408, 17.572], occ=1.0, bfac=8.54, leaving=False, leaving_group=['H2'], parent='CA', element=7, metal=False, charge=0, hyb=2, nhyd=1, hvydeg=1, align=1, hetero=False),
 ('A',
  '1',
  'GLY',
  'CA'): Atom(name=('A', '1', 'GLY', 'CA'), xyz=[-19.452, 6.122, 17.104], occ=1.0, bfac=9.08, leaving=False, leaving_group=[], parent='C', element=6, metal=False, charge=0, hyb=3, nhyd=2, hvydeg=2, align=1, hetero=False),
 ('A',
  '1',
  'GLY',
  'C'): Atom(name=('A', '1', 'GLY', 'C'), xyz=[-18.415, 5.592, 18.061], occ=1.0, bfac=7.63, leaving=False, leaving_group=['OXT', 'HXT'], parent='OXT', element=6, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=3, align=1, hetero=False),
 ('A',
  '1',
  'GLY',
  'O'): Atom(name=('A', '1', 'GLY', 'O'), xyz=[-17.736, 6.358, 18.754], occ=1.0, bfac=8.77, leaving=False, leaving_group=[], parent='C', element=8, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=1, align=1, hetero=False),
 ('A',


In [104]:
atom_features, bond_features, topology = featurizer(new_site, metal_unknown=False)

AttributeError: 'MetalSiteFeaturizer' object has no attribute 'distances'

In [None]:
visualize_featurized_metal_site_3d(
    atom_features_dict=atom_features,
    bond_features_dict=bond_features)

<py3Dmol.view at 0x7f572c4040d0>

In [None]:
# mutate in the featurizer call - confirm that the collapse mask is correct
atom_features, bond_features, topology = featurizer(
    site_chain, mutations=[('8', 'TRP', 'GLY')])

In [None]:
list(zip(atom_features['atom_resid'], atom_features['atom_name'], atom_features['collapse_mask'], atom_features['positions']))

[('1', 'N', tensor([False]), tensor([1.6063, 6.7166, 4.8309])),
 ('1', 'CA', tensor([False]), tensor([2.0903, 5.4306, 4.3629])),
 ('1', 'C', tensor([False]), tensor([3.1273, 4.9006, 5.3199])),
 ('1', 'O', tensor([False]), tensor([3.8063, 5.6666, 6.0129])),
 ('2', 'N', tensor([False]), tensor([3.2583, 3.5786, 5.3489])),
 ('2', 'CA', tensor([False]), tensor([4.1653, 2.8766, 6.2589])),
 ('2', 'C', tensor([False]), tensor([5.1363, 1.9956, 5.4999])),
 ('2', 'O', tensor([False]), tensor([4.9873, 0.7776, 5.4289])),
 ('2', 'CB', tensor([False]), tensor([3.3883, 2.0976, 7.2919])),
 ('2', 'SG', tensor([False]), tensor([4.4633, 1.4656, 8.6199])),
 ('3', 'N', tensor([False]), tensor([6.1743, 2.6106, 4.9369])),
 ('3', 'CA', tensor([False]), tensor([6.4143, 4.0386, 4.7629])),
 ('3', 'C', tensor([False]), tensor([5.7693, 4.5536, 3.4749])),
 ('3', 'O', tensor([False]), tensor([5.1643, 3.7456, 2.7359])),
 ('3', 'CB', tensor([False]), tensor([7.9403, 4.0926, 4.6459])),
 ('3', 'CG', tensor([False]), tens

In [None]:
visualize_featurized_metal_site_3d(
    atom_features_dict=atom_features,
    bond_features_dict=bond_features)

<py3Dmol.view at 0x7f57e9b2bc10>

In [None]:
# check if the mutated residue gets collapsed that is uses the new collapse mask
atom_features, bond_features, topology = featurizer(
    site_chain, mutations=[('8', 'TRP', 'GLY')],
    do_collapsing=True,
    resids_to_collapse=[('8', 'GLY')],
    fixed_ca=False,
    center_gaussian_sigma=1.5)

In [None]:
list(zip(atom_features['atom_resid'], atom_features['atom_name'], atom_features['collapse_mask'], atom_features['positions']))

[('1', 'N', tensor([False]), tensor([1.6063, 6.7166, 4.8309])),
 ('1', 'CA', tensor([False]), tensor([2.0903, 5.4306, 4.3629])),
 ('1', 'C', tensor([False]), tensor([3.1273, 4.9006, 5.3199])),
 ('1', 'O', tensor([False]), tensor([3.8063, 5.6666, 6.0129])),
 ('2', 'N', tensor([False]), tensor([3.2583, 3.5786, 5.3489])),
 ('2', 'CA', tensor([False]), tensor([4.1653, 2.8766, 6.2589])),
 ('2', 'C', tensor([False]), tensor([5.1363, 1.9956, 5.4999])),
 ('2', 'O', tensor([False]), tensor([4.9873, 0.7776, 5.4289])),
 ('2', 'CB', tensor([False]), tensor([3.3883, 2.0976, 7.2919])),
 ('2', 'SG', tensor([False]), tensor([4.4633, 1.4656, 8.6199])),
 ('3', 'N', tensor([False]), tensor([6.1743, 2.6106, 4.9369])),
 ('3', 'CA', tensor([False]), tensor([6.4143, 4.0386, 4.7629])),
 ('3', 'C', tensor([False]), tensor([5.7693, 4.5536, 3.4749])),
 ('3', 'O', tensor([False]), tensor([5.1643, 3.7456, 2.7359])),
 ('3', 'CB', tensor([False]), tensor([7.9403, 4.0926, 4.6459])),
 ('3', 'CG', tensor([False]), tens

In [None]:
visualize_featurized_metal_site_3d(
    atom_features_dict=atom_features,
    bond_features_dict=bond_features)

<py3Dmol.view at 0x7f4ab7926700>

Indeed CA is now movable, adn the positions got tweeked a bit.

## Topology

In [None]:
features = featurizer.featurize_one(clean_chain, metal_unknown=True)
topology_data = features.topology

In [None]:
topology_data.keys()

dict_keys(['bonds', 'bond_lengths', 'angles', 'torsions', 'chirals', 'planars', 'permuts', 'frames'])

In [None]:
topology_data['frames'].size()

torch.Size([115, 3])

In [None]:
featurizer.get_feature_vocab_sizes()

{'element': 48,
 'charge': 8,
 'nhyd': 7,
 'hyb': 8,
 'bond_order': 6,
 'is_in_ring': 3,
 'is_aromatic': 3}

In [None]:
features.distances

tensor([[1.4516],
        [2.4188],
        [2.7092],
        ...,
        [5.3327],
        [3.1899],
        [2.7116]])