# Load a protein structure

In [1]:
import torch

from abflow.structure import extract_pdb_structure
                    
# Extract backbone coordinates
N_coords, CA_coords, C_coords, CB_coords, res_type, cdr_mask = extract_pdb_structure("./data/7eow.pdb")

print(N_coords.shape, CA_coords.shape, C_coords.shape, CB_coords.shape, res_type.shape)

# Make sure the coordinates are in the correct shape (1, N, 3) for a single batch
N_coords = N_coords.unsqueeze(0)
CA_coords = CA_coords.unsqueeze(0)
C_coords = C_coords.unsqueeze(0)
CB_coords = CB_coords.unsqueeze(0)
res_type = res_type.unsqueeze(0)

[INFO] Chain A does not contain valid Fv region for scheme chothia.
torch.Size([330, 3]) torch.Size([330, 3]) torch.Size([330, 3]) torch.Size([330, 3]) torch.Size([330])


In [2]:
cdr_mask['LCDR2'].shape

torch.Size([330])

# Frames to coords and back

In [None]:
from abflow.structure import bb_coords_to_frames, bb_frames_to_coords, impute_CB_coords, write_to_pdb
from abflow.constants import aa3_name_to_index

# Calculate frames from coordinates
rotations, translations = bb_coords_to_frames(N_coords, CA_coords, C_coords)

# Convert frames back to coordinates
N_coords_reconstructed, CA_coords_reconstructed, C_coords_reconstructed = bb_frames_to_coords(rotations, translations)

# Impute CB coordinates
CB_coords_reconstructed = impute_CB_coords(N_coords_reconstructed, CA_coords_reconstructed, C_coords_reconstructed)

# For glycine, set CB_coords to CA_coords
gly_mask = res_type == aa3_name_to_index["GLY"]
CB_coords_reconstructed[gly_mask] = CA_coords_reconstructed[gly_mask]

data = {
    "N_coords": N_coords_reconstructed,
    "CA_coords": CA_coords_reconstructed,
    "C_coords": C_coords_reconstructed,
    "CB_coords": CB_coords_reconstructed,
    "res_type": res_type,
    "valid_mask": torch.ones_like(res_type, dtype=torch.bool)
}

# Write the reconstructed structure to a PDB file
write_to_pdb(data, "./data/7eow_reconstructed.pdb")

# Structure violation

In [5]:
import torch
from abflow.model.metrics import get_violation, get_bb_bond_angle_loss, get_bb_bond_length_loss, get_bb_clash_loss

num_res = 5
noise = 0.005
# add some noise to the coordinates
pred_N_coords = N_coords[:, :num_res] * (1 + torch.randn_like(N_coords[:, :num_res]) * noise)
pred_CA_coords = CA_coords[:, :num_res] * (1 + torch.randn_like(CA_coords[:, :num_res]) * noise)
pred_C_coords = C_coords[:, :num_res] * (1 + torch.randn_like(C_coords[:, :num_res]) * noise)
masks = [torch.ones(pred_N_coords.shape[:2])]

print(get_bb_clash_loss(pred_N_coords, pred_CA_coords, pred_C_coords, masks))
print(get_bb_bond_length_loss(pred_N_coords, pred_CA_coords, pred_C_coords, masks))
print(get_bb_bond_angle_loss(pred_N_coords, pred_CA_coords, pred_C_coords, masks))
print(get_violation(pred_N_coords, pred_CA_coords, pred_C_coords, masks))

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0239]])
tensor([[0., 0., 0., 0., 1.]])
(tensor([0.0048]), tensor([0.2000]))
tensor([[0.0264, 1.2024, 1.0924, 0.1371, 0.2910]])
tensor([[1., 1., 1., 1., 1.]])
(tensor([0.5499]), tensor([1.]))
tensor([[0.0000, 0.1133, 0.2795, 0.3422, 0.0000]])
tensor([[0., 1., 1., 1., 0.]])
(tensor([0.1470]), tensor([0.6000]))
tensor([[0.0264, 1.2024, 1.0924, 0.1371, 0.2910]])
tensor([[1., 1., 1., 1., 1.]])
tensor([[0.0000, 0.1133, 0.2795, 0.3422, 0.0000]])
tensor([[0., 1., 1., 1., 0.]])
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0239]])
tensor([[0., 0., 0., 0., 1.]])
tensor([1.])


In [3]:
(1 + torch.randn_like(N_coords[:, :num_res]) * noise)

tensor([[[1.0028, 1.0051, 0.9985],
         [1.0111, 0.9914, 0.9933],
         [0.9980, 0.9893, 0.9853],
         [0.9930, 1.0068, 1.0230],
         [0.9871, 0.9849, 0.9919]]])

# TM score

In [13]:
import torch
from abflow.model.metrics import get_tm_score

num_res = 3
true_N_coords = N_coords[:, :num_res]
true_CA_coords = CA_coords[:, :num_res]
true_C_coords = C_coords[:, :num_res]

# add a bit of noise to the true structure
pred_N_coords = true_N_coords + torch.randn_like(true_N_coords) * 0.1
pred_CA_coords = true_CA_coords + torch.randn_like(true_CA_coords) * 0.1
pred_C_coords = true_C_coords + torch.randn_like(true_C_coords) * 0.1

masks = [torch.ones(pred_N_coords.shape[:2])]

pred_coords = [pred_N_coords]
true_coords = [true_N_coords]
print(get_tm_score(pred_coords, true_coords, masks))


tensor([0.5109])
