In [2]:
%env GEOMSTATS_BACKEND=pytorch

env: GEOMSTATS_BACKEND=pytorch


In [3]:
import sys

sys.path.append("../")

import numpy as np
import os
import pickle
import wandb

from foldflow.data import utils as du
from foldflow.data import residue_constants
from tools.analysis.utils import write_prot_to_pdb

wandb.init(project="foldflow", entity=None, name="vis-2f60")

[34m[1mwandb[0m: Currently logged in as: [33mstanislav-chekmenev[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
print("Number of possible atom types:", len(residue_constants.atom_types))
print(residue_constants.atom_order)

Number of possible atom types: 37
{'N': 0, 'CA': 1, 'C': 2, 'CB': 3, 'O': 4, 'CG': 5, 'CG1': 6, 'CG2': 7, 'OG': 8, 'OG1': 9, 'SG': 10, 'CD': 11, 'CD1': 12, 'CD2': 13, 'ND1': 14, 'ND2': 15, 'OD1': 16, 'OD2': 17, 'SD': 18, 'CE': 19, 'CE1': 20, 'CE2': 21, 'CE3': 22, 'NE': 23, 'NE1': 24, 'NE2': 25, 'OE1': 26, 'OE2': 27, 'CH2': 28, 'NH1': 29, 'NH2': 30, 'OH': 31, 'CZ': 32, 'CZ2': 33, 'CZ3': 34, 'NZ': 35, 'OXT': 36}


In [5]:
print("Short restype order:", residue_constants.restype_order_with_x)
print("Full restype order:", residue_constants.resname_to_idx)

Short restype order: {'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4, 'Q': 5, 'E': 6, 'G': 7, 'H': 8, 'I': 9, 'L': 10, 'K': 11, 'M': 12, 'F': 13, 'P': 14, 'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19, 'X': 20}
Full restype order: {'ALA': 0, 'ARG': 1, 'ASN': 2, 'ASP': 3, 'CYS': 4, 'GLN': 5, 'GLU': 6, 'GLY': 7, 'HIS': 8, 'ILE': 9, 'LEU': 10, 'LYS': 11, 'MET': 12, 'PHE': 13, 'PRO': 14, 'SER': 15, 'THR': 16, 'TRP': 17, 'TYR': 18, 'VAL': 19, 'UNK': 20}


In [6]:
# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
# psi and chi angles:
# 0: 'backbone group',
# 1: 'pre-omega-group', (empty)
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
# 3: 'psi-group',
# 4,5,6,7: 'chi1,2,3,4-group'

# Let's plot that for arginine:
rigid_groups = {"backbone": 0, "pre-omega": 1, "phi": 2, "psi": 3, "chi1": 4, "chi2": 5, "chi3": 6, "chi4": 7}

for rigid_group_name, rigid_group_idx in rigid_groups.items():
    group_data = residue_constants.restype_rigid_group_default_frame[1][rigid_group_idx]
    print(f"Rigid group {rigid_group_name}: \n{group_data}\n")

Rigid group backbone: 
[[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]]

Rigid group pre-omega: 
[[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]]

Rigid group phi: 
[[-0.35907093  0.9333103   0.         -0.524     ]
 [ 0.9333103   0.35907093  0.          1.362     ]
 [-0.          0.         -1.         -0.        ]
 [ 0.          0.          0.          1.        ]]

Rigid group psi: 
[[ 1.     0.    -0.     1.525]
 [-0.    -1.    -0.    -0.   ]
 [-0.     0.    -1.    -0.   ]
 [ 0.     0.     0.     1.   ]]

Rigid group chi1: 
[[-0.3424368  -0.51215166  0.7876787  -0.524     ]
 [-0.50842714  0.80601937  0.30304232 -0.778     ]
 [-0.79008794 -0.29670438 -0.53640246 -1.209     ]
 [ 0.          0.          0.          1.        ]]

Rigid group chi2: 
[[ 0.4051618 -0.914245   0.         0.616    ]
 [ 0.914245   0.4051618  0.         1.39     ]
 [-0.         0.         1.        -0.       ]
 [ 0.         0.         0.         1.       ]]

Rigid group chi3: 
[[ 0.370

In [7]:
with open("../data/2f60.pkl", "rb") as f:
    data = pickle.load(f)

  data = pickle.load(f)


In [10]:
print("Atom positions of the first residue of the protein:")
data["atom_positions"][0]

Atom positions of the first residue of the protein:


array([[-5.1976315 , -5.42995258, 15.23373318],
       [-3.87963166, -6.11195178, 14.99373341],
       [-2.66863122, -5.24695201, 15.38573265],
       [-3.70063128, -6.53995319, 13.53573322],
       [-1.70463147, -5.77595325, 16.00773335],
       [-3.84163155, -8.03495212, 13.30173302],
       [-0.        ,  0.        ,  0.        ],
       [-0.        ,  0.        ,  0.        ],
       [-0.        ,  0.        ,  0.        ],
       [-0.        ,  0.        ,  0.        ],
       [-0.        ,  0.        ,  0.        ],
       [-3.52163138, -8.43095203, 11.87273312],
       [-0.        ,  0.        ,  0.        ],
       [-0.        ,  0.        ,  0.        ],
       [-0.        ,  0.        ,  0.        ],
       [-0.        ,  0.        ,  0.        ],
       [-0.        ,  0.        ,  0.        ],
       [-0.        ,  0.        ,  0.        ],
       [-0.        ,  0.        ,  0.        ],
       [-0.        ,  0.        ,  0.        ],
       [-0.        ,  0.        ,  0.   

In [7]:
for data_name, data_value in data.items():
    print(data_name)
    print(data_value.shape)
    print("\n")

atom_positions
(125, 37, 3)


aatype
(125,)


atom_mask
(125, 37)


residue_index
(125,)


chain_index
(125,)


b_factors
(125, 37)


bb_mask
(125,)


bb_positions
(125, 3)


modeled_idx
(60,)




In [8]:
zipped = zip(data["aatype"], data["chain_index"], data["bb_mask"], data["atom_mask"], data["residue_index"])

for aa, chain_idx, bb_mask, a_mask, res_idx in zipped:
    print(
        f"Amino acid: {aa}, Chain index: {chain_idx}, Has backbone coords: {bool(bb_mask)}, "
        f"Has any side-chain atoms coords: {a_mask.any()}, Residue index: {res_idx}"
    )

Amino acid: 6, Chain index: 36, Has backbone coords: True, Has any side-chain atoms coords: True, Residue index: 121
Amino acid: 8, Chain index: 36, Has backbone coords: True, Has any side-chain atoms coords: True, Residue index: 122
Amino acid: 9, Chain index: 36, Has backbone coords: True, Has any side-chain atoms coords: True, Residue index: 123
Amino acid: 14, Chain index: 36, Has backbone coords: True, Has any side-chain atoms coords: True, Residue index: 124
Amino acid: 7, Chain index: 36, Has backbone coords: True, Has any side-chain atoms coords: True, Residue index: 125
Amino acid: 16, Chain index: 36, Has backbone coords: True, Has any side-chain atoms coords: True, Residue index: 126
Amino acid: 10, Chain index: 36, Has backbone coords: True, Has any side-chain atoms coords: True, Residue index: 127
Amino acid: 1, Chain index: 36, Has backbone coords: True, Has any side-chain atoms coords: True, Residue index: 128
Amino acid: 13, Chain index: 36, Has backbone coords: True, H

In [9]:
data["bb_mask"].sum()

np.float64(60.0)

In [10]:
data.keys()

dict_keys(['atom_positions', 'aatype', 'atom_mask', 'residue_index', 'chain_index', 'b_factors', 'bb_mask', 'bb_positions', 'modeled_idx'])

In [11]:
bb_mask = data["bb_mask"].astype(bool)
a_mask = data["atom_mask"].astype(bool)

In [12]:
debug_dir = "../data/debug"


if not os.path.exists(debug_dir):
    os.makedirs(debug_dir, exist_ok=True)


debug_2f60_pdb_path = os.path.join(debug_dir, "2f60.pdb")

saved_path = write_prot_to_pdb(data["atom_positions"], debug_2f60_pdb_path, no_indexing=True, b_factors=None)

In [20]:
saved_path

'../data/debug/2f60.pdb'

In [18]:
eval_path = "../eval_outputs/2f60/default/step_2000/len_60_sample_0_flowed_1.00.pdb"

In [19]:
wandb.log({"2f60": wandb.Molecule(open(saved_path))})
wandb.log({"Predicted 2f60": wandb.Molecule(open(eval_path))})

In [None]:
wandb.finish()