In [1]:
import torch
import numpy as np
import util.npose_util as nu
import os
import pathlib
import dgl
from dgl import backend as F
import torch_geometric
from torch.utils.data import random_split, DataLoader, Dataset
from typing import Dict
from torch import Tensor
from dgl import DGLGraph
from torch import nn
from chemical import cos_ideal_NCAC #from RoseTTAFold2
from torch import einsum
torch.cuda.is_available()

True

In [2]:
from data_rigid_diffuser import so3_diffuser
from data_rigid_diffuser import r3_diffuser
from scipy.spatial.transform import Rotation
from data_rigid_diffuser import rigid_utils as ru
import yaml

In [3]:
from se3_transformer.model.basis import get_basis, update_basis_with_fused
from se3_transformer.model.transformer import Sequential, SE3Transformer
from se3_transformer.model.transformer_topk import SE3Transformer_topK
from se3_transformer.model.layers.attentiontopK import AttentionBlockSE3
from se3_transformer.model.layers.linear import LinearSE3
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from se3_transformer.model.layers.norm import NormSE3
from se3_transformer.model.layers.pooling import GPooling
from se3_transformer.runtime.utils import str2bool, to_cuda
from se3_transformer.model.fiber import Fiber
from se3_transformer.model.transformer import get_populated_edge_features

In [4]:
#npose indexing
# Useful numbers
# N [-1.45837285,  0 , 0]
# CA [0., 0., 0.]
# C [0.55221403, 1.41890368, 0.        ]
# CB [ 0.52892494, -0.77445692, -1.19923854]

N_CA_dist = torch.tensor(1.458/10.0).to('cuda')
C_CA_dist = torch.tensor(1.523/10.0).to('cuda')

if ( hasattr(os, 'ATOM_NAMES') ):
    assert( hasattr(os, 'PDB_ORDER') )

    ATOM_NAMES = os.ATOM_NAMES
    PDB_ORDER = os.PDB_ORDER
else:
    ATOM_NAMES=['N', 'CA', 'CB', 'C', 'O']
    PDB_ORDER = ['N', 'CA', 'C', 'O', 'CB']

_byte_atom_names = []
_atom_names = []
for i, atom_name in enumerate(ATOM_NAMES):
    long_name = " " + atom_name + "       "
    _atom_names.append(long_name[:4])
    _byte_atom_names.append(atom_name.encode())

    globals()[atom_name] = i

R = len(ATOM_NAMES)

if ( "N" not in globals() ):
    N = -1
if ( "C" not in globals() ):
    C = -1
if ( "CB" not in globals() ):
    CB = -1


_pdb_order = []
for name in PDB_ORDER:
    _pdb_order.append( ATOM_NAMES.index(name) )

In [5]:
# data_path_str  = 'data/h4_ca_coords.npz'
# test_limit = 1028
# rr = np.load(data_path_str)
# ca_coords = [rr[f] for f in rr.files][0][:test_limit,:,:3]
# ca_coords.shape

# getting N-Ca, Ca-C vectors to add as typeI features
#apa = apart helices for test/train split
#tog = together helices for test/train split
apa_path_str  = 'data_npose/h4_apa_coords.npz'
tog_path_str  = 'data_npose/h4_tog_coords.npz'

#grab the first 3 atoms which are N,CA,C
test_limit = 1028
rr = np.load(apa_path_str)
coords_apa = [rr[f] for f in rr.files][0][:test_limit,:]

rr = np.load(tog_path_str)
coords_tog = [rr[f] for f in rr.files][0][:test_limit,:]

In [6]:
def build_npose_from_coords(coords_in):
    """Use N, CA, C coordinates to generate O an CB atoms"""
    rot_mat_cat = np.ones(sum((coords_in.shape[:-1], (1,)), ()))
    
    coords = np.concatenate((coords_in,rot_mat_cat),axis=-1)
    
    npose = np.ones((coords_in.shape[0]*5,4)) #5 is atoms per res

    by_res = npose.reshape(-1, 5, 4)
    
    if ( "N" in ATOM_NAMES ):
        by_res[:,N,:3] = coords_in[:,0,:3]
    if ( "CA" in ATOM_NAMES ):
        by_res[:,CA,:3] = coords_in[:,1,:3]
    if ( "C" in ATOM_NAMES ):
        by_res[:,C,:3] = coords_in[:,2,:3]
    if ( "O" in ATOM_NAMES ):
        by_res[:,O,:3] = nu.build_O(npose)
    if ( "CB" in ATOM_NAMES ):
        tpose = nu.tpose_from_npose(npose)
        by_res[:,CB,:] = nu.build_CB(tpose)

    return npose

