# Collate some site graphs into features

In [77]:
from metalsitenn.dataloading import MetalSiteDataset
from metalsitenn.featurizer import MetalSiteFeaturizer
from metalsitenn.utils import visualize_metal_site_3d, visualize_chain_3d,visualize_featurized_metal_site_3d
import numpy as np
import torch

In [3]:
ds = MetalSiteDataset(
    cache_folder='../data/1/1.1_parse_sites_metadata',
)

In [4]:
md = ds.get_all_metadata()
md

Unnamed: 0,pdb_code,site_name,site_idx,n_entities,n_atoms,n_bonds,metal,n_metals,n_waters,n_organic_ligands,n_metal_ligands,n_amino_acids,n_coordinating_amino_acids,n_nucleotides,non_residue_non_metal_names,n_non_residue_non_metal,coordination_distance
0,6fpw,6fpw_0,0,19,158,150,Fe,4,0,0,1,18,4,0,,0,2.9
1,6fpw,6fpw_1,1,22,138,124,Fe,3,3,0,1,18,3,0,,0,2.9
2,6fpw,6fpw_2,2,29,203,184,Fe,4,2,0,1,26,6,0,,0,2.9
3,6fpw,6fpw_3,3,20,141,123,"Fe,Ni",2,0,0,1,19,4,0,,0,2.9
4,6fpw,6fpw_4,4,19,101,83,Mg,1,6,0,1,12,3,0,,0,2.9
5,6fpw,6fpw_5,5,19,158,150,Fe,4,0,0,1,18,4,0,,0,2.9
6,6fpw,6fpw_6,6,23,146,131,Fe,3,3,0,1,19,3,0,,0,2.9
7,6fpw,6fpw_7,7,29,203,184,Fe,4,2,0,1,26,6,0,,0,2.9
8,6fpw,6fpw_8,8,20,141,123,"Fe,Ni",2,0,0,1,19,4,0,,0,2.9
9,6fpw,6fpw_9,9,19,101,83,Mg,1,6,0,1,12,3,0,,0,2.9


In [10]:
site_chain = ds[1][1]

In [11]:
visualize_chain_3d(site_chain)

<py3Dmol.view at 0x7f4ab7a4e730>

In [12]:
featurizer = MetalSiteFeaturizer(
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic',]
)
features = featurizer(site_chain, metal_unknown=False)

In [13]:
atom__features, bond_features, topology_data = features

In [14]:
atom__features.keys()

dict_keys(['atom_resid', 'atom_resname', 'atom_name', 'atom_ishetero', 'element', 'charge', 'nhyd', 'hyb', 'atom_loss_mask', 'collapse_mask', 'positions'])

In [15]:
bond_features.keys()

dict_keys(['bond_order', 'is_in_ring', 'is_aromatic', 'bond_distances', 'bond_loss_mask'])

In [16]:
topology_data.keys()

dict_keys(['bonds', 'bond_lengths', 'angles', 'torsions', 'chirals', 'planars', 'permuts', 'frames'])

In [17]:
atom__features['positions']

