In [1]:
"""SE(3) diffusion methods."""
#mod imports to just get diffuser
from scipy.spatial.transform import Rotation
# import numpy as np
from data_rigid_diffuser import so3_diffuser
from data_rigid_diffuser import r3_diffuser
from data_rigid_diffuser import se3_diffuser
# from scipy.spatial.transform import Rotation
from data_rigid_diffuser import rigid_utils as ru
import torch
# import utils as du
# import torch
# import logging
import numpy  as np
import os
import yaml
from collections import namedtuple

In [23]:

# Useful numbers
# N [-1.45837285,  0 , 0]
# CA [0., 0., 0.]
# C [0.55221403, 1.41890368, 0.        ]
# CB [ 0.52892494, -0.77445692, -1.19923854]
N_CA_dist = 1.458
C_CA_dist = 1.523

if ( hasattr(os, 'ATOM_NAMES') ):
    assert( hasattr(os, 'PDB_ORDER') )

    ATOM_NAMES = os.ATOM_NAMES
    PDB_ORDER = os.PDB_ORDER
else:
    ATOM_NAMES=['N', 'CA', 'CB', 'C', 'O']
    PDB_ORDER = ['N', 'CA', 'C', 'O', 'CB']

_byte_atom_names = []
_atom_names = []
for i, atom_name in enumerate(ATOM_NAMES):
    long_name = " " + atom_name + "       "
    _atom_names.append(long_name[:4])
    _byte_atom_names.append(atom_name.encode())

    globals()[atom_name] = i

R = len(ATOM_NAMES)

if ( "N" not in globals() ):
    N = -1
if ( "C" not in globals() ):
    C = -1
if ( "CB" not in globals() ):
    CB = -1


_pdb_order = []
for name in PDB_ORDER:
    _pdb_order.append( ATOM_NAMES.index(name) )

apa_path_str  = 'data_npose/h4_apa_coords.npz'
tog_path_str  = 'data_npose/h4_tog_coords.npz'

#grab the first 3 atoms which are N,CA,C
test_limit = 2
rr = np.load(apa_path_str)
coords_apa = [rr[f] for f in rr.files][0][:test_limit,:]

rr = np.load(tog_path_str)
coords_tog = [rr[f] for f in rr.files][0][:test_limit,:]
coords_tog = torch.tensor(coords_tog)

In [26]:
rigid = ru.Rigid.make_transform_from_reference(coords_tog[...,N,:3], coords_tog[...,CA,:3], coords_tog[...,C,:3])

In [47]:
def _extract_trans_rots(rigid: ru.Rigid):
    rot = rigid.get_rots().get_rot_mats().cpu().numpy()
    rot_shape = rot.shape
    num_rots = np.cumprod(rot_shape[:-2])[-1]
    rot = rot.reshape((num_rots, 3, 3))
    rot = Rotation.from_matrix(rot).as_rotvec().reshape(rot_shape[:-2] +(3,))
    tran = rigid.get_trans().cpu().numpy()
    return tran, rot
def normalize(v):
    norm = np.linalg.norm(v,axis=len(v.shape)-1)
    norm[norm == 0] = 1
    return v / norm[...,None]

In [48]:
trans, rot = _extract_trans_rots(rigid)

In [49]:
rmat = rigid.get_rots().get_rot_mats().cpu().numpy()

In [50]:
coords_tog[...,N,:3].shape

torch.Size([2, 65, 3])

In [51]:
(coords_tog[...,N,:3]/np.linalg.norm(coords_tog[...,N,:3],axis=2)[...,None]).shape

torch.Size([2, 65, 3])

In [52]:
#fix normalize with dimension
N_CA_vec = normalize(coords_tog[...,N,:3]-coords_tog[...,CA,:3])
C_CA_vec = normalize(coords_tog[...,C,:3]-coords_tog[...,CA,:3])

In [113]:
?np.random.normal

In [53]:
N_CA_vec.shape

torch.Size([2, 65, 3])

In [54]:
with open('data_rigid_diffuser/base.yaml', 'r') as file:
    config = yaml.safe_load(file)


In [None]:
with open('data_rigid_diffuser/base.yaml', 'r') as file:
    config = yaml.safe_load(file)

# Applies to Python-3 Standard Library
class Struct(object):
    def __init__(self, data):
        for name, value in data.items():
            setattr(self, name, self._wrap(value))

    def _wrap(self, value):
        if isinstance(value, (tuple, list, set, frozenset)): 
            return type(value)([self._wrap(v) for v in value])
        else:
            return Struct(value) if isinstance(value, dict) else value
conf = Struct(config['diffuser'])