def dump_coord_pdb(coords_in, fileOut='fileOut.pdb'):
    
    npose =  build_npose_from_coords(coords_in)
    nu.dump_npdb(npose,fileOut)

In [7]:
#goal define edges of
#connected backbone 1, 
#unconnected atoms 0,


def get_midpoint(ep_in):
    """Get midpoint, of each batched set of points"""
    
    #calculate midpoint
    midpoint = ep_in.sum(axis=1)/np.repeat(ep_in.shape[1], ep_in.shape[2])
    
    return midpoint


def normalize_points(input_xyz, print_dist=False):
    
    #broadcast to distance matrix [Batch, M, R3] to [Batch,M,1, R3] to [Batch,1,M, R3] to [Batch, M,M, R3] 
    vec_diff = input_xyz[...,None,:]-input_xyz[...,None,:,:]
    dist = np.sqrt(np.sum(np.square(vec_diff),axis=len(input_xyz.shape)))
    furthest_dist = np.max(dist)
    centroid  = get_midpoint(input_xyz)
    if print_dist:
        print(f'largest distance {furthest_dist:0.1f}')
    
    xyz_mean_zero = input_xyz - centroid[:,None,:]
    return xyz_mean_zero/furthest_dist

def define_graph_edges(n_nodes):
    #connected backbone

    con_v1 = np.arange(n_nodes-1) #vertex 1 of edges in chronological order
    con_v2 = np.arange(1,n_nodes) #vertex 2 of edges in chronological order

    ind = con_v1*(n_nodes-1)+con_v2-1 #account for removed self connections (-1)


    #unconnected backbone

    nodes = np.arange(n_nodes)
    v1 = np.repeat(nodes,n_nodes-1) #starting vertices, same number repeated for each edge

    start_v2 = np.repeat(np.arange(n_nodes)[None,:],n_nodes,axis=0)
    diag_ind = np.diag_indices(n_nodes)
    start_v2[diag_ind] = -1 #diagonal of matrix is self connections which we remove (self connections are managed by SE3 Conv channels)
    v2 = start_v2[start_v2>-0.5] #remove diagonal and flatten

    edge_data = torch.zeros(len(v2))
    edge_data[ind] = 1
    
    return v1,v2,edge_data, ind



def make_pe_encoding(n_nodes=65, embed_dim = 12, scale = 1000, cast_type=torch.float32, print_out=False):
    #positional encoding of node
    i_array = np.arange(1,(embed_dim/2)+1)
    wk = (1/(scale**(i_array*2/embed_dim)))
    t_array = np.arange(n_nodes)
    si = torch.tensor(np.sin(wk*t_array.reshape((-1,1))))
    ci = torch.tensor(np.cos(wk*t_array.reshape((-1,1))))
    pe = torch.stack((si,ci),axis=2).reshape(t_array.shape[0],embed_dim).type(cast_type)
    
    if print_out == True:
        for x in range(int(n_nodes/12)):
            print(np.round(pe[x],1))
    
    return pe
    
    
#v1,v2,edge_data, ind = define_graph_edges(n_nodes)
#norm_p = normalize_points(ca_coords,print_dist=True)
pe = make_pe_encoding(n_nodes=65, embed_dim = 12, scale = 10, print_out=True)

tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
tensor([0.6000, 0.8000, 0.4000, 0.9000, 0.3000, 1.0000, 0.2000, 1.0000, 0.1000,
        1.0000, 0.1000, 1.0000])
tensor([1.0000, 0.2000, 0.8000, 0.6000, 0.6000, 0.8000, 0.4000, 0.9000, 0.3000,
        1.0000, 0.2000, 1.0000])
tensor([ 0.9000, -0.5000,  1.0000,  0.2000,  0.8000,  0.6000,  0.6000,  0.8000,
         0.4000,  0.9000,  0.3000,  1.0000])
tensor([ 0.4000, -0.9000,  1.0000, -0.3000,  1.0000,  0.3000,  0.8000,  0.7000,
         0.6000,  0.8000,  0.4000,  0.9000])


