In [1]:
import torch
import numpy as np
import util.npose_util as nu
import os
import pathlib
import dgl
from dgl import backend as F
import torch_geometric
from torch.utils.data import random_split, DataLoader, Dataset
from typing import Dict
from torch import Tensor
from dgl import DGLGraph
from torch import nn
from chemical import cos_ideal_NCAC #from RoseTTAFold2
from torch import einsum
torch.cuda.is_available()

True

In [25]:
from data_rigid_diffuser import so3_diffuser
from data_rigid_diffuser import r3_diffuser
from scipy.spatial.transform import Rotation
from data_rigid_diffuser import rigid_utils as ru
import yaml

In [26]:
from se3_transformer.model.basis import get_basis, update_basis_with_fused
from se3_transformer.model.transformer import Sequential, SE3Transformer
from se3_transformer.model.transformer_topk import SE3Transformer_topK
from se3_transformer.model.layers.attentiontopK import AttentionBlockSE3
from se3_transformer.model.layers.linear import LinearSE3
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from se3_transformer.model.layers.norm import NormSE3
from se3_transformer.model.layers.pooling import GPooling
from se3_transformer.runtime.utils import str2bool, to_cuda
from se3_transformer.model.fiber import Fiber
from se3_transformer.model.transformer import get_populated_edge_features

In [27]:
#npose indexing
# Useful numbers
# N [-1.45837285,  0 , 0]
# CA [0., 0., 0.]
# C [0.55221403, 1.41890368, 0.        ]
# CB [ 0.52892494, -0.77445692, -1.19923854]

N_CA_dist = torch.tensor(1.458/10.0).to('cuda')
C_CA_dist = torch.tensor(1.523/10.0).to('cuda')

if ( hasattr(os, 'ATOM_NAMES') ):
    assert( hasattr(os, 'PDB_ORDER') )

    ATOM_NAMES = os.ATOM_NAMES
    PDB_ORDER = os.PDB_ORDER
else:
    ATOM_NAMES=['N', 'CA', 'CB', 'C', 'O']
    PDB_ORDER = ['N', 'CA', 'C', 'O', 'CB']

_byte_atom_names = []
_atom_names = []
for i, atom_name in enumerate(ATOM_NAMES):
    long_name = " " + atom_name + "       "
    _atom_names.append(long_name[:4])
    _byte_atom_names.append(atom_name.encode())

    globals()[atom_name] = i

R = len(ATOM_NAMES)

if ( "N" not in globals() ):
    N = -1
if ( "C" not in globals() ):
    C = -1
if ( "CB" not in globals() ):
    CB = -1


_pdb_order = []
for name in PDB_ORDER:
    _pdb_order.append( ATOM_NAMES.index(name) )

In [28]:
# data_path_str  = 'data/h4_ca_coords.npz'
# test_limit = 1028
# rr = np.load(data_path_str)
# ca_coords = [rr[f] for f in rr.files][0][:test_limit,:,:3]
# ca_coords.shape

# getting N-Ca, Ca-C vectors to add as typeI features
#apa = apart helices for test/train split
#tog = together helices for test/train split
apa_path_str  = 'data_npose/h4_apa_coords.npz'
tog_path_str  = 'data_npose/h4_tog_coords.npz'

#grab the first 3 atoms which are N,CA,C
test_limit = 1028
rr = np.load(apa_path_str)
coords_apa = [rr[f] for f in rr.files][0][:test_limit,:]

rr = np.load(tog_path_str)
coords_tog = [rr[f] for f in rr.files][0][:test_limit,:]

In [29]:
def build_npose_from_coords(coords_in):
    """Use N, CA, C coordinates to generate O an CB atoms"""
    rot_mat_cat = np.ones(sum((coords_in.shape[:-1], (1,)), ()))
    
    coords = np.concatenate((coords_in,rot_mat_cat),axis=-1)
    
    npose = np.ones((coords_in.shape[0]*5,4)) #5 is atoms per res

    by_res = npose.reshape(-1, 5, 4)
    
    if ( "N" in ATOM_NAMES ):
        by_res[:,N,:3] = coords_in[:,0,:3]
    if ( "CA" in ATOM_NAMES ):
        by_res[:,CA,:3] = coords_in[:,1,:3]
    if ( "C" in ATOM_NAMES ):
        by_res[:,C,:3] = coords_in[:,2,:3]
    if ( "O" in ATOM_NAMES ):
        by_res[:,O,:3] = nu.build_O(npose)
    if ( "CB" in ATOM_NAMES ):
        tpose = nu.tpose_from_npose(npose)
        by_res[:,CB,:] = nu.build_CB(tpose)

    return npose

def dump_coord_pdb(coords_in, fileOut='fileOut.pdb'):
    
    npose =  build_npose_from_coords(coords_in)
    nu.dump_npdb(npose,fileOut)

In [30]:
#goal define edges of
#connected backbone 1, 
#unconnected atoms 0,


def get_midpoint(ep_in):
    """Get midpoint, of each batched set of points"""
    
    #calculate midpoint
    midpoint = ep_in.sum(axis=1)/np.repeat(ep_in.shape[1], ep_in.shape[2])
    
    return midpoint


def normalize_points(input_xyz, print_dist=False):
    
    #broadcast to distance matrix [Batch, M, R3] to [Batch,M,1, R3] to [Batch,1,M, R3] to [Batch, M,M, R3] 
    vec_diff = input_xyz[...,None,:]-input_xyz[...,None,:,:]
    dist = np.sqrt(np.sum(np.square(vec_diff),axis=len(input_xyz.shape)))
    furthest_dist = np.max(dist)
    centroid  = get_midpoint(input_xyz)
    if print_dist:
        print(f'largest distance {furthest_dist:0.1f}')
    
    xyz_mean_zero = input_xyz - centroid[:,None,:]
    return xyz_mean_zero/furthest_dist

def define_graph_edges(n_nodes):
    #connected backbone

    con_v1 = np.arange(n_nodes-1) #vertex 1 of edges in chronological order
    con_v2 = np.arange(1,n_nodes) #vertex 2 of edges in chronological order

    ind = con_v1*(n_nodes-1)+con_v2-1 #account for removed self connections (-1)


    #unconnected backbone

    nodes = np.arange(n_nodes)
    v1 = np.repeat(nodes,n_nodes-1) #starting vertices, same number repeated for each edge

    start_v2 = np.repeat(np.arange(n_nodes)[None,:],n_nodes,axis=0)
    diag_ind = np.diag_indices(n_nodes)
    start_v2[diag_ind] = -1 #diagonal of matrix is self connections which we remove (self connections are managed by SE3 Conv channels)
    v2 = start_v2[start_v2>-0.5] #remove diagonal and flatten

    edge_data = torch.zeros(len(v2))
    edge_data[ind] = 1
    
    return v1,v2,edge_data, ind



