In [1]:
import sys

sys.path.append('/projects/metalsitenn/pdbx')

from metalsitenn.placer_modules.cifutils import CIFParser, mutate_chain

from metalsitenn.utils import visualize_metal_site_3d, visualize_chain_3d

from metalsitenn.featurizer import MetalSiteFeaturizer
from metalsitenn.utils import visualize_featurized_metal_site_3d
import pandas as pd
import numpy as np
import torch

In [2]:
parser = CIFParser()

In [3]:
parsed_data = parser.parse('/datasets/alphafold_data/data_v2/pdb_mmcif/mmcif_files/6fpw.cif')

In [4]:
chains, assemblies, covalent_bonds, metadata = parsed_data

In [5]:
chains.keys()

dict_keys(['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'AA'])

In [6]:
chains['A'].residues

{'1': Residue(name='LEU', atoms={'N': Atom(name='N', xyz=[0.0, 0.0, 0.0], occ=0.0, bfac=0.0, leaving=False, leaving_group=['H2'], parent='CA', element=7, metal=False, charge=1, hyb=3, nhyd=3, hvydeg=1, align=1, hetero=False), 'CA': Atom(name='CA', xyz=[0.0, 0.0, 0.0], occ=0.0, bfac=0.0, leaving=False, leaving_group=[], parent='CB', element=6, metal=False, charge=0, hyb=3, nhyd=1, hvydeg=3, align=1, hetero=False), 'C': Atom(name='C', xyz=[0.0, 0.0, 0.0], occ=0.0, bfac=0.0, leaving=False, leaving_group=['HXT', 'OXT'], parent='OXT', element=6, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=3, align=1, hetero=False), 'O': Atom(name='O', xyz=[0.0, 0.0, 0.0], occ=0.0, bfac=0.0, leaving=False, leaving_group=[], parent='C', element=8, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=1, align=1, hetero=False), 'CB': Atom(name='CB', xyz=[0.0, 0.0, 0.0], occ=0.0, bfac=0.0, leaving=False, leaving_group=[], parent='CG', element=6, metal=False, charge=0, hyb=3, nhyd=2, hvydeg=2, align=1, hetero=False), 'C

In [7]:
sites = parser.get_metal_sites(parsed_data, max_atoms_per_site=500, max_water_bfactor=15, merge_threshold=6, cutoff_distance=6, backbone_treatment='free')

In [8]:
site = sites[1]

In [9]:
site_chain = site['site_chain']

In [10]:
site_chain.residues

{'1': Residue(name='ILE', atoms={'N': Atom(name=('A', '1', 'ILE', 'N'), xyz=[-27.979, 4.391, 8.846], occ=1.0, bfac=10.19, leaving=False, leaving_group=['H2'], parent='CA', element=7, metal=False, charge=0, hyb=2, nhyd=1, hvydeg=1, align=1, hetero=False), 'CA': Atom(name=('A', '1', 'ILE', 'CA'), xyz=[-27.268, 3.274, 8.227], occ=1.0, bfac=10.13, leaving=False, leaving_group=[], parent='CB', element=6, metal=False, charge=0, hyb=3, nhyd=1, hvydeg=3, align=1, hetero=False), 'C': Atom(name=('A', '1', 'ILE', 'C'), xyz=[-28.183, 2.537, 7.25], occ=1.0, bfac=9.8, leaving=False, leaving_group=['HXT', 'OXT'], parent='OXT', element=6, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=3, align=1, hetero=False), 'O': Atom(name=('A', '1', 'ILE', 'O'), xyz=[-27.832, 2.25, 6.092], occ=1.0, bfac=10.0, leaving=False, leaving_group=[], parent='C', element=8, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=1, align=1, hetero=False), 'CB': Atom(name=('A', '1', 'ILE', 'CB'), xyz=[-26.772, 2.288, 9.312], occ=1.0, bfa

In [11]:
site_chain.atoms

{('A',
  '1',
  'ILE',
  'N'): Atom(name=('A', '1', 'ILE', 'N'), xyz=[-27.979, 4.391, 8.846], occ=1.0, bfac=10.19, leaving=False, leaving_group=['H2'], parent='CA', element=7, metal=False, charge=0, hyb=2, nhyd=1, hvydeg=1, align=1, hetero=False),
 ('A',
  '1',
  'ILE',
  'CA'): Atom(name=('A', '1', 'ILE', 'CA'), xyz=[-27.268, 3.274, 8.227], occ=1.0, bfac=10.13, leaving=False, leaving_group=[], parent='CB', element=6, metal=False, charge=0, hyb=3, nhyd=1, hvydeg=3, align=1, hetero=False),
 ('A',
  '1',
  'ILE',
  'C'): Atom(name=('A', '1', 'ILE', 'C'), xyz=[-28.183, 2.537, 7.25], occ=1.0, bfac=9.8, leaving=False, leaving_group=['HXT', 'OXT'], parent='OXT', element=6, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=3, align=1, hetero=False),
 ('A',
  '1',
  'ILE',
  'O'): Atom(name=('A', '1', 'ILE', 'O'), xyz=[-27.832, 2.25, 6.092], occ=1.0, bfac=10.0, leaving=False, leaving_group=[], parent='C', element=8, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=1, align=1, hetero=False),
 ('A',
  '1

In [12]:
site_chain.planars

[[('A', '14', 'ASN', 'CG'),
  ('A', '14', 'ASN', 'CB'),
  ('A', '14', 'ASN', 'OD1'),
  ('A', '14', 'ASN', 'ND2')],
 [('A', '21', 'TRP', 'CG'),
  ('A', '21', 'TRP', 'CB'),
  ('A', '21', 'TRP', 'CD1'),
  ('A', '21', 'TRP', 'CD2')],
 [('A', '21', 'TRP', 'CD2'),
  ('A', '21', 'TRP', 'CG'),
  ('A', '21', 'TRP', 'CE2'),
  ('A', '21', 'TRP', 'CE3')],
 [('A', '21', 'TRP', 'CE2'),
  ('A', '21', 'TRP', 'CD2'),
  ('A', '21', 'TRP', 'NE1'),
  ('A', '21', 'TRP', 'CZ2')],
 [('A', '22', 'PHE', 'CG'),
  ('A', '22', 'PHE', 'CB'),
  ('A', '22', 'PHE', 'CD1'),
  ('A', '22', 'PHE', 'CD2')],
 [('A', '16', 'F3S', 'FE1'),
  ('A', '16', 'F3S', 'S1'),
  ('A', '16', 'F3S', 'S2'),
  ('A', '16', 'F3S', 'S3')],
 [('A', '16', 'F3S', 'FE3'),
  ('A', '16', 'F3S', 'S1'),
  ('A', '16', 'F3S', 'S3'),
  ('A', '16', 'F3S', 'S4')],
 [('A', '16', 'F3S', 'FE4'),
  ('A', '16', 'F3S', 'S2'),
  ('A', '16', 'F3S', 'S3'),
  ('A', '16', 'F3S', 'S4')]]

In [13]:
site_chain.chirals

[[('A', '5', 'CYS', 'CA'),
  ('A', '5', 'CYS', 'N'),
  ('A', '5', 'CYS', 'CB'),
  ('A', '5', 'CYS', 'C')],
 [('A', '19', 'PRO', 'CA'),
  ('A', '19', 'PRO', 'N'),
  ('A', '19', 'PRO', 'CB'),
  ('A', '19', 'PRO', 'C')],
 [('A', '1', 'ILE', 'CA'),
  ('A', '1', 'ILE', 'N'),
  ('A', '1', 'ILE', 'CB'),
  ('A', '1', 'ILE', 'C')],
 [('A', '1', 'ILE', 'CB'),
  ('A', '1', 'ILE', 'CA'),
  ('A', '1', 'ILE', 'CG2'),
  ('A', '1', 'ILE', 'CG1')],
 [('A', '9', 'THR', 'CA'),
  ('A', '9', 'THR', 'N'),
  ('A', '9', 'THR', 'CB'),
  ('A', '9', 'THR', 'C')],
 [('A', '9', 'THR', 'CB'),
  ('A', '9', 'THR', 'CA'),
  ('A', '9', 'THR', 'CG2'),
  ('A', '9', 'THR', 'OG1')],
 [('A', '14', 'ASN', 'CA'),
  ('A', '14', 'ASN', 'N'),
  ('A', '14', 'ASN', 'CB'),
  ('A', '14', 'ASN', 'C')],
 [('A', '13', 'CYS', 'CA'),
  ('A', '13', 'CYS', 'N'),
  ('A', '13', 'CYS', 'CB'),
  ('A', '13', 'CYS', 'C')],
 [('A', '21', 'TRP', 'CA'),
  ('A', '21', 'TRP', 'N'),
  ('A', '21', 'TRP', 'CB'),
  ('A', '21', 'TRP', 'C')],
 [('A', '22',

In [14]:

viewer = visualize_metal_site_3d(site)
viewer.show()

## Now get from dataset with requested filtering

In [2]:
from metalsitenn.dataloading import MetalSiteDataset

In [3]:
ds = MetalSiteDataset(
    cache_folder='../data/1/1.1_parse_sites_metadata',
)

In [5]:
md = ds.get_filtered_metadata()
md

Unnamed: 0,pdb_code,site_name,site_idx,n_entities,n_atoms,n_bonds,metal,n_metals,n_waters,n_organic_ligands,n_metal_ligands,n_amino_acids,n_coordinating_amino_acids,n_nucleotides,non_residue_non_metal_names,n_non_residue_non_metal,coordination_distance
0,6fpw,6fpw_0,0,19,158,150,Fe,4,0,1,0,18,4,0,SF4,1,2.9
1,6fpw,6fpw_1,1,22,138,124,Fe,3,3,1,0,18,3,0,F3S,1,2.9
2,6fpw,6fpw_2,2,29,203,184,Fe,4,2,1,0,26,6,0,SF3,1,2.9
3,6fpw,6fpw_3,3,20,141,123,"Fe,Ni",2,0,1,0,19,4,0,EJ2,1,2.9
4,6fpw,6fpw_4,4,19,101,83,Mg,1,6,1,0,12,3,0,MG,1,2.9
5,6fpw,6fpw_5,5,19,158,150,Fe,4,0,1,0,18,4,0,SF4,1,2.9
6,6fpw,6fpw_6,6,23,146,131,Fe,3,3,1,0,19,3,0,F3S,1,2.9
7,6fpw,6fpw_7,7,29,203,184,Fe,4,2,1,0,26,6,0,SF3,1,2.9
8,6fpw,6fpw_8,8,20,141,123,"Fe,Ni",2,0,1,0,19,4,0,EJ2,1,2.9
9,6fpw,6fpw_9,9,19,101,83,Mg,1,6,1,0,12,3,0,MG,1,2.9


In [12]:
ds = MetalSiteDataset(
    cache_folder='../data/1/1.1_parse_sites_metadata',
    max_metals=1,
    max_sites_per_pdb=1
)

In [13]:
md = ds.get_metadata()
md

Unnamed: 0,pdb_code,site_name,site_idx,n_entities,n_atoms,n_bonds,metal,n_metals,n_waters,n_organic_ligands,n_metal_ligands,n_amino_acids,n_coordinating_amino_acids,n_nucleotides,non_residue_non_metal_names,n_non_residue_non_metal,coordination_distance
0,1foi,1foi_0,0,13,85,72,Zn,1,0,1,0,12,4,0,ZN,1,2.9
1,2c2j,2c2j_0,0,11,71,60,Mg,1,0,1,0,10,3,0,MG,1,2.9
2,6bpv,6bpv_0,0,13,106,99,Fe,1,0,2,0,11,4,0,"F2Y,FE2",2,2.9
3,6tgt,6tgt_0,0,11,71,60,Ca,1,0,1,0,10,3,0,CA,1,2.9


In [6]:
ds = MetalSiteDataset(
    cache_folder='../data/1/1.1_parse_sites_metadata',
    valid_metals=['Mn', 'Ni']
)

In [8]:
md = ds.get_metadata()
md

Unnamed: 0,pdb_code,site_name,site_idx,n_entities,n_atoms,n_bonds,metal,n_metals,n_waters,n_organic_ligands,n_metal_ligands,n_amino_acids,n_nucleotides,non_residue_non_metal_names,n_non_residue_non_metal
0,6fpw,6fpw_3,3,20,141,120,"Fe,Ni",2,0,1,0,19,0,EJ2,1
1,6fpw,6fpw_8,8,20,141,120,"Fe,Ni",2,0,1,0,19,0,EJ2,1
2,5uuy,5uuy_0,0,20,164,151,"Ca,Mn",2,0,3,0,17,0,"CA,MN,XMM",3
3,4ls3,4ls3_0,0,13,90,83,Ni,1,3,1,0,9,0,NI,1
4,4ls3,4ls3_1,1,13,83,76,Ni,1,4,1,0,8,0,NI,1
5,4my4,4my4_0,0,20,129,113,Mn,2,3,2,0,15,0,MN,1


In [19]:
ds = MetalSiteDataset(
    cache_folder='../data/1/1.1_parse_sites_metadata',
)

In [20]:
md = ds.get_metadata()
md

Unnamed: 0,pdb_code,site_name,site_idx,n_entities,n_atoms,n_bonds,metal,n_metals,n_waters,n_organic_ligands,n_metal_ligands,n_amino_acids,n_coordinating_amino_acids,n_nucleotides,non_residue_non_metal_names,n_non_residue_non_metal,coordination_distance
0,6fpw,6fpw_0,0,19,158,150,Fe,4,0,1,0,18,4,0,SF4,1,2.9
1,6fpw,6fpw_1,1,22,138,124,Fe,3,3,1,0,18,3,0,F3S,1,2.9
2,6fpw,6fpw_2,2,29,203,184,Fe,4,2,1,0,26,6,0,SF3,1,2.9
3,6fpw,6fpw_3,3,20,141,123,"Fe,Ni",2,0,1,0,19,4,0,EJ2,1,2.9
4,6fpw,6fpw_4,4,19,101,83,Mg,1,6,1,0,12,3,0,MG,1,2.9
5,6fpw,6fpw_5,5,19,158,150,Fe,4,0,1,0,18,4,0,SF4,1,2.9
6,6fpw,6fpw_6,6,23,146,131,Fe,3,3,1,0,19,3,0,F3S,1,2.9
7,6fpw,6fpw_7,7,29,203,184,Fe,4,2,1,0,26,6,0,SF3,1,2.9
8,6fpw,6fpw_8,8,20,141,123,"Fe,Ni",2,0,1,0,19,4,0,EJ2,1,2.9
9,6fpw,6fpw_9,9,19,101,83,Mg,1,6,1,0,12,3,0,MG,1,2.9


In [21]:
site_chain = ds[-1][1]

In [22]:
viewer = visualize_chain_3d(site_chain)
viewer.show()

In [28]:
featurizer = MetalSiteFeaturizer(
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic',]
)
features = featurizer(site_chain, metal_unknown=False)

In [29]:
atom__features, bond_features, topology_data = features

In [30]:
atom__features.keys()

dict_keys(['atom_resid', 'atom_resname', 'atom_name', 'atom_ishetero', 'element', 'charge', 'nhyd', 'hyb', 'atom_loss_mask', 'collapse_mask', 'positions'])

In [31]:
bond_features.keys()

dict_keys(['bond_order', 'is_in_ring', 'is_aromatic', 'bond_distances', 'bond_loss_mask'])

In [32]:
topology_data.keys()

dict_keys(['bonds', 'bond_lengths', 'angles', 'torsions', 'chirals', 'planars', 'permuts', 'frames'])

In [33]:
atom__features['positions']

tensor([[-2.0939e+00,  3.6913e+00,  4.6294e+00],
        [-2.1189e+00,  5.1593e+00,  4.6024e+00],
        [-3.2729e+00,  5.8053e+00,  5.3144e+00],
        [-3.3339e+00,  7.0373e+00,  5.3484e+00],
        [-2.1149e+00,  5.6023e+00,  3.1314e+00],
        [-8.7293e-01, -5.3977e+00, -3.7616e+00],
        [-1.8759e+00, -6.1557e+00, -4.4746e+00],
        [-1.2459e+00, -6.8947e+00, -5.6576e+00],
        [-1.5949e+00, -8.0667e+00, -5.9226e+00],
        [-3.0379e+00, -5.2257e+00, -4.8816e+00],
        [-3.9019e+00, -4.8567e+00, -3.6676e+00],
        [-3.8749e+00, -5.7947e+00, -6.0536e+00],
        [-4.5219e+00, -3.4827e+00, -3.7896e+00],
        [-4.3159e+00, -4.8417e+00,  6.4377e-02],
        [-3.9799e+00, -3.6207e+00,  7.8938e-01],
        [-4.3859e+00, -3.7787e+00,  2.2614e+00],
        [-5.4869e+00, -4.2707e+00,  2.5304e+00],
        [-4.6649e+00, -2.4297e+00,  9.6377e-02],
        [-4.4789e+00, -7.8774e-01,  8.8738e-01],
        [ 5.7041e+00,  8.1026e-01, -6.3326e+00],
        [ 6.7831e+00

### Observe vocab

In [34]:
featurizer.tokenizers['element'].i2d

{0: '<UNK>',
 1: '<MASK>',
 2: 'Ag',
 3: 'Al',
 4: 'Au',
 5: 'Br',
 6: 'C',
 7: 'Ca',
 8: 'Cd',
 9: 'Cl',
 10: 'Co',
 11: 'Cr',
 12: 'Cu',
 13: 'D',
 14: 'Dy',
 15: 'F',
 16: 'Fe',
 17: 'Ga',
 18: 'H',
 19: 'Hg',
 20: 'I',
 21: 'Ir',
 22: 'K',
 23: 'La',
 24: 'Li',
 25: 'Mg',
 26: 'Mn',
 27: 'Mo',
 28: 'N',
 29: 'Na',
 30: 'Nd',
 31: 'Ni',
 32: 'O',
 33: 'P',
 34: 'Pb',
 35: 'Pr',
 36: 'Pt',
 37: 'S',
 38: 'Se',
 39: 'Si',
 40: 'Tb',
 41: 'Te',
 42: 'Ti',
 43: 'U',
 44: 'V',
 45: 'W',
 46: 'Zn',
 47: '<METAL>'}

In [35]:
atom_id = np.where(np.array(atom__features['atom_name']) == 'CZ2')[0][0]

In [36]:
atom__features['element']

tensor([[28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [37],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [28],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [37],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [ 6],
        [32],
        [28],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [ 6],
        [ 6],
        [16],
        [16],
        [16],
        [37],
        [37],
        [37],
      

In [37]:
atom__features['element'][atom_id]

tensor([6])

In [38]:
featurizer.tokenizers['element'].metal_token_id

47

In [39]:
# with unknown metal
atom__features, bond_features, topology = featurizer(site_chain, metal_unknown=True)
atom__features['element']

tensor([[28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [37],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [28],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [37],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [ 6],
        [32],
        [28],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [ 6],
        [28],
        [ 6],
        [ 6],
        [32],
        [ 6],
        [ 6],
        [ 6],
        [47],
        [47],
        [47],
        [37],
        [37],
        [37],
      

In [40]:
# check other metals are properly converted to metal token
featurizer.tokenizers['element'].encode('Ho')

47

### Charge

In [41]:
featurizer.tokenizers['charge'].i2d

{0: '<MASK>', 1: -3, 2: -2, 3: -1, 4: 0, 5: 1, 6: 2, 7: 3}

In [42]:
np.array(atom__features['atom_name']).reshape(-1,1)[atom__features['charge'] !=4]

array(['NZ'], dtype='<U3')

### Num hydr

In [29]:
featurizer.tokenizers['nhyd'].i2d

{0: '<MASK>', 1: 0, 2: 1, 3: 2, 4: 3, 5: 4}

In [30]:
atom__features['nhyd'][atom_id]

tensor([2])

### hybridization

Note in the og code hyb goes i.e., 1 = sp, 2 = sp2, 3 = sp3 ...).

In [31]:
featurizer.tokenizers['hyb'].i2d

{0: '<MASK>', 1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5}

In [32]:
atom__features['hyb'][atom_id]

tensor([3])

In [33]:
# CZ2 in benzene ring is indeed sp2

### Bond order

In [34]:
featurizer.tokenizers['bond_order'].i2d

{0: '<MASK>', 1: 0, 2: 1, 3: 2, 4: 3, 5: 4}

In [35]:
src, dst = topology['bonds'].T

In [36]:
site_chain.residues

{'1': Residue(name='ILE', atoms={'N': Atom(name=('A', '1', 'ILE', 'N'), xyz=[-27.979, 4.391, 8.846], occ=1.0, bfac=10.19, leaving=False, leaving_group=['H2'], parent='CA', element=7, metal=False, charge=0, hyb=2, nhyd=1, hvydeg=1, align=1, hetero=False), 'CA': Atom(name=('A', '1', 'ILE', 'CA'), xyz=[-27.268, 3.274, 8.227], occ=1.0, bfac=10.13, leaving=False, leaving_group=[], parent='CB', element=6, metal=False, charge=0, hyb=3, nhyd=1, hvydeg=3, align=1, hetero=False), 'C': Atom(name=('A', '1', 'ILE', 'C'), xyz=[-28.183, 2.537, 7.25], occ=1.0, bfac=9.8, leaving=False, leaving_group=['HXT', 'OXT'], parent='OXT', element=6, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=3, align=1, hetero=False), 'O': Atom(name=('A', '1', 'ILE', 'O'), xyz=[-27.832, 2.25, 6.092], occ=1.0, bfac=10.0, leaving=False, leaving_group=[], parent='C', element=8, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=1, align=1, hetero=False), 'CB': Atom(name=('A', '1', 'ILE', 'CB'), xyz=[-26.772, 2.288, 9.312], occ=1.0, bfa

In [37]:
bond_order = bond_features['bond_order']

In [38]:
bond_order.shape

torch.Size([138, 138])

In [39]:
list(zip(bond_order[src, dst], np.array(atom__features['atom_resid'])[src], np.array(atom__features['atom_name'])[src], np.array(atom__features['atom_resid'])[dst], np.array(atom__features['atom_name'])[dst]))

[(tensor(3), '21', 'CZ2', '21', 'CH2'),
 (tensor(3), '3', 'C', '3', 'O'),
 (tensor(2), '21', 'CG', '21', 'CD2'),
 (tensor(3), '21', 'CD2', '21', 'CE2'),
 (tensor(2), '3', 'CA', '3', 'CB'),
 (tensor(2), '15', 'N', '15', 'CA'),
 (tensor(3), '2', 'C', '2', 'O'),
 (tensor(2), '22', 'N', '22', 'CA'),
 (tensor(2), '19', 'N', '19', 'CA'),
 (tensor(2), '20', 'CG1', '20', 'CD1'),
 (tensor(2), '21', 'NE1', '21', 'CE2'),
 (tensor(2), '14', 'N', '14', 'CA'),
 (tensor(2), '1', 'CA', '1', 'C'),
 (tensor(2), '13', 'N', '13', 'CA'),
 (tensor(2), '21', 'CZ3', '21', 'CH2'),
 (tensor(2), '14', 'CG', '14', 'ND2'),
 (tensor(2), '1', 'N', '1', 'CA'),
 (tensor(2), '15', 'CG', '15', 'CD'),
 (tensor(2), '20', 'N', '20', 'CA'),
 (tensor(3), '22', 'C', '22', 'O'),
 (tensor(2), '20', 'CA', '20', 'CB'),
 (tensor(3), '6', 'C', '6', 'O'),
 (tensor(2), '17', 'CB', '17', 'SG'),
 (tensor(2), '22', 'CA', '22', 'CB'),
 (tensor(2), '14', 'CA', '14', 'C'),
 (tensor(2), '2', 'CA', '2', 'CB'),
 (tensor(2), '14', 'CB', '14', 

### Is aromatic

In [40]:
featurizer.tokenizers['is_aromatic'].i2d

{0: '<MASK>', 1: False, 2: True}

In [41]:
(bond_features['is_aromatic'] ==2).sum()

tensor(32)

In [42]:
is_aromatic = bond_features['is_aromatic']

In [43]:
is_aromatic.shape

torch.Size([138, 138])

In [44]:
src, dst = np.where(is_aromatic == 2)

In [45]:
site_chain.residues

{'1': Residue(name='ILE', atoms={'N': Atom(name=('A', '1', 'ILE', 'N'), xyz=[-27.979, 4.391, 8.846], occ=1.0, bfac=10.19, leaving=False, leaving_group=['H2'], parent='CA', element=7, metal=False, charge=0, hyb=2, nhyd=1, hvydeg=1, align=1, hetero=False), 'CA': Atom(name=('A', '1', 'ILE', 'CA'), xyz=[-27.268, 3.274, 8.227], occ=1.0, bfac=10.13, leaving=False, leaving_group=[], parent='CB', element=6, metal=False, charge=0, hyb=3, nhyd=1, hvydeg=3, align=1, hetero=False), 'C': Atom(name=('A', '1', 'ILE', 'C'), xyz=[-28.183, 2.537, 7.25], occ=1.0, bfac=9.8, leaving=False, leaving_group=['HXT', 'OXT'], parent='OXT', element=6, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=3, align=1, hetero=False), 'O': Atom(name=('A', '1', 'ILE', 'O'), xyz=[-27.832, 2.25, 6.092], occ=1.0, bfac=10.0, leaving=False, leaving_group=[], parent='C', element=8, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=1, align=1, hetero=False), 'CB': Atom(name=('A', '1', 'ILE', 'CB'), xyz=[-26.772, 2.288, 9.312], occ=1.0, bfa

In [46]:
list(zip(bond_order[src, dst], np.array(atom__features['atom_resid'])[src], np.array(atom__features['atom_name'])[src], np.array(atom__features['atom_resid'])[dst], np.array(atom__features['atom_name'])[dst]))

[(tensor(3), '21', 'CG', '21', 'CD1'),
 (tensor(2), '21', 'CG', '21', 'CD2'),
 (tensor(3), '21', 'CD1', '21', 'CG'),
 (tensor(2), '21', 'CD1', '21', 'NE1'),
 (tensor(2), '21', 'CD2', '21', 'CG'),
 (tensor(3), '21', 'CD2', '21', 'CE2'),
 (tensor(2), '21', 'CD2', '21', 'CE3'),
 (tensor(2), '21', 'NE1', '21', 'CD1'),
 (tensor(2), '21', 'NE1', '21', 'CE2'),
 (tensor(3), '21', 'CE2', '21', 'CD2'),
 (tensor(2), '21', 'CE2', '21', 'NE1'),
 (tensor(2), '21', 'CE2', '21', 'CZ2'),
 (tensor(2), '21', 'CE3', '21', 'CD2'),
 (tensor(3), '21', 'CE3', '21', 'CZ3'),
 (tensor(2), '21', 'CZ2', '21', 'CE2'),
 (tensor(3), '21', 'CZ2', '21', 'CH2'),
 (tensor(3), '21', 'CZ3', '21', 'CE3'),
 (tensor(2), '21', 'CZ3', '21', 'CH2'),
 (tensor(3), '21', 'CH2', '21', 'CZ2'),
 (tensor(2), '21', 'CH2', '21', 'CZ3'),
 (tensor(3), '22', 'CG', '22', 'CD1'),
 (tensor(2), '22', 'CG', '22', 'CD2'),
 (tensor(3), '22', 'CD1', '22', 'CG'),
 (tensor(2), '22', 'CD1', '22', 'CE1'),
 (tensor(2), '22', 'CD2', '22', 'CG'),
 (tensor

The tryptophan and phenolalinine in here has the proper aromatic labels.

### Is in ring

In [47]:
featurizer.tokenizers['is_in_ring'].i2d

{0: '<MASK>', 1: False, 2: True}

In [48]:
src, dst = np.where(bond_features['is_in_ring'] == 2)

In [49]:
list(zip(bond_features['bond_order'][src, dst], np.array(atom__features['atom_resid'])[src], np.array(atom__features['atom_name'])[src], np.array(atom__features['atom_resid'])[dst], np.array(atom__features['atom_name'])[dst]))

[(tensor(2), '15', 'N', '15', 'CA'),
 (tensor(2), '15', 'N', '15', 'CD'),
 (tensor(2), '15', 'CA', '15', 'N'),
 (tensor(2), '15', 'CA', '15', 'CB'),
 (tensor(2), '15', 'CB', '15', 'CA'),
 (tensor(2), '15', 'CB', '15', 'CG'),
 (tensor(2), '15', 'CG', '15', 'CB'),
 (tensor(2), '15', 'CG', '15', 'CD'),
 (tensor(2), '15', 'CD', '15', 'N'),
 (tensor(2), '15', 'CD', '15', 'CG'),
 (tensor(2), '16', 'FE1', '16', 'S1'),
 (tensor(2), '16', 'FE1', '16', 'S2'),
 (tensor(2), '16', 'FE1', '16', 'S3'),
 (tensor(2), '16', 'FE3', '16', 'S1'),
 (tensor(2), '16', 'FE3', '16', 'S3'),
 (tensor(2), '16', 'FE3', '16', 'S4'),
 (tensor(2), '16', 'FE4', '16', 'S2'),
 (tensor(2), '16', 'FE4', '16', 'S3'),
 (tensor(2), '16', 'FE4', '16', 'S4'),
 (tensor(2), '16', 'S1', '16', 'FE1'),
 (tensor(2), '16', 'S1', '16', 'FE3'),
 (tensor(2), '16', 'S2', '16', 'FE1'),
 (tensor(2), '16', 'S2', '16', 'FE4'),
 (tensor(2), '16', 'S3', '16', 'FE1'),
 (tensor(2), '16', 'S3', '16', 'FE3'),
 (tensor(2), '16', 'S3', '16', 'FE4'),


Awesome. I see some Ca and Cb of only proline. And all of the tryptophan rings are in rings.

### Bond distance

In [50]:
featurizer.tokenizers['bond_distance'].i2d

{0: '<MASK>', 1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7}

In [51]:
src, dst = np.where(bond_features['bond_distance'] == 3)

In [52]:
list(zip(bond_features['bond_order'][src, dst], np.array(atom__features['atom_resid'])[src], np.array(atom__features['atom_name'])[src], np.array(atom__features['atom_resid'])[dst], np.array(atom__features['atom_name'])[dst]))

[(tensor(1), '1', 'N', '1', 'C'),
 (tensor(1), '1', 'N', '1', 'CB'),
 (tensor(1), '1', 'CA', '1', 'O'),
 (tensor(1), '1', 'CA', '1', 'CG1'),
 (tensor(1), '1', 'CA', '1', 'CG2'),
 (tensor(1), '1', 'C', '1', 'N'),
 (tensor(1), '1', 'C', '1', 'CB'),
 (tensor(1), '1', 'O', '1', 'CA'),
 (tensor(1), '1', 'CB', '1', 'N'),
 (tensor(1), '1', 'CB', '1', 'C'),
 (tensor(1), '1', 'CB', '1', 'CD1'),
 (tensor(1), '1', 'CG1', '1', 'CA'),
 (tensor(1), '1', 'CG1', '1', 'CG2'),
 (tensor(1), '1', 'CG2', '1', 'CA'),
 (tensor(1), '1', 'CG2', '1', 'CG1'),
 (tensor(1), '1', 'CD1', '1', 'CB'),
 (tensor(1), '2', 'N', '2', 'C'),
 (tensor(1), '2', 'N', '2', 'CB'),
 (tensor(1), '2', 'CA', '2', 'O'),
 (tensor(1), '2', 'CA', '2', 'CG'),
 (tensor(1), '2', 'C', '2', 'N'),
 (tensor(1), '2', 'C', '2', 'CB'),
 (tensor(1), '2', 'O', '2', 'CA'),
 (tensor(1), '2', 'CB', '2', 'N'),
 (tensor(1), '2', 'CB', '2', 'C'),
 (tensor(1), '2', 'CB', '2', 'CD1'),
 (tensor(1), '2', 'CB', '2', 'CD2'),
 (tensor(1), '2', 'CG', '2', 'CA'),


All of the bond orders are 0, all od the counts are in the same residue (makes sense because we freed the backbone atoms), see an N bound to C at distance 2, which is correct in the peptide bond.

## Try masking atoms

In [53]:
masked_atom_features, masked_bond_features = featurizer.mask_atoms(
    atom__features, bond_features,
    random_token_prob=0.05, mask_prob=0.15,
)

In [54]:
masked_atom_features.keys()

dict_keys(['atom_resid', 'atom_resname', 'atom_name', 'atom_ishetero', 'element', 'charge', 'nhyd', 'hyb', 'atom_loss_mask', 'collapse_mask', 'positions', 'element_labels'])

In [55]:
(masked_atom_features['element'] == featurizer.tokenizers['element'].mask_token_id).float().mean()

tensor(0.1522)

In [56]:
masked_atom_features['element'][masked_atom_features['element'] != masked_atom_features['element_labels']]

tensor([13,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  9,  1,  1,  1,  8, 34,
         1,  1, 38, 42,  1,  1, 35,  1,  1,  3,  1, 33,  4])

In [57]:
masked_atom_features['atom_loss_mask'].shape

torch.Size([138, 1])

In [58]:
masked_atom_features['element_labels'][masked_atom_features['atom_loss_mask']].shape

torch.Size([33])

In [59]:
torch.cat([masked_atom_features['element'][masked_atom_features['atom_loss_mask']].reshape(-1,1), masked_atom_features['element_labels'][masked_atom_features['atom_loss_mask']].reshape(-1,1)], axis=-1)

tensor([[13, 28],
        [ 1,  6],
        [ 1, 32],
        [ 1,  6],
        [ 1,  6],
        [ 1,  6],
        [ 1,  6],
        [ 1, 28],
        [ 1,  6],
        [ 1,  6],
        [ 1, 32],
        [ 1,  6],
        [ 9, 28],
        [ 1, 28],
        [ 1,  6],
        [ 1, 47],
        [ 8, 37],
        [34, 37],
        [ 1, 32],
        [ 1, 28],
        [38, 32],
        [42,  6],
        [ 6,  6],
        [ 1,  6],
        [ 1,  6],
        [35, 32],
        [ 1,  6],
        [28, 28],
        [ 1,  6],
        [ 3,  6],
        [ 1,  6],
        [33, 32],
        [ 4,  6]])

In [60]:
actually_masked = masked_atom_features['element'] == featurizer.tokenizers['element'].mask_token_id

In [61]:
masked_atom_features['charge'][actually_masked]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [62]:
masked_atom_features['nhyd'][actually_masked]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [63]:
masked_atom_features['hyb'][actually_masked]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

How were the bonds masked?

In [64]:
actually_masked_id = torch.where(actually_masked)[0]

In [65]:
masked_bond_features['bond_order'][actually_masked_id]

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])

In [66]:
masked_bond_features['bond_order'][actually_masked_id].shape

torch.Size([21, 138])

In [67]:
masked_bond_features['bond_order'].T[actually_masked_id]

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])

In [68]:
masked_bond_features['bond_distance']

tensor([[1, 2, 3,  ..., 1, 1, 1],
        [2, 1, 2,  ..., 1, 1, 1],
        [3, 2, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 3, 2],
        [1, 1, 1,  ..., 3, 1, 2],
        [1, 1, 1,  ..., 2, 2, 1]])

In [69]:
masked_bond_features['bond_distance'][actually_masked_id]

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])

### topology data

In [70]:
chains['A'].chirals

[[('A', '1', 'LEU', 'CA'),
  ('A', '1', 'LEU', 'N'),
  ('A', '1', 'LEU', 'CB'),
  ('A', '1', 'LEU', 'C')],
 [('A', '2', 'GLU', 'CA'),
  ('A', '2', 'GLU', 'N'),
  ('A', '2', 'GLU', 'CB'),
  ('A', '2', 'GLU', 'C')],
 [('A', '3', 'ASN', 'CA'),
  ('A', '3', 'ASN', 'N'),
  ('A', '3', 'ASN', 'CB'),
  ('A', '3', 'ASN', 'C')],
 [('A', '4', 'LYS', 'CA'),
  ('A', '4', 'LYS', 'N'),
  ('A', '4', 'LYS', 'CB'),
  ('A', '4', 'LYS', 'C')],
 [('A', '5', 'PRO', 'CA'),
  ('A', '5', 'PRO', 'N'),
  ('A', '5', 'PRO', 'CB'),
  ('A', '5', 'PRO', 'C')],
 [('A', '6', 'ARG', 'CA'),
  ('A', '6', 'ARG', 'N'),
  ('A', '6', 'ARG', 'CB'),
  ('A', '6', 'ARG', 'C')],
 [('A', '7', 'ILE', 'CA'),
  ('A', '7', 'ILE', 'N'),
  ('A', '7', 'ILE', 'CB'),
  ('A', '7', 'ILE', 'C')],
 [('A', '7', 'ILE', 'CB'),
  ('A', '7', 'ILE', 'CA'),
  ('A', '7', 'ILE', 'CG2'),
  ('A', '7', 'ILE', 'CG1')],
 [('A', '8', 'PRO', 'CA'),
  ('A', '8', 'PRO', 'N'),
  ('A', '8', 'PRO', 'CB'),
  ('A', '8', 'PRO', 'C')],
 [('A', '9', 'VAL', 'CA'),
  ('A'

In [71]:
np.array(atom__features['atom_name'])[topology['chirals']]

array([['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C'],
       ['CB', 'CA', 'CG2', 'CG1'],
       ['CA', 'N', 'CB', 'C'],
       ['CB', 'CA', 'CG2', 'OG1'],
       ['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C'],
       ['CB', 'CA', 'CG2', 'CG1'],
       ['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C'],
       ['CA', 'N', 'CB', 'C']], dtype='<U3')

# attempt to corrupt residues

In [72]:
corrupted_atom_features, corrupted_bond_features = featurizer.collapse_residues(
    atom__features, bond_features, 
    collapse_rate = 0.2,
    min_residues = 1,
    fixed_ca=True,
    center_gaussian_sigma=1,
    collapse_gaussian_sigma=0.2
)

In [73]:
corrupted_atom_features.keys()

dict_keys(['atom_resid', 'atom_resname', 'atom_name', 'atom_ishetero', 'element', 'charge', 'nhyd', 'hyb', 'atom_loss_mask', 'collapse_mask', 'positions', 'positions_labels'])

In [74]:
corrupted_atom_features['collapse_mask']

tensor([[False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [ True],
        [False],
        [ True],
        [ True],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [ True],
        [False],
        [ True],
        [ True],
        [ True

In [75]:
np.array(corrupted_atom_features['atom_resid']).reshape(-1,1)[corrupted_atom_features['collapse_mask']]

array(['6', '6', '6', '12', '12', '12', '12', '12', '12', '12', '12',
       '16', '16', '16', '16', '16', '16', '16', '20', '20', '20', '20',
       '20', '20', '20'], dtype='<U2')

In [76]:
list(zip(corrupted_atom_features['atom_resid'], corrupted_atom_features['atom_name'], corrupted_atom_features['collapse_mask'], corrupted_atom_features['positions']))

[('1', 'N', tensor([False]), tensor([-6.8359,  3.6373, -3.7826])),
 ('1', 'CA', tensor([False]), tensor([-6.1249,  2.5203, -4.4016])),
 ('1', 'C', tensor([False]), tensor([-7.0399,  1.7833, -5.3786])),
 ('1', 'O', tensor([False]), tensor([-6.6889,  1.4963, -6.5366])),
 ('1', 'CB', tensor([False]), tensor([-5.6289,  1.5343, -3.3166])),
 ('1', 'CG1', tensor([False]), tensor([-4.5229,  2.1583, -2.4506])),
 ('1', 'CG2', tensor([False]), tensor([-5.1679,  0.2163, -3.9106])),
 ('1', 'CD1', tensor([False]), tensor([-3.2959,  2.5853, -3.1806])),
 ('2', 'N', tensor([False]), tensor([-3.4779, -3.3877,  3.1774])),
 ('2', 'CA', tensor([False]), tensor([-3.7699, -3.3397,  4.5914])),
 ('2', 'C', tensor([False]), tensor([-4.4889, -2.0827,  5.0214])),
 ('2', 'O', tensor([False]), tensor([-4.9939, -2.0267,  6.1634])),
 ('2', 'CB', tensor([False]), tensor([-2.4609, -3.3877,  5.3834])),
 ('2', 'CG', tensor([False]), tensor([-1.4869, -4.5077,  5.0944])),
 ('2', 'CD1', tensor([False]), tensor([-0.2409, -4.

The ones labed as collapsed are around the CA

In [77]:
# try with letting CA move
corrupted_atom_features, corrupted_bond_features = featurizer.collapse_residues(
    atom__features, bond_features, 
    collapse_rate = 0.2,
    min_residues = 1,
    fixed_ca=False,
    center_gaussian_sigma=1,
    collapse_gaussian_sigma=0.2
)

In [78]:
list(zip(corrupted_atom_features['atom_resid'], corrupted_atom_features['atom_name'], corrupted_atom_features['collapse_mask'], corrupted_atom_features['positions']))

[('1', 'N', tensor([False]), tensor([-6.8359,  3.6373, -3.7826])),
 ('1', 'CA', tensor([False]), tensor([-6.1249,  2.5203, -4.4016])),
 ('1', 'C', tensor([False]), tensor([-7.0399,  1.7833, -5.3786])),
 ('1', 'O', tensor([False]), tensor([-6.6889,  1.4963, -6.5366])),
 ('1', 'CB', tensor([False]), tensor([-5.6289,  1.5343, -3.3166])),
 ('1', 'CG1', tensor([False]), tensor([-4.5229,  2.1583, -2.4506])),
 ('1', 'CG2', tensor([False]), tensor([-5.1679,  0.2163, -3.9106])),
 ('1', 'CD1', tensor([False]), tensor([-3.2959,  2.5853, -3.1806])),
 ('2', 'N', tensor([False]), tensor([-3.4779, -3.3877,  3.1774])),
 ('2', 'CA', tensor([False]), tensor([-3.7699, -3.3397,  4.5914])),
 ('2', 'C', tensor([False]), tensor([-4.4889, -2.0827,  5.0214])),
 ('2', 'O', tensor([False]), tensor([-4.9939, -2.0267,  6.1634])),
 ('2', 'CB', tensor([False]), tensor([-2.4609, -3.3877,  5.3834])),
 ('2', 'CG', tensor([False]), tensor([-1.4869, -4.5077,  5.0944])),
 ('2', 'CD1', tensor([False]), tensor([-0.2409, -4.

In [79]:
visualize_featurized_metal_site_3d(
    atom_features_dict=atom__features,
    bond_features_dict=bond_features)
    

<py3Dmol.view at 0x7efe3aa10fd0>

In [80]:
visualize_featurized_metal_site_3d(
    atom_features_dict=masked_atom_features,
    bond_features_dict=masked_bond_features,
    mask_color='purple',
)

<py3Dmol.view at 0x7efe3aa102b0>

In [86]:
visualize_featurized_metal_site_3d(
    atom_features_dict=corrupted_atom_features,
    bond_features_dict=corrupted_bond_features,
    
)

<py3Dmol.view at 0x7efe3faf5910>

## Try masking and corrupting in one go

In [87]:
atom_features, bond_features, topology = featurizer(
    site_chain, metal_unknown=False,
    do_masking=True, do_collapsing=True,
    random_token_prob=0.05, mask_prob=0.15,
    collapse_rate=0.2, min_residues=1,
    fixed_ca=True, center_gaussian_sigma=.5,
    collapse_gaussian_sigma=0.2
)

In [88]:
visualize_featurized_metal_site_3d(
    atom_features_dict=atom_features,
    bond_features_dict=bond_features,
    mask_color='purple')

<py3Dmol.view at 0x7efe3f9c02b0>

## Try mutating

In [89]:
site_chain.residues

{'1': Residue(name='ILE', atoms={'N': Atom(name=('A', '1', 'ILE', 'N'), xyz=[-27.979, 4.391, 8.846], occ=1.0, bfac=10.19, leaving=False, leaving_group=['H2'], parent='CA', element=7, metal=False, charge=0, hyb=2, nhyd=1, hvydeg=1, align=1, hetero=False), 'CA': Atom(name=('A', '1', 'ILE', 'CA'), xyz=[-27.268, 3.274, 8.227], occ=1.0, bfac=10.13, leaving=False, leaving_group=[], parent='CB', element=6, metal=False, charge=0, hyb=3, nhyd=1, hvydeg=3, align=1, hetero=False), 'C': Atom(name=('A', '1', 'ILE', 'C'), xyz=[-28.183, 2.537, 7.25], occ=1.0, bfac=9.8, leaving=False, leaving_group=['HXT', 'OXT'], parent='OXT', element=6, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=3, align=1, hetero=False), 'O': Atom(name=('A', '1', 'ILE', 'O'), xyz=[-27.832, 2.25, 6.092], occ=1.0, bfac=10.0, leaving=False, leaving_group=[], parent='C', element=8, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=1, align=1, hetero=False), 'CB': Atom(name=('A', '1', 'ILE', 'CB'), xyz=[-26.772, 2.288, 9.312], occ=1.0, bfa

In [92]:
new_site = mutate_chain(
    site_chain, 
    target_res_num='21',
    target_res_name='TRP',
    new_res_name='GLY')

In [93]:
new_site.residues['21'].atoms

{'N': Atom(name=('A', '21', 'GLY', 'N'), xyz=[-14.36, 1.558, 7.279], occ=0.0, bfac=0.0, leaving=False, leaving_group=['H2'], parent='CA', element=7, metal=False, charge=1, hyb=3, nhyd=3, hvydeg=1, align=1, hetero=False),
 'CA': Atom(name=('A', '21', 'GLY', 'CA'), xyz=[-14.36, 1.558, 7.279], occ=0.0, bfac=0.0, leaving=False, leaving_group=[], parent='C', element=6, metal=False, charge=0, hyb=3, nhyd=2, hvydeg=2, align=1, hetero=False),
 'C': Atom(name=('A', '21', 'GLY', 'C'), xyz=[-14.36, 1.558, 7.279], occ=0.0, bfac=0.0, leaving=False, leaving_group=['HXT', 'OXT'], parent='OXT', element=6, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=3, align=1, hetero=False),
 'O': Atom(name=('A', '21', 'GLY', 'O'), xyz=[-14.36, 1.558, 7.279], occ=0.0, bfac=0.0, leaving=False, leaving_group=[], parent='C', element=8, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=1, align=1, hetero=False),
 'OXT': Atom(name=('A', '21', 'GLY', 'OXT'), xyz=[-14.36, 1.558, 7.279], occ=0.0, bfac=0.0, leaving=True, leaving_g

In [94]:
new_site.atoms

{('A',
  '1',
  'ILE',
  'N'): Atom(name=('A', '1', 'ILE', 'N'), xyz=[-27.979, 4.391, 8.846], occ=1.0, bfac=10.19, leaving=False, leaving_group=['H2'], parent='CA', element=7, metal=False, charge=0, hyb=2, nhyd=1, hvydeg=1, align=1, hetero=False),
 ('A',
  '1',
  'ILE',
  'CA'): Atom(name=('A', '1', 'ILE', 'CA'), xyz=[-27.268, 3.274, 8.227], occ=1.0, bfac=10.13, leaving=False, leaving_group=[], parent='CB', element=6, metal=False, charge=0, hyb=3, nhyd=1, hvydeg=3, align=1, hetero=False),
 ('A',
  '1',
  'ILE',
  'C'): Atom(name=('A', '1', 'ILE', 'C'), xyz=[-28.183, 2.537, 7.25], occ=1.0, bfac=9.8, leaving=False, leaving_group=['HXT', 'OXT'], parent='OXT', element=6, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=3, align=1, hetero=False),
 ('A',
  '1',
  'ILE',
  'O'): Atom(name=('A', '1', 'ILE', 'O'), xyz=[-27.832, 2.25, 6.092], occ=1.0, bfac=10.0, leaving=False, leaving_group=[], parent='C', element=8, metal=False, charge=0, hyb=2, nhyd=0, hvydeg=1, align=1, hetero=False),
 ('A',
  '1

In [None]:
atom_features, bond_features, topology = featurizer(new_site, metal_unknown=False)

In [95]:
visualize_featurized_metal_site_3d(
    atom_features_dict=atom_features,
    bond_features_dict=bond_features)

<py3Dmol.view at 0x7f0b7a242400>

In [114]:
# mutate in the featurizer call - confirm that the collapse mask is correct
atom_features, bond_features, topology = featurizer(
    site_chain, mutations=[('21', 'TRP', 'GLY')])

In [115]:
list(zip(atom_features['atom_resid'], atom_features['atom_name'], atom_features['collapse_mask'], atom_features['positions']))

[('1', 'N', tensor([False]), tensor([-6.4367,  3.6996, -3.8951])),
 ('1', 'CA', tensor([False]), tensor([-5.7257,  2.5826, -4.5141])),
 ('1', 'C', tensor([False]), tensor([-6.6407,  1.8456, -5.4911])),
 ('1', 'O', tensor([False]), tensor([-6.2897,  1.5586, -6.6491])),
 ('1', 'CB', tensor([False]), tensor([-5.2297,  1.5966, -3.4291])),
 ('1', 'CG1', tensor([False]), tensor([-4.1237,  2.2206, -2.5631])),
 ('1', 'CG2', tensor([False]), tensor([-4.7687,  0.2786, -4.0231])),
 ('1', 'CD1', tensor([False]), tensor([-2.8967,  2.6476, -3.2931])),
 ('2', 'N', tensor([False]), tensor([-3.0787, -3.3254,  3.0649])),
 ('2', 'CA', tensor([False]), tensor([-3.3707, -3.2774,  4.4789])),
 ('2', 'C', tensor([False]), tensor([-4.0897, -2.0204,  4.9089])),
 ('2', 'O', tensor([False]), tensor([-4.5947, -1.9644,  6.0509])),
 ('2', 'CB', tensor([False]), tensor([-2.0617, -3.3254,  5.2709])),
 ('2', 'CG', tensor([False]), tensor([-1.0877, -4.4454,  4.9819])),
 ('2', 'CD1', tensor([False]), tensor([ 0.1583, -4.

In [116]:
visualize_featurized_metal_site_3d(
    atom_features_dict=atom_features,
    bond_features_dict=bond_features)

<py3Dmol.view at 0x7efe425ce040>

In [123]:
# check if the mutated residue gets collapsed that is uses the new collapse mask
atom_features, bond_features, topology = featurizer(
    site_chain, mutations=[('21', 'TRP', 'GLY')],
    do_collapsing=True,
    resids_to_collapse=[('21', 'GLY')],
    fixed_ca=False,
    center_gaussian_sigma=1.5)

In [124]:
list(zip(atom_features['atom_resid'], atom_features['atom_name'], atom_features['collapse_mask'], atom_features['positions']))

[('1', 'N', tensor([False]), tensor([-6.4367,  3.6996, -3.8951])),
 ('1', 'CA', tensor([False]), tensor([-5.7257,  2.5826, -4.5141])),
 ('1', 'C', tensor([False]), tensor([-6.6407,  1.8456, -5.4911])),
 ('1', 'O', tensor([False]), tensor([-6.2897,  1.5586, -6.6491])),
 ('1', 'CB', tensor([False]), tensor([-5.2297,  1.5966, -3.4291])),
 ('1', 'CG1', tensor([False]), tensor([-4.1237,  2.2206, -2.5631])),
 ('1', 'CG2', tensor([False]), tensor([-4.7687,  0.2786, -4.0231])),
 ('1', 'CD1', tensor([False]), tensor([-2.8967,  2.6476, -3.2931])),
 ('2', 'N', tensor([False]), tensor([-3.0787, -3.3254,  3.0649])),
 ('2', 'CA', tensor([False]), tensor([-3.3707, -3.2774,  4.4789])),
 ('2', 'C', tensor([False]), tensor([-4.0897, -2.0204,  4.9089])),
 ('2', 'O', tensor([False]), tensor([-4.5947, -1.9644,  6.0509])),
 ('2', 'CB', tensor([False]), tensor([-2.0617, -3.3254,  5.2709])),
 ('2', 'CG', tensor([False]), tensor([-1.0877, -4.4454,  4.9819])),
 ('2', 'CD1', tensor([False]), tensor([ 0.1583, -4.

In [125]:
visualize_featurized_metal_site_3d(
    atom_features_dict=atom_features,
    bond_features_dict=bond_features)

<py3Dmol.view at 0x7f0b7a1291f0>

Indeed CA is now movable, adn the positions got tweeked a bit.

### Topology

In [None]:
topology_data.keys()

dict_keys(['bonds', 'bond_lengths', 'angles', 'torsions', 'chirals', 'planars', 'permuts', 'frames'])

In [None]:
topology_data['frames'].size()

torch.Size([115, 3])

In [None]:
atom__features['positions'].shape

torch.Size([138, 3])

In [None]:
featurizer.get_feature_vocab_sizes()

{'element': 48,
 'charge': 8,
 'nhyd': 6,
 'hyb': 7,
 'bond_order': 6,
 'is_in_ring': 3,
 'is_aromatic': 3,
 'bond_distance': 9}