In [9]:
#?dgl.nn.pytorch.KNNGraph, nearest neighbor graph maker
def define_graph(batch_size=8,n_nodes=65):
    
    v1,v2,edge_data, ind = define_graph_edges(n_nodes)
    pe = make_pe_encoding(n_nodes=n_nodes)
    
    graphList = []
    
    for i in range(batch_size):
        
        g = dgl.graph((v1,v2))
        g.edata['con'] = edge_data
        g.ndata['pe'] = pe

        graphList.append(g)
        
    batched_graph = dgl.batch(graphList)

    return batched_graph


In [10]:
def torch_normalize(v, eps=1e-6):
    """Normalize vector in last axis"""
    norm = torch.linalg.vector_norm(v, dim=len(v.shape)-1)+eps
    return v / norm[...,None]

def normalize(v):
    """Normalize vector in last axis"""
    norm = np.linalg.norm(v,axis=len(v.shape)-1)
    norm[norm == 0] = 1
    return v / norm[...,None]

def get_CN_vector(coords_in):
    N_CA_vec = normalize(coords_in[...,N,:3]-coords_in[...,CA,:3])
    C_CA_vec = normalize(coords_in[...,C,:3]-coords_in[...,CA,:3])
    return N_CA_vec, C_CA_vec



# Applies to Python-3 Standard Library
class Struct(object):
    def __init__(self, data):
        for name, value in data.items():
            setattr(self, name, self._wrap(value))

    def _wrap(self, value):
        if isinstance(value, (tuple, list, set, frozenset)): 
            return type(value)([self._wrap(v) for v in value])
        else:
            return Struct(value) if isinstance(value, dict) else value

In [11]:
def _get_relative_pos(graph_in: dgl.DGLGraph) -> torch.Tensor:
    x = graph_in.ndata['pos']
    src, dst = graph_in.edges()
    rel_pos = x[dst] - x[src]
    return rel_pos

class Graph_RadiusMP_4H_Dataset(Dataset):
    def __init__(self, coordinates: np.array, radius: float, cast_type=torch.float32, mp_stride = 1,
                channels_start = 32):
        #prots,#length_prot in aa, #residues/aa, #xyz per atom
           
        #alphaFold reduce by 10
        coord_div = 10
        
        coordinates = coordinates/coord_div
        self.radius = radius/coord_div
        self.ca_coords = coordinates[:,:,CA,:]
        #unsqueeze to stack together later
        self.N_CA_vec = torch.tensor(coordinates[:,:,N,:] - coordinates[:,:,CA,:], dtype=cast_type).unsqueeze(2)
        self.C_CA_vec = torch.tensor(coordinates[:,:,C,:] - coordinates[:,:,CA,:], dtype=cast_type).unsqueeze(2)
       
        self.graphList = []
        self.mpGraphList = []
        self.mpSelfGraphList = []
    
       
        for i,c in enumerate(self.ca_coords):
            n_nodes = c.shape[0]
            pe = make_pe_encoding(n_nodes=n_nodes)
   
            caXYZ = torch.tensor(c)
            graph = dgl.radius_graph(caXYZ, self.radius)
            graph.ndata['pe'] = pe
            graph.ndata['pos'] = torch.tensor(c,dtype=cast_type)
            graph.ndata['bb_ori'] = torch.cat((self.N_CA_vec[i],self.C_CA_vec[i]),axis=1)
           
            #define covalent connections
            esrc, edst = graph.edges()
            graph.edata['con'] = (torch.abs(esrc-edst)==1).type(cast_type).reshape((-1,1))
           
           
            mp_list = torch.zeros((len(list(range(0,graph.num_nodes(), mp_stride))),caXYZ.shape[1]))
            new_src = torch.tensor([],dtype=torch.int)
            new_dst = torch.tensor([],dtype=torch.int)
           
            i=0#mp list counter
            for x in range(0,graph.num_nodes(), mp_stride):
                src, dst = graph.in_edges(x) #dst repeats x
                n_tot = torch.cat((torch.tensor(x).unsqueeze(0),src)) #add x to node list
                mp_list[i] = caXYZ[n_tot].sum(axis=0)/n_tot.shape[0]
                mp_node = i + graph.num_nodes() #add midpoints nodes at end of graph
                #define edges between midpoint nodes and nodes defining midpoint for midpointGraph
                new_src = torch.cat((new_src,n_tot))
                new_dst = torch.cat((new_dst,
                                     (torch.tensor(mp_node).unsqueeze(0).repeat(n_tot.shape[0]))))
                i+=1
            
            
            mpGraph = dgl.graph((new_src,new_dst))
            mpGraph.ndata['pos'] = torch.cat((caXYZ,mp_list),axis=0).type(cast_type)
            mp_node_indx = torch.arange(0,graph.num_nodes(), mp_stride).type(torch.int)
            #match output shape of first transformer
            pe_mp = torch.cat((pe,torch.zeros((pe.shape[0],channels_start-pe.shape[1]))),axis=1)
            mpGraph.ndata['pe'] = torch.cat((pe_mp,pe_mp[mp_node_indx]))
            mpGraph.edata['con'] = torch.zeros((mpGraph.num_edges(),1))
            
            #make graph for self interaction of midpoints
            v1,v2,edge_data, ind = define_graph_edges(len(mp_list))
            mpSelfGraph = dgl.graph((v1,v2))
            mpSelfGraph.edata['con'] = edge_data.reshape((-1,1))
            mpSelfGraph.ndata['pe'] = pe[mp_node_indx] #not really needed
            mpSelfGraph.ndata['pos'] = mp_list.type(cast_type)
            
            self.mpSelfGraphList.append(mpSelfGraph) 
            self.mpGraphList.append(mpGraph)
            self.graphList.append(graph)
       


    def __len__(self):
        return len(self.graphList)

    def __getitem__(self, idx):
        return {'g':self.graphList[idx], 'mp':self.mpGraphList[idx], 'mpself':self.mpSelfGraphList[idx]}