def make_pe_encoding(n_nodes=65, embed_dim = 12, scale = 1000, cast_type=torch.float32, print_out=False):
    #positional encoding of node
    i_array = np.arange(1,(embed_dim/2)+1)
    wk = (1/(scale**(i_array*2/embed_dim)))
    t_array = np.arange(n_nodes)
    si = torch.tensor(np.sin(wk*t_array.reshape((-1,1))))
    ci = torch.tensor(np.cos(wk*t_array.reshape((-1,1))))
    pe = torch.stack((si,ci),axis=2).reshape(t_array.shape[0],embed_dim).type(cast_type)
    
    if print_out == True:
        for x in range(int(n_nodes/12)):
            print(np.round(pe[x],1))
    
    return pe
    
    
#v1,v2,edge_data, ind = define_graph_edges(n_nodes)
#norm_p = normalize_points(ca_coords,print_dist=True)
pe = make_pe_encoding(n_nodes=65, embed_dim = 12, scale = 10, print_out=True)

tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
tensor([0.6000, 0.8000, 0.4000, 0.9000, 0.3000, 1.0000, 0.2000, 1.0000, 0.1000,
        1.0000, 0.1000, 1.0000])
tensor([1.0000, 0.2000, 0.8000, 0.6000, 0.6000, 0.8000, 0.4000, 0.9000, 0.3000,
        1.0000, 0.2000, 1.0000])
tensor([ 0.9000, -0.5000,  1.0000,  0.2000,  0.8000,  0.6000,  0.6000,  0.8000,
         0.4000,  0.9000,  0.3000,  1.0000])
tensor([ 0.4000, -0.9000,  1.0000, -0.3000,  1.0000,  0.3000,  0.8000,  0.7000,
         0.6000,  0.8000,  0.4000,  0.9000])


In [31]:
#?dgl.nn.pytorch.KNNGraph, nearest neighbor graph maker
def define_graph(batch_size=8,n_nodes=65):
    
    v1,v2,edge_data, ind = define_graph_edges(n_nodes)
    pe = make_pe_encoding(n_nodes=n_nodes)
    
    graphList = []
    
    for i in range(batch_size):
        
        g = dgl.graph((v1,v2))
        g.edata['con'] = edge_data
        g.ndata['pe'] = pe

        graphList.append(g)
        
    batched_graph = dgl.batch(graphList)

    return batched_graph


In [32]:
def torch_normalize(v, eps=1e-6):
    """Normalize vector in last axis"""
    norm = torch.linalg.vector_norm(v, dim=len(v.shape)-1)+eps
    return v / norm[...,None]

def normalize(v):
    """Normalize vector in last axis"""
    norm = np.linalg.norm(v,axis=len(v.shape)-1)
    norm[norm == 0] = 1
    return v / norm[...,None]

def get_CN_vector(coords_in):
    N_CA_vec = normalize(coords_in[...,N,:3]-coords_in[...,CA,:3])
    C_CA_vec = normalize(coords_in[...,C,:3]-coords_in[...,CA,:3])
    return N_CA_vec, C_CA_vec



# Applies to Python-3 Standard Library
class Struct(object):
    def __init__(self, data):
        for name, value in data.items():
            setattr(self, name, self._wrap(value))

    def _wrap(self, value):
        if isinstance(value, (tuple, list, set, frozenset)): 
            return type(value)([self._wrap(v) for v in value])
        else:
            return Struct(value) if isinstance(value, dict) else value

In [33]:
def _get_relative_pos(graph_in: dgl.DGLGraph) -> torch.Tensor:
    x = graph_in.ndata['pos']
    src, dst = graph_in.edges()
    rel_pos = x[dst] - x[src]
    return rel_pos

class Helix4_Dataset(Dataset):
    def __init__(self, coordinates: np.array, cast_type=torch.float32):
        #prots,#length_prot in aa, #residues/aa, #xyz per atom
           
        #alphaFold reduce by 10
        coord_div = 10
        
        coordinates = coordinates/coord_div
        self.ca_coords = torch.tensor(coordinates[:,:,CA,:], dtype=cast_type)
        #unsqueeze to stack together later
        self.N_CA_vec = torch.tensor(coordinates[:,:,N,:] - coordinates[:,:,CA,:], dtype=cast_type).unsqueeze(2)
        self.C_CA_vec = torch.tensor(coordinates[:,:,C,:] - coordinates[:,:,CA,:], dtype=cast_type).unsqueeze(2)

    def __len__(self):
        return len(self.ca_coords)

    def __getitem__(self, idx):
        return {'CA':self.ca_coords[idx], 'N_CA':self.N_CA_vec[idx], 'C_CA':self.C_CA_vec[idx]}
    
    

    
    
