In [1]:
import torch
import numpy as np
import util.npose_util as nu
import pathlib
import dgl
from dgl import backend as F
import torch_geometric
from torch.utils.data import random_split, DataLoader, Dataset
from typing import Dict
from torch import Tensor
from dgl import DGLGraph
from torch import nn
from chemical import cos_ideal_NCAC #from RoseTTAFold2
from torch import einsum
torch.cuda.is_available()

True

In [2]:
from se3_transformer.model.basis import get_basis, update_basis_with_fused
from se3_transformer.model.transformer import Sequential
from se3_transformer.model.layers.attentiontopK import AttentionBlockSE3
from se3_transformer.model.layers.linear import LinearSE3
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from se3_transformer.model.layers.norm import NormSE3
from se3_transformer.model.layers.pooling import GPooling
from se3_transformer.runtime.utils import str2bool, to_cuda
from se3_transformer.model.fiber import Fiber
from se3_transformer.model.transformer import get_populated_edge_features

In [3]:
import os
# Useful numbers
# N [-1.45837285,  0 , 0]
# CA [0., 0., 0.]
# C [0.55221403, 1.41890368, 0.        ]
# CB [ 0.52892494, -0.77445692, -1.19923854]

if ( hasattr(os, 'ATOM_NAMES') ):
    assert( hasattr(os, 'PDB_ORDER') )

    ATOM_NAMES = os.ATOM_NAMES
    PDB_ORDER = os.PDB_ORDER
else:
    ATOM_NAMES=['N', 'CA', 'CB', 'C', 'O']
    PDB_ORDER = ['N', 'CA', 'C', 'O', 'CB']

_byte_atom_names = []
_atom_names = []
for i, atom_name in enumerate(ATOM_NAMES):
    long_name = " " + atom_name + "       "
    _atom_names.append(long_name[:4])
    _byte_atom_names.append(atom_name.encode())

    globals()[atom_name] = i

R = len(ATOM_NAMES)

if ( "N" not in globals() ):
    N = -1
if ( "C" not in globals() ):
    C = -1
if ( "CB" not in globals() ):
    CB = -1


_pdb_order = []
for name in PDB_ORDER:
    _pdb_order.append( ATOM_NAMES.index(name) )

In [4]:
# data_path_str  = 'data/h4_ca_coords.npz'
# test_limit = 1028
# rr = np.load(data_path_str)
# ca_coords = [rr[f] for f in rr.files][0][:test_limit,:,:3]
# ca_coords.shape

# getting N-Ca, Ca-C vectors to add as typeI features
#apa = apart helices for test/train split
#tog = together helices for test/train split
apa_path_str  = 'data/h4_apa_coords.npz'
tog_path_str  = 'data/h4_tog_coords.npz'

#grab the first 3 atoms which are N,CA,C
test_limit = 1028
rr = np.load(apa_path_str)
coords_apa = [rr[f] for f in rr.files][0][:test_limit,:]

rr = np.load(tog_path_str)
coords_tog = [rr[f] for f in rr.files][0][:test_limit,:]

In [5]:
def build_npose_from_coords(coords_in):
    """Use N, CA, C coordinates to generate O an CB atoms"""
    rot_mat_cat = np.ones(sum((coords_in.shape[:-1], (1,)), ()))
    
    coords = np.concatenate((coords_in,rot_mat_cat),axis=-1)
    
    npose = np.ones((coords_in.shape[0]*5,4)) #5 is atoms per res

    by_res = npose.reshape(-1, 5, 4)
    
    if ( "N" in ATOM_NAMES ):
        by_res[:,N,:3] = coords_in[:,0,:3]
    if ( "CA" in ATOM_NAMES ):
        by_res[:,CA,:3] = coords_in[:,1,:3]
    if ( "C" in ATOM_NAMES ):
        by_res[:,C,:3] = coords_in[:,2,:3]
    if ( "O" in ATOM_NAMES ):
        by_res[:,O,:3] = nu.build_O(npose)
    if ( "CB" in ATOM_NAMES ):
        tpose = nu.tpose_from_npose(npose)
        by_res[:,CB,:] = nu.build_CB(tpose)

    return npose

def dump_coord_pdb(coords_in, fileOut='fileOut.pdb'):
    
    npose =  build_npose_from_coords(coords_in)
    nu.dump_npdb(npose,fileOut)

In [6]:
#goal define edges of
#connected backbone 1, 
#unconnected atoms 0,


def get_midpoint(ep_in):
    """Get midpoint, of each batched set of points"""
    
    #calculate midpoint
    midpoint = ep_in.sum(axis=1)/np.repeat(ep_in.shape[1], ep_in.shape[2])
    
    return midpoint


def normalize_points(input_xyz, print_dist=False):
    
    #broadcast to distance matrix [Batch, M, R3] to [Batch,M,1, R3] to [Batch,1,M, R3] to [Batch, M,M, R3] 
    vec_diff = input_xyz[...,None,:]-input_xyz[...,None,:,:]
    dist = np.sqrt(np.sum(np.square(vec_diff),axis=len(input_xyz.shape)))
    furthest_dist = np.max(dist)
    centroid  = get_midpoint(input_xyz)
    if print_dist:
        print(f'largest distance {furthest_dist:0.1f}')
    
    xyz_mean_zero = input_xyz - centroid[:,None,:]
    return xyz_mean_zero/furthest_dist

def normalize(v):
    norm = np.linalg.norm(v)
    if norm == 0: 
        return v
    return v / norm

def define_graph_edges(n_nodes):
    #connected backbone

    con_v1 = np.arange(n_nodes-1) #vertex 1 of edges in chronological order
    con_v2 = np.arange(1,n_nodes) #vertex 2 of edges in chronological order

    ind = con_v1*(n_nodes-1)+con_v2-1 #account for removed self connections (-1)


    #unconnected backbone

    nodes = np.arange(n_nodes)
    v1 = np.repeat(nodes,n_nodes-1) #starting vertices, same number repeated for each edge

    start_v2 = np.repeat(np.arange(n_nodes)[None,:],n_nodes,axis=0)
    diag_ind = np.diag_indices(n_nodes)
    start_v2[diag_ind] = -1 #diagonal of matrix is self connections which we remove (self connections are managed by SE3 Conv channels)
    v2 = start_v2[start_v2>-0.5] #remove diagonal and flatten

    edge_data = torch.zeros(len(v2))
    edge_data[ind] = 1
    
    return v1,v2,edge_data, ind

def make_pe_encoding(n_nodes=65, embed_dim = 12, scale = 1000, cast_type=torch.float32, print_out=False):
    #positional encoding of node
    i_array = np.arange(1,(embed_dim/2)+1)
    wk = (1/(scale**(i_array*2/embed_dim)))
    t_array = np.arange(n_nodes)
    si = torch.tensor(np.sin(wk*t_array.reshape((-1,1))))
    ci = torch.tensor(np.cos(wk*t_array.reshape((-1,1))))
    pe = torch.stack((si,ci),axis=2).reshape(t_array.shape[0],embed_dim).type(cast_type)
    
    if print_out == True:
        for x in range(int(n_nodes/12)):
            print(np.round(pe[x],1))
    
    return pe
    
    
#v1,v2,edge_data, ind = define_graph_edges(n_nodes)
#norm_p = normalize_points(ca_coords,print_dist=True)
pe = make_pe_encoding(n_nodes=65, embed_dim = 12, scale = 10, print_out=True)

tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
tensor([0.6000, 0.8000, 0.4000, 0.9000, 0.3000, 1.0000, 0.2000, 1.0000, 0.1000,
        1.0000, 0.1000, 1.0000])
tensor([1.0000, 0.2000, 0.8000, 0.6000, 0.6000, 0.8000, 0.4000, 0.9000, 0.3000,
        1.0000, 0.2000, 1.0000])
tensor([ 0.9000, -0.5000,  1.0000,  0.2000,  0.8000,  0.6000,  0.6000,  0.8000,
         0.4000,  0.9000,  0.3000,  1.0000])
tensor([ 0.4000, -0.9000,  1.0000, -0.3000,  1.0000,  0.3000,  0.8000,  0.7000,
         0.6000,  0.8000,  0.4000,  0.9000])


In [7]:
#v1,v2,edge_data, ind = define_graph_edges(4)

In [8]:
#?dgl.nn.pytorch.KNNGraph, nearest neighbor graph maker
def define_graph(batch_size=8,n_nodes=65):
    
    v1,v2,edge_data, ind = define_graph_edges(n_nodes)
    pe = make_pe_encoding(n_nodes=n_nodes)
    
    graphList = []
    
    for i in range(batch_size):
        
        g = dgl.graph((v1,v2))
        g.edata['con'] = edge_data
        g.ndata['pe'] = pe

        graphList.append(g)
        
    batched_graph = dgl.batch(graphList)

    return batched_graph


In [9]:
def define_UGraph(n_nodes, batch_size, cast_type=torch.float32, cuda_out=True ):
    
    v1,v2,edge_data, ind = define_graph_edges(n_nodes)
    #pe = make_pe_encoding(n_nodes=n_nodes)#pe e
    
    graphList = []
    
    for i in range(batch_size):
        
        g = dgl.graph((v1,v2))
        g.edata['con'] = edge_data.type(cast_type).reshape((-1,1))
        g.ndata['pos'] = torch.zeros((n_nodes,3),dtype=torch.float32)

        graphList.append(g)
        
    batched_graph = dgl.batch(graphList)
    
    if cuda_out:
        return to_cuda(batched_graph)
    else:
        return batched_graph

def get_edge_features(graph,edge_feature_dim=1):
    return {'0': graph.edata['con'][:, :edge_feature_dim, None]}

class Graph_4H_Dataset(Dataset):
    def __init__(self, coordinates, cast_type=torch.float32):
                                    #prots,#length_prot in aa, #residues/aa, #xyz per atom
            
        #alphaFold reduce by 10
        coordinates = coordinates/10
            
        self.ca_coords = coordinates[:,:,CA,:]
        #unsqueeze to stack together later
        self.N_CA_vec = torch.tensor(coordinates[:,:,N,:] - coordinates[:,:,CA,:], dtype=cast_type).unsqueeze(2)
        self.C_CA_vec = torch.tensor(coordinates[:,:,C,:] - coordinates[:,:,CA,:], dtype=cast_type).unsqueeze(2)
        
        #set mean to zero and max_distance between points to 1, is this necessary? since se3 transforms distances
        #self.norm_ca = normalize_points(self.ca_coords)
        
        
        n_nodes = self.ca_coords.shape[1] 
        
        v1,v2,edge_data, ind = define_graph_edges(n_nodes)
        pe = make_pe_encoding(n_nodes=n_nodes)

        graphList = []

        for i,c in enumerate(self.ca_coords):

            g = dgl.graph((v1,v2))
            g.edata['con'] = edge_data.type(cast_type).reshape((-1,1))
            g.ndata['pe'] = pe
            g.ndata['pos'] = torch.tensor(c,dtype=cast_type)
            g.ndata['bb_ori'] = torch.cat((self.N_CA_vec[i],self.C_CA_vec[i]),axis=1)
            graphList.append(g)
        
        self.graphList = graphList


    def __len__(self):
        return len(self.graphList)

    def __getitem__(self, idx):
        return self.graphList[idx]