#needs to be done
class H4_DataModule():
    """
    Datamodule wrapping hGen data set. 8 Helical endpoints defining a four helix protein.
    """
    #8 long positional encoding
    NODE_FEATURE_DIM_0 = 12
    EDGE_FEATURE_DIM = 1 # 0 or 1 helix or loop
    NODE_FEATURE_DIM_1 = 2
    

    def __init__(self, coords: np.array, radius=5, batch_size=8, mp_stride=1, cuda=True, ndf1=6, ndf0=32):
        
        self.GraphDatasetObj = Graph_RadiusMP_4H_Dataset(coords, radius, mp_stride=mp_stride)
        self.gds = DataLoader(self.GraphDatasetObj, batch_size=batch_size, shuffle=True, drop_last=True,
                              collate_fn=self._collate)
        self.cuda = cuda
        self.ndf1 = ndf1 #awkard adding of nodes features to mpGraph
        self.ndf0 = ndf0
        
        
    
        
    def _collate(self, graphs_in):
        batched_graph = dgl.batch([g['g'] for g in graphs_in])
        batched_mpgraph = dgl.batch([g['mp'] for g in graphs_in])
        batched_mpself_graph = dgl.batch([g['mpself'] for g in graphs_in])
        #reshape that batched graph to redivide into the individual graphs
        edge_feats        =    {'0':   batched_graph.edata['con'][:, :self.EDGE_FEATURE_DIM, None]}
        edge_feats_mp     = {'0': batched_mpgraph.edata['con'][:, :self.EDGE_FEATURE_DIM, None]} #def all zero now
        edge_feats_mpself = {'0': batched_mpself_graph.edata['con'][:, :self.EDGE_FEATURE_DIM, None]}
        batched_graph.edata['rel_pos']   = _get_relative_pos(batched_graph)
        batched_mpgraph.edata['rel_pos'] = _get_relative_pos(batched_mpgraph)
        batched_mpself_graph.edata['rel_pos'] = _get_relative_pos(batched_mpself_graph)
        # get node features
        node_feats =         {'0': batched_graph.ndata['pe'][:, :self.NODE_FEATURE_DIM_0, None],
                              '1': batched_graph.ndata['bb_ori'][:,:self.NODE_FEATURE_DIM_1, :3]}
        node_feats_mp =      {'0': batched_mpgraph.ndata['pe'][:, :self.ndf0, None],
                              '1': torch.ones((batched_mpgraph.num_nodes(),self.ndf1,3))}
        #unused
        node_feats_mpself =  {'0': batched_mpself_graph.ndata['pe'][:, :self.NODE_FEATURE_DIM_0, None]}
        
        if self.cuda:
            
            bg,nf,ef = to_cuda(batched_graph), to_cuda(node_feats), to_cuda(edge_feats)
            bg_mp, nf_mp, ef_mp = to_cuda(batched_mpgraph), to_cuda(node_feats_mp), to_cuda(edge_feats_mp)
            bg_mps, nf_mps, ef_mps = to_cuda(batched_mpself_graph), to_cuda(node_feats_mpself), to_cuda(edge_feats_mpself)
            
            return bg,nf,ef, bg_mp, nf_mp, ef_mp, bg_mps, nf_mps, ef_mps 
        else:
            return ((batched_graph, node_feats, edge_feats), 
                    (batched_mpgraph, node_feats_mp, edge_feats_mp), 
                    (batched_mpself_graph, node_feats_mpself, edge_feats_mpself))
            
               
    