In [55]:
# Applies to Python-3 Standard Library
class Struct(object):
    def __init__(self, data):
        for name, value in data.items():
            setattr(self, name, self._wrap(value))

    def _wrap(self, value):
        if isinstance(value, (tuple, list, set, frozenset)): 
            return type(value)([self._wrap(v) for v in value])
        else:
            return Struct(value) if isinstance(value, dict) else value
conf = Struct(config['diffuser'])

In [56]:
so3d = so3_diffuser.SO3Diffuser(conf.so3)

In [57]:
#so3d.forward_marginal(rot,0.1)

In [58]:
r3d = r3_diffuser.R3Diffuser(conf.r3)

In [112]:
r3d.forward

<bound method R3Diffuser.forward of <data_rigid_diffuser.r3_diffuser.R3Diffuser object at 0x7f8f46d6d0a0>>

In [105]:
n_samples = np.cumprod(rot.shape[:-1])[-1]
rotvec = so3d.sample(0.075, n_samples=n_samples)
rotmat = Rotation.from_rotvec(rotvec).as_matrix()
# rotvec_shape = rotvec.shape
# num_rotvecs = np.cumprod(rotvec_shape[:-1])[-1]
# rotvec = rotvec.reshape((num_rotvecs, 3))
# rotmat = Rotation.from_rotvec(rotvec).as_matrix().reshape(
#     rotvec_shape[:-1] + (3, 3))



In [106]:
rp = np.repeat(rotmat[0][None,...],65,axis=0)

In [107]:
N_CAnew = ru.rot_vec_mul(torch.tensor(rp), N_CA_vec)
C_CAnew = ru.rot_vec_mul(torch.tensor(rp), C_CA_vec)

In [99]:
import numpy as np
import util.npose_util as nu

In [100]:
def build_npose_from_coords(coords_in):
    """Use N, CA, C coordinates to generate O an CB atoms"""
    rot_mat_cat = np.ones(sum((coords_in.shape[:-1], (1,)), ()))
    
    coords = np.concatenate((coords_in,rot_mat_cat),axis=-1)
    
    npose = np.ones((coords_in.shape[0]*5,4)) #5 is atoms per res

    by_res = npose.reshape(-1, 5, 4)
    
    if ( "N" in ATOM_NAMES ):
        by_res[:,N,:3] = coords_in[:,0,:3]
    if ( "CA" in ATOM_NAMES ):
        by_res[:,CA,:3] = coords_in[:,1,:3]
    if ( "C" in ATOM_NAMES ):
        by_res[:,C,:3] = coords_in[:,2,:3]
    if ( "O" in ATOM_NAMES ):
        by_res[:,O,:3] = nu.build_O(npose)
    if ( "CB" in ATOM_NAMES ):
        tpose = nu.tpose_from_npose(npose)
        by_res[:,CB,:] = nu.build_CB(tpose)

    return npose

def dump_coord_pdb(coords_in, fileOut='fileOut.pdb'):
    
    npose =  build_npose_from_coords(coords_in)
    nu.dump_npdb(npose,fileOut)

In [108]:
N_CA = coords_tog[...,CA,:3]+N_CA_dist*N_CAnew
C_CA = coords_tog[...,CA,:3]+C_CA_dist*C_CAnew

In [66]:
N_CA = coords_tog[...,CA,:3]+N_CA_dist*N_CA_vec
C_CA = coords_tog[...,CA,:3]+C_CA_dist*C_CA_vec

In [109]:
co = torch.stack((N_CA,coords_tog[...,CA,:3], C_CA),axis=2).squeeze()

In [110]:
co.numpy().shape

(2, 65, 3, 3)

In [111]:
dump_coord_pdb(co.numpy()[0])

In [58]:
N_CA.shape

torch.Size([1, 65, 3])

In [11]:
se3d= se3_diffuser.SE3Diffuser(conf)

In [12]:
dif = se3d.forward_marginal(rigid[0],0.1)

In [32]:
dif['trans_score_scaling']

3.1050834559359974

In [15]:
#first 4 quaternion, last three translation
#rot_vec is axis angle of rotation where then norm is the angle
dif['rigids_t'][1]

tensor([-0.0211,  0.9482,  0.1337,  0.2875,  0.3689, -2.3312, -7.4057])

In [None]:
ru.quat_to_rot

In [None]:
def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
    """
        Converts a quaternion to a rotation matrix.

        Args:
            quat: [*, 4] quaternions
        Returns:
            [*, 3, 3] rotation matrices
    """
    # [*, 4, 4]
    quat = quat[..., None] * quat[..., None, :]

    # [4, 4, 3, 3]
    mat = quat.new_tensor(_QTR_MAT, requires_grad=False)

    # [*, 4, 4, 3, 3]
    shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
    quat = quat[..., None, None] * shaped_qtr_mat

    # [*, 3, 3]
    return torch.sum(quat, dim=(-3, -4))