class Make_KNN_MP_Graphs():
    
    #8 long positional encoding
    NODE_FEATURE_DIM_0 = 12
    EDGE_FEATURE_DIM = 1 # 0 or 1 primary seq connection or not
    NODE_FEATURE_DIM_1 = 2
    
    def __init__(self, mp_stride=4, n_nodes=65, radius=15, coord_div=10, cast_type=torch.float32, channels_start=32,
                       ndf1=6, ndf0=32,cuda=True):
        
        self.KNN = 30
        self.n_nodes = n_nodes
        self.pe = make_pe_encoding(n_nodes=n_nodes)
        self.mp_stride = mp_stride
        self.cast_type = cast_type
        self.channels_start = channels_start
        
        self.cuda = cuda
        self.ndf1 = ndf1 #awkard adding of nodes features to mpGraph
        self.ndf0 = ndf0
        
    def create_and_batch(self, bb_dict):
        
        graphList = []
        mpGraphList = []
        mpRevGraphList = []
        mpSelfGraphList = []
        
        for j, caXYZ in enumerate(bb_dict['CA']):
            graph = dgl.knn_graph(caXYZ, self.KNN)
            graph.ndata['pe'] = pe
            graph.ndata['pos'] = caXYZ
            graph.ndata['bb_ori'] = torch.cat((bb_dict['N_CA'][j],  bb_dict['C_CA'][j]),axis=1)
            
            #define covalent connections
            esrc, edst = graph.edges()
            graph.edata['con'] = (torch.abs(esrc-edst)==1).type(self.cast_type).reshape((-1,1))
            
            mp_list = torch.zeros((len(list(range(0,self.n_nodes, self.mp_stride))),caXYZ.shape[1]))
            
            new_src = torch.tensor([],dtype=torch.int)
            new_dst = torch.tensor([],dtype=torch.int)
            
            new_src_rev = torch.tensor([], dtype=torch.int)
            new_dst_rev = torch.tensor([], dtype=torch.int)
           
            i=0#mp list counter
            for x in range(0,self.n_nodes, self.mp_stride):
                src, dst = graph.in_edges(x) #dst repeats x
                n_tot = torch.cat((torch.tensor(x).unsqueeze(0),src)) #add x to node list
                mp_list[i] = caXYZ[n_tot].sum(axis=0)/n_tot.shape[0]
                mp_node = i + graph.num_nodes() #add midpoints nodes at end of graph
                #define edges between midpoint nodes and nodes defining midpoint for midpointGraph
                
                new_src = torch.cat((new_src,n_tot))
                new_dst = torch.cat((new_dst,
                                     (torch.tensor(mp_node).unsqueeze(0).repeat(n_tot.shape[0]))))
                #and reverse graph for coming off
                new_src_rev = torch.cat((new_src_rev,
                                         (torch.tensor(mp_node).unsqueeze(0).repeat(n_tot.shape[0]))))
                new_dst_rev = torch.cat((new_dst_rev,n_tot))
                
                i+=1
                
            mpGraph = dgl.graph((new_src,new_dst))
            mpGraph.ndata['pos'] = torch.cat((caXYZ,mp_list),axis=0).type(self.cast_type)
            mp_node_indx = torch.arange(0,self.n_nodes, self.mp_stride).type(torch.int)
            #match output shape of first transformer
            pe_mp = torch.cat((pe,torch.zeros((pe.shape[0],self.channels_start-pe.shape[1]))),axis=1)
            mpGraph.ndata['pe'] = torch.cat((pe_mp,pe_mp[mp_node_indx]))
            mpGraph.edata['con'] = torch.zeros((mpGraph.num_edges(),1))
            
            mpGraph_rev = dgl.graph((new_src_rev,new_dst_rev))
            mpGraph_rev.ndata['pos'] = torch.cat((caXYZ,mp_list),axis=0).type(self.cast_type)
            mpGraph_rev.ndata['pe'] = torch.cat((pe_mp,pe_mp[mp_node_indx]))
            mpGraph_rev.edata['con'] = torch.zeros((mpGraph_rev.num_edges(),1))
            
            #make graph for self interaction of midpoints
            v1,v2,edge_data, ind = define_graph_edges(len(mp_list))
            mpSelfGraph = dgl.graph((v1,v2))
            mpSelfGraph.edata['con'] = edge_data.reshape((-1,1))
            mpSelfGraph.ndata['pe'] = pe[mp_node_indx] #not really needed
            mpSelfGraph.ndata['pos'] = mp_list.type(self.cast_type)
            
            
            mpSelfGraphList.append(mpSelfGraph) 
            mpGraphList.append(mpGraph)
            mpRevGraphList.append(mpGraph_rev)
            graphList.append(graph)
        
        return dgl.batch(graphList), dgl.batch(mpGraphList), dgl.batch(mpSelfGraphList), dgl.batch(mpRevGraphList)
    
    def prep_for_network(self, bb_dict, cuda=True):
    
        batched_graph, batched_mpgraph, batched_mpself_graph, batched_mpRevgraph =  self.create_and_batch(bb_dict)
        
        edge_feats        =    {'0':   batched_graph.edata['con'][:, :self.EDGE_FEATURE_DIM, None]}
        edge_feats_mp     = {'0': batched_mpgraph.edata['con'][:, :self.EDGE_FEATURE_DIM, None]} #def all zero now
        edge_feats_mpself = {'0': batched_mpself_graph.edata['con'][:, :self.EDGE_FEATURE_DIM, None]}
#         edge_feats_mp     = {'0': batched_mpRevgraph.edata['con'][:, :self.EDGE_FEATURE_DIM, None]}
        batched_graph.edata['rel_pos']   = _get_relative_pos(batched_graph)
        batched_mpgraph.edata['rel_pos'] = _get_relative_pos(batched_mpgraph)
        batched_mpself_graph.edata['rel_pos'] = _get_relative_pos(batched_mpself_graph)
        batched_mpRevgraph.edata['rel_pos'] = _get_relative_pos(batched_mpRevgraph)
        # get node features
        node_feats =         {'0': batched_graph.ndata['pe'][:, :self.NODE_FEATURE_DIM_0, None],
                              '1': batched_graph.ndata['bb_ori'][:,:self.NODE_FEATURE_DIM_1, :3]}
        node_feats_mp =      {'0': batched_mpgraph.ndata['pe'][:, :self.ndf0, None],
                              '1': torch.ones((batched_mpgraph.num_nodes(),self.ndf1,3))}
        #unused
        node_feats_mpself =  {'0': batched_mpself_graph.ndata['pe'][:, :self.NODE_FEATURE_DIM_0, None]}
        
        if cuda:
            bg,nf,ef = to_cuda(batched_graph), to_cuda(node_feats), to_cuda(edge_feats)
            bg_mp, nf_mp, ef_mp = to_cuda(batched_mpgraph), to_cuda(node_feats_mp), to_cuda(edge_feats_mp)
            bg_mps, nf_mps, ef_mps = to_cuda(batched_mpself_graph), to_cuda(node_feats_mpself), to_cuda(edge_feats_mpself)
            bg_mpRev = to_cuda(batched_mpRevgraph)
            
            return bg,nf,ef, bg_mp, nf_mp, ef_mp, bg_mps, nf_mps, ef_mps, bg_mpRev
        
        else:
            bg,nf,ef = batched_graph, node_feats, edge_feats
            bg_mp, nf_mp, ef_mp = batched_mpgraph, node_feats_mp, edge_feats_mp
            bg_mps, nf_mps, ef_mps = batched_mpself_graph, node_feats_mpself, edge_feats_mpself
            bg_mpRev = batched_mpRevgraph
            
            return bg,nf,ef, bg_mp, nf_mp, ef_mp, bg_mps, nf_mps, ef_mps, bg_mpRev
        
            

def get_edge_features(graph,edge_feature_dim=1):
    return {'0': graph.edata['con'][:, :edge_feature_dim, None]}