def get_edge_features(graph,edge_feature_dim=1):
    return {'0': graph.edata['con'][:, :edge_feature_dim, None]}

def define_poolGraph(n_nodes, batch_size, cast_type=torch.float32, cuda_out=True ):
    
    v1,v2,edge_data, ind = define_graph_edges(n_nodes)
    #pe = make_pe_encoding(n_nodes=n_nodes)#pe e
    
    graphList = []
    
    for i in range(batch_size):
        
        g = dgl.graph((v1,v2))
        g.edata['con'] = edge_data.type(cast_type).reshape((-1,1))
        g.ndata['pos'] = torch.zeros((n_nodes,3),dtype=torch.float32)

        graphList.append(g)
        
    batched_graph = dgl.batch(graphList)
    
    if cuda_out:
        return to_cuda(batched_graph)
    else:
        return batched_graph            
        

In [12]:
def pull_edge_features(graph, edge_feat_dim=1):
    return {'0': graph.edata['con'][:, :edge_feat_dim, None]}

def prep_for_gcn(graph, xyz_pos, edge_feats_input, idx, max_degree=3, comp_grad=True):
    
    src, dst = graph.edges()
    
    new_pos = F.gather_row(xyz_pos, idx)
    rel_pos = F.gather_row(new_pos,dst) - F.gather_row(new_pos,src) 
    
    basis_out = get_basis(rel_pos, max_degree=max_degree,
                                   compute_gradients=comp_grad,
                                   use_pad_trick=False)
    basis_out = update_basis_with_fused(basis_out, max_degree, use_pad_trick=False,
                                            fully_fused=False)
    edge_feats_out = get_populated_edge_features(rel_pos, edge_feats_input)
    return edge_feats_out, basis_out, new_pos
    
class Unpool(torch.nn.Module):
    """
    Place features into torch.zeros array
    """

    def __init__(self):
        super().__init__()

    def forward(self, features: Dict[str, Tensor], graph: DGLGraph, idx: Tensor, u_features: Dict[str, Tensor]):
        out_feats = {}
        for key,val in features.items():
            new_h = val.new_zeros([graph.num_nodes(), val.shape[1], 1])
            out_feats[key] = F.scatter_row(new_h,idx,val)
            out_feats[key] = torch.add(out_feats[key],u_features[key])
        return out_feats
    
class Latent_Unpool(torch.nn.Module):
    """
    Duplicate Latent onto Graph
    """

    def __init__(self):
        super().__init__()

    def forward(self, features: Dict[str, Tensor], graph: DGLGraph, u_features: Dict[str, Tensor]):
        out_feats = {}
        for key,val in features.items():
            new_h = val.repeat_interleave(int(graph.num_nodes()/val.shape[0]),0)
            out_feats[key] = torch.add(new_h,u_features[key])
        return out_feats
    
    
    

In [13]:
class FrameDiffNoise(nn.Module):
    """Generate Diffusion Noise based on FrameDiff"""
    
    def __init__(self, config_path='data_rigid_diffuser/base.yaml'):
        super().__init__()
        with open(config_path, 'r') as file:
            config = yaml.safe_load(file)
        conf = Struct(config['diffuser'])
        
        self.so3d = so3_diffuser.SO3Diffuser(conf.so3)
        self.r3d =  r3_diffuser.R3Diffuser(conf.r3)
        
    def forward(self,feats_in, batched_graph, t):
        #sample rotation
        n_samples = np.cumprod(feats_in['1'].shape[:-1])[-1]
        rotvec = self.so3d.sample(t, n_samples=n_samples)
        rotmat = Rotation.from_rotvec(rotvec).as_matrix()
        #apply rotation
        feats_in['1'] = ru.rot_vec_mul(torch.tensor(rotmat).to('cuda'), feats_in['1'].reshape((-1,3))).reshape(feats_in['1'].shape)
        feats_in['1'] = feats_in['1'].type(torch.float32)
        
        x_t, _  = self.r3d.forward_marginal(batched_graph.ndata['pos'].clone().to('cpu').numpy(), t)
        batched_graph.ndata['pos'] = torch.tensor(x_t, dtype=torch.float32).to('cuda')
        
        return batched_graph, feats_in