def _get_relative_pos(graph_in: dgl.DGLGraph) -> torch.Tensor:
    x = graph_in.ndata['pos']
    src, dst = graph_in.edges()
    rel_pos = x[dst] - x[src]
    return rel_pos
    
#needs to be done
class H4_DataModule():
    """
    Datamodule wrapping hGen data set. 8 Helical endpoints defining a four helix protein.
    """
    #8 long positional encoding
    NODE_FEATURE_DIM_0 = 12
    EDGE_FEATURE_DIM = 1 # 0 or 1 helix or loop
    NODE_FEATURE_DIM_1 = 2
    

    def __init__(self,
                 coords: np.array, batch_size=8):
        
        self.GraphDatasetObj = Graph_4H_Dataset(coords)
        self.gds = DataLoader(self.GraphDatasetObj, batch_size=batch_size, shuffle=True, drop_last=True,
                              collate_fn=self._collate)
        
    
        
    def _collate(self, graphs):
        batched_graph = dgl.batch(graphs)
        #reshape that batched graph to redivide into the individual graphs
        edge_feats = {'0': batched_graph.edata['con'][:, :self.EDGE_FEATURE_DIM, None]}
        batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
        # get node features
        node_feats = {'0': batched_graph.ndata['pe'][:, :self.NODE_FEATURE_DIM_0, None],
                      '1': batched_graph.ndata['bb_ori'][:,:self.NODE_FEATURE_DIM_1, :3]}
        
        return (batched_graph, node_feats, edge_feats)
    
class GaussianNoise(nn.Module):
    """Gaussian noise regularizer.

    Args:
        sigma (float, optional): relative standard deviation used to generate the
            noise. Relative means that it will be multiplied by the magnitude of
            the value your are adding the noise to. This means that sigma can be
            the same regardless of the scale of the vector.
        is_relative_detach (bool, optional): whether to detach the variable before
            computing the scale of the noise. If `False` then the scale of the noise
            won't be seen as a constant but something to optimize: this will bias the
            network to generate vectors with smaller values.
    """

    def __init__(self, sigma=0.1):
        super().__init__()
        self.sigma = sigma
        self.noise = torch.tensor(0,dtype=torch.float)

    def forward(self,x,scale=1.0):
        if self.sigma != 0:
            #without modifer mult, mean=0, std_dev=1
            sampled_noise = (self.noise.repeat(*x.size()).normal_() * self.sigma*scale).to(x.device)
            x = x + sampled_noise
        return x, sampled_noise

In [10]:
def topK_se3(graph, feat, xi, k):
    #remove this read from graph code, since se3 transformer natively uses pulled out feats from graph
    # READOUT_ON_ATTRS = {
    #     "nodes": ("ndata", "batch_num_nodes", "number_of_nodes"),
    #     "edges": ("edata", "batch_num_edges", "number_of_edges"),
    # }
    # _, batch_num_objs_attr, _ = READOUT_ON_ATTRS["nodes"]

    # #this is a fancy way of saying 'batch_num_nodes
    # data = getattr(bg, "nodes")[None].data
    # if F.ndim(data[feat]) > 2:
    #     raise DGLError(
    #         "Only support {} feature `{}` with dimension less than or"
    #         " equal to 2".format(typestr, feat)
    #     )
    # feat = data[feat]


    hidden_size = feat.shape[-1]
    batch_num_objs = getattr(graph, 'batch_num_nodes')(None)
    batch_size = len(batch_num_objs)
    descending = True

    length = max(max(F.asnumpy(batch_num_objs)), k) #max k or batch of nodes size
    fill_val = -float("inf") if descending else float("inf")
    
    feat_y = F.pad_packed_tensor(
        feat, batch_num_objs, fill_val, l_min=k
    )  # (batch_size, l, d)

    order = F.argsort(feat_y, 1, descending=descending)
    topk_indices_unsort_batch = F.slice_axis(order, 1, 0, k)
    #sort to matches original connectivity with define_graph_edges, likely change but probably won't hurt now
    topk_indices, tpk_ind = torch.sort(topk_indices_unsort_batch,dim=1) 

    #get batch shifts
    feat_ = F.reshape(feat_y, (-1,))
    shift = F.repeat(
        F.arange(0, batch_size), k * hidden_size, -1
    ) * length * hidden_size + F.cat(
        [F.arange(0, hidden_size)] * batch_size * k, -1
    )
    
    shift = F.copy_to(shift, F.context(feat))
    topk_indices_ = F.reshape(topk_indices, (-1,)) * hidden_size + shift
    #trainable params gather
    out_y = F.reshape(F.gather_row(feat_, topk_indices_), (batch_size*k, -1))
    out_y = F.replace_inf_with_zero(out_y)
    #nodes features gather
    out_xi = F.reshape(F.gather_row(xi, topk_indices_), (batch_size*k, -1))
    out_xi = F.replace_inf_with_zero(out_xi)
    return out_y, out_xi, topk_indices_