def define_poolGraph(n_nodes, batch_size, cast_type=torch.float32, cuda_out=True ):
    
    v1,v2,edge_data, ind = define_graph_edges(n_nodes)
    #pe = make_pe_encoding(n_nodes=n_nodes)#pe e
    
    graphList = []
    
    for i in range(batch_size):
        
        g = dgl.graph((v1,v2))
        g.edata['con'] = edge_data.type(cast_type).reshape((-1,1))
        g.ndata['pos'] = torch.zeros((n_nodes,3),dtype=torch.float32)

        graphList.append(g)
        
    batched_graph = dgl.batch(graphList)
    
    if cuda_out:
        return to_cuda(batched_graph)
    else:
        return batched_graph            
        

In [34]:
def pull_edge_features(graph, edge_feat_dim=1):
    return {'0': graph.edata['con'][:, :edge_feat_dim, None]}

def prep_for_gcn(graph, xyz_pos, edge_feats_input, idx, max_degree=3, comp_grad=True):
    
    src, dst = graph.edges()
    
    new_pos = F.gather_row(xyz_pos, idx)
    rel_pos = F.gather_row(new_pos,dst) - F.gather_row(new_pos,src) 
    
    basis_out = get_basis(rel_pos, max_degree=max_degree,
                                   compute_gradients=comp_grad,
                                   use_pad_trick=False)
    basis_out = update_basis_with_fused(basis_out, max_degree, use_pad_trick=False,
                                            fully_fused=False)
    edge_feats_out = get_populated_edge_features(rel_pos, edge_feats_input)
    return edge_feats_out, basis_out, new_pos    

In [35]:
class FrameDiffNoise(nn.Module):
    """Generate Diffusion Noise based on FrameDiff"""
    
    def __init__(self, config_path='data_rigid_diffuser/base.yaml'):
        super().__init__()
        with open(config_path, 'r') as file:
            config = yaml.safe_load(file)
        conf = Struct(config['diffuser'])
        
        self.so3d = so3_diffuser.SO3Diffuser(conf.so3)
        self.r3d =  r3_diffuser.R3Diffuser(conf.r3)
        
    def forward(self, bb_dict, t, cast=torch.float32):
        
        ca = bb_dict['CA']
        nc_vec = bb_dict['N_CA'].reshape((-1,3))
        cc_vec = bb_dict['C_CA'].reshape((-1,3))
        #sample rotation
        n_samples = nc_vec.shape[0]
        rotvec = self.so3d.sample(t, n_samples=n_samples)
        rotmat = Rotation.from_rotvec(rotvec).as_matrix()
        #apply rotation
        
        batch_shape =  bb_dict['N_CA'].shape
        
        nc_vec_noised = ru.rot_vec_mul(torch.tensor(rotmat,dtype=cast),nc_vec).reshape(batch_shape)
        cc_vec_noised = ru.rot_vec_mul(torch.tensor(rotmat,dtype=cast),cc_vec).reshape(batch_shape)
        
        #r3d requires numpy
        ca_noised , _  = fdn.r3d.forward_marginal(ca.numpy(), t)
        ca_noised = torch.tensor(ca_noised, dtype=cast)
        
        bb_noised_out = {'CA': ca_noised, 'N_CA': nc_vec_noised, 'C_CA': cc_vec_noised}
        
        return bb_noised_out

In [36]:
def rigid_from_3_points(N, Ca, C, non_ideal=False, eps=1e-8):
    #N, Ca, C - [B,L, 3]
    #R - [B,L, 3, 3], det(R)=1, inv(R) = R.T, R is a rotation matrix
    B,L = N.shape[:2]
    
    v1 = C-Ca
    v2 = N-Ca
    e1 = v1/(torch.norm(v1, dim=-1, keepdim=True)+eps)
    u2 = v2-(torch.einsum('bli, bli -> bl', e1, v2)[...,None]*e1)
    e2 = u2/(torch.norm(u2, dim=-1, keepdim=True)+eps)
    e3 = torch.cross(e1, e2, dim=-1)
    R = torch.cat([e1[...,None], e2[...,None], e3[...,None]], axis=-1) #[B,L,3,3] - rotation matrix
    
    if non_ideal:
        v2 = v2/(torch.norm(v2, dim=-1, keepdim=True)+eps)
        cosref = torch.clamp( torch.sum(e1*v2, dim=-1), min=-1.0, max=1.0) # cosine of current N-CA-C bond angle
        costgt = cos_ideal_NCAC.item()
        cos2del = torch.clamp( cosref*costgt + torch.sqrt((1-cosref*cosref)*(1-costgt*costgt)+eps), min=-1.0, max=1.0 )
        cosdel = torch.sqrt(0.5*(1+cos2del)+eps)
        sindel = torch.sign(costgt-cosref) * torch.sqrt(1-0.5*(1+cos2del)+eps)
        Rp = torch.eye(3, device=N.device).repeat(B,L,1,1)
        Rp[:,:,0,0] = cosdel
        Rp[:,:,0,1] = -sindel
        Rp[:,:,1,0] = sindel
        Rp[:,:,1,1] = cosdel
    
        R = torch.einsum('blij,bljk->blik', R,Rp)

    return R, Ca

def get_t(N, Ca, C, non_ideal=False, eps=1e-5):
    I,B,L=N.shape[:3]
    Rs,Ts = rigid_from_3_points(N.view(I*B,L,3), Ca.view(I*B,L,3), C.view(I*B,L,3), non_ideal=non_ideal, eps=eps)
    Rs = Rs.view(I,B,L,3,3)
    Ts = Ts.view(I,B,L,3)
    t = Ts[:,:,None] - Ts[:,:,:,None] # t[0,1] = residue 0 -> residue 1 vector
    return einsum('iblkj, iblmk -> iblmj', Rs, t) # (I,B,L,L,3)

def FAPE_loss(pred, true,  d_clamp=10.0, d_clamp_inter=30.0, A=10.0, gamma=1.0, eps=1e-6):
    '''
    Calculate Backbone FAPE loss from RosettaTTAFold
    https://github.com/uw-ipd/RoseTTAFold2/blob/main/network/loss.py
    Input:
        - pred: predicted coordinates (I, B, L, n_atom, 3)
        - true: true coordinates (B, L, n_atom, 3)
    Output: str loss
    '''
    I = pred.shape[0]
    true = true.unsqueeze(0)
    t_tilde_ij = get_t(true[:,:,:,0], true[:,:,:,1], true[:,:,:,2])
    t_ij = get_t(pred[:,:,:,0], pred[:,:,:,1], pred[:,:,:,2])

    difference = torch.sqrt(torch.square(t_tilde_ij-t_ij).sum(dim=-1) + eps)
    eij_label = difference[-1].clone().detach()

    clamp = torch.zeros_like(difference)

    # intra vs inter#me coded
    clamp[:,True] = d_clamp

    difference = torch.clamp(difference, max=clamp)
    loss = difference / A # (I, B, L, L)

    # calculate masked loss (ignore missing regions when calculate loss)
    loss = (loss[:,True]).sum(dim=-1) / (torch.ones_like(loss).sum()+eps) # (I)

    # weighting loss
    w_loss = torch.pow(torch.full((I,), gamma, device=pred.device), torch.arange(I, device=pred.device))
    w_loss = torch.flip(w_loss, (0,))
    w_loss = w_loss / w_loss.sum()

    tot_loss = (w_loss * loss).sum()
    
    return tot_loss, loss.detach()