# More complicated version splits error in CA-N and CA-C (giving more accurate CB position)
# It returns the rigid transformation from local frame to global frame


def rigid_from_3_points(N, Ca, C, non_ideal=False, eps=1e-8):
    #N, Ca, C - [B,L, 3]
    #R - [B,L, 3, 3], det(R)=1, inv(R) = R.T, R is a rotation matrix
    B,L = N.shape[:2]
    
    v1 = C-Ca
    v2 = N-Ca
    e1 = v1/(torch.norm(v1, dim=-1, keepdim=True)+eps)
    u2 = v2-(torch.einsum('bli, bli -> bl', e1, v2)[...,None]*e1)
    e2 = u2/(torch.norm(u2, dim=-1, keepdim=True)+eps)
    e3 = torch.cross(e1, e2, dim=-1)
    R = torch.cat([e1[...,None], e2[...,None], e3[...,None]], axis=-1) #[B,L,3,3] - rotation matrix
    
    if non_ideal:
        v2 = v2/(torch.norm(v2, dim=-1, keepdim=True)+eps)
        cosref = torch.clamp( torch.sum(e1*v2, dim=-1), min=-1.0, max=1.0) # cosine of current N-CA-C bond angle
        costgt = cos_ideal_NCAC.item()
        cos2del = torch.clamp( cosref*costgt + torch.sqrt((1-cosref*cosref)*(1-costgt*costgt)+eps), min=-1.0, max=1.0 )
        cosdel = torch.sqrt(0.5*(1+cos2del)+eps)
        sindel = torch.sign(costgt-cosref) * torch.sqrt(1-0.5*(1+cos2del)+eps)
        Rp = torch.eye(3, device=N.device).repeat(B,L,1,1)
        Rp[:,:,0,0] = cosdel
        Rp[:,:,0,1] = -sindel
        Rp[:,:,1,0] = sindel
        Rp[:,:,1,1] = cosdel
    
        R = torch.einsum('blij,bljk->blik', R,Rp)

    return R, Ca

def get_t(N, Ca, C, non_ideal=False, eps=1e-5):
    I,B,L=N.shape[:3]
    Rs,Ts = rigid_from_3_points(N.view(I*B,L,3), Ca.view(I*B,L,3), C.view(I*B,L,3), non_ideal=non_ideal, eps=eps)
    Rs = Rs.view(I,B,L,3,3)
    Ts = Ts.view(I,B,L,3)
    t = Ts[:,:,None] - Ts[:,:,:,None] # t[0,1] = residue 0 -> residue 1 vector
    return einsum('iblkj, iblmk -> iblmj', Rs, t) # (I,B,L,L,3)

def FAPE_loss(pred, true,  d_clamp=10.0, d_clamp_inter=30.0, A=10.0, gamma=1.0, eps=1e-6):
    '''
    Calculate Backbone FAPE loss from RosettaTTAFold
    https://github.com/uw-ipd/RoseTTAFold2/blob/main/network/loss.py
    Input:
        - pred: predicted coordinates (I, B, L, n_atom, 3)
        - true: true coordinates (B, L, n_atom, 3)
    Output: str loss
    '''
    I = pred.shape[0]
    true = true.unsqueeze(0)
    t_tilde_ij = get_t(true[:,:,:,0], true[:,:,:,1], true[:,:,:,2])
    t_ij = get_t(pred[:,:,:,0], pred[:,:,:,1], pred[:,:,:,2])

    difference = torch.sqrt(torch.square(t_tilde_ij-t_ij).sum(dim=-1) + eps)
    eij_label = difference[-1].clone().detach()

    clamp = torch.zeros_like(difference)

    # intra vs inter#me coded
    clamp[:,True] = d_clamp

    difference = torch.clamp(difference, max=clamp)
    loss = difference / A # (I, B, L, L)

    # calculate masked loss (ignore missing regions when calculate loss)
    loss = (loss[:,True]).sum(dim=-1) / (torch.ones_like(loss).sum()+eps) # (I)

    # weighting loss
    w_loss = torch.pow(torch.full((I,), gamma, device=pred.device), torch.arange(I, device=pred.device))
    w_loss = torch.flip(w_loss, (0,))
    w_loss = w_loss / w_loss.sum()

    tot_loss = (w_loss * loss).sum()
    
    return tot_loss, loss.detach()