class TopK_Pool(torch.nn.Module):
    """
    https://arxiv.org/pdf/1905.05178.pdf
    Project Node Features to 1D for topK pooling using trainable weights
    #code from Linear Layer SE3 add topK_se3 method
    Only type '0' features coded so far, no interactions between types on linear layers
    """
    
    #in the future can I pool '1' features

    def __init__(self, fiber_in: Fiber, k=5):
        super().__init__()
        self.k = k
        fiber_out = Fiber({0: 1}) #convert to 1D of nodes for topK selection
        self.weights = torch.nn.ParameterDict({
            str(degree_out): torch.nn.Parameter(
                torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
            for degree_out, channels_out in fiber_out
        })
        


    def forward(self, features: Dict[str, Tensor], graph: DGLGraph) -> Dict[str, Tensor]:
        #add topK selection, sigmoid, return nodes
        yi = {
            degree: torch.div(self.weights[degree] @ features[degree], self.weights[degree].norm())
            for degree, weight in self.weights.items()
        }
        y_selected, feats_selected, topk_indices_batched = topK_se3(graph, yi['0'], features['0'], self.k)
        return {'0':(torch.sigmoid(y_selected)*feats_selected).unsqueeze(-1)}, topk_indices_batched
    
class Unpool(torch.nn.Module):
    """
    Place features into torch.zeros array
    """

    def __init__(self):
        super().__init__()

    def forward(self, features: Dict[str, Tensor], graph: DGLGraph, idx: Tensor, u_features: Dict[str, Tensor]):
        out_feats = {}
        for key,val in features.items():
            new_h = val.new_zeros([graph.num_nodes(), val.shape[1], 1])
            out_feats[key] = F.scatter_row(new_h,idx,val)
            out_feats[key] = torch.add(out_feats[key],u_features[key])
        return out_feats
    
class Latent_Unpool(torch.nn.Module):
    """
    Duplicate Latent onto Graph
    """

    def __init__(self):
        super().__init__()

    def forward(self, features: Dict[str, Tensor], graph: DGLGraph, u_features: Dict[str, Tensor]):
        out_feats = {}
        for key,val in features.items():
            new_h = val.repeat_interleave(int(graph.num_nodes()/val.shape[0]),0)
            out_feats[key] = torch.add(new_h,u_features[key])
        return out_feats
    
    
    

In [11]:
# More complicated version splits error in CA-N and CA-C (giving more accurate CB position)
# It returns the rigid transformation from local frame to global frame


def rigid_from_3_points(N, Ca, C, non_ideal=False, eps=1e-8):
    #N, Ca, C - [B,L, 3]
    #R - [B,L, 3, 3], det(R)=1, inv(R) = R.T, R is a rotation matrix
    B,L = N.shape[:2]
    
    v1 = C-Ca
    v2 = N-Ca
    e1 = v1/(torch.norm(v1, dim=-1, keepdim=True)+eps)
    u2 = v2-(torch.einsum('bli, bli -> bl', e1, v2)[...,None]*e1)
    e2 = u2/(torch.norm(u2, dim=-1, keepdim=True)+eps)
    e3 = torch.cross(e1, e2, dim=-1)
    R = torch.cat([e1[...,None], e2[...,None], e3[...,None]], axis=-1) #[B,L,3,3] - rotation matrix
    
    if non_ideal:
        v2 = v2/(torch.norm(v2, dim=-1, keepdim=True)+eps)
        cosref = torch.clamp( torch.sum(e1*v2, dim=-1), min=-1.0, max=1.0) # cosine of current N-CA-C bond angle
        costgt = cos_ideal_NCAC.item()
        cos2del = torch.clamp( cosref*costgt + torch.sqrt((1-cosref*cosref)*(1-costgt*costgt)+eps), min=-1.0, max=1.0 )
        cosdel = torch.sqrt(0.5*(1+cos2del)+eps)
        sindel = torch.sign(costgt-cosref) * torch.sqrt(1-0.5*(1+cos2del)+eps)
        Rp = torch.eye(3, device=N.device).repeat(B,L,1,1)
        Rp[:,:,0,0] = cosdel
        Rp[:,:,0,1] = -sindel
        Rp[:,:,1,0] = sindel
        Rp[:,:,1,1] = cosdel
    
        R = torch.einsum('blij,bljk->blik', R,Rp)

    return R, Ca

def get_t(N, Ca, C, non_ideal=False, eps=1e-5):
    I,B,L=N.shape[:3]
    Rs,Ts = rigid_from_3_points(N.view(I*B,L,3), Ca.view(I*B,L,3), C.view(I*B,L,3), non_ideal=non_ideal, eps=eps)
    Rs = Rs.view(I,B,L,3,3)
    Ts = Ts.view(I,B,L,3)
    t = Ts[:,:,None] - Ts[:,:,:,None] # t[0,1] = residue 0 -> residue 1 vector
    return einsum('iblkj, iblmk -> iblmj', Rs, t) # (I,B,L,L,3)

def FAPE_loss(pred, true,  d_clamp=10.0, d_clamp_inter=30.0, A=10.0, gamma=1.0, eps=1e-6):
    '''
    Calculate Backbone FAPE loss from RosettaTTAFold
    https://github.com/uw-ipd/RoseTTAFold2/blob/main/network/loss.py
    Input:
        - pred: predicted coordinates (I, B, L, n_atom, 3)
        - true: true coordinates (B, L, n_atom, 3)
    Output: str loss
    '''
    I = pred.shape[0]
    true = true.unsqueeze(0)
    t_tilde_ij = get_t(true[:,:,:,0], true[:,:,:,1], true[:,:,:,2])
    t_ij = get_t(pred[:,:,:,0], pred[:,:,:,1], pred[:,:,:,2])

    difference = torch.sqrt(torch.square(t_tilde_ij-t_ij).sum(dim=-1) + eps)
    eij_label = difference[-1].clone().detach()

    clamp = torch.zeros_like(difference)

    # intra vs inter#me coded
    clamp[:,True] = d_clamp

    difference = torch.clamp(difference, max=clamp)
    loss = difference / A # (I, B, L, L)

    # calculate masked loss (ignore missing regions when calculate loss)
    loss = (loss[:,True]).sum(dim=-1) / (torch.ones_like(loss).sum()+eps) # (I)

    # weighting loss
    w_loss = torch.pow(torch.full((I,), gamma, device=pred.device), torch.arange(I, device=pred.device))
    w_loss = torch.flip(w_loss, (0,))
    w_loss = w_loss / w_loss.sum()

    tot_loss = (w_loss * loss).sum()
    
    return tot_loss, loss.detach()

In [12]:
def pull_edge_features(graph, edge_feat_dim=1):
    return {'0': graph.edata['con'][:, :edge_feat_dim, None]}

def prep_for_gcn(graph, xyz_pos, edge_feats_input, idx, max_degree=3, comp_grad=True):
    
    src, dst = graph.edges()
    
    new_pos = F.gather_row(xyz_pos, idx)
    rel_pos = F.gather_row(new_pos,dst) - F.gather_row(new_pos,src) 
    
    basis_out = get_basis(rel_pos, max_degree=max_degree,
                                   compute_gradients=comp_grad,
                                   use_pad_trick=False)
    basis_out = update_basis_with_fused(basis_out, max_degree, use_pad_trick=False,
                                            fully_fused=False)
    edge_feats_out = get_populated_edge_features(rel_pos, edge_feats_input)
    return edge_feats_out, basis_out, new_pos


class GraphUNet(torch.nn.Module):
    def __init__(self, 
                 pred_fiber=Fiber({0: 12, 1:12}),
                 ks = [5],
                 batch_size = 8,
                 in_dim=12,
                 ndf_mult=4,
                 max_degree=3,
                 num_heads = 8,
                 channels_div=2,
                 max_nodes = 65,
                 cuda_out=True,
                 comp_basis_grad=True):
        super(GraphUNet, self).__init__()
        self.edge_feature_dim = 1
        
        self.comp_basis_grad = comp_basis_grad
        self.ks = ks
        
        self.down_gcns = nn.ModuleList()
        self.up_gcns = nn.ModuleList()
        self.pools = nn.ModuleList()
        self.unpools = nn.ModuleList()
        
        self.l_n = len(ks)
        
        out_dim = in_dim*ndf_mult
        
        for i in range(self.l_n):
            self.down_gcns.append(AttentionBlockSE3(fiber_in= Fiber({0: in_dim, 1: 2}),
                                                     fiber_out  = Fiber({0: out_dim}),
                                                     fiber_edge = Fiber({0: self.edge_feature_dim}),
                                                     num_heads=num_heads,
                                                     channels_div=channels_div,
                                                     use_layer_norm=True,
                                                     max_degree=max_degree,
                                                     fuse_level=ConvSE3FuseLevel.NONE,
                                                     low_memory='True'))
        
            self.pools.append(TopK_Pool(Fiber({0: out_dim}), k=ks[i]))
                                  
            in_dim = out_dim
            out_dim = in_dim*ndf_mult
                                  
        self.bottom_gcn = AttentionBlockSE3( fiber_in= Fiber({0: in_dim}),
                                                     fiber_out  = Fiber({0: out_dim}),
                                                     fiber_edge = Fiber({0: self.edge_feature_dim}),
                                                     num_heads=num_heads,
                                                     channels_div=channels_div,
                                                     use_layer_norm=True,
                                                     max_degree=max_degree,
                                                     fuse_level=ConvSE3FuseLevel.NONE,
                                                     low_memory='True')
        
        self.global_pool = GPooling(pool='avg', feat_type=0)
        self.latent_unpool = Latent_Unpool()
        
        in_dim = out_dim
        out_dim = out_dim/ndf_mult
                                          
        for i in range(self.l_n):
            self.up_gcns.append(AttentionBlockSE3( fiber_in= Fiber({0: in_dim}),
                                                     fiber_out  = Fiber({0: out_dim}),
                                                     fiber_edge = Fiber({0: self.edge_feature_dim}),
                                                     num_heads=num_heads,
                                                     channels_div=channels_div,
                                                     use_layer_norm=True,
                                                     max_degree=max_degree,
                                                     fuse_level=ConvSE3FuseLevel.NONE,
                                                     low_memory='True'))
        
            self.unpools.append(Unpool())
            
            in_dim = out_dim
            out_dim = out_dim/ndf_mult
            
        #channels div set at num head here
        self.top_gcn = AttentionBlockSE3( fiber_in= Fiber({0: in_dim}),
                                                     fiber_out  = pred_fiber,
                                                     fiber_edge = Fiber({0: self.edge_feature_dim}),
                                                     num_heads=4,
                                                     channels_div=2,
                                                     use_layer_norm=True,
                                                     max_degree=max_degree,
                                                     fuse_level=ConvSE3FuseLevel.NONE,
                                                     low_memory='True')
        
#         self.pred_gcn = AttentionBlockSE3( fiber_in= Fiber({0: 24}),
#                                                      fiber_out  = pred_fiber ,
#                                                      fiber_edge = Fiber({0: self.edge_feature_dim}),
#                                                      num_heads=16,
#                                                      channels_div=2,
#                                                      use_layer_norm=True,
#                                                      max_degree=max_degree,
#                                                      fuse_level=ConvSE3FuseLevel.NONE,
#                                                      low_memory='True')
#         self.final_conv   =   ConvSE3(fiber_in=pred_fiber,
#                                       fiber_out= Fiber({0:1,1:2}),
#                                       fiber_edge= Fiber({0: self.edge_feature_dim}),
#                                       self_interaction=True,
#                                       use_layer_norm=True,
#                                       max_degree=max_degree,
#                                       low_memory=False)
        
        self.final = LinearSE3(fiber_in=pred_fiber,fiber_out= Fiber({0:1,1:3}))

        
        self.graph_list = [define_UGraph(max_nodes, batch_size, cast_type=torch.float32, cuda_out=cuda_out )]
        self.edge_pre = [pull_edge_features(self.graph_list[-1], edge_feat_dim=1)] #define edge feats here from graph definitions
        for i in range(self.l_n):
            max_nodes = ks[i]
            self.graph_list.append(define_UGraph(max_nodes, batch_size, cast_type=torch.float32, cuda_out=cuda_out))
            self.edge_pre.append(pull_edge_features(self.graph_list[-1], edge_feat_dim=1))
            

            
    def forward(self, node_feats_in, batched_graph):
        
        indices_list = [batched_graph.num_nodes()]
        down_gcn_in = [node_feats_in] #node features from gcn outputs
        down_gcn_out = []
        down_pools = []
        
        up_gcn_in = []
        up_gcn_out = []
        
        pos = [batched_graph.ndata['pos']]
        edge_basis_pos_post = []
        
        #gcn and down pooling
        for i in range(self.l_n):
            #define basis (spherical harmonics) from xyz_positions, pull edge connections connectivity
            edge_basis_pos_post.append(prep_for_gcn(gu.graph_list[i], pos[i], self.edge_pre[i], gu.graph_list[i].nodes(),
                                                    comp_grad = self.comp_basis_grad))

            down_gcn_out.append(gu.down_gcns[i].forward(down_gcn_in[i], edge_basis_pos_post[i][0],
                                                        graph=gu.graph_list[i],basis=edge_basis_pos_post[i][1]))
            #top k pool, save indices pooled for unpooling
            out_and_indx = gu.pools[i](down_gcn_out[i], gu.graph_list[i])
            #save indices, level outputs (topk pool node features), and positions for other side of unet (unpooling and adding)
            #and use in lower levels
            down_gcn_in.append(out_and_indx[0])
            indices_list.append(out_and_indx[1])
            pos.append(edge_basis_pos_post[i][2])
          
        edge_basis_pos_post.append(prep_for_gcn(gu.graph_list[-1], pos[-1], self.edge_pre[-1], gu.graph_list[-1].nodes(),
                                               comp_grad = self.comp_basis_grad))
        
        
        bottom_out = self.bottom_gcn.forward(down_gcn_in[-1], edge_basis_pos_post[-1][0],
                                graph=gu.graph_list[-1],basis=edge_basis_pos_post[-1][1])
        
        latent = {'0':self.global_pool(bottom_out, graph=gu.graph_list[-1]).unsqueeze(-1)}
        up_gcn_in.append(gu.latent_unpool(latent, graph=gu.graph_list[-1],u_features=bottom_out))
        
        reverse_counter = self.l_n
        #up gcns
        for i in range(self.l_n):
            up_gcn_out.append(gu.up_gcns[i].forward(up_gcn_in[i],edge_basis_pos_post[reverse_counter-i][0],
                                                    graph=gu.graph_list[reverse_counter-i],
                                                    basis=edge_basis_pos_post[reverse_counter-i][1]))
            
            up_gcn_in.append(gu.unpools[i](up_gcn_out[i],graph=gu.graph_list[i],
                                           idx = indices_list[reverse_counter-i],
                                           u_features=down_gcn_out[reverse_counter-i-1])) #add from level up
            
        
        final = self.top_gcn(up_gcn_in[-1],edge_basis_pos_post[0][0],
                     graph=gu.graph_list[0],basis=edge_basis_pos_post[0][1])
        
#         pred_move = self.pred_gcn(final,edge_basis_pos_post[0][0],
#                      graph=gu.graph_list[0],basis=edge_basis_pos_post[0][1])
        
        final2 =   self.final(final,edge_basis_pos_post[0][0])
        
        #add NC_ CA Vecs back to start
        final2['1'][:,1,:] = final2['1'][:,1,:] + down_gcn_in[0]['1'][:,0,:]
        final2['1'][:,2,:] = final2['1'][:,2,:] + down_gcn_in[0]['1'][:,1,:]
        
        return final2
        
            
        
#         for i in range(self.l_n):
#             feats
        
#         def forward(self, g, h):
#         adj_ms = []
#         indices_list = []
#         down_outs = []
#         hs = []
#         org_h = h
#         for i in range(self.l_n):
            
#             basis = get_basis(rel_pos, max_degree=max_degree,
#                                    compute_gradients=True,
#                                    use_pad_trick=False)
            
#             h = self.down_gcns[i](self.graph_list[i], h)
            
            
#             adj_ms.append(g)
#             down_outs.append(h)
#             g, h, idx = self.pools[i](g, h)
#             indices_list.append(idx)
#         h = self.bottom_gcn(g, h)
#         for i in range(self.l_n):
#             up_idx = self.l_n - i - 1
#             g, idx = adj_ms[up_idx], indices_list[up_idx]
#             g, h = self.unpools[i](g, h, down_outs[up_idx], idx)
#             h = self.up_gcns[i](g, h)
#             h = h.add(down_outs[up_idx])
#             hs.append(h)
#         h = h.add(org_h)
#         hs.append(h)
#         return hs
    
def train_step(batched_graph, node_feats, gauss_noise, graph_unet):
    
    true_pos = batched_graph.ndata['pos'].clone()

    #add vectors for 
    CA_t  = true_pos.reshape(B, L, 3)
    NC_t = CA_t + node_feats['1'][:,0,:].reshape(B, L, 3)
    CC_t = CA_t + node_feats['1'][:,1,:].reshape(B, L, 3)
    true =  torch.cat((NC_t,CA_t,CC_t),dim=2).reshape(B,L,3,3)

    
    batched_graph.ndata['pos'], noise = gauss_noise.forward(batched_graph.ndata['pos'])


    shift=gu.forward(node_feats,  batched_graph)
    #offset = shift['1'].reshape(B, L, 3)
    #pred = torch.add(Ts.reshape(B, L, 3), batched_graph.ndata['pos'].reshape(B, L, 3))
    
    CA_p = shift['1'][:,0,:].reshape(B, L, 3)+batched_graph.ndata['pos'].reshape(B, L, 3)
    NC_p = shift['1'][:,1,:].reshape(B, L, 3)+batched_graph.ndata['pos'].reshape(B, L, 3)
    CC_p = shift['1'][:,2,:].reshape(B, L, 3)+batched_graph.ndata['pos'].reshape(B, L, 3)
    pred = torch.cat((NC_p,CA_p,CC_p),dim=2).reshape(B,L,3,3)

    tloss, loss = FAPE_loss(pred.unsqueeze(0), true)
    
    opti.zero_grad()
    tloss.backward()
    opti.step()
    
    return tloss.detach()
                            
        
        
            
        

In [13]:
def get_noise_pred_true(batched_graph, node_feats, gauss_noise, model, B, L=65):
    
    true_pos = batched_graph.ndata['pos'].clone()

    #add vectors for N and C atomts
    CA_t  = true_pos.reshape(B, L, 3)
    NC_t = CA_t + node_feats['1'][:,0,:].reshape(B, L, 3)
    CC_t = CA_t + node_feats['1'][:,1,:].reshape(B, L, 3)
    true =  torch.cat((NC_t,CA_t,CC_t),dim=2).reshape(B,L,3,3)

    
    batched_graph.ndata['pos'], noise = gauss_noise.forward(batched_graph.ndata['pos'])
    
    CA_n = batched_graph.ndata['pos'].clone().reshape(B, L, 3)
    NC_n = CA_n + node_feats['1'][:,0,:].reshape(B, L, 3)
    CC_n = CA_n + node_feats['1'][:,1,:].reshape(B, L, 3)
    noise_xyz =  torch.cat((NC_n,CA_n,CC_n),dim=2).reshape(B,L,3,3)


    shift= model.forward(node_feats, batched_graph)
    #offset = shift['1'].reshape(B, L, 3)
    #pred = torch.add(Ts.reshape(B, L, 3), batched_graph.ndata['pos'].reshape(B, L, 3))
    
    CA_p = shift['1'][:,0,:].reshape(B, L, 3)+batched_graph.ndata['pos'].reshape(B, L, 3)
    NC_p = shift['1'][:,1,:].reshape(B, L, 3)+batched_graph.ndata['pos'].reshape(B, L, 3)
    CC_p = shift['1'][:,2,:].reshape(B, L, 3)+batched_graph.ndata['pos'].reshape(B, L, 3)
    pred = torch.cat((NC_p,CA_p,CC_p),dim=2).reshape(B,L,3,3)
    
    return true.to('cpu').numpy()*10, noise_xyz.to('cpu').numpy()*10, pred.detach().to('cpu').numpy()*10

In [14]:
B=32
L=65
gu = GraphUNet(pred_fiber = Fiber({0:24,1:24}), max_degree=3,batch_size=B,cuda_out=True, comp_basis_grad=False).to('cuda')
dm = H4_DataModule(coords_tog,batch_size=B)
gn = GaussianNoise(sigma=0.05).to('cuda')
opti = torch.optim.Adam(gu.parameters(), lr=0.001, weight_decay=5e-6)
loss = nn.MSELoss()

In [17]:
for e in range(300):
    lsum = 0
    for i, inp in enumerate(dm.gds):
        batched_graph, node_feats, edge_feats = inp
        bg = to_cuda(batched_graph)
        nf = to_cuda(node_feats)
        out = train_step(bg, nf, gn, gu)
        lsum += out

    print(lsum/i)
    if e%5 ==0:
        for y, inp in enumerate(dm.gds):
            batched_graph, node_feats, edge_feats = inp
            bg = to_cuda(batched_graph)
            nf = to_cuda(node_feats)

            true, noise, pred= get_noise_pred_true(bg,nf,gn,gu,B)



            dump_coord_pdb(true[y],f'output/true_{e}.pdb')
            dump_coord_pdb(noise[y],f'output/noised_{e}.pdb')
            dump_coord_pdb(pred[y],fileOut=f'output/after_{e}.pdb')

            break

tensor(0.0210, device='cuda:0')
tensor(0.0151, device='cuda:0')
tensor(0.0135, device='cuda:0')
tensor(0.0130, device='cuda:0')
tensor(0.0126, device='cuda:0')
tensor(0.0124, device='cuda:0')
tensor(0.0123, device='cuda:0')
tensor(0.0121, device='cuda:0')
tensor(0.0121, device='cuda:0')
tensor(0.0119, device='cuda:0')
tensor(0.0118, device='cuda:0')
tensor(0.0117, device='cuda:0')
tensor(0.0116, device='cuda:0')
tensor(0.0115, device='cuda:0')
tensor(0.0115, device='cuda:0')
tensor(0.0114, device='cuda:0')
tensor(0.0112, device='cuda:0')
tensor(0.0112, device='cuda:0')
tensor(0.0109, device='cuda:0')
tensor(0.0107, device='cuda:0')
tensor(0.0104, device='cuda:0')
tensor(0.0104, device='cuda:0')
tensor(0.0102, device='cuda:0')
tensor(0.0100, device='cuda:0')
tensor(0.0100, device='cuda:0')
tensor(0.0098, device='cuda:0')
tensor(0.0098, device='cuda:0')
tensor(0.0097, device='cuda:0')
tensor(0.0097, device='cuda:0')
tensor(0.0096, device='cuda:0')
tensor(0.0096, device='cuda:0')
tensor(0

KeyboardInterrupt: 

In [16]:
# for i, inp in enumerate(dm.gds):
#     batched_graph, node_feats, edge_feats = inp
#     bg = to_cuda(batched_graph)
#     nf = to_cuda(node_feats)
    
#     xyz_noise, noise, true_pos, noisexyz_N_C = get_noise(bg,nf,gn)
#     out = pred_CA(bg,nf,gn,gu)
#     for x in range(4):
#         dump_coord_pdb(true_pos.to('cpu').numpy()[x]*10,f'output/true_{x}.pdb')
#         dump_coord_pdb(noisexyz_N_C.to('cpu').numpy()[x]*10,f'output/noised_{x}.pdb')
#         dump_coord_pdb(out.to('cpu').numpy()[x]*10,fileOut=f'output/after_{x}.pdb')
    
#     break

In [43]:
for i, inp in enumerate(dm.gds):
    batched_graph, node_feats, edge_feats = inp
    bg = to_cuda(batched_graph)
    nf = to_cuda(node_feats)

    true, noise, pred= get_noise_pred_true(bg,nf,gn,gu,B)
    
    desired_outputs = 4
    for x in range(desired_outputs):
        dump_coord_pdb(true[x],f'output/true_{x}.pdb')
        dump_coord_pdb(noise[x],f'output/noised_{x}.pdb')
        dump_coord_pdb(pred[x],fileOut=f'output/after_{x}.pdb')

    break

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/mnt/c/Users/nwood/OneDrive/Desktop/gudiff/se3_transformer/model/basis.py", line 114, in update_basis_with_fused
    for d_out in range(max_degree + 1):
        sum_freq = sum([degree_to_dim(min(d, d_out)) for d in range(max_degree + 1)])
        basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, degree_to_dim(d_out) + int(use_pad_trick),
                      ~~~~~~~~~~~ <--- HERE
                                  device=device, dtype=dtype)
        acc_d, acc_f = 0, 0
RuntimeError: CUDA error: unknown error
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.



In [34]:
dm.gds.dataset[0]

Graph(num_nodes=65, num_edges=4160,
      ndata_schemes={'pe': Scheme(shape=(12,), dtype=torch.float32), 'pos': Scheme(shape=(3,), dtype=torch.float32), 'bb_ori': Scheme(shape=(2, 3), dtype=torch.float32)}
      edata_schemes={'con': Scheme(shape=(1,), dtype=torch.float32)})

In [43]:
# B=32
# L=65
# gu = GraphUNet(pred_fiber = Fiber({0:32,1:32}), max_degree=4,batch_size=B,cuda_out=False, comp_basis_grad=False)
# dm = H4_DataModule(coords_tog,batch_size=B)
# gn = GaussianNoise(sigma=0.5)
# opti = torch.optim.Adam(gu.parameters(), lr=0.0005, weight_decay=0.00008)
# #loss = nn.MSELoss()
# for e in range(100):
#     lsum = 0
#     for i, inp in enumerate(dm.gds):
#         batched_graph, node_feats, edge_feats = inp
#         out = train_step(batched_graph, node_feats, gn, gu)
#         lsum += out
#         print(lsum/(i+1))

#     print(lsum/i)

tensor(0.2224)
tensor(0.2151)
tensor(0.2123)


KeyboardInterrupt: 

In [36]:
# for i, inp in enumerate(dm.gds):
#     batched_graph, node_feats, edge_feats = inp
#     bg = to_cuda(batched_graph)
#     nf = to_cuda(node_feats)
    
#     xyz_noise, noise, true_pos, noisexyz_N_C = get_noise(bg,nf,gn)
#     out = pred(bg,nf,gn,gu)
#     for x in range(8):
#         dump_coord_pdb(true_pos.to('cpu').numpy()[x],f'output/true_{x}.pdb')
#         dump_coord_pdb(noisexyz_N_C.to('cpu').numpy()[x],f'output/noised_{x}.pdb')
#         dump_coord_pdb(out.to('cpu').numpy()[x],fileOut=f'output/after_{x}.pdb')
    
#     break
    

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/mnt/c/Users/nwood/OneDrive/Desktop/gudiff/se3_transformer/model/basis.py", line 129, in update_basis_with_fused
    for d_in in range(max_degree + 1):
        sum_freq = sum([degree_to_dim(min(d, d_in)) for d in range(max_degree + 1)])
        basis_fused = torch.zeros(num_edges, degree_to_dim(d_in), sum_freq, sum_dim,
                      ~~~~~~~~~~~ <--- HERE
                                  device=device, dtype=dtype)
        acc_d, acc_f = 0, 0
RuntimeError: CUDA error: unknown error
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.



In [58]:
batched_graph

Graph(num_nodes=520, num_edges=33280,
      ndata_schemes={'pe': Scheme(shape=(12,), dtype=torch.float32), 'pos': Scheme(shape=(3,), dtype=torch.float32), 'bb_ori': Scheme(shape=(2, 3), dtype=torch.float32)}
      edata_schemes={'con': Scheme(shape=(1,), dtype=torch.float32), 'rel_pos': Scheme(shape=(3,), dtype=torch.float32)})

In [50]:
out = pred(bg,nf,gn,gu)
    

In [97]:
xyz_noise, noise, true_pos, noisexyz_N_C = get_noise(bg,nf,gn)




RuntimeError: shape '[8, 65, 3]' is invalid for input of size 6240

In [93]:
dump_coord_pdb(noisexyz_N_C.to('cpu').numpy()[0],'output/before.pdb')

In [51]:
dump_coord_pdb(out.to('cpu').numpy()[0],fileOut='output/after.pdb')

In [83]:
 batched_graph

Graph(num_nodes=2080, num_edges=133120,
      ndata_schemes={'pe': Scheme(shape=(12,), dtype=torch.float32), 'pos': Scheme(shape=(3,), dtype=torch.float32), 'bb_ori': Scheme(shape=(2, 3), dtype=torch.float32)}
      edata_schemes={'con': Scheme(shape=(1,), dtype=torch.float32), 'rel_pos': Scheme(shape=(3,), dtype=torch.float32)})

In [119]:
true_pos = bg.ndata['pos'].clone()
bg.ndata['pos'], noise = gn.forward(bg.ndata['pos'])

In [120]:
B=8
L=65

shift=gu.forward(nf,  bg)

In [121]:
CA_p = shift['1'][:,0,:].reshape(B, L, 3)+bg.ndata['pos'].reshape(B, L, 3)
NC_p = shift['1'][:,1,:].reshape(B, L, 3)+bg.ndata['pos'].reshape(B, L, 3)
CC_p = shift['1'][:,2,:].reshape(B, L, 3)+bg.ndata['pos'].reshape(B, L, 3)


CA_t  = true_pos.reshape(B, L, 3)
NC_t = CA_t + nf['1'][:,0,:].reshape(B, L, 3)
CC_t = CA_t + nf['1'][:,1,:].reshape(B, L, 3)



pred = torch.cat((NC_p,CA_p,CC_p),dim=2).reshape(B,L,3,3)
true =  torch.cat((NC_t,CA_t,CC_t),dim=2).reshape(B,L,3,3)

In [122]:
pred.shape

torch.Size([8, 65, 3, 3])

In [123]:
true.shape

torch.Size([8, 65, 3, 3])

In [124]:
tot_loss, loss = FAPE_loss(pred.unsqueeze(0), true)

In [125]:
tot_loss

tensor(0.8654, device='cuda:0', grad_fn=<SumBackward0>)

In [126]:
true.shape

torch.Size([8, 65, 3, 3])

KeyboardInterrupt: 

In [131]:
true_pos = batched_graph.ndata['pos'].clone()
        
B=8
L=65
#add vectors for 
CA_t  = true_pos.reshape(B, L, 3)
NC_t = CA_t + node_feats['1'][:,0,:].reshape(B, L, 3)
CC_t = CA_t + node_feats['1'][:,1,:].reshape(B, L, 3)
true =  torch.cat((NC_t,CA_t,CC_t),dim=2).reshape(B,L,3,3)


out_pos, noise = gn.forward(bg.ndata['pos'])



In [113]:
CA_o  = out_pos.reshape(B, L, 3)
NC_o = CA_o + nf['1'][:,0,:].reshape(B, L, 3)
CC_o = CA_o + nf['1'][:,1,:].reshape(B, L, 3)
out_noise =  torch.cat((NC_o,CA_o,CC_o),dim=2).reshape(B,L,3,3)

In [60]:
tp = true_pos.reshape(8,65,3).to('cpu')

In [61]:
# true_pos = batched_graph.ndata['pos']

# batched_graph.ndata['pos'], noise = gn.forward(true_pos)
# shift=gu.forward(node_feats,  batched_graph)

In [72]:
import util.npose_util as nu
def build_npose_from_coords(coords_in):
    
    rot_mat_cat = np.ones(sum((coords_in.shape[:-1], (1,)), ()))
    
    coords = np.concatenate((coords_in,rot_mat_cat),axis=-1)
    
    npose = np.ones((coords_in.shape[0]*5,4)) #5 is atoms per res

    by_res = npose.reshape(-1, 5, 4)
    
    if ( "N" in ATOM_NAMES ):
        by_res[:,N,:3] = coords_in[:,0,:3]
    if ( "CA" in ATOM_NAMES ):
        by_res[:,CA,:3] = coords_in[:,1,:3]
    if ( "C" in ATOM_NAMES ):
        by_res[:,C,:3] = coords_in[:,2,:3]
    if ( "O" in ATOM_NAMES ):
        by_res[:,O,:3] = nu.build_O(npose)
    if ( "CB" in ATOM_NAMES ):
        tpose = nu.tpose_from_npose(npose)
        by_res[:,CB,:] = nu.build_CB(tpose)

    return npose

In [135]:
npose = build_npose_from_coords(true.reshape(8,65,3,3)[0])

In [105]:
npose = build_npose_from_coords(po.reshape(8,65,3,3)[0])

In [137]:
npose = build_npose_from_coords(out_noise.reshape(8,65,3,3).to('cpu').numpy()[0])

In [136]:
nu.dump_npdb(npose,'output/tester.pdb')

In [101]:
nu.dump_npdb(out_noise,'output/tester2.pdb')

In [138]:
nu.dump_npdb(npose,'output/tester3.pdb')

In [35]:
npose.shape

(40, 4)

In [82]:
po.shape

(8, 65, 3, 3)

In [94]:
predout = pred(bg, nf, gn, gu)

In [95]:
po = predout.to('cpu').numpy()

In [91]:
npose = build_npose_from_coords(po[0])

In [271]:
true_pos = batched_graph.ndata['pos']

batched_graph.ndata['pos'], noise = gn.forward(true_pos)
shift=gu.forward(node_feats,  batched_graph)

In [272]:
output = loss(noise,shift['1'].squeeze(1))
print(output)

tensor(0.0065, grad_fn=<MseLossBackward0>)


In [273]:
opti.zero_grad()
output.backward()
opti.step()

In [189]:
true_pos = batched_graph.ndata['pos']
shift=gu.forward(node_feats,  batched_graph)

In [191]:
shift['1']

tensor([[[ 2.6800,  4.0886, -0.8873]],

        [[ 3.4822,  4.9773, -0.8749]],

        [[ 3.9890,  4.0204, -1.2399]],

        ...,

        [[ 0.1211, -3.3383, -0.6193]],

        [[-0.4978, -0.5428,  1.7745]],

        [[ 4.5303, -1.8521,  0.1575]]], grad_fn=<TransposeBackward0>)

In [193]:
#torch.autograd.gradcheck(FAPE_loss_CA,(shift['1'].reshape(B, L, 3).unsqueeze(0),true_pos.reshape(B, L, 3)) , eps=1e-3, atol=1e-3)

In [None]:
pred = torch.add(Ts.reshape(B, L, 3), batched_graph.ndata['pos'].reshape(B, L, 3))

In [None]:
FAPE_loss_CA(pred.unsqueeze(0).repeat(2,1,1,1), true_pos.reshape(B, L, 3))

In [None]:
FAPE_loss_CA()

In [112]:
def train_step(batched_graph, node_feats, gauss_noise, graph_unet):
    
    true_pos = batched_graph.ndata['pos'].clone()
    batched_graph.ndata['pos'] = gauss_noise.forward(batched_graph.ndata['pos'])
    
    B=8
    L=65

    shift=gu.forward(node_feats,  batched_graph)
    offset = shift['1'].reshape(B, L, , 3)
    #pred = torch.add(Ts.reshape(B, L, 3), batched_graph.ndata['pos'].reshape(B, L, 3))
    
        
    #tloss, loss = FAPE_loss_CA(pred.unsqueeze(0).repeat(2,1,1,1), true_pos.reshape(B, L, 3))
    
    opti.zero_grad()
    tloss.backward()
    opti.step()
    
    return offset

In [122]:
num_heads = 16
gu = GraphUNet(pred_fiber = Fiber({0:32,1:32}),cdivpred=2).to('cuda')
dm = H4_DataModule(ca_coords)
gn = GaussianNoise(sigma=0.01)
opti = torch.optim.Adam(gu.parameters(), lr=0.01,
            weight_decay=0)




In [None]:
num_heads = 16
gu = GraphUNet(pred_fiber = Fiber({0:32,1:32}),cdivpred=2).to('cuda')
dm = H4_DataModule(ca_coords)
gn = GaussianNoise(sigma=0.01).to('cuda')
opti = torch.optim.Adam(gu.parameters(), lr=0.001, weight_decay=0)

In [123]:
for i, inp in enumerate(dm.gds):
    batched_graph, node_feats, edge_feats = inp
    break

In [124]:
true_pos = batched_graph.ndata['pos']
shift=gu.forward(node_feats,  batched_graph)



RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

In [None]:
for i, inp in enumerate(dm.gds):
    batched_graph, node_feats, edge_feats = inp

    bg = to_cuda(batched_graph)
    bg.ndata['pos'].requires_grad = True
    nf = to_cuda(node_feats)
    nf['0'].requires_grad = True
    break

In [117]:
true_pos = batched_graph.ndata['pos'].clone()
batched_graph.ndata['pos'] = gn.forward(batched_graph.ndata['pos'])

B=8
L=65

shift=gu.forward(node_feats,  batched_graph)
offset = shift['1'].reshape(B, L, 2, 3)
Ts = offset[:,:,0,:]# translationupdate['1'][0]
pred = torch.add(Ts.reshape(B, L, 3), batched_graph.ndata['pos'].reshape(B, L, 3))


tloss, loss = FAPE_loss_CA(pred.unsqueeze(0).repeat(2,1,1,1), true_pos.reshape(B, L, 3))

opti.zero_grad()
tloss.backward()
opti.step()

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

In [95]:
num_heads = 16
gu = GraphUNet(pred_fiber = Fiber({0:32,1:32}),cdivpred=2).to('cuda')
dm = H4_DataModule(ca_coords)
gn = GaussianNoise(sigma=0.01).to('cuda')
opti = torch.optim.Adam(gu.parameters(), lr=0.001, weight_decay=0)

In [96]:
for i, inp in enumerate(dm.gds):
    batched_graph, node_feats, edge_feats = inp
    break

In [98]:
true_pos = batched_graph.ndata['pos'].clone()
noise_pos = gn.forward(batched_graph.ndata['pos'])
noise_pos2 = gn.forward(batched_graph.ndata['pos'],scale=10000)



In [99]:
noise_pos.shape

torch.Size([520, 3])

In [100]:
true_pos[1]

tensor([ 0.1053, -0.3110, -0.2766])

In [101]:
noise_pos2[1]

tensor([129.7982, -45.9416,  70.2688])

In [102]:
B=8
L=65
FAPE_loss_CA(noise_pos.reshape(B, L, 3).unsqueeze(0), true_pos.reshape(B, L, 3))

(tensor(0.0083), tensor([0.0083]))

In [103]:
B=8
L=65
FAPE_loss_CA(noise_pos2.reshape(B, L, 3).unsqueeze(0), true_pos.reshape(B, L, 3))

(tensor(0.9841), tensor([0.9841]))

In [None]:
    true_pos = batched_graph.ndata['pos'].clone()
    batched_graph.ndata['pos'] = gauss_noise.forward(batched_graph.ndata['pos'])
    
    B=8
    L=65

    shift=gu.forward(node_feats,  batched_graph)
    offset = shift['1'].reshape(B, L, 2, 3)
    Ts = offset[:,:,0,:]# translationupdate['1'][0]
    pred = torch.add(Ts.reshape(B, L, 3), batched_graph.ndata['pos'].reshape(B, L, 3))
    
        
    tloss, loss = FAPE_loss_CA(pred.unsqueeze(0).repeat(2,1,1,1), true_pos.reshape(B, L, 3))
    
    opti.zero_grad()
    tloss.backward()
    opti.step()
    
    return offset

In [114]:

for e in range(100):
    lsum = 0
    for i, inp in enumerate(dm.gds):
        batched_graph, node_feats, edge_feats = inp

        bg = to_cuda(batched_graph)
        bg.ndata['pos'].requires_grad = True
        nf = to_cuda(node_feats)
        nf['0'].requires_grad = True

        loss = train_step(bg, nf, gn, gu)
        
        lsum += loss
        
    print(lsum/i)


RuntimeError: shape '[8, 65, 2, 3]' is invalid for input of size 1560

In [133]:
B=8
L=65

shift=gu.forward(node_feats,  batched_graph)

offset = shift['1'].reshape(B, L, 2, 3)
Ts = offset[:,:,0,:] * 31 # translationupdate['1'][0]
pred1 = torch.add(Ts.reshape(B, L, 3),batched_graph.ndata['pos'].reshape(B, L, 3))



In [134]:
tloss, loss = FAPE_loss_CA(  pred1.unsqueeze(0).repeat(2,1,1,1), true_pos.reshape(B, L, 3))

In [32]:
bg =to_cuda(batched_graph)


In [35]:
bg.edges()

(tensor([  0,   0,   0,  ..., 519, 519, 519], device='cuda:0'),
 tensor([  1,   2,   3,  ..., 516, 517, 518], device='cuda:0'))

In [135]:
opti.zero_grad()
tloss.backward()
opti.step()

In [123]:
FAPE_loss_CA(  pred1.unsqueeze(0).repeat(2,1,1,1), true_pos.reshape(B, L, 3))

torch.Size([2, 8, 63, 63])
torch.Size([2])


(tensor(0.9841, grad_fn=<SumBackward0>), tensor([0.4921, 0.4921]))

In [43]:
true_pos.shape

torch.Size([520, 3])

In [19]:
offset['1'].shape

torch.Size([520, 2, 3])

In [16]:
offset = shift['1'].reshape(B, L, 2, 3)
        Ts = offset[:,:,0,:] * 10.0 # translationupdate['1'][0]

tensor([[  256.8314, -1025.1440,  -437.2067],
        [  -32.3810,   269.8080,  -132.2642]], grad_fn=<SelectBackward0>)

In [None]:
ppp['1'].shape

In [113]:
199680/12

16640.0

In [56]:
399360/33280

12.0

In [None]:
channels_div = num_heads/channels

fiber = max_degrees*channel

In [None]:
for i, inp in enumerate(dm.gds):
    batched_graph, node_feats, edge_feats = inp
    break
    
bg = to_cuda(batched_graph)
edge_feats = to_cuda(edge_feats)
node_feats = to_cuda(node_feats)

basis = get_basis(bg.edata['rel_pos'], max_degree=max_degree,
                                   compute_gradients=True,
                                   use_pad_trick=False)

#need to add basis fused here?

basis = update_basis_with_fused(basis, max_degree, use_pad_trick=False,
                                        fully_fused=False)

#concatenate on the distances of the edge based on 'rel pos'
edge_feats_cat = get_populated_edge_features(bg.edata['rel_pos'], edge_feats)

In [None]:
#sample denoted 
gu = GraphUNet()
dm = H4_DataModule(ca_coords)
for i, inp in enumerate(dm.gds):
    batched_graph, node_feats, edge_feats = inp
    break
pos = batched_graph.ndata['pos']
edge_feats_out,basis_out, pos = prep_for_gcn(gu.graph_list[0], pos, edge_feats, gu.graph_list[0].nodes())
out = gu.down_gcns[0].forward(node_feats, edge_feats_out,graph=gu.graph_list[0],basis=basis_out)
out2, indx = gu.pools[0](out,gu.graph_list[0])
#figure out the basis prep here

edge_feats_1, basis_1, pos2 = prep_for_gcn(gu.graph_list[1], pos, get_edge_features(gu.graph_list[1]), indx)
out3 = gu.bottom_gcn.forward(out2, edge_feats_1, graph=gu.graph_list[1],basis=basis_1)
latent = {'0':gu.global_pool(out3,graph=gu.graph_list[1]).unsqueeze(-1)}
up1 = gu.latent_unpool(latent,graph=gu.graph_list[1],u_features=out3)
up2 = gu.up_gcns[0].forward(up1,edge_feats_1,graph=gu.graph_list[1],basis=basis_1)
up3 = gu.unpools[0](up2,graph=gu.graph_list[0], idx = indx,u_features=out)
final = gu.top_gcn.forward(up3,edge_feats_out, graph=gu.graph_list[0], basis=basis_out)

In [13]:
dm = H4_DataModule(ca_coords)

In [15]:
n_nodes = 65
NODE_FEATURE_DIM = 12
EDGE_FEATURE_DIM = 1 # probably expand to [2] one hot primary connect 
num_degrees = 4 # how many levels of spherical harmonics to use
num_channels = 8 # how many
num_heads = 4
channels_div = 2
max_degree = 4

use_layer_norm = True

fuse_level = ConvSE3FuseLevel.NONE

fiber_in=Fiber({0: NODE_FEATURE_DIM})
fiber_hidden=Fiber({0: num_degrees * num_channels})
fiber_edge=Fiber({0: EDGE_FEATURE_DIM})
fiber_out = Fiber({0: num_degrees * num_channels}) # can this be arbitrary, or projected

In [16]:
ablock = AttentionBlockSE3(fiber_in=fiber_in,
               fiber_out=fiber_hidden,
               fiber_edge=fiber_edge,
               num_heads=num_heads,
               channels_div=channels_div,
               use_layer_norm=use_layer_norm,
               max_degree=max_degree,
               fuse_level=fuse_level,
               low_memory='True')
acuda = ablock.to('cuda')

In [17]:
tk = TopK_Pool(fiber_hidden)
tk_cuda = tk.to('cuda')
# tblock = [ablock,tk]
# model = Sequential(*tblock)

In [18]:
for i, inp in enumerate(dm.gds):
    batched_graph, node_feats, edge_feats = inp
    break
    
bg = to_cuda(batched_graph)
edge_feats = to_cuda(edge_feats)
node_feats = to_cuda(node_feats)

basis = get_basis(bg.edata['rel_pos'], max_degree=max_degree,
                                   compute_gradients=True,
                                   use_pad_trick=False)

#need to add basis fused here?

basis = update_basis_with_fused(basis, max_degree, use_pad_trick=False,
                                        fully_fused=False)

#concatenate on the distances of the edge based on 'rel pos'
edge_feats_cat = get_populated_edge_features(bg.edata['rel_pos'], edge_feats)

In [27]:
out = acuda.forward(node_feats, edge_feats_cat,graph=bg,basis=basis)

  assert input.numel() == input.storage().size(), (


In [28]:
node_feat_1, inde = tk.forward(out, bg)
ndf1 = {}
ndf1['0'] = node_feat_1.unsqueeze(-1)

new_pos = F.gather_row(bg.ndata['pos'], inde)

In [29]:
#define subgraph 1, and rel pos indices 

bg_pool1 = define_UGraph(tk.k,batch_size=8) 
src, dst = bg_pool1.edges()
src_pool1 = to_cuda(src)
dst_pool1 = to_cuda(dst)


In [30]:

#
edge_feats_1 = {'0': bg_pool1.edata['con'][:, :1, None]}

edge_feats_1 = to_cuda(edge_feats_1)


rel_pos_pool1 = F.gather_row(new_pos,dst_pool1) - F.gather_row(new_pos,src_pool1) 
#rel_pos_pool1 = 

edge_feats_1_cat = get_populated_edge_features(rel_pos_pool1, edge_feats_1)
basis_1 = get_basis(rel_pos_pool1, max_degree=max_degree,
                                   compute_gradients=True,
                                   use_pad_trick=False)
basis_1 = update_basis_with_fused(basis_1, max_degree, use_pad_trick=False,
                                        fully_fused=False)



In [31]:
NODE_FEATURE_DIM = 32
EDGE_FEATURE_DIM = 1 # 
num_degrees = 4 # how many levels of spherical harmonics to use
num_channels = 16 # how many
num_heads = 8
channels_div = 2
max_degree = 4

use_layer_norm = True
fuse_level = ConvSE3FuseLevel.NONE


fiber_in2=Fiber({0: NODE_FEATURE_DIM})
fiber_hidden2=Fiber({0: num_degrees * num_channels*4})
fiber_edge2=Fiber({0: EDGE_FEATURE_DIM})
#fiber_out2 = Fiber({0: num_degrees * num_channels * 4}) # can this be arbitrary, or projected
ablock2 = AttentionBlockSE3(fiber_in=fiber_in2,
               fiber_out=fiber_hidden2,
               fiber_edge=fiber_edge2,
               num_heads=num_heads,
               channels_div=channels_div,
               use_layer_norm=use_layer_norm,
               max_degree=max_degree,
               fuse_level=fuse_level,
               low_memory='True')
acuda2 = ablock2.to('cuda')

In [42]:
out2 = acuda2.forward(ndf1, edge_feats_1_cat, graph=to_cuda(bg_pool1), basis=basis_1)

In [43]:
#i think i need to build a new graph after the pool, ugh, ugh ,ugh
#unpooling should be easier just add onto old.

In [44]:
out2['0'].shape

torch.Size([40, 256, 1])

In [45]:
global_pooling_module = GPooling(pool='max', feat_type=0)

In [46]:
latent_pool = global_pooling_module(out2, to_cuda(bg_pool1))

In [53]:
lp = Latent_Unpool()

In [56]:
lp.forward(latent_pool, bg_pool1,out2)

TypeError: forward() missing 1 required positional argument: 'u_features'

In [90]:
global_pooling_module = GPooling(pool='max', feat_type=0)
latent_pool = global_pooling_module(out2, to_cuda(bg_pool1))
node_feat_up1 = latent_pool.repeat_interleave(5,0) #copy pool to all new nodes
node_feat_up1  = torch.add(node_feat_up1.unsqueeze(-1),out2['0']) #unet add 

In [91]:
node_feat_up1 = latent_pool.repeat_interleave(5,0) #copy pool to all new nodes
node_feat_up1  = torch.add(node_feat_up1.unsqueeze(-1),out2['0']) #unet add 


In [92]:
NODE_FEATURE_DIM = 256
EDGE_FEATURE_DIM = 1 # 
num_degrees = 4 # how many levels of spherical harmonics to use
num_channels = 4 # how many
num_heads = 4
channels_div = 1
max_degree = 4

use_layer_norm = True
fuse_level = ConvSE3FuseLevel.NONE


fiber_in3=Fiber({0: NODE_FEATURE_DIM})
fiber_hidden3=Fiber({0: num_degrees * num_channels*2})
fiber_edge3=Fiber({0: EDGE_FEATURE_DIM})
#fiber_out2 = Fiber({0: num_degrees * num_channels * 4}) # can this be arbitrary, or projected
ablock3 = AttentionBlockSE3(fiber_in=fiber_in3,
               fiber_out=fiber_hidden3,
               fiber_edge=fiber_edge3,
               num_heads=num_heads,
               channels_div=channels_div,
               use_layer_norm=use_layer_norm,
               max_degree=max_degree,
               fuse_level=fuse_level,
               low_memory='True')
acuda3 = ablock3.to('cuda')

In [93]:
out3 = acuda3.forward({'0':node_feat_up1}, edge_feats_1_cat, graph=to_cuda(bg_pool1), basis=basis_1)

In [106]:
out3['0'].shape

torch.Size([40, 32, 1])

In [107]:
#unpool 

node_feat_up2 = torch.zeros((bg.num_nodes(), out3['0'].shape[1],1 )).to('cuda')
dd=F.scatter_row(node_feat_up2,inde,out3['0'])

#add U-net
pad= dd.shape[1]-node_feats['0'].shape[1]
unet_add = torch.cat((node_feats['0'], torch.zeros(node_feats['0'].shape[0],pad ,1).to('cuda')), 1)

torch.add(unet_add,dd).shape

torch.Size([520, 32, 1])

In [121]:
upool = Unpool()
inde_list = [inde]
upool(out3,bg,inde_list,node_feats)['0'].shape

torch.Size([520, 32, 1])

In [117]:
indelist = [inde]
idx_count = 0
out_feats = {}
for key,val in out3.items():
    new_h = val.new_zeros([bg.num_nodes(), val.shape[1],1])
    pad = val.new_zeros([bg.num_nodes(), new_h.shape[1]-node_feats[key].shape[1],1])
    print(new_h.shape,pad.shape)
    F.scatter_row(new_h,indelist[idx_count],val)
    break

torch.Size([520, 32, 1]) torch.Size([520, 20, 1])


In [116]:
inde.shape

torch.Size([40])

In [119]:
class Unpool(torch.nn.Module):
    """
    Place features into torch.zeros array
    """

    def __init__(self):
        super().__init__()

    def forward(self, features: Dict[str, Tensor], graph: DGLGraph, idx: Tensor, u_features: Dict[str, Tensor]):
        idx_count = 0
        out_feats = {}
        for key,val in features.items():
            new_h = val.new_zeros([graph.num_nodes(), val.shape[1], 1])
            pad = val.new_zeros([graph.num_nodes(), new_h.shape[1]-u_features[key].shape[1],1])
            out_feats[key] = torch.add(F.scatter_row(new_h,idx[idx_count],val),torch.cat((u_features[key],pad),1))
            idx_count +=1
        return out_feats

In [89]:
unet_add = torch.cat((node_feats['0'], torch.zeros(node_feats['0'].shape[0],dd.shape[1]-node_feats['0'].shape[1],1).to('cuda')), 1).shape

In [73]:
inde

tensor([ 45,  46,  47,  48,  49, 110, 111, 112, 113, 114, 175, 176, 177, 178,
        179, 240, 241, 242, 243, 244, 305, 306, 307, 308, 309, 370, 371, 372,
        373, 374, 435, 436, 437, 438, 439, 500, 501, 502, 503, 504],
       device='cuda:0')

In [51]:
?torch.scatter

In [None]:
F.scatter_row

In [41]:
out3['0'].shape

torch.Size([40, 32, 1])

In [51]:
out2['0'].shape

torch.Size([40, 256, 1])

In [52]:
node_feat_up1.shape

torch.Size([40, 256])

In [48]:
latent_pool.repeat_interleave(5,0).shape

torch.Size([40, 256])

In [None]:
#unpooling

#need torch.zeros size of previous graph
torch.zeros((bg_pool1.num_nodes,))


In [31]:
bg_pool1.num_nodes

Graph(num_nodes=40, num_edges=160,
      ndata_schemes={'pe': Scheme(shape=(12,), dtype=torch.float32), 'pos': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={'con': Scheme(shape=(1,), dtype=torch.float32)})

In [None]:
def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor:
        pooled = self.pool(graph, features[str(self.feat_type)])
        return pooled.squeeze(dim=-1)

In [None]:
#remake without pulling from graph?

READOUT_ON_ATTRS = {
    "nodes": ("ndata", "batch_num_nodes", "number_of_nodes"),
    "edges": ("edata", "batch_num_edges", "number_of_edges"),
}

def _topk_on(graph, typestr, feat, k, descending, sortby, ntype_or_etype):
    """Internal function to take graph-wise top-k node/edge features of
    field :attr:`feat` in :attr:`graph` ranked by keys at given
    index :attr:`sortby`. If :attr:`descending` is set to False, return the
    k smallest elements instead.

    Parameters
    ---------
    graph : DGLGraph
        The graph
    typestr : str
        'nodes' or 'edges'
    feat : str
        The feature field name.
    k : int
        The :math:`k` in "top-:math`k`".
    descending : bool
        Controls whether to return the largest or smallest elements,
         defaults to True.
    sortby : int
        The key index we sort :attr:`feat` on, if set to None, we sort
        the whole :attr:`feat`.
    ntype_or_etype : str, tuple of str
        Node/edge type.

    Returns
    -------
    sorted_feat : Tensor
        A tensor with shape :math:`(B, K, D)`, where
        :math:`B` is the batch size of the input graph.
    sorted_idx : Tensor
        A tensor with shape :math:`(B, K)`(:math:`(B, K, D)` if sortby
        is set to None), where
        :math:`B` is the batch size of the input graph, :math:`D`
        is the feature size.


    Notes
    -----
    If an example has :math:`n` nodes/edges and :math:`n<k`, in the first
    returned tensor the :math:`n+1` to :math:`k`th rows would be padded
    with all zero; in the second returned tensor, the behavior of :math:`n+1`
    to :math:`k`th elements is not defined.
    """
    _, batch_num_objs_attr, _ = READOUT_ON_ATTRS[typestr]
    data = getattr(graph, typestr)[ntype_or_etype].data
    if F.ndim(data[feat]) > 2:
        raise DGLError(
            "Only support {} feature `{}` with dimension less than or"
            " equal to 2".format(typestr, feat)
        )
    feat = data[feat]
    hidden_size = F.shape(feat)[-1]
    batch_num_objs = getattr(graph, batch_num_objs_attr)(ntype_or_etype)
    batch_size = len(batch_num_objs)
    length = max(max(F.asnumpy(batch_num_objs)), k)
    fill_val = -float("inf") if descending else float("inf")
    feat_ = F.pad_packed_tensor(
        feat, batch_num_objs, fill_val, l_min=k
    )  # (batch_size, l, d)

    if F.backend_name == "pytorch" and sortby is not None:
        # PyTorch's implementation of top-K
        keys = feat_[..., sortby]  # (batch_size, l)
        return _topk_torch(keys, k, descending, feat_)
    else:
        # Fallback to framework-agnostic implementation of top-K
        if sortby is not None:
            keys = F.squeeze(F.slice_axis(feat_, -1, sortby, sortby + 1), -1)
            order = F.argsort(keys, -1, descending=descending)
        else:
            order = F.argsort(feat_, 1, descending=descending)
        topk_indices = F.slice_axis(order, 1, 0, k)

        if sortby is not None:
            feat_ = F.reshape(feat_, (batch_size * length, -1))
            shift = F.repeat(F.arange(0, batch_size) * length, k, -1)
            shift = F.copy_to(shift, F.context(feat))
            topk_indices_ = F.reshape(topk_indices, (-1,)) + shift
        else:
            feat_ = F.reshape(feat_, (-1,))
            shift = F.repeat(
                F.arange(0, batch_size), k * hidden_size, -1
            ) * length * hidden_size + F.cat(
                [F.arange(0, hidden_size)] * batch_size * k, -1
            )
            shift = F.copy_to(shift, F.context(feat))
            topk_indices_ = F.reshape(topk_indices, (-1,)) * hidden_size + shift
        out = F.reshape(F.gather_row(feat_, topk_indices_), (batch_size, k, -1))
        out = F.replace_inf_with_zero(out)
        return out, topk_indices

In [73]:
for x in range(8):
    for y in range(30):
        if len(np.unique(sort_ind[x][:,y].cpu().numpy())) != 30:
            print(False)

In [56]:
sort_feat[0]

tensor([[ 1.2111, -0.0900,  0.7318,  0.2344, -0.0669,  0.0927,  0.5824,  0.4390,
          0.0690,  0.5615,  0.7202,  0.1591,  0.5839,  0.1566,  0.3246,  0.1450,
          0.8620,  0.5528,  1.3103,  0.5718,  0.4452, -0.5404,  0.4007,  0.3480,
          0.4214,  0.7397,  0.1821,  0.0995,  0.1473,  0.8472,  0.4826,  0.7203],
        [ 1.2044, -0.0912,  0.7244,  0.2296, -0.1206,  0.0902,  0.5798,  0.4309,
          0.0607,  0.5256,  0.7131,  0.1099,  0.5822,  0.1490,  0.3212,  0.1406,
          0.8589,  0.5515,  1.3042,  0.5710,  0.4410, -0.5436,  0.3697,  0.3449,
          0.4145,  0.7275,  0.1739,  0.0936,  0.1362,  0.8349,  0.4770,  0.7097],
        [ 1.2012, -0.0977,  0.7028,  0.2276, -0.1264,  0.0872,  0.5759,  0.4240,
          0.0556,  0.5194,  0.6933,  0.0538,  0.5749,  0.1320,  0.3177,  0.1326,
          0.8483,  0.5320,  1.2934,  0.5513,  0.4327, -0.5531,  0.3374,  0.3360,
          0.4096,  0.7245,  0.1727,  0.0884,  0.1138,  0.8332,  0.4558,  0.6915],
        [ 1.1841, -0.0988

In [54]:
sort_ind[0].shape

torch.Size([10, 32])

In [None]:
model = SE3TransformerPooled(
        fiber_in=Fiber({0: dm.NODE_FEATURE_DIM}),
        fiber_out=Fiber({0: num_degrees * num_channels}),
        fiber_edge=Fiber({0: dm.EDGE_FEATURE_DIM}),
        output_dim=1,
        tensor_cores=using_tensor_cores(False),
        num_degrees=num_degrees,
        num_channels=num_channels,
        **kwargs
    )

In [41]:
?dgl.nn.pytorch.

In [None]:
# get dataset from pdb using npose utils
# import util.npose_util as nu
# import os
# import numpy as np

# model_direc = '/mnt/c/Users/nwood/OneDrive/Desktop/hTest/HelixGen_master/data/4H_dataset/models/'

# fL = os.listdir(model_direc)
# coords = np.zeros((len(fL),65*5,4)) #65 aa, 5 atoms per aa
# for i,file in enumerate(fL):
#     coords[i] = nu.npose_from_file(f'{model_direc}/{file}')

# coords_out = coords.reshape((27894,65,5,4))[...,:3]
# ca_coords = coords_out.reshape((27894,65,5,3))[:,:,1,:]

# np.savez_compressed('../gudiff/data/h4_coords.npz',coords_out)
# np.savez_compressed('../gudiff/data/h4_ca_coords.npz',ca_coords)

In [10]:




def normalize_pc(points):
    """Center at Zero Divide furtherst points"""
    centroid = np.mean(points, axis=0)
    points -= centroid
    #since the points are centered zero, the furthest points is the abs value di
    furthest_distance = np.max(np.sqrt(np.sum(abs(points)**2,axis=-1)))
    points /= furthest_distance

    return points, furthest_distance
    
def make_pe_encoding(i_pos=8, embed_dim = 8, scale = 10, cast_type=torch.float32):
    #positional encoding of node
    i_array = np.arange(1,(embed_dim/2)+1)
    wk = (1/(scale**(i_array*2/embed_dim)))
    t_array = np.arange(i_pos)
    si = torch.tensor(np.sin(wk*t_array.reshape((-1,1))))
    ci = torch.tensor(np.cos(wk*t_array.reshape((-1,1))))
    pe = torch.stack((si,ci),axis=2).reshape(t_array.shape[0],embed_dim).type(cast_type)
    return pe


def make_graph_struct(batch_size=32, n_nodes = 8):
    # make a fake graph to be filled with generator outputs
    
    v1 = np.arange(n_nodes-1) #vertex 1 of edges in chronological order
    v2 = np.arange(1,n_nodes) #vertex 2 of edges in chronological order

    ss = np.zeros(len(v1),dtype=np.int32)
    ss[np.arange(ss.shape[0])%2==0]=1  #alternate 0,1 for helix, loop, helix, etc
    ss = ss[:,None] #unsqueeze
    
    pe = make_pe_encoding(i_pos=8, embed_dim = 8, scale = 10, cast_type=torch.float32)

    graphList = []
    for i in range(batch_size):
        g = dgl.graph((v1,v2))
        g.edata['ss'] = torch.tensor(ss,dtype=torch.float32)
        g.ndata['pe'] = pe

        graphList.append(g)

    batched_graph = dgl.batch(graphList)

    return batched_graph


class GraphDataset(Dataset):
    def __init__(self, ep_file : pathlib.Path, limit=1000):
        self.data_path = ep_file
        rr = np.load(self.data_path)
        ep = [rr[f] for f in rr.files][0][:1000]
        
        #need to save furthest distance to regen later
        #maybe consider small change for next steps
        ep, self.furthest_distance = normalize_pc(ep.reshape((-1,3)))
        self.ep = ep.reshape((-1,8,3))
        
        
        v1 = np.arange(self.ep.shape[1]-1) #vertex 1 of edges in chronological order
        v2 = np.arange(1,self.ep.shape[1]) #vertex 2 of edges in chronological order

        ss = np.zeros(len(v1))
        ss[np.arange(ss.shape[0])%2==0]=1  #alternate 0,1 for helix, loop, helix, etc
        ss = ss[:,None] #unsqueeze

        #positional encoding of node
        pe = make_pe_encoding(i_pos=8, embed_dim = 8, scale = 10, cast_type=torch.float32)

        graphList = []

        for i,c in enumerate(self.ep):

            g = dgl.graph((v1,v2))
            g.ndata['pos'] = torch.tensor(c,dtype=torch.float32)
            g.edata['ss'] = torch.tensor(ss,dtype=torch.float32)
            g.ndata['pe'] = pe

            graphList.append(g)
        
        self.graphList = graphList


    def __len__(self):
        return len(self.graphList)

    def __getitem__(self, idx):
        return self.graphList[idx]

    
class HGenDataModule():
    """
    Datamodule wrapping hGen data set. 8 Helical endpoints defining a four helix protein.
    """
    #8 long positional encoding
    NODE_FEATURE_DIM = 8
    EDGE_FEATURE_DIM = 1 # 0 or 1 helix or loop

    def __init__(self,
                 data_dir: pathlib.Path, batch_size=32):
        
        self.data_dir = data_dir 
        self.GraphDatasetObj = GraphDataset(self.data_dir)
        self.gds = DataLoader(self.GraphDatasetObj,batch_size=batch_size, shuffle=True, drop_last=True,
                              collate_fn=self._collate)
        
    
        
    def _collate(self, graphs):
        batched_graph = dgl.batch(graphs)
        #reshape that batched graph to redivide into the individual graphs
        edge_feats = {'0': batched_graph.edata['ss'][:, :self.EDGE_FEATURE_DIM, None]}
        batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
        # get node features
        node_feats = {'0': batched_graph.ndata['pe'][:, :self.NODE_FEATURE_DIM, None]}
        
        return (batched_graph, node_feats, edge_feats)
    
def eval_gen(batch_size=8,z=12):
    
    in_z = torch.randn((batch_size,z), device='cuda',dtype = torch.float32)
    out = hg(in_z)*31
    out = out.reshape((-1,8,3)).detach().cpu().numpy()
    
    return eval_endpoints(out)
    
    

def eval_endpoints(ep_in): 
    
    ep = ep_in.reshape((-1,8,3))

    v1 = np.arange(ep.shape[1]-1) #vertex 1 of edges in chronological order
    v2 = np.arange(1,ep.shape[1]) #vertex 2 of edges in chronological order

    hLL = np.linalg.norm(ep[:,v1]-ep[:,v2],axis=2)

    hLoc = np.array([0,2,4,6])
    lLoc = np.array([1,3,5])

    return np.mean(hLL[:,hLoc]), np.mean(hLL[:,lLoc])
        

NameError: name 'Dataset' is not defined

In [123]:
class UnpoolDep(torch.nn.Module):
    """
    Place features into torch.zeros array
    """

    def __init__(self):
        super().__init__()

    def forward(self, features: Dict[str, Tensor], graph: DGLGraph, idx: Tensor, u_features: Dict[str, Tensor]):
        idx_count = 0
        out_feats = {}
        for key,val in features.items():
            new_h = val.new_zeros([graph.num_nodes(), val.shape[1], 1])
            pad = val.new_zeros([graph.num_nodes(), new_h.shape[1]-u_features[key].shape[1],1])
            out_feats[key] = torch.add(F.scatter_row(new_h,idx[idx_count],val),torch.cat((u_features[key],pad),1))
            idx_count +=1
        return out_feats

In [None]:
class Graph_4H_Dataset(Dataset):
    def __init__(self, ca_coordinates, limit=1000, cast_type=torch.float32):
        
        self.ca_coords = ca_coordinates
        self.norm_ca = normalize_points(ca_coordinates)
        
        n_nodes = self.ca_coords.shape[1] 
        
        v1,v2,edge_data, ind = define_graph_edges(n_nodes)
        pe = make_pe_encoding(n_nodes=n_nodes)

        graphList = []

        for i,c in enumerate(self.norm_ca):

            g = dgl.graph((v1,v2))
            g.edata['con'] = edge_data.type(cast_type).reshape((-1,1))
            g.ndata['pe'] = pe
            g.ndata['pos'] = torch.tensor(c,dtype=torch.float32)

            graphList.append(g)
        
        self.graphList = graphList


    def __len__(self):
        return len(self.graphList)

    def __getitem__(self, idx):
        return self.graphList[idx]