In [37]:
class Latent_Unpool(torch.nn.Module):
    """
    Duplicate Latent onto Graph. Add upper features from U-net
    """

    def __init__(self, fiber_in: Fiber, fiber_add: Fiber, knodes: int):
        super().__init__()
        self.fiber_out = Fiber.combine_max(fiber_in, fiber_add)
        self.node_repeat = knodes

    def forward(self, features: Dict[str, Tensor], u_features: Dict[str, Tensor]):
        out_feats = {}
        for degree_out, channels_out in self.fiber_out:
            cd = str(degree_out)
            if cd in features.keys():
                #repeat latent for all nodes
                feat_out = features[cd].repeat_interleave(self.node_repeat,0)[...,None]
                #add upper level features to front of 
                out_feats[cd] = torch.add(feat_out, 
                                          torch.concat((u_features[cd],
                                                       torch.zeros((u_features[cd].shape[0],)+
                                                                   (features[cd].shape[1]-u_features[cd].shape[1],)+
                                                                   u_features[cd].shape[2:]
                                                                   ,device = u_features[cd].device)
                     ),axis=1))

            else:
                #upper feats have additional type features
                out_feats[cd] = u_features[cd]
                    
        return out_feats
    
class Unpool_Layer(torch.nn.Module):
    """
    Uses indices to place nodes into zeros. Add upper features
    Assumes lower features are more
    """

    def __init__(self, fiber_in: Fiber, fiber_add: Fiber):
        super().__init__()
        self.fiber_in = fiber_in
        self.fiber_add = fiber_add
        self.fiber_out = Fiber.combine_max(fiber_in, fiber_add)
        
    def forward(self, features: Dict[str, Tensor], u_features: Dict[str, Tensor], idx : Tensor):
        out_feats = {}
        for degree_out, channels_out in self.fiber_out:
            cd = str(degree_out)
            unpool_feats = u_features[cd].new_zeros([u_features[cd].shape[0], features[cd].shape[1], u_features[cd].shape[2]])
            pad = features[cd].new_zeros([unpool_feats.shape[0], unpool_feats.shape[1]-u_features[cd].shape[1], unpool_feats.shape[2]])
            out_feats[cd] = torch.add(F.scatter_row(unpool_feats,idx,features[cd]), torch.cat((u_features[cd],pad),1))

        return out_feats
    
def prep_for_gcn(graph, xyz_pos, edge_feats_input, idx, max_degree=3, comp_grad=True):
    
    src, dst = graph.edges()
    
    new_pos = F.gather_row(xyz_pos, idx)
    rel_pos = F.gather_row(new_pos,dst) - F.gather_row(new_pos,src) 
    
    basis_out = get_basis(rel_pos, max_degree=max_degree,
                                   compute_gradients=comp_grad,
                                   use_pad_trick=False)
    basis_out = update_basis_with_fused(basis_out, max_degree, use_pad_trick=False,
                                            fully_fused=False)
    edge_feats_out = get_populated_edge_features(rel_pos, edge_feats_input)
    return edge_feats_out, basis_out, new_pos

def gs(dict_in):
    str_out = ''
    for x in dict_in.keys():
        sh = dict_in[x].shape
        str_out = f'{str_out}{x}: {sh}   '
    print(str_out)
        

