In [1]:
%load_ext autoreload
%autoreload 2

In [9]:
from openfold.utils.loss import lddt_loss, AlphaFoldLoss
import openfold
import glob
import torch
from pathlib import Path

from plaid.proteins import LatentToStructure, LatentToSequence
from plaid.esmfold.misc import batch_encode_sequences
from plaid.transforms import trim_or_pad_batch_first

device = torch.device("cuda")
print(openfold.__file__)

shard_dir = "/homefs/home/lux70/storage/data/cath/shards/"
pdb_dir = "/data/bucket/lux70/data/cath/dompdb"

def pretty_print(p):
    pprint.pprint(list(p.keys()))
    print("\n _______ \n")

/homefs/home/lux70/openfold/openfold/__init__.py


# Feature Parser Development 

In [3]:

pdb_path = glob.glob(f"{pdb_dir}/*")[0]
print(pdb_path)

with open(pdb_path, "r") as f:
    pdb_str = f.read()
# print(pdb_str)

/data/bucket/lux70/data/cath/dompdb/4a9aC01


In [4]:
from plaid.openfold_utils import protein_from_pdb_string, make_pdb_features


protein_object = protein_from_pdb_string(pdb_str)
pdb_id = Path(pdb_path).stem

# TODO: what is the `is_distillation` argument?
protein_features = make_pdb_features(
    protein_object, description=pdb_id, is_distillation=False
)

In [5]:
import pprint

pdb_feat_keys = list(protein_features.keys())
pdb_feat_keys.sort()
pprint.pprint(pdb_feat_keys)

['aatype',
 'all_atom_mask',
 'all_atom_positions',
 'between_segment_residues',
 'domain_name',
 'is_distillation',
 'residue_index',
 'resolution',
 'seq_length',
 'sequence']


In [6]:
sequence = protein_features['sequence'][0].decode()
sequence

'LEKQPKITLEEFIETERGKLDKSKLTPITIANFAQWKKDHVIAKINAEKKLSSKRKPTGREIILKMSAE'

In [7]:
from plaid.openfold_utils import get_chi_atom_indices

chi_idxs = get_chi_atom_indices()
print(chi_idxs)

[[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 1, 3, 5], [1, 3, 5, 11], [3, 5, 11, 23], [5, 11, 23, 32]], [[0, 1, 3, 5], [1, 3, 5, 16], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 1, 3, 5], [1, 3, 5, 16], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 1, 3, 10], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 1, 3, 5], [1, 3, 5, 11], [3, 5, 11, 26], [0, 0, 0, 0]], [[0, 1, 3, 5], [1, 3, 5, 11], [3, 5, 11, 26], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 1, 3, 5], [1, 3, 5, 14], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 1, 3, 6], [1, 3, 6, 12], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 1, 3, 5], [1, 3, 5, 12], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 1, 3, 5], [1, 3, 5, 11], [3, 5, 11, 19], [5, 11, 19, 35]], [[0, 1, 3, 5], [1, 3, 5, 18], [3, 5, 18, 19], [0, 0, 0, 0]], [[0, 1, 3, 5], [1, 3, 5, 12], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 1, 3, 5], [1, 3, 5, 11], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 1, 3, 8], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 1, 3, 9], [0, 0, 0, 0], [0, 0, 0, 0],

In [8]:
pdb_feats = {}
import numpy as np

for k, v in protein_features.items():
    if not v.dtype == np.object_:
        pdb_feats[k] = torch.from_numpy(v) #.cuda()

In [9]:
from plaid.openfold_utils._data_transforms import *



In [10]:
p = make_all_atom_aatype(pdb_feats)
pretty_print(p)

p = make_seq_mask(p)
pretty_print(p)

p = make_atom14_masks(p)
pretty_print(p)

p = make_atom14_positions(p)
pretty_print(p)

p = atom37_to_frames(p)
pretty_print(p)

p = get_backbone_frames(p)
pretty_print(p)

# f = make_pseudo_beta("")
# p = f(p)
# pretty_print(p)

# f = atom37_to_torsion_angles("")
# p = f(p)
# pretty_print(p)

# p = get_chi_angles(p)
# pretty_print(p)

['aatype',
 'between_segment_residues',
 'residue_index',
 'seq_length',
 'all_atom_positions',
 'all_atom_mask',
 'resolution',
 'is_distillation',
 'all_atom_aatype']

 _______ 

['aatype',
 'between_segment_residues',
 'residue_index',
 'seq_length',
 'all_atom_positions',
 'all_atom_mask',
 'resolution',
 'is_distillation',
 'all_atom_aatype',
 'seq_mask']

 _______ 

['aatype',
 'between_segment_residues',
 'residue_index',
 'seq_length',
 'all_atom_positions',
 'all_atom_mask',
 'resolution',
 'is_distillation',
 'all_atom_aatype',
 'seq_mask',
 'atom14_atom_exists',
 'residx_atom14_to_atom37',
 'residx_atom37_to_atom14',
 'atom37_atom_exists']

 _______ 