tensor([[ 1.2071e+00,  6.6543e+00,  4.9434e+00],
        [ 1.6911e+00,  5.3683e+00,  4.4754e+00],
        [ 2.7281e+00,  4.8383e+00,  5.4324e+00],
        [ 3.4071e+00,  5.6043e+00,  6.1254e+00],
        [ 2.8591e+00,  3.5163e+00,  5.4614e+00],
        [ 3.7661e+00,  2.8143e+00,  6.3714e+00],
        [ 4.7371e+00,  1.9333e+00,  5.6124e+00],
        [ 4.5881e+00,  7.1526e-01,  5.5414e+00],
        [ 2.9891e+00,  2.0353e+00,  7.4044e+00],
        [ 4.0641e+00,  1.4033e+00,  8.7324e+00],
        [ 5.7751e+00,  2.5483e+00,  5.0494e+00],
        [ 6.0151e+00,  3.9763e+00,  4.8754e+00],
        [ 5.3701e+00,  4.4913e+00,  3.5874e+00],
        [ 4.7651e+00,  3.6833e+00,  2.8484e+00],
        [ 7.5411e+00,  4.0303e+00,  4.7584e+00],
        [ 7.8921e+00,  2.7673e+00,  4.0054e+00],
        [ 6.8321e+00,  1.7613e+00,  4.3834e+00],
        [-6.8359e+00,  3.6373e+00, -3.7826e+00],
        [-6.1249e+00,  2.5203e+00, -4.4016e+00],
        [-7.0399e+00,  1.7833e+00, -5.3786e+00],
        [-6.6889e+00

### Observe vocab

In [18]:
featurizer.tokenizers['element'].i2d

{0: '<UNK>',
 1: '<MASK>',
 2: 'Ag',
 3: 'Al',
 4: 'Au',
 5: 'Br',
 6: 'C',
 7: 'Ca',
 8: 'Cd',
 9: 'Cl',
 10: 'Co',
 11: 'Cr',
 12: 'Cu',
 13: 'D',
 14: 'Dy',
 15: 'F',
 16: 'Fe',
 17: 'Ga',
 18: 'H',
 19: 'Hg',
 20: 'I',
 21: 'Ir',
 22: 'K',
 23: 'La',
 24: 'Li',
 25: 'Mg',
 26: 'Mn',
 27: 'Mo',
 28: 'N',
 29: 'Na',
 30: 'Nd',
 31: 'Ni',
 32: 'O',
 33: 'P',
 34: 'Pb',
 35: 'Pr',
 36: 'Pt',
 37: 'S',
 38: 'Se',
 39: 'Si',
 40: 'Tb',
 41: 'Te',
 42: 'Ti',
 43: 'U',
 44: 'V',
 45: 'W',
 46: 'Zn',
 47: '<METAL>'}

In [20]:
atom_id = np.where(np.array(atom__features['atom_name']) == 'CZ2')[0][0]

In [21]:
atom__features['element']

tensor([[28],
        [ 6],
        [ 6],
        [32],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [37],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [ 6],
        [ 6],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [32],
        [ 6],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [ 6],
        [32],
        [28],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [37],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [28],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
      

In [22]:
atom__features['element'][atom_id]

tensor([6])

In [23]:
featurizer.tokenizers['element'].metal_token_id

47

In [24]:
# with unknown metal
atom__features, bond_features, topology = featurizer(site_chain, metal_unknown=True)
atom__features['element']

tensor([[28],
        [ 6],
        [ 6],
        [32],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [37],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [ 6],
        [ 6],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [32],
        [ 6],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [ 6],
        [32],
        [28],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [37],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [28],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
      

In [25]:
# check other metals are properly converted to metal token
featurizer.tokenizers['element'].encode('Ho')

47

### Charge

In [26]:
featurizer.tokenizers['charge'].i2d

{0: '<MASK>', 1: -3, 2: -2, 3: -1, 4: 0, 5: 1, 6: 2, 7: 3}

In [27]:
np.array(atom__features['atom_name']).reshape(-1,1)[atom__features['charge'] !=4]

array(['NZ'], dtype='<U3')

### Num hydr

In [28]:
featurizer.tokenizers['nhyd'].i2d

{0: '<MASK>', 1: 0, 2: 1, 3: 2, 4: 3, 5: 4}

In [29]:
atom__features['nhyd'][atom_id]

tensor([2])

### hybridization

Note in the og code hyb goes i.e., 1 = sp, 2 = sp2, 3 = sp3 ...).

In [30]:
featurizer.tokenizers['hyb'].i2d

{0: '<UNK>', 1: '<MASK>', 2: 0, 3: 1, 4: 2, 5: 3, 6: 4, 7: 5}

In [31]:
atom__features['hyb'][atom_id]

tensor([4])

In [32]:
# CZ2 in benzene ring is indeed sp2

### Bond order

In [33]:
featurizer.tokenizers['bond_order'].i2d

{0: '<UNK>', 1: '<MASK>', 2: 0, 3: 1, 4: 2, 5: 3, 6: 4}

In [34]:
src, dst = topology['bonds'].T

In [35]:
site_chain.residues

{'1': Residue(name='GLY', atoms={'N': Atom(name=('A', '1', 'GLY', 'N'), xyz=[-19.936, 7.408, 17.572], occ=1.0, bfac=8.54, leaving=False, leaving_group=['H2'], parent='CA', element=7, metal=False, charge=0, hyb=2, nhyd=1, hvydeg=1, align=1, hetero=False), 'CA': Atom(name=('A', '1', 'GLY', 'CA'), xyz=[-19.452, 6.122, 17.104], occ=1.0, bfac=9.08, leaving=False, leaving_group=[], parent='C', element=6, metal=False, charge=0, hyb=3, nhyd=2, hvydeg=2, align=1, hetero=False), 'C': Atom(name=('A', '1', 'GLY', 'C'), xyz=[-18.415, 5.592, 18.061], occ=1.0, bfac=7.63, leaving=False, leaving_group=['OXT', 'HXT'], parent='OXT', element=6, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=3, align=1, hetero=False), 'O': Atom(name=('A', '1', 'GLY', 'O'), xyz=[-17.736, 6.358, 18.754], occ=1.0, bfac=8.77, leaving=False, leaving_group=[], parent='C', element=8, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=1, align=1, hetero=False)}, bonds=[Bond(a=('A', '1', 'GLY', 'N'), b=('A', '1', 'GLY', 'CA'), aromatic=Fal

In [36]:
bond_order = bond_features['bond_order']

In [37]:
bond_order.shape

torch.Size([138, 138])

In [38]:
list(zip(bond_order[src, dst], np.array(atom__features['atom_resid'])[src], np.array(atom__features['atom_name'])[src], np.array(atom__features['atom_resid'])[dst], np.array(atom__features['atom_name'])[dst]))

[(tensor(4), '8', 'C', '8', 'O'),
 (tensor(3), '7', 'N', '7', 'CA'),
 (tensor(3), '8', 'CA', '8', 'C'),
 (tensor(3), '16', 'CA', '16', 'CB'),
 (tensor(3), '5', 'CB', '5', 'OG1'),
 (tensor(3), '10', 'CA', '10', 'CB'),
 (tensor(3), '13', 'N', '13', 'CA'),
 (tensor(3), '5', 'N', '5', 'CA'),
 (tensor(3), '14', 'CA', '14', 'C'),
 (tensor(3), '13', 'CB', '13', 'SG'),
 (tensor(3), '14', 'CG', '14', 'CD1'),
 (tensor(4), '9', 'CD2', '9', 'CE2'),
 (tensor(3), '10', 'CA', '10', 'C'),
 (tensor(3), '3', 'CG', '3', 'CD'),
 (tensor(3), '17', 'CA', '17', 'CB'),
 (tensor(3), '17', 'CA', '17', 'C'),
 (tensor(3), '8', 'N', '8', 'CA'),
 (tensor(4), '1', 'C', '1', 'O'),
 (tensor(3), '9', 'CA', '9', 'C'),
 (tensor(3), '2', 'CB', '2', 'SG'),
 (tensor(3), '9', 'CD1', '9', 'CE1'),
 (tensor(3), '4', 'CA', '4', 'CB'),
 (tensor(3), '9', 'CA', '9', 'CB'),
 (tensor(4), '14', 'C', '14', 'O'),
 (tensor(3), '8', 'CD1', '8', 'NE1'),
 (tensor(3), '11', 'CB', '11', 'CG1'),
 (tensor(4), '16', 'C', '16', 'O'),
 (tensor(4),

### Is aromatic

In [39]:
featurizer.tokenizers['is_aromatic'].i2d

{0: '<MASK>', 1: False, 2: True}

In [40]:
(bond_features['is_aromatic'] ==2).sum()

tensor(32)

In [41]:
is_aromatic = bond_features['is_aromatic']

In [42]:
is_aromatic.shape

torch.Size([138, 138])

In [43]:
src, dst = np.where(is_aromatic == 2)

In [44]:
site_chain.residues

{'1': Residue(name='GLY', atoms={'N': Atom(name=('A', '1', 'GLY', 'N'), xyz=[-19.936, 7.408, 17.572], occ=1.0, bfac=8.54, leaving=False, leaving_group=['H2'], parent='CA', element=7, metal=False, charge=0, hyb=2, nhyd=1, hvydeg=1, align=1, hetero=False), 'CA': Atom(name=('A', '1', 'GLY', 'CA'), xyz=[-19.452, 6.122, 17.104], occ=1.0, bfac=9.08, leaving=False, leaving_group=[], parent='C', element=6, metal=False, charge=0, hyb=3, nhyd=2, hvydeg=2, align=1, hetero=False), 'C': Atom(name=('A', '1', 'GLY', 'C'), xyz=[-18.415, 5.592, 18.061], occ=1.0, bfac=7.63, leaving=False, leaving_group=['OXT', 'HXT'], parent='OXT', element=6, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=3, align=1, hetero=False), 'O': Atom(name=('A', '1', 'GLY', 'O'), xyz=[-17.736, 6.358, 18.754], occ=1.0, bfac=8.77, leaving=False, leaving_group=[], parent='C', element=8, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=1, align=1, hetero=False)}, bonds=[Bond(a=('A', '1', 'GLY', 'N'), b=('A', '1', 'GLY', 'CA'), aromatic=Fal

In [45]:
list(zip(bond_order[src, dst], np.array(atom__features['atom_resid'])[src], np.array(atom__features['atom_name'])[src], np.array(atom__features['atom_resid'])[dst], np.array(atom__features['atom_name'])[dst]))

[(tensor(4), '8', 'CG', '8', 'CD1'),
 (tensor(3), '8', 'CG', '8', 'CD2'),
 (tensor(4), '8', 'CD1', '8', 'CG'),
 (tensor(3), '8', 'CD1', '8', 'NE1'),
 (tensor(3), '8', 'CD2', '8', 'CG'),
 (tensor(4), '8', 'CD2', '8', 'CE2'),
 (tensor(3), '8', 'CD2', '8', 'CE3'),
 (tensor(3), '8', 'NE1', '8', 'CD1'),
 (tensor(3), '8', 'NE1', '8', 'CE2'),
 (tensor(4), '8', 'CE2', '8', 'CD2'),
 (tensor(3), '8', 'CE2', '8', 'NE1'),
 (tensor(3), '8', 'CE2', '8', 'CZ2'),
 (tensor(3), '8', 'CE3', '8', 'CD2'),
 (tensor(4), '8', 'CE3', '8', 'CZ3'),
 (tensor(3), '8', 'CZ2', '8', 'CE2'),
 (tensor(4), '8', 'CZ2', '8', 'CH2'),
 (tensor(4), '8', 'CZ3', '8', 'CE3'),
 (tensor(3), '8', 'CZ3', '8', 'CH2'),
 (tensor(4), '8', 'CH2', '8', 'CZ2'),
 (tensor(3), '8', 'CH2', '8', 'CZ3'),
 (tensor(4), '9', 'CG', '9', 'CD1'),
 (tensor(3), '9', 'CG', '9', 'CD2'),
 (tensor(4), '9', 'CD1', '9', 'CG'),
 (tensor(3), '9', 'CD1', '9', 'CE1'),
 (tensor(3), '9', 'CD2', '9', 'CG'),
 (tensor(4), '9', 'CD2', '9', 'CE2'),
 (tensor(3), '9', 'C

The tryptophan and phenolalinine in here has the proper aromatic labels.

### Is in ring

In [46]:
featurizer.tokenizers['is_in_ring'].i2d

{0: '<MASK>', 1: False, 2: True}

In [47]:
src, dst = np.where(bond_features['is_in_ring'] == 2)

In [48]:
list(zip(bond_features['bond_order'][src, dst], np.array(atom__features['atom_resid'])[src], np.array(atom__features['atom_name'])[src], np.array(atom__features['atom_resid'])[dst], np.array(atom__features['atom_name'])[dst]))

[(tensor(3), '3', 'N', '3', 'CA'),
 (tensor(3), '3', 'N', '3', 'CD'),
 (tensor(3), '3', 'CA', '3', 'N'),
 (tensor(3), '3', 'CA', '3', 'CB'),
 (tensor(3), '3', 'CB', '3', 'CA'),
 (tensor(3), '3', 'CB', '3', 'CG'),
 (tensor(3), '3', 'CG', '3', 'CB'),
 (tensor(3), '3', 'CG', '3', 'CD'),
 (tensor(3), '3', 'CD', '3', 'N'),
 (tensor(3), '3', 'CD', '3', 'CG'),
 (tensor(4), '8', 'CG', '8', 'CD1'),
 (tensor(3), '8', 'CG', '8', 'CD2'),
 (tensor(4), '8', 'CD1', '8', 'CG'),
 (tensor(3), '8', 'CD1', '8', 'NE1'),
 (tensor(3), '8', 'CD2', '8', 'CG'),
 (tensor(4), '8', 'CD2', '8', 'CE2'),
 (tensor(3), '8', 'CD2', '8', 'CE3'),
 (tensor(3), '8', 'NE1', '8', 'CD1'),
 (tensor(3), '8', 'NE1', '8', 'CE2'),
 (tensor(4), '8', 'CE2', '8', 'CD2'),
 (tensor(3), '8', 'CE2', '8', 'NE1'),
 (tensor(3), '8', 'CE2', '8', 'CZ2'),
 (tensor(3), '8', 'CE3', '8', 'CD2'),
 (tensor(4), '8', 'CE3', '8', 'CZ3'),
 (tensor(3), '8', 'CZ2', '8', 'CE2'),
 (tensor(4), '8', 'CZ2', '8', 'CH2'),
 (tensor(4), '8', 'CZ3', '8', 'CE3'),
 (

Awesome. I see some Ca and Cb of only proline. And all of the tryptophan rings are in rings.

## Try masking atoms

In [50]:
masked_atom_features, masked_bond_features = featurizer.mask_atoms(
    atom__features, bond_features,
    random_token_prob=0.05, mask_prob=0.15,
)

In [51]:
masked_atom_features.keys()

dict_keys(['atom_resid', 'atom_resname', 'atom_name', 'atom_ishetero', 'element', 'charge', 'nhyd', 'hyb', 'atom_loss_mask', 'collapse_mask', 'positions', 'element_labels'])

In [52]:
(masked_atom_features['element'] == featurizer.tokenizers['element'].mask_token_id).float().mean()

tensor(0.1232)

In [53]:
masked_atom_features['element'][masked_atom_features['element'] != masked_atom_features['element_labels']]

tensor([ 1,  1,  1,  1, 38,  1,  1, 22,  1, 25, 10,  1,  1, 20,  1,  1,  1,  1,
         1,  1,  1, 28,  1])

In [54]:
masked_atom_features['atom_loss_mask'].shape

torch.Size([138, 1])

In [55]:
masked_atom_features['element_labels'][masked_atom_features['atom_loss_mask']].shape

torch.Size([26])

In [58]:
featurizer.tokenizers['element'].i2d

{0: '<UNK>',
 1: '<MASK>',
 2: 'Ag',
 3: 'Al',
 4: 'Au',
 5: 'Br',
 6: 'C',
 7: 'Ca',
 8: 'Cd',
 9: 'Cl',
 10: 'Co',
 11: 'Cr',
 12: 'Cu',
 13: 'D',
 14: 'Dy',
 15: 'F',
 16: 'Fe',
 17: 'Ga',
 18: 'H',
 19: 'Hg',
 20: 'I',
 21: 'Ir',
 22: 'K',
 23: 'La',
 24: 'Li',
 25: 'Mg',
 26: 'Mn',
 27: 'Mo',
 28: 'N',
 29: 'Na',
 30: 'Nd',
 31: 'Ni',
 32: 'O',
 33: 'P',
 34: 'Pb',
 35: 'Pr',
 36: 'Pt',
 37: 'S',
 38: 'Se',
 39: 'Si',
 40: 'Tb',
 41: 'Te',
 42: 'Ti',
 43: 'U',
 44: 'V',
 45: 'W',
 46: 'Zn',
 47: '<METAL>'}

In [57]:
torch.cat([masked_atom_features['element'][masked_atom_features['atom_loss_mask']].reshape(-1,1), masked_atom_features['element_labels'][masked_atom_features['atom_loss_mask']].reshape(-1,1)], axis=-1)

tensor([[32, 32],
        [ 1, 32],
        [ 1, 28],
        [ 1,  6],
        [ 1,  6],
        [38,  6],
        [ 1,  6],
        [ 1, 32],
        [ 6,  6],
        [22, 28],
        [ 1,  6],
        [25,  6],
        [10, 32],
        [ 1,  6],
        [ 1,  6],
        [20,  6],
        [ 1,  6],
        [ 1,  6],
        [ 1,  6],
        [ 1, 28],
        [ 1,  6],
        [ 1,  6],
        [ 6,  6],
        [ 1, 32],
        [28,  6],
        [ 1, 37]])

In [59]:
actually_masked = masked_atom_features['element'] == featurizer.tokenizers['element'].mask_token_id

In [60]:
masked_atom_features['charge'][actually_masked]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [61]:
masked_atom_features['nhyd'][actually_masked]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [63]:
featurizer.tokenizers['hyb'].i2d

{0: '<UNK>', 1: '<MASK>', 2: 0, 3: 1, 4: 2, 5: 3, 6: 4, 7: 5}

In [62]:
masked_atom_features['hyb'][actually_masked]

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

How were the bonds masked?

In [64]:
actually_masked_id = torch.where(actually_masked)[0]

In [65]:
masked_bond_features['bond_order'][actually_masked_id]

tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])

In [66]:
masked_bond_features['bond_order'][actually_masked_id].shape

torch.Size([17, 138])

In [67]:
masked_bond_features['bond_order'].T[actually_masked_id]

tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])

### topology data

In [69]:
np.array(atom__features['atom_name'])[topology['chirals']]

array([['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C'],
       ['CB', 'CA', 'CG2', 'CG1'],
       ['CA', 'N', 'CB', 'C'],
       ['CB', 'CA', 'CG2', 'OG1'],
       ['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C'],
       ['CB', 'CA', 'CG2', 'CG1'],
       ['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C']], dtype='<U3')

# attempt to corrupt residues

In [70]:
corrupted_atom_features, corrupted_bond_features = featurizer.collapse_residues(
    atom__features, bond_features, 
    collapse_rate = 0.2,
    min_residues = 1,
    fixed_ca=True,
    center_gaussian_sigma=1,
    collapse_gaussian_sigma=0.2
)

In [71]:
corrupted_atom_features.keys()

dict_keys(['atom_resid', 'atom_resname', 'atom_name', 'atom_ishetero', 'element', 'charge', 'nhyd', 'hyb', 'atom_loss_mask', 'collapse_mask', 'positions', 'positions_labels'])

In [72]:
corrupted_atom_features['collapse_mask']

tensor([[ True],
        [False],
        [ True],
        [ True],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [ True],
        [False],
        [ True],
        [ True],
        [ True],
        [ True],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False

In [73]:
np.array(corrupted_atom_features['atom_resid']).reshape(-1,1)[corrupted_atom_features['collapse_mask']]

array(['1', '1', '1', '7', '7', '7', '7', '7', '15', '15', '15', '18',
       '18', '18', '18', '18', '18', '18', '18'], dtype='<U2')

In [74]:
list(zip(corrupted_atom_features['atom_resid'], corrupted_atom_features['atom_name'], corrupted_atom_features['collapse_mask'], corrupted_atom_features['positions']))

[('1', 'N', tensor([True]), tensor([1.4516, 5.4261, 4.4644])),
 ('1', 'CA', tensor([False]), tensor([1.6911, 5.3683, 4.4754])),
 ('1', 'C', tensor([True]), tensor([1.9691, 4.9829, 4.5932])),
 ('1', 'O', tensor([True]), tensor([1.8015, 5.0461, 4.4342])),
 ('2', 'N', tensor([False]), tensor([2.8591, 3.5163, 5.4614])),
 ('2', 'CA', tensor([False]), tensor([3.7661, 2.8143, 6.3714])),
 ('2', 'C', tensor([False]), tensor([4.7371, 1.9333, 5.6124])),
 ('2', 'O', tensor([False]), tensor([4.5881, 0.7153, 5.5414])),
 ('2', 'CB', tensor([False]), tensor([2.9891, 2.0353, 7.4044])),
 ('2', 'SG', tensor([False]), tensor([4.0641, 1.4033, 8.7324])),
 ('3', 'N', tensor([False]), tensor([5.7751, 2.5483, 5.0494])),
 ('3', 'CA', tensor([False]), tensor([6.0151, 3.9763, 4.8754])),
 ('3', 'C', tensor([False]), tensor([5.3701, 4.4913, 3.5874])),
 ('3', 'O', tensor([False]), tensor([4.7651, 3.6833, 2.8484])),
 ('3', 'CB', tensor([False]), tensor([7.5411, 4.0303, 4.7584])),
 ('3', 'CG', tensor([False]), tensor(

The ones labed as collapsed are around the CA

In [75]:
# try with letting CA move
corrupted_atom_features, corrupted_bond_features = featurizer.collapse_residues(
    atom__features, bond_features, 
    collapse_rate = 0.2,
    min_residues = 1,
    fixed_ca=False,
    center_gaussian_sigma=1,
    collapse_gaussian_sigma=0.2
)

In [76]:
list(zip(corrupted_atom_features['atom_resid'], corrupted_atom_features['atom_name'], corrupted_atom_features['collapse_mask'], corrupted_atom_features['positions']))

[('1', 'N', tensor([False]), tensor([1.2071, 6.6543, 4.9434])),
 ('1', 'CA', tensor([False]), tensor([1.6911, 5.3683, 4.4754])),
 ('1', 'C', tensor([False]), tensor([2.7281, 4.8383, 5.4324])),
 ('1', 'O', tensor([False]), tensor([3.4071, 5.6043, 6.1254])),
 ('2', 'N', tensor([False]), tensor([2.8591, 3.5163, 5.4614])),
 ('2', 'CA', tensor([False]), tensor([3.7661, 2.8143, 6.3714])),
 ('2', 'C', tensor([False]), tensor([4.7371, 1.9333, 5.6124])),
 ('2', 'O', tensor([False]), tensor([4.5881, 0.7153, 5.5414])),
 ('2', 'CB', tensor([False]), tensor([2.9891, 2.0353, 7.4044])),
 ('2', 'SG', tensor([False]), tensor([4.0641, 1.4033, 8.7324])),
 ('3', 'N', tensor([False]), tensor([5.7751, 2.5483, 5.0494])),
 ('3', 'CA', tensor([False]), tensor([6.0151, 3.9763, 4.8754])),
 ('3', 'C', tensor([False]), tensor([5.3701, 4.4913, 3.5874])),
 ('3', 'O', tensor([False]), tensor([4.7651, 3.6833, 2.8484])),
 ('3', 'CB', tensor([False]), tensor([7.5411, 4.0303, 4.7584])),
 ('3', 'CG', tensor([False]), tens

### Visualize masking and corruption

In [80]:
# untouched
visualize_featurized_metal_site_3d(
    atom_features_dict=atom__features,
    bond_features_dict=bond_features)
    

<py3Dmol.view at 0x7f4ab6bb70d0>

In [81]:
# atom masked 
visualize_featurized_metal_site_3d(
    atom_features_dict=masked_atom_features,
    bond_features_dict=masked_bond_features,
    mask_color='purple',
)

<py3Dmol.view at 0x7f4ab7b12880>

In [82]:
# residues collapsed
visualize_featurized_metal_site_3d(
    atom_features_dict=corrupted_atom_features,
    bond_features_dict=corrupted_bond_features,
    
)

<py3Dmol.view at 0x7f4ab80c0520>

## Try masking and corrupting in one go

In [83]:
atom_features, bond_features, topology = featurizer(
    site_chain, metal_unknown=False,
    do_masking=True, do_collapsing=True,
    random_token_prob=0.05, mask_prob=0.15,
    collapse_rate=0.2, min_residues=1,
    fixed_ca=True, center_gaussian_sigma=.5,
    collapse_gaussian_sigma=0.2
)

In [84]:
visualize_featurized_metal_site_3d(
    atom_features_dict=atom_features,
    bond_features_dict=bond_features,
    mask_color='purple')

<py3Dmol.view at 0x7f4ab7926370>

## Try mutating

In [85]:
site_chain.residues

{'1': Residue(name='GLY', atoms={'N': Atom(name=('A', '1', 'GLY', 'N'), xyz=[-19.936, 7.408, 17.572], occ=1.0, bfac=8.54, leaving=False, leaving_group=['H2'], parent='CA', element=7, metal=False, charge=0, hyb=2, nhyd=1, hvydeg=1, align=1, hetero=False), 'CA': Atom(name=('A', '1', 'GLY', 'CA'), xyz=[-19.452, 6.122, 17.104], occ=1.0, bfac=9.08, leaving=False, leaving_group=[], parent='C', element=6, metal=False, charge=0, hyb=3, nhyd=2, hvydeg=2, align=1, hetero=False), 'C': Atom(name=('A', '1', 'GLY', 'C'), xyz=[-18.415, 5.592, 18.061], occ=1.0, bfac=7.63, leaving=False, leaving_group=['OXT', 'HXT'], parent='OXT', element=6, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=3, align=1, hetero=False), 'O': Atom(name=('A', '1', 'GLY', 'O'), xyz=[-17.736, 6.358, 18.754], occ=1.0, bfac=8.77, leaving=False, leaving_group=[], parent='C', element=8, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=1, align=1, hetero=False)}, bonds=[Bond(a=('A', '1', 'GLY', 'N'), b=('A', '1', 'GLY', 'CA'), aromatic=Fal

In [87]:
from metalsitenn.placer_modules.cifutils import mutate_chain
new_site = mutate_chain(
    site_chain, 
    target_res_num='8',
    target_res_name='TRP',
    new_res_name='GLY')

In [88]:
new_site.residues['8'].atoms

{'N': Atom(name=('A', '8', 'GLY', 'N'), xyz=[-14.36, 1.558, 7.279], occ=0.0, bfac=0.0, leaving=False, leaving_group=['H2'], parent='CA', element=7, metal=False, charge=1, hyb=3, nhyd=3, hvydeg=1, align=1, hetero=False),
 'CA': Atom(name=('A', '8', 'GLY', 'CA'), xyz=[-14.36, 1.558, 7.279], occ=0.0, bfac=0.0, leaving=False, leaving_group=[], parent='C', element=6, metal=False, charge=0, hyb=3, nhyd=2, hvydeg=2, align=1, hetero=False),
 'C': Atom(name=('A', '8', 'GLY', 'C'), xyz=[-14.36, 1.558, 7.279], occ=0.0, bfac=0.0, leaving=False, leaving_group=['OXT', 'HXT'], parent='OXT', element=6, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=3, align=1, hetero=False),
 'O': Atom(name=('A', '8', 'GLY', 'O'), xyz=[-14.36, 1.558, 7.279], occ=0.0, bfac=0.0, leaving=False, leaving_group=[], parent='C', element=8, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=1, align=1, hetero=False),
 'OXT': Atom(name=('A', '8', 'GLY', 'OXT'), xyz=[-14.36, 1.558, 7.279], occ=0.0, bfac=0.0, leaving=True, leaving_group=

In [89]:
new_site.atoms

{('A',
  '1',
  'GLY',
  'N'): Atom(name=('A', '1', 'GLY', 'N'), xyz=[-19.936, 7.408, 17.572], occ=1.0, bfac=8.54, leaving=False, leaving_group=['H2'], parent='CA', element=7, metal=False, charge=0, hyb=2, nhyd=1, hvydeg=1, align=1, hetero=False),
 ('A',
  '1',
  'GLY',
  'CA'): Atom(name=('A', '1', 'GLY', 'CA'), xyz=[-19.452, 6.122, 17.104], occ=1.0, bfac=9.08, leaving=False, leaving_group=[], parent='C', element=6, metal=False, charge=0, hyb=3, nhyd=2, hvydeg=2, align=1, hetero=False),
 ('A',
  '1',
  'GLY',
  'C'): Atom(name=('A', '1', 'GLY', 'C'), xyz=[-18.415, 5.592, 18.061], occ=1.0, bfac=7.63, leaving=False, leaving_group=['OXT', 'HXT'], parent='OXT', element=6, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=3, align=1, hetero=False),
 ('A',
  '1',
  'GLY',
  'O'): Atom(name=('A', '1', 'GLY', 'O'), xyz=[-17.736, 6.358, 18.754], occ=1.0, bfac=8.77, leaving=False, leaving_group=[], parent='C', element=8, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=1, align=1, hetero=False),
 ('A',


In [90]:
atom_features, bond_features, topology = featurizer(new_site, metal_unknown=False)

In [91]:
visualize_featurized_metal_site_3d(
    atom_features_dict=atom_features,
    bond_features_dict=bond_features)

<py3Dmol.view at 0x7f572c4040d0>

In [92]:
# mutate in the featurizer call - confirm that the collapse mask is correct
atom_features, bond_features, topology = featurizer(
    site_chain, mutations=[('8', 'TRP', 'GLY')])

In [93]:
list(zip(atom_features['atom_resid'], atom_features['atom_name'], atom_features['collapse_mask'], atom_features['positions']))

[('1', 'N', tensor([False]), tensor([1.6063, 6.7166, 4.8309])),
 ('1', 'CA', tensor([False]), tensor([2.0903, 5.4306, 4.3629])),
 ('1', 'C', tensor([False]), tensor([3.1273, 4.9006, 5.3199])),
 ('1', 'O', tensor([False]), tensor([3.8063, 5.6666, 6.0129])),
 ('2', 'N', tensor([False]), tensor([3.2583, 3.5786, 5.3489])),
 ('2', 'CA', tensor([False]), tensor([4.1653, 2.8766, 6.2589])),
 ('2', 'C', tensor([False]), tensor([5.1363, 1.9956, 5.4999])),
 ('2', 'O', tensor([False]), tensor([4.9873, 0.7776, 5.4289])),
 ('2', 'CB', tensor([False]), tensor([3.3883, 2.0976, 7.2919])),
 ('2', 'SG', tensor([False]), tensor([4.4633, 1.4656, 8.6199])),
 ('3', 'N', tensor([False]), tensor([6.1743, 2.6106, 4.9369])),
 ('3', 'CA', tensor([False]), tensor([6.4143, 4.0386, 4.7629])),
 ('3', 'C', tensor([False]), tensor([5.7693, 4.5536, 3.4749])),
 ('3', 'O', tensor([False]), tensor([5.1643, 3.7456, 2.7359])),
 ('3', 'CB', tensor([False]), tensor([7.9403, 4.0926, 4.6459])),
 ('3', 'CG', tensor([False]), tens

In [94]:
visualize_featurized_metal_site_3d(
    atom_features_dict=atom_features,
    bond_features_dict=bond_features)

<py3Dmol.view at 0x7f57e9b2bc10>

In [97]:
# check if the mutated residue gets collapsed that is uses the new collapse mask
atom_features, bond_features, topology = featurizer(
    site_chain, mutations=[('8', 'TRP', 'GLY')],
    do_collapsing=True,
    resids_to_collapse=[('8', 'GLY')],
    fixed_ca=False,
    center_gaussian_sigma=1.5)

In [98]:
list(zip(atom_features['atom_resid'], atom_features['atom_name'], atom_features['collapse_mask'], atom_features['positions']))

[('1', 'N', tensor([False]), tensor([1.6063, 6.7166, 4.8309])),
 ('1', 'CA', tensor([False]), tensor([2.0903, 5.4306, 4.3629])),
 ('1', 'C', tensor([False]), tensor([3.1273, 4.9006, 5.3199])),
 ('1', 'O', tensor([False]), tensor([3.8063, 5.6666, 6.0129])),
 ('2', 'N', tensor([False]), tensor([3.2583, 3.5786, 5.3489])),
 ('2', 'CA', tensor([False]), tensor([4.1653, 2.8766, 6.2589])),
 ('2', 'C', tensor([False]), tensor([5.1363, 1.9956, 5.4999])),
 ('2', 'O', tensor([False]), tensor([4.9873, 0.7776, 5.4289])),
 ('2', 'CB', tensor([False]), tensor([3.3883, 2.0976, 7.2919])),
 ('2', 'SG', tensor([False]), tensor([4.4633, 1.4656, 8.6199])),
 ('3', 'N', tensor([False]), tensor([6.1743, 2.6106, 4.9369])),
 ('3', 'CA', tensor([False]), tensor([6.4143, 4.0386, 4.7629])),
 ('3', 'C', tensor([False]), tensor([5.7693, 4.5536, 3.4749])),
 ('3', 'O', tensor([False]), tensor([5.1643, 3.7456, 2.7359])),
 ('3', 'CB', tensor([False]), tensor([7.9403, 4.0926, 4.6459])),
 ('3', 'CG', tensor([False]), tens

In [99]:
visualize_featurized_metal_site_3d(
    atom_features_dict=atom_features,
    bond_features_dict=bond_features)

<py3Dmol.view at 0x7f4ab7926700>

Indeed CA is now movable, adn the positions got tweeked a bit.

### Topology

In [100]:
topology_data.keys()

dict_keys(['bonds', 'bond_lengths', 'angles', 'torsions', 'chirals', 'planars', 'permuts', 'frames'])

In [101]:
topology_data['frames'].size()

torch.Size([115, 3])

In [102]:
atom__features['positions'].shape

torch.Size([138, 3])

In [103]:
featurizer.get_feature_vocab_sizes()

{'element': 48,
 'charge': 8,
 'nhyd': 6,
 'hyb': 8,
 'bond_order': 7,
 'is_in_ring': 3,
 'is_aromatic': 3}