In [38]:
class GraphUNet(torch.nn.Module):
    def __init__(self, 
                 fiber_start = Fiber({0:12, 1:2}),
                 fiber_out = Fiber({1:3}),
                 k=4,
                 batch_size = 8,
                 stride=4,
                 max_degree=3,
                 channels=32,
                 num_heads = 8,
                 channels_div=4,
                 num_layers = 1,
                 edge_feature_dim=1,
                 latent_pool_type = 'avg',
                 cuda=True):
        super(GraphUNet, self).__init__()
        
        ##to do pass precalculated basis
        
        self.comp_basis_grad = True
        self.cuda = cuda
        
        if cuda:
            self.device='cuda:0'
        else:
            self.device='cpu'
        
        self.max_degree=max_degree
        self.B = batch_size
        self.k = k
        
        self.num_layers = 1
        self.channels = 32
        self.feat0 = 32
        self.feat1 = 6
        self.channels_div = 4
        self.num_heads = 8
        self.mult = int(stride/2)
        self.fiber_edge=Fiber({0:edge_feature_dim})
        self.edge_feat_dim = edge_feature_dim
        
        self.pool_type = latent_pool_type
        
        self.channels_down_ca = channels
        #down c_alpha interactions by radius
        self.fiber_start =  fiber_start
        self.fiber_hidden_down_ca = Fiber.create(self.max_degree, self.channels_down_ca)
        self.fiber_out_down_ca =Fiber({0: self.feat0, 1: self.feat1})
        
        self.down_ca = SE3Transformer(num_layers = self.num_layers,
                        fiber_in=self.fiber_start,
                        fiber_hidden= self.fiber_hidden_down_ca, 
                        fiber_out=self.fiber_out_down_ca,
                        num_heads = self.num_heads,
                        channels_div = self.channels_div,
                        fiber_edge=self.fiber_edge,
                        low_memory=True,
                        tensor_cores=False)
        
        self.channels_down_ca2mp = self.channels_down_ca*self.mult
        
        #pool from c_alpha onto midpoints
        self.fiber_in_down_ca2mp     = self.fiber_out_down_ca
        self.fiber_hidden_down_ca2mp = Fiber.create(max_degree, self.channels_down_ca2mp)
        self.fiber_out_down_ca2mp    = Fiber({0: self.feat0*self.mult, 1: self.feat1*self.mult})

        self.down_ca2mp = SE3Transformer(num_layers = self.num_layers,
                            fiber_in     = self.fiber_in_down_ca2mp,
                            fiber_hidden = self.fiber_hidden_down_ca2mp, 
                            fiber_out    = self.fiber_out_down_ca2mp,
                            num_heads =    self.num_heads,
                            channels_div = self.channels_div,
                            fiber_edge=self. fiber_edge,
                            low_memory=True,
                            tensor_cores=False)
        
        self.fiber_in_mptopk =  self.fiber_out_down_ca2mp
        self.fiber_hidden_down_mp  =self.fiber_hidden_down_ca2mp
        self.fiber_out_down_mp_out =self.fiber_out_down_ca2mp
        self.fiber_out_topkpool=Fiber({0: self.feat0*self.mult*self.mult})

        self.mp_topk = SE3Transformer_topK(num_layers      = self.num_layers,
                                        fiber_in      = self.fiber_in_mptopk,
                                        fiber_hidden  = self.fiber_hidden_down_mp, 
                                        fiber_out     = self.fiber_out_down_mp_out ,
                                        fiber_out_topk= self.fiber_out_topkpool,
                                        k             = self.k,
                                        num_heads     = self.num_heads,
                                        channels_div  = self.channels_div,
                                        fiber_edge    =  self.fiber_edge,
                                        low_memory=True,
                                        tensor_cores=False)
        
        self.gsmall = define_poolGraph(self.k, self.B, cast_type=torch.float32, cuda_out=self.cuda)
        self.ef_small = pull_edge_features(self.gsmall, edge_feat_dim=self.edge_feat_dim)
        
        #change to doing convolutions instead of points
        self.fiber_in_down_gcn   =  self.fiber_out_topkpool
        self.fiber_out_down_gcn  = Fiber({0: self.feat0*self.mult*self.mult, 1: self.feat1*self.mult})

        self.down_gcn = ConvSE3(fiber_in  = self.fiber_in_down_gcn,
                           fiber_out = self.fiber_out_down_gcn,
                           fiber_edge= self.fiber_edge,
                             self_interaction=True,
                             use_layer_norm=True,
                             max_degree=self.max_degree,
                             fuse_level= ConvSE3FuseLevel.NONE,
                             low_memory= True)
        
        
        self.fiber_in_down_gcnp = self.fiber_out_down_gcn
        #probably rename latent
        self.latent_size = self.feat0*self.mult*self.mult
        self.fiber_latent = Fiber({0: self.latent_size})

        self.down_gcn2pool = ConvSE3(fiber_in=self.fiber_in_down_gcnp,
                                     fiber_out=self.fiber_latent,
                                     fiber_edge=self.fiber_edge,
                                     self_interaction=True,
                                     use_layer_norm=True,
                                     max_degree=self.max_degree,
                                     fuse_level= ConvSE3FuseLevel.NONE,
                                     low_memory= True)
        
        self.global_pool = GPooling(pool=self.pool_type, feat_type=0)

        self.latent_unpool_layer = Latent_Unpool(fiber_in = self.fiber_latent, fiber_add = self.fiber_out_down_gcn, 
                                            knodes = self.k)

        self.up_gcn = ConvSE3(fiber_in=self.fiber_out_down_gcn,
                             fiber_out=self.fiber_out_down_gcn,
                             fiber_edge=self.fiber_edge,
                             self_interaction=True,
                             use_layer_norm=True,
                             max_degree=self.max_degree,
                             fuse_level= ConvSE3FuseLevel.NONE,
                             low_memory= True)
        
        self.unpool_layer = Unpool_Layer(fiber_in=self.fiber_out_down_gcn, fiber_add=self.fiber_out_down_ca)
        
        self.fiber_in_up_gcn_mp = self.unpool_layer.fiber_out
        self.fiber_hidden_up_mp= self.fiber_hidden_down_ca2mp
        self.fiber_out_up_gcn_mp = self.fiber_out_down_mp_out

        self.up_gcn_mp = SE3Transformer(num_layers = num_layers,
                        fiber_in=self.fiber_in_up_gcn_mp,
                        fiber_hidden= self.fiber_hidden_up_mp, 
                        fiber_out=self.fiber_out_up_gcn_mp,
                        num_heads = self.num_heads,
                        channels_div = self.channels_div,
                        fiber_edge=self.fiber_edge,
                        low_memory=True,
                        tensor_cores=False)
        
        self.unpool_layer_off_mp = Unpool_Layer(fiber_in=self.fiber_out_down_mp_out, fiber_add=self.fiber_out_down_mp_out)

        self.fiber_in_up_off_mp = self.fiber_out_up_gcn_mp
        self.fiber_hidden_up_off_mp = self.fiber_hidden_up_mp
        self.fiber_out_up_off_mp = self.fiber_out_down_ca 
        
        #uses reverse graph to move mp off 
        
        self.up_off_mp = SE3Transformer(num_layers = self.num_layers,
                        fiber_in=self.fiber_in_up_off_mp,
                        fiber_hidden= self.fiber_hidden_up_off_mp, 
                        fiber_out=self.fiber_out_up_off_mp,
                        num_heads = self.num_heads,
                        channels_div = self.channels_div,
                        fiber_edge=self.fiber_edge,
                        low_memory=True,
                        tensor_cores=False)
        
        self.fiber_out = fiber_out
        
        self.up_ca = SE3Transformer(num_layers =self.num_layers,
                                    fiber_in=self.fiber_out_down_ca,
                                    fiber_hidden= self.fiber_hidden_down_ca, 
                                    fiber_out=self.fiber_out,
                                    num_heads = self.num_heads,
                                    channels_div = self.channels_div,
                                    fiber_edge= self.fiber_edge,
                                    low_memory=True,
                                    tensor_cores=False)
        
        self.linear = LinearSE3(fiber_in=fiber_out,
                                fiber_out=fiber_out)
        
        self.zero_linear()
        
    def zero_linear(self):
        nn.init.zeros_(self.linear.weights['1'])
        
    def concat_mp_feats(self, ca_feats_in, mp_feats):

        nf0_c = ca_feats_in['0'].shape[-2]
        nf1_c = ca_feats_in['1'].shape[-2]

        out0_cat_shape = (B,self.ca_nodes,-1,1)
        mp0_cat_shape  = (B,self.mp_nodes,-1,1)
        out1_cat_shape = (B,self.ca_nodes,-1,3)
        mp1_cat_shape  = (B,self.mp_nodes,-1,3)

        nf_c = {} #nf_cat
        nf_c['0'] = torch.cat((ca_feats_in['0'].reshape(out0_cat_shape), 
                                 mp_feats['0'].reshape(mp0_cat_shape)[:,-(self.mp_nodes-self.ca_nodes):,:,:]),
                              axis=1).reshape((-1,nf0_c,1))

        nf_c['1'] = torch.cat((ca_feats_in['1'].reshape(out1_cat_shape), 
                                 mp_feats['1'].reshape(mp1_cat_shape)[:,-(self.mp_nodes-self.ca_nodes):,:,:]),
                              axis=1).reshape((-1,nf1_c,3))

        return nf_c
        
    def pull_out_mp_feats(self, ca_mp_feats):

        nf0_c = ca_mp_feats['0'].shape[1]
        nf1_c = ca_mp_feats['1'].shape[1]

        nf_mp_ = {}
        #select just mp nodes to move on, the other nodes don't connect but mainting self connections
        nf_mp_['0'] = ca_mp_feats['0'].reshape(B,self.mp_nodes,
                                               nf0_c,1)[:,-(self.mp_nodes-self.ca_nodes):,...].reshape((-1,nf0_c,1))
        nf_mp_['1'] = ca_mp_feats['1'].reshape(B,self.mp_nodes,
                                               nf1_c,3)[:,-(self.mp_nodes-self.ca_nodes):,...].reshape((-1,nf1_c,3))

        return nf_mp_
        
        
    def forward(self, input_tuple):

        b_graph, nf, ef, b_graph_mp, nf_mp, ef_mp, b_graph_mps, nf_mps, ef_mps, b_graph_mpRev = input_tuple
        #assumes equal node numbers in g raphs
        self.ca_nodes = int(b_graph.batch_num_nodes()[0])
        self.mp_nodes = int(b_graph_mp.batch_num_nodes()[0]) #ca+mp nodes number

        #SE3 Attention Transformer, c_alpha 
        nf_ca_down_out = self.down_ca(b_graph, nf, ef)

        #concatenate on midpoints feats
        nf_down_cat_mp = self.concat_mp_feats(nf_ca_down_out, nf_mp)

        #pool from ca onto selected midpoints via SE3 Attention transformer
        #edges from ca to mp only (ca nodes zero after this)
        nf_down_ca2mp_out = self.down_ca2mp(b_graph_mp, nf_down_cat_mp, ef_mp)

        #remove ca node feats from tensor 
        nf_mp_out = self.pull_out_mp_feats(nf_down_ca2mp_out)

        node_feats_tk, topk_feats, topk_indx = self.mp_topk(b_graph_mps, nf_mp_out, ef_mps)

        #make new basis for small graph of k selected midpoints
        edge_feats_out, basis_out, new_pos = prep_for_gcn(self.gsmall, b_graph_mps.ndata['pos'], self.ef_small,
                                                          topk_indx,
                                                          max_degree=self.max_degree, comp_grad=True)

        down_gcn_out = self.down_gcn(topk_feats, edge_feats_out, self.gsmall,  basis_out)

        down_gcnpool_out = self.down_gcn2pool(down_gcn_out, edge_feats_out, self.gsmall,  basis_out)

        pooled_tensor = self.global_pool(down_gcnpool_out,self.gsmall)
        pooled = {'0':pooled_tensor}
        #----------------------------------------- end of down section
        lat_unp = self.latent_unpool_layer(pooled,down_gcn_out)

        up_gcn_out = self.up_gcn(lat_unp, edge_feats_out, self.gsmall,  basis_out)

        k_to_mp  = self.unpool_layer(up_gcn_out,node_feats_tk,topk_indx)

        up_mp_gcn_out = self.up_gcn_mp(b_graph_mps, k_to_mp, ef_mps)
        
        off_mp_add = {}
        for k,v in up_mp_gcn_out.items():
            off_mp_add[k] = torch.add(up_mp_gcn_out[k],nf_mp_out[k])


        #####triple check from here
        #midpoints node indices for unpool layer
        mp_node_indx = torch.arange(self.ca_nodes,self.mp_nodes, device=self.device)
        mp_idx = mp_node_indx[None,...].repeat_interleave(self.B,0)
        mp_idx =((torch.arange(self.B,device=self.device)*(self.mp_nodes)).reshape((-1,1))+mp_idx).reshape((-1))
        
        #during unpool, keep mp=values and ca=zeros
        zeros_mp_ca = {}
        for k,v in nf_down_cat_mp.items():
            zeros_mp_ca[k] = torch.zeros_like(v, device=self.device)


        unpoff_out = self.unpool_layer_off_mp(off_mp_add, zeros_mp_ca, mp_idx)
        
        out_up_off_mp = self.up_off_mp(b_graph_mpRev, unpoff_out, ef_mp)
        
        #select just ca nodes, mp = zeros from last convolution
        inv_mp_idx= torch.arange(0,self.ca_nodes, device=self.device)
        inv_mp_idx = inv_mp_idx[None,...].repeat_interleave(self.B,0)
        inv_mp_idx =((torch.arange(self.B,device=self.device)*(self.mp_nodes)).reshape((-1,1))
                     +inv_mp_idx).reshape((-1))

        node_final_ca = {}
        for key in out_up_off_mp.keys():
            node_final_ca[key] = torch.add(out_up_off_mp[key][inv_mp_idx,...],nf_ca_down_out[key])

        #return updates 
        return self.linear(self.up_ca(b_graph, node_final_ca, ef))
                
        
        