['aatype',
 'between_segment_residues',
 'residue_index',
 'seq_length',
 'all_atom_positions',
 'all_atom_mask',
 'resolution',
 'is_distillation',
 'all_atom_aatype',
 'seq_mask',
 'atom14_atom_exists',
 'residx_atom14_to_atom37',
 'residx_atom37_to_atom14',
 'atom37_atom_exists',
 'atom14_gt_exists',
 'atom14_gt_positions',


In [11]:
for k, v in p.items():
    try:
        print(k, v.shape)
    except:
        print(k)
        pass

aatype torch.Size([69, 21])
between_segment_residues torch.Size([69])
residue_index torch.Size([69])
seq_length torch.Size([69])
all_atom_positions torch.Size([69, 37, 3])
all_atom_mask torch.Size([69, 37])
resolution torch.Size([1])
is_distillation torch.Size([])
all_atom_aatype torch.Size([69, 21])
seq_mask torch.Size([69, 21])
atom14_atom_exists torch.Size([69, 21, 14])
residx_atom14_to_atom37 torch.Size([69, 21, 14])
residx_atom37_to_atom14 torch.Size([69, 21, 37])
atom37_atom_exists torch.Size([69, 21, 37])
atom14_gt_exists torch.Size([69, 21, 14])
atom14_gt_positions torch.Size([69, 21, 14, 3])
atom14_alt_gt_positions torch.Size([69, 21, 14, 3])
atom14_alt_gt_exists torch.Size([69, 21, 14])
atom14_atom_is_ambiguous torch.Size([69, 21, 14])
rigidgroups_gt_frames torch.Size([69, 21, 8, 4, 4])
rigidgroups_gt_exists torch.Size([69, 21, 8])
rigidgroups_group_exists torch.Size([69, 21, 8])
rigidgroups_group_is_ambiguous torch.Size([69, 21, 8])
rigidgroups_alt_gt_frames torch.Size([69, 

In [12]:
# f = make_pseudo_beta("")
# p = f(p)
# pretty_print(p)

# # f = atom37_to_torsion_angles("")
# # p = f(p)
# # pretty_print(p)

# # p = get_chi_angles(p)
# # pretty_print(p)

# Load Features

In [5]:
# get saved embedding
from plaid.datasets import CATHStructureDataModule

shard_dir = "/homefs/home/lux70/storage/data/cath/shards/"
pdb_dir = "/data/bucket/lux70/data/cath/dompdb"
# shard_dir = "/homefs/home/lux70/storage/data/rocklin/shards/"
# pdb_dir = "/data/bucket/lux70/data/rocklin/structures/"

NUM_SAMPLES = 4
max_seq_len=64

dm = CATHStructureDataModule(
    shard_dir,
    pdb_dir,
    seq_len=max_seq_len,
    batch_size=NUM_SAMPLES,
    max_num_samples=NUM_SAMPLES,
    shuffle_val_dataset=False
) 
    
dm.setup()
val_dataloader = dm.val_dataloader()
next_batch = next(iter(val_dataloader))
print(len(val_dataloader.dataset))

4


In [6]:
x, sequence, batch = next_batch

for k, v in batch.items():
    try:
        batch[k] = v.to(device)
    except:
        pass


print(batch.keys())
print(batch['aatype'].shape)
print(batch['atom14_gt_positions'].shape)
print(batch['atom14_alt_gt_positions'].shape)
print(batch['backbone_rigid_tensor'].shape)

dict_keys(['aatype', 'between_segment_residues', 'domain_name', 'residue_index', 'seq_length', 'sequence', 'all_atom_positions', 'all_atom_mask', 'resolution', 'is_distillation', 'mask', 'all_atom_aatype', 'seq_mask', 'atom14_atom_exists', 'residx_atom14_to_atom37', 'residx_atom37_to_atom14', 'atom37_atom_exists', 'atom14_gt_exists', 'atom14_gt_positions', 'atom14_alt_gt_positions', 'atom14_alt_gt_exists', 'atom14_atom_is_ambiguous', 'rigidgroups_gt_frames', 'rigidgroups_gt_exists', 'rigidgroups_group_exists', 'rigidgroups_group_is_ambiguous', 'rigidgroups_alt_gt_frames', 'backbone_rigid_tensor', 'backbone_rigid_mask'])
torch.Size([4, 64])
torch.Size([4, 64, 14, 3])
torch.Size([4, 64, 14, 3])
torch.Size([4, 64, 4, 4])


In [7]:
# get corresponding mask, to be used later
_, mask, _, _, _ = batch_encode_sequences(sequence)

mask = trim_or_pad_batch_first(mask, pad_to=max_seq_len, pad_idx=0).to(device)

# implicitly calls Openfold structure module
latent_to_structure = LatentToStructure()
latent_to_structure.to(device)
struct = latent_to_structure.to_structure(
    x, sequence, return_raw_features=True, batch_size=NUM_SAMPLES, num_recycles=1
)

loading esmfold model...
ESMFold model created in 45.73 seconds.


(Generating structure): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.61s/it]


In [11]:
import pprint
out = struct[-1][0]
pretty_print(out)

['frames',
 'sidechain_frames',
 'unnormalized_angles',
 'angles',
 'positions',
 'states',
 's_initial',
 's_after_layernorm',
 'z_after_layernorm',
 'single',
 'sm_s',
 'sm_z',
 's_s',
 's_z',
 'distogram_logits',
 'lm_logits',
 'aatype',
 'atom14_atom_exists',
 'residx_atom14_to_atom37',
 'residx_atom37_to_atom14',
 'atom37_atom_exists',
 'residue_index',
 'lddt_head',
 'plddt',
 'ptm_logits',
 'ptm',
 'aligned_confidence_probs',
 'predicted_aligned_error',
 'max_predicted_aligned_error']

 _______ 



In [12]:
print(out['positions'].shape)
print(out['positions'].device)
print(batch["atom14_alt_gt_positions"].shape)
print(batch["atom14_alt_gt_positions"].device)

torch.Size([8, 4, 64, 14, 3])
cuda:0
torch.Size([4, 64, 14, 3])
cuda:0


In [13]:
# save some CUDA memory:
del latent_to_structure

### Renaming

In [14]:
# from openfold.utils.loss import compute_renamed_ground_truth
from plaid.openfold_utils._losses import compute_renamed_ground_truth
outd  = compute_renamed_ground_truth(batch, out['positions'])
print(outd.keys())

batch = batch | outd

dict_keys(['alt_naming_is_better', 'renamed_atom14_gt_positions', 'renamed_atom14_gt_exists'])


## Get loss config

In [15]:
from openfold.config import model_config
config = model_config(name="initial_training", train=True)
print(type(config))
print(config.keys())
print(config.loss.keys())
print(config.loss.fape.keys())

<class 'ml_collections.config_dict.config_dict.ConfigDict'>
['data', 'ema', 'globals', 'loss', 'model', 'relax']
['distogram', 'eps', 'experimentally_resolved', 'fape', 'masked_msa', 'plddt_loss', 'supervised_chi', 'tm', 'violation']
['backbone', 'eps', 'sidechain', 'weight']


# Try Losses

## Backbone Loss

In [16]:
from openfold.utils.loss import backbone_loss
traj = out['frames']
bb_loss = backbone_loss(traj=traj, **{**batch, **config.loss.fape.backbone})
print(bb_loss)

tensor(0.9094, device='cuda:0')


## Sidechain Loss

In [18]:
from plaid.openfold_utils._losses import sidechain_loss

with torch.no_grad():
    sc_loss = sidechain_loss(
        out["sidechain_frames"],
        out["positions"],
        **{**batch, **config.loss.fape.sidechain},
    )

## Distogram Loss

In [19]:
from openfold.utils.loss import distogram_loss
distogram_loss(out['distogram_logits'], **{**batch, **config.loss.distogram})

TypeError: distogram_loss() missing 2 required positional arguments: 'pseudo_beta' and 'pseudo_beta_mask'

## Experimentally resolved loss

In [20]:
from openfold.utils.loss import experimentally_resolved_loss
experimentally_resolved_loss(
    logits=out['experimentally_resolved_logits'],
    **{**batch, **config.loss.experimentally_resolved}
)

KeyError: 'experimentally_resolved_logits'

## Combined FAPE loss

In [21]:
from openfold.utils.loss import fape_loss
wrapper_out = {"sm": out}
fape_loss(wrapper_out, batch, config.loss.fape)

RuntimeError: The size of tensor a (896) must match the size of tensor b (7168) at non-singleton dimension 2

## LDDT loss

In [22]:
from openfold.utils.loss import lddt_loss
lddt_loss(
    logits=out['plddt'].to(device), #TODO: these are not logits
    all_atom_pred_pos=out['positions'].to(device),  #TODO: confim if these are the same
    **{**batch, **config.loss.plddt_loss}
)
out.keys()

RuntimeError: The size of tensor a (50) must match the size of tensor b (37) at non-singleton dimension 3

# Supervised chi loss

In [23]:
from openfold.utils.loss import supervised_chi_loss
supervised_chi_loss(
    out["angles"], out["unnormalized_angles"], **{**batch, **config.loss.supervised_chi}
)

TypeError: supervised_chi_loss() missing 2 required positional arguments: 'chi_mask' and 'chi_angles_sin_cos'

## Violation Loss

In [24]:
from openfold.utils.loss import violation_loss
violation_loss(
    out['violation'], **{**batch, **config.loss.violation}
)

KeyError: 'violation'

## TM loss

In [25]:
from openfold.utils.loss import tm_loss
tm_loss(
    logits=out['ptm_logits'], 
    **{**batch, **out, **config.loss.tm}
)

TypeError: tm_loss() missing 1 required positional argument: 'final_affine_tensor'

In [26]:
from openfold.utils.loss import chain_center_of_mass_loss
chain_center_of_mass_loss(
    all_atom_pred_pos=out['positions'],
    **{**batch, **config.loss.chain_center_of_mass}
)

ImportError: cannot import name 'chain_center_of_mass_loss' from 'openfold.utils.loss' (/homefs/home/lux70/openfold/openfold/utils/loss.py)