In [None]:
import pickle as pkl
import os
from context_generator.utils.protein import parsers
from context_generator.modules.common import all_atom
import numpy as np
import jax.numpy as jnp
import torch

folder_path = "results/tn/architecture.hidden_shapes=[512,512,512],batch_size=8,loss=ism,n=4,steps=300000/42/sample_results"
for root,dirs,files in os.walk(folder_path):
    break

err_all = []

for pdb in dirs:
    work_dir = f"{folder_path}/{pdb}"
    data_file = open(f"{work_dir}/data.pickle", "rb")
    batch = pkl.load(data_file)
    sampled_chis_file = open(f"{work_dir}/xs.pickle", "rb")
    sampled_chis = pkl.load(sampled_chis_file).numpy()

    data = parsers.make_atom14_masks(batch)   
    for k, v in data.items():
        if torch.is_tensor(v):
            data[k] = v.numpy()
    # print(data['pos_heavyatom'].shape)
    # print(data['residx_atom37_to_atom14'].shape)
    atom37_data = parsers.batched_gather(
        data['pos_heavyatom'][...,:14,:],
        data['residx_atom37_to_atom14'],
        batch_dims=2
    )
    mask_residx_atom37_to_atom14 = np.where(data['residx_atom37_to_atom14']!=0, data['residx_atom37_to_atom14'], 14)
    mask_residx_atom37_to_atom14[:,:,0] = 0
    atom37_data_exists = parsers.batched_gather(
        data['mask_heavyatom'],
        mask_residx_atom37_to_atom14,
        batch_dims=2
    )
    atom37_data *= atom37_data_exists[...,None].astype(atom37_data.dtype)
    torsion_angles_dict = all_atom.atom37_to_torsion_angles(
        aatype=data['aa_AF2'],
        all_atom_pos=atom37_data,
        all_atom_mask=atom37_data_exists,        
    )

    # err_avg = 0
    
    # for i in range(100):
    torsion_angles_sin_cos = torsion_angles_dict['torsion_angles_sin_cos']
    torsion_angles_sin_cos = jnp.concatenate([torsion_angles_sin_cos[...,:3,:], jnp.stack([jnp.sin(sampled_chis), jnp.cos(sampled_chis)], axis=-1)], axis=-2)
    
    from context_generator.modules.common.geometry import construct_3d_basis
    from context_generator.utils.protein.constants import BBHeavyAtom
    from context_generator.modules.model import r3

    backb_to_global = r3.rigids_from_3_points(
        point_on_neg_x_axis=r3.Vecs(
            batch['pos_atoms'][:,:,BBHeavyAtom.C][...,0],
            batch['pos_atoms'][:,:,BBHeavyAtom.C][...,1],
            batch['pos_atoms'][:,:,BBHeavyAtom.C][...,2]),
        origin=r3.Vecs(
            batch['pos_atoms'][:,:,BBHeavyAtom.CA][...,0],
            batch['pos_atoms'][:,:,BBHeavyAtom.CA][...,1],
            batch['pos_atoms'][:,:,BBHeavyAtom.CA][...,2]), 
        point_on_xy_plane=r3.Vecs(
            batch['pos_atoms'][:,:,BBHeavyAtom.N][...,0],
            batch['pos_atoms'][:,:,BBHeavyAtom.N][...,1],
            batch['pos_atoms'][:,:,BBHeavyAtom.N][...,2]),     
    )
    rots = np.tile(np.eye(3, dtype=np.float32), [1, data['pos_heavyatom'].shape[1], 1, 1])
    rots[...,0,0] = -1
    rots[...,2,2] = -1
    backb_to_global = r3.rigids_mul_rots(backb_to_global, r3.rots_from_tensor3x3(rots))
    # backb_to_global = r3.Rigids(
    #     rot=r3.Rots(*tree.flatten(R)),
    #     trans=r3.Vecs(*t)
    # )

    from context_generator.modules.common import all_atom
    all_frames_to_global =  all_atom.torsion_angles_to_frames(
        aatype=data['aa_AF2'],
        backb_to_global=backb_to_global,
        torsion_angles_sin_cos=torsion_angles_sin_cos)

    pred_positions, mask = all_atom.frames_and_literature_positions_to_atom14_pos(
        aatype=data['aa_AF2'],
        all_frames_to_global=all_frames_to_global
    )

    pred_positions = jnp.stack([pred_positions.x, pred_positions.y, pred_positions.z], axis=-1)
    pred_positions = pred_positions*batch['mask_heavyatom'][...,:14][...,jnp.newaxis]

    with open(f"{work_dir}/atom14.pickle", 'wb') as f:
        pkl.dump(np.asarray(pred_positions), f)

    err = 0
    for j in range(14):
        err += np.sum(all_atom.squared_difference(pred_positions, batch['pos_heavyatom'][...,:14,:])[...,j,:])
    err /= (14 * 128)
    # err_avg += err
    # err_avg /= 100

    err_all.append(err)
    # print(err_all)

err_all_avg = sum(err_all) / len(err_all)
print(err_all_avg)
print(np.sqrt(err_all_avg))