In [39]:
def noise_test(backbone_dict, noised_dict):
    CA_t  = bb_dict['CA'].reshape(B, L, 3).to('cuda')
    NC_t = CA_t + bb_dict['N_CA'].reshape(B, L, 3).to('cuda')
    CC_t = CA_t + bb_dict['C_CA'].reshape(B, L, 3).to('cuda')
    true =  torch.cat((NC_t,CA_t,CC_t),dim=2).reshape(B,L,3,3)
    
    CA_n  = noised_dict['CA'].reshape(B, L, 3).to('cuda')
    NC_n = CA_n + noised_dict['N_CA'].reshape(B, L, 3).to('cuda')
    CC_n = CA_n + noised_dict['C_CA'].reshape(B, L, 3).to('cuda')
    noise_xyz =  torch.cat((NC_n,CA_n,CC_n),dim=2).reshape(B,L,3,3)
    return true.to('cpu').numpy()*10, noise_xyz.to('cpu').numpy()*10
    

def get_noise_pred_true(backbone_dict, noised_dict, graph_maker, graph_unet):
    
    CA_t  = bb_dict['CA'].reshape(B, L, 3).to('cuda')
    NC_t = CA_t + bb_dict['N_CA'].reshape(B, L, 3).to('cuda')
    CC_t = CA_t + bb_dict['C_CA'].reshape(B, L, 3).to('cuda')
    true =  torch.cat((NC_t,CA_t,CC_t),dim=2).reshape(B,L,3,3)
    
    CA_n  = noised_dict['CA'].reshape(B, L, 3).to('cuda')
    NC_n = CA_n + noised_dict['N_CA'].reshape(B, L, 3).to('cuda')
    CC_n = CA_n + noised_dict['C_CA'].reshape(B, L, 3).to('cuda')
    noise_xyz =  torch.cat((NC_n,CA_n,CC_n),dim=2).reshape(B,L,3,3)

    x = graph_maker.prep_for_network(noised_dict)
    out = graph_unet(x)
    
    CA_p = out['1'][:,0,:].reshape(B, L, 3)+noised_dict['CA'].reshape(B, L, 3).to('cuda') #translation of Calpha

    pred = torch.cat((NC_t,CA_p,CC_t),dim=2).reshape(B,L,3,3)
    
    return true.to('cpu').numpy()*10, noise_xyz.to('cpu').numpy()*10, pred.detach().to('cpu').numpy()*10