In [17]:
class Latent_Unpool(torch.nn.Module):
    """
    Duplicate Latent onto Graph. Add upper features from U-net
    """

    def __init__(self, fiber_in: Fiber, fiber_add: Fiber, knodes: int):
        super().__init__()
        self.fiber_out = Fiber.combine_max(fiber_in, fiber_add)
        self.node_repeat = knodes

    def forward(self, features: Dict[str, Tensor], u_features: Dict[str, Tensor]):
        out_feats = {}
        for degree_out, channels_out in self.fiber_out:
            cd = str(degree_out)
            if cd in features.keys():
                #repeat latent for all nodes
                feat_out = features[cd].repeat_interleave(self.node_repeat,0)[...,None]
                #add upper level features to front of 
                out_feats[cd] = torch.add(feat_out, 
                                          torch.concat((u_features[cd],
                                                       torch.zeros((u_features[cd].shape[0],)+
                                                                   (features[cd].shape[1]-u_features[cd].shape[1],)+
                                                                   u_features[cd].shape[2:]
                                                                   ,device = u_features[cd].device)
                     ),axis=1))

            else:
                #upper feats have additional type features
                out_feats[cd] = u_features[cd]
                    
        return out_feats
    
class Unpool_Layer(torch.nn.Module):
    """
    Uses indices to place nodes into zeros. Add upper features
    Assumes lower features are more
    """

    def __init__(self, fiber_in: Fiber, fiber_add: Fiber):
        super().__init__()
        self.fiber_in = fiber_in
        self.fiber_add = fiber_add
        self.fiber_out = Fiber.combine_max(fiber_in, fiber_add)
        
    def forward(self, features: Dict[str, Tensor], u_features: Dict[str, Tensor], idx : Tensor):
        out_feats = {}
        for degree_out, channels_out in self.fiber_out:
            cd = str(degree_out)
            unpool_feats = u_features[cd].new_zeros([u_features[cd].shape[0], features[cd].shape[1], u_features[cd].shape[2]])
            pad = features[cd].new_zeros([unpool_feats.shape[0], unpool_feats.shape[1]-u_features[cd].shape[1], unpool_feats.shape[2]])
            out_feats[cd] = torch.add(F.scatter_row(unpool_feats,idx,features[cd]), torch.cat((u_features[cd],pad),1))

        return out_feats
    
def prep_for_gcn(graph, xyz_pos, edge_feats_input, idx, max_degree=3, comp_grad=True):
    
    src, dst = graph.edges()
    
    new_pos = F.gather_row(xyz_pos, idx)
    rel_pos = F.gather_row(new_pos,dst) - F.gather_row(new_pos,src) 
    
    basis_out = get_basis(rel_pos, max_degree=max_degree,
                                   compute_gradients=comp_grad,
                                   use_pad_trick=False)
    basis_out = update_basis_with_fused(basis_out, max_degree, use_pad_trick=False,
                                            fully_fused=False)
    edge_feats_out = get_populated_edge_features(rel_pos, edge_feats_input)
    return edge_feats_out, basis_out, new_pos

In [15]:
#normalize outvec
B=8
L=65
edge_feature_dim = 1
stride = 4

dm = H4_DataModule(coords_tog[:100],batch_size=B,radius=15, mp_stride = stride, cuda=True)
for x in dm.gds:
    #batched_graphs, node_feats, edge_feats for full, midpoint gather , and midpoints self graphs
    bg, nf,ef, bg_mp, nf_mp, ef_mp, bg_mps, nf_mps, ef_mps = x
    break

In [None]:
class GraphUNet(torch.nn.Module):
    def __init__(self, 
                 fiber_start = Fiber({0:12, 1:2}),
                 fiber_out = Fiber({1:3}),
                 k=4,
                 batch_size = 8,
                 stride=4,
                 max_degree=3,
                 channels=32,
                 num_heads = 8,
                 channels_div=4,
                 num_layers = 1,
                 edge_feature_dim=1,
                 cuda=True):
        super(GraphUNet, self).__init__()
        
        self.comp_basis_grad = True
        self.cuda = cuda
        
        if cuda:
            self.device='cuda:0'
        else:
            self.device='cpu'
        
        self.max_degree=max_degree
        self.B = batch_size
        self.k = k
        
        self.num_layers = 1
        self.channels = 32
        self.feat0 = 32
        self.feat1 = 6
        self.channels_div = 4
        self.num_heads = 8
        self.mult = int(stride/2)
        fiber_edge=Fiber({0:edge_feature_dim})
        
        self.channels_dca = channels
        #down c_alpha interactions by radius
        self.fiber_start =  fiber_start
        self.fiber_hidden_dca = Fiber.create(self.max_degree, self.channels_dca)
        self.fiber_out_dca =Fiber({0: self.feat0, 1: self.feat1})
        
        
        self.d_ca = SE3Transformer(num_layers = self.num_layers,
                        fiber_in=self.fiber_start,
                        fiber_hidden= self.fiber_hidden_dca, 
                        fiber_out=self.fiber_out_dca,
                        num_heads = self.num_heads,
                        channels_div = self.channels_div,
                        fiber_edge=self.fiber_edge,
                        low_memory=True,
                        tensor_cores=False)
        
        self.channels_d_ca2mp = self.channels_dca*self.mult
        
        #pool from c_alpha onto midpoints
        self.fiber_in_dca2mp     = self.fiber_out_dca
        self.fiber_hidden_dca2mp = Fiber.create(max_degree, self.channels_dca2mp)
        self.fiber_out_dca2mp    = Fiber({0: self.feat0*self.mult, 1: self.feat1*self.mult})

        self.d_ca2mp = SE3Transformer(num_layers = self.num_layers,
                            fiber_in     = self.fiber_in_dca2mp,
                            fiber_hidden = self.fiber_hidden_dca2mp, 
                            fiber_out    = self.fiber_out_dca2mp,
                            num_heads =    self.num_heads,
                            channels_div = self.channels_div,
                            fiber_edge=self. fiber_edge,
                            low_memory=True,
                            tensor_cores=False)
        
        self.fiber_in_mptopk =  self.fiber_out_dca2mp
        #hidden    =self.fiber_hidden_dca2mp
        #fiber_out =self.fiber_out_dca2mp
        self.fiber_out_topkpool=Fiber({0: self.feat0*self.mult*self.mult})

        mp_topk = SE3Transformer_topK(num_layers      = self.num_layers,
                                        fiber_in      = self.fiber_in_mptopk,
                                        fiber_hidden  = self.fiber_hidden_dca2mp, 
                                        fiber_out     = self.fiber_out_dca2mp ,
                                        fiber_out_topk= self.fiber_out_topkpool,
                                        k             = self.k,
                                        num_heads     = self.num_heads,
                                        channels_div  = self.channels_div,
                                        fiber_edge    =  self.fiber_edge,
                                        low_memory=True,
                                        tensor_cores=False)
        
        #change to doing convolutions instead of points
        self.fiber_in_d_gcn   =  self.fiber_out_topkpool
        self.fiber_out_d_gcn  = Fiber({0: self.feat0*self.mult*self.mult, 1: self.feat1*self.mult})

        down_gcn = ConvSE3(fiber_in  = self.fiber_in_d_gcn,
                           fiber_out = self.fiber_out_d_gcn,
                           fiber_edge= self.fiber_edge,
                             self_interaction=True,
                             use_layer_norm=True,
                             max_degree=self.max_degree,
                             fuse_level= ConvSE3FuseLevel.NONE,
                             low_memory= True)

        
        
        

In [None]:
from util.npose_util import makePointPDB
#gds = Graph_RadiusMP_4H_Dataset(coords_tog[:100], 10, mp_stride = 3)
def view_mp_graph(mps: DGLGraph, coords: np.array ):
    p = mps.ndata['pos']*10
    
    to = np.concatenate((coords, np.ones_like(coords)[:,:,0][...,None]),axis=2)
    
    makePointPDB(p,'test.pdb',outDirec='output')
    nu.dump_npdb(to,'output/test2.pdb')
#view_mp_graph(gds.mpSelfGraphList[0], coords_tog[0])