def dump_tnp(true, noise, pred,e=0, numOut=1,outdir='output/'):
    
    if numOut>true.shape[0]:
        numOut = true.shape[0]
    
    for x in range(numOut):
        dump_coord_pdb(true[x], fileOut=f'{outdir}/true_{e}_{x}.pdb')
        dump_coord_pdb(noise[x], fileOut=f'{outdir}/noise_{e}_{x}.pdb')
        dump_coord_pdb(pred[x], fileOut=f'{outdir}/pred_{e}_{x}.pdb')
        
def test_model(bb_dict, noised_bb,epoch, numOut=1):
    true, noise, pred = get_noise_pred_true(bb_dict, noised_bb, gm, gu)
    dump_tnp(true,noise,pred, e=epoch,numOut=numOut)

def train_step(backbone_dict, noised_dict, graph_maker, graph_unet):
    
    CA_t  = bb_dict['CA'].reshape(B, L, 3).to('cuda')
    NC_t = CA_t + bb_dict['N_CA'].reshape(B, L, 3).to('cuda')
    CC_t = CA_t + bb_dict['C_CA'].reshape(B, L, 3).to('cuda')
    true =  torch.cat((NC_t,CA_t,CC_t),dim=2).reshape(B,L,3,3)
    
    x = graph_maker.prep_for_network(noised_dict)
    out = graph_unet(x)
    
    CA_p = out['1'][:,0,:].reshape(B, L, 3)+noised_dict['CA'].reshape(B, L, 3).to('cuda')
    #add rot code here
    
    pred = torch.cat((NC_t,CA_p,CC_t),dim=2).reshape(B,L,3,3)
    
    tloss, loss = FAPE_loss(pred.unsqueeze(0), true)

    opti.zero_grad()
    tloss.backward()
    opti.step()
    
    return tloss.detach()

In [18]:
B = 48
L=65
t = 0.025
h4 = Helix4_Dataset(coords_tog[:2048])
h4_dL = DataLoader(h4, batch_size=B, shuffle=True, drop_last=True)

In [20]:
gu = GraphUNet(batch_size = B).to('cuda')
opti = torch.optim.Adam(gu.parameters(), lr=0.005, weight_decay=5e-7)
gm = Make_KNN_MP_Graphs()

In [21]:
t=0.005
fdn= FrameDiffNoise()
for e in range(300):
    lo_sum = 0 
    counter = 1
    for bb_dict in h4_dL:
        noised_bb = fdn(bb_dict,t)
        lo = train_step(bb_dict, noised_bb, gm, gu).cpu()
        lo_sum += lo
        counter += 1
        
    print(lo_sum/counter)

    if e %5==0:
        test_model(bb_dict,noised_bb ,e,numOut=5)


  assert input.numel() == input.storage().size(), (


tensor(0.0414)
tensor(0.0334)
tensor(0.0328)
tensor(0.0322)
tensor(0.0311)
tensor(0.0269)
tensor(0.0245)
tensor(0.0254)
tensor(0.0233)
tensor(0.0231)
tensor(0.0226)
tensor(0.0226)
tensor(0.0222)
tensor(0.0219)
tensor(0.0219)
tensor(0.0226)
tensor(0.0219)
tensor(0.0221)
tensor(0.0220)
tensor(0.0218)
tensor(0.0214)
tensor(0.0212)
tensor(0.0213)
tensor(0.0211)
tensor(0.0207)
tensor(0.0205)
tensor(0.0206)
tensor(0.0203)
tensor(0.0200)
tensor(0.0198)
tensor(0.0196)
tensor(0.0195)
tensor(0.0194)
tensor(0.0193)
tensor(0.0192)
tensor(0.0191)
tensor(0.0191)
tensor(0.0188)
tensor(0.0188)
tensor(0.0188)
tensor(0.0185)
tensor(0.0185)
tensor(0.0186)
tensor(0.0184)
tensor(0.0183)
tensor(0.0183)
tensor(0.0183)
tensor(0.0182)
tensor(0.0183)
tensor(0.0180)
tensor(0.0180)
tensor(0.0180)
tensor(0.0179)
tensor(0.0181)
tensor(0.0181)
tensor(0.0180)
tensor(0.0181)
tensor(0.0183)
tensor(0.0179)
tensor(0.0179)
tensor(0.0179)
tensor(0.0178)
tensor(0.0177)
tensor(0.0178)
tensor(0.0176)
tensor(0.0177)
tensor(0.0

KeyboardInterrupt: 

In [102]:
# for bb_dict in h4_dL:
#     noised_bb = fdn(bb_dict,t)
#     break
# true, noise = noise_test(bb_dict, noised_bb)
# x=0
# e='test'
# outdir='output/'
# dump_coord_pdb(true[x], fileOut=f'{outdir}/true_{e}_{x}.pdb')
# # dump_coord_pdb(noise[x], fileOut=f'{outdir}/noise_{e}_{x}.pdb')
# test_model(bb_dict, noised_bb,'test')

In [40]:
for bb_dict in h4_dL:
    noised_bb = fdn(bb_dict,t)
    break
    
test_model(bb_dict,noised_bb ,'testP',numOut=10)
    

NameError: name 'normQ' is not defined

In [None]:
torch.tensor()

In [None]:
from util.npose_util import makePointPDB
#gds = Graph_RadiusMP_4H_Dataset(coords_tog[:100], 10, mp_stride = 3)
def view_mp_graph(mps: DGLGraph, coords: np.array ):
    p = mps.ndata['pos']*10
    
    to = np.concatenate((coords, np.ones_like(coords)[:,:,0][...,None]),axis=2)
    
    makePointPDB(p,'test.pdb',outDirec='output')
    nu.dump_npdb(to,'output/test2.pdb')
#view_mp_graph(gds.mpSelfGraphList[0], coords_tog[0])