In [18]:
import torch
import numpy as np
import pathlib
import dgl
from dgl import backend as F
import torch_geometric
from torch.utils.data import random_split, DataLoader, Dataset
from typing import Dict
from torch import Tensor
from dgl import DGLGraph

In [19]:
#torch.use_deterministic_algorithms(True)

In [20]:
#?torch_geometric.nn.pool.ASAPooling

In [21]:
from se3_transformer.model.basis import get_basis, update_basis_with_fused
from se3_transformer.model.transformer import Sequential
from se3_transformer.model.layers.attentiontopK import AttentionBlockSE3
from se3_transformer.model.layers.linear import LinearSE3
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from se3_transformer.model.layers.norm import NormSE3
from se3_transformer.model.layers.pooling import GPooling
from se3_transformer.runtime.utils import str2bool, to_cuda
from se3_transformer.model.fiber import Fiber
from se3_transformer.model.transformer import get_populated_edge_features

In [22]:
torch.cuda.is_available()

True

In [23]:
data_path_str  = 'data/h4_ca_coords.npz'
test_limit = 128
rr = np.load(data_path_str)
ca_coords = [rr[f] for f in rr.files][0][:test_limit,:,:3]
ca_coords.shape

(128, 65, 3)

In [24]:
#goal define edges of
#connected backbone 1, 
#unconnected atoms 0,


def get_midpoint(ep_in):
    """Get midpoint, of each batched set of points"""
    
    #calculate midpoint
    midpoint = ep_in.sum(axis=1)/np.repeat(ep_in.shape[1], ep_in.shape[2])
    
    return midpoint


def normalize_points(input_xyz, print_dist=False):
    
    #broadcast to distance matrix [Batch, M, R3] to [Batch,M,1, R3] to [Batch,1,M, R3] to [Batch, M,M, R3] 
    vec_diff = input_xyz[...,None,:]-input_xyz[...,None,:,:]
    dist = np.sqrt(np.sum(np.square(vec_diff),axis=len(input_xyz.shape)))
    furthest_dist = np.max(dist)
    centroid  = get_midpoint(input_xyz)
    if print_dist:
        print(f'largest distance {furthest_dist:0.1f}')
    
    xyz_mean_zero = input_xyz - centroid[:,None,:]
    return xyz_mean_zero/furthest_dist



def define_graph_edges(n_nodes):
    #connected backbone

    con_v1 = np.arange(n_nodes-1) #vertex 1 of edges in chronological order
    con_v2 = np.arange(1,n_nodes) #vertex 2 of edges in chronological order

    ind = con_v1*(n_nodes-1)+con_v2-1 #account for removed self connections (-1)


    #unconnected backbone

    nodes = np.arange(n_nodes)
    v1 = np.repeat(nodes,n_nodes-1) #starting vertices, same number repeated for each edge

    start_v2 = np.repeat(np.arange(n_nodes)[None,:],n_nodes,axis=0)
    diag_ind = np.diag_indices(n_nodes)
    start_v2[diag_ind] = -1 #diagonal of matrix is self connections which we remove (self connections are managed elsewhere)
    v2 = start_v2[start_v2>-0.5] #remove diagonal and flatten

    edge_data = torch.zeros(len(v2))
    edge_data[ind] = 1
    
    return v1,v2,edge_data, ind

def make_pe_encoding(n_nodes=65, embed_dim = 12, scale = 1000, cast_type=torch.float32, print_out=False):
    #positional encoding of node
    i_array = np.arange(1,(embed_dim/2)+1)
    wk = (1/(scale**(i_array*2/embed_dim)))
    t_array = np.arange(n_nodes)
    si = torch.tensor(np.sin(wk*t_array.reshape((-1,1))))
    ci = torch.tensor(np.cos(wk*t_array.reshape((-1,1))))
    pe = torch.stack((si,ci),axis=2).reshape(t_array.shape[0],embed_dim).type(cast_type)
    
    if print_out == True:
        for x in range(int(n_nodes/12)):
            print(np.round(pe[x],1))
    
    return pe
    
    
#v1,v2,edge_data, ind = define_graph_edges(n_nodes)
norm_p = normalize_points(ca_coords,print_dist=True)
pe = make_pe_encoding(n_nodes=65, embed_dim = 12, scale = 10, print_out=True)

largest distance 32.8
tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
tensor([0.6000, 0.8000, 0.4000, 0.9000, 0.3000, 1.0000, 0.2000, 1.0000, 0.1000,
        1.0000, 0.1000, 1.0000])
tensor([1.0000, 0.2000, 0.8000, 0.6000, 0.6000, 0.8000, 0.4000, 0.9000, 0.3000,
        1.0000, 0.2000, 1.0000])
tensor([ 0.9000, -0.5000,  1.0000,  0.2000,  0.8000,  0.6000,  0.6000,  0.8000,
         0.4000,  0.9000,  0.3000,  1.0000])
tensor([ 0.4000, -0.9000,  1.0000, -0.3000,  1.0000,  0.3000,  0.8000,  0.7000,
         0.6000,  0.8000,  0.4000,  0.9000])


In [25]:
#v1,v2,edge_data, ind = define_graph_edges(4)

In [26]:
#?dgl.nn.pytorch.KNNGraph, nearest neighbor graph maker
def define_graph(batch_size=8,n_nodes=65):
    
    v1,v2,edge_data, ind = define_graph_edges(n_nodes)
    pe = make_pe_encoding(n_nodes=n_nodes)
    
    graphList = []
    
    for i in range(batch_size):
        
        g = dgl.graph((v1,v2))
        g.edata['con'] = edge_data
        g.ndata['pe'] = pe

        graphList.append(g)
        
    batched_graph = dgl.batch(graphList)

    return batched_graph


In [27]:
def define_UGraph(n_nodes, batch_size, cast_type=torch.float32 ):
    
    v1,v2,edge_data, ind = define_graph_edges(n_nodes)
    #pe = make_pe_encoding(n_nodes=n_nodes)#pe e
    
    graphList = []
    
    for i in range(batch_size):
        
        g = dgl.graph((v1,v2))
        g.edata['con'] = edge_data.type(cast_type).reshape((-1,1))
        g.ndata['pos'] = torch.zeros((n_nodes,3),dtype=torch.float32)

        graphList.append(g)
        
    batched_graph = dgl.batch(graphList)

    return batched_graph

class Graph_4H_Dataset(Dataset):
    def __init__(self, ca_coordinates, limit=1000, cast_type=torch.float32):
        
        self.ca_coords = ca_coordinates
        self.norm_ca = normalize_points(ca_coordinates)
        
        n_nodes = self.ca_coords.shape[1] 
        
        v1,v2,edge_data, ind = define_graph_edges(n_nodes)
        pe = make_pe_encoding(n_nodes=n_nodes)

        graphList = []

        for i,c in enumerate(self.norm_ca):

            g = dgl.graph((v1,v2))
            g.edata['con'] = edge_data.type(cast_type).reshape((-1,1))
            g.ndata['pe'] = pe
            g.ndata['pos'] = torch.tensor(c,dtype=torch.float32)

            graphList.append(g)
        
        self.graphList = graphList


    def __len__(self):
        return len(self.graphList)

    def __getitem__(self, idx):
        return self.graphList[idx]

def _get_relative_pos(graph_in: dgl.DGLGraph) -> torch.Tensor:
    x = graph_in.ndata['pos']
    src, dst = graph_in.edges()
    rel_pos = x[dst] - x[src]
    return rel_pos
    
#needs to be done
class H4_DataModule():
    """
    Datamodule wrapping hGen data set. 8 Helical endpoints defining a four helix protein.
    """
    #8 long positional encoding
    NODE_FEATURE_DIM = 12
    EDGE_FEATURE_DIM = 1 # 0 or 1 helix or loop

    def __init__(self,
                 ca_coords: np.array, batch_size=8):
        
        self.GraphDatasetObj = Graph_4H_Dataset(ca_coords)
        self.gds = DataLoader(self.GraphDatasetObj, batch_size=batch_size, shuffle=True, drop_last=True,
                              collate_fn=self._collate)
        
    
        
    def _collate(self, graphs):
        batched_graph = dgl.batch(graphs)
        #reshape that batched graph to redivide into the individual graphs
        edge_feats = {'0': batched_graph.edata['con'][:, :self.EDGE_FEATURE_DIM, None]}
        batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
        # get node features
        node_feats = {'0': batched_graph.ndata['pe'][:, :self.NODE_FEATURE_DIM, None]}
        
        return (batched_graph, node_feats, edge_feats)

In [28]:
def topK_se3(graph, feat, xi, k):
    #remove this read from graph code, since se3 transformer natively uses pulled out feats from graph
    # READOUT_ON_ATTRS = {
    #     "nodes": ("ndata", "batch_num_nodes", "number_of_nodes"),
    #     "edges": ("edata", "batch_num_edges", "number_of_edges"),
    # }
    # _, batch_num_objs_attr, _ = READOUT_ON_ATTRS["nodes"]

    # #this is a fancy way of saying 'batch_num_nodes
    # data = getattr(bg, "nodes")[None].data
    # if F.ndim(data[feat]) > 2:
    #     raise DGLError(
    #         "Only support {} feature `{}` with dimension less than or"
    #         " equal to 2".format(typestr, feat)
    #     )
    # feat = data[feat]


    hidden_size = feat.shape[-1]
    batch_num_objs = getattr(bg, 'batch_num_nodes')(None)
    batch_size = len(batch_num_objs)
    descending = True

    length = max(max(F.asnumpy(batch_num_objs)), k) #max k or batch of nodes size
    fill_val = -float("inf") if descending else float("inf")
    
    feat_y = F.pad_packed_tensor(
        feat, batch_num_objs, fill_val, l_min=k
    )  # (batch_size, l, d)

    order = F.argsort(feat_y, 1, descending=descending)
    topk_indices_unsort_batch = F.slice_axis(order, 1, 0, k)
    #sort to matches original connectivity with define_graph_edges, likely change but probably won't hurt now
    topk_indices, tpk_ind = torch.sort(topk_indices_unsort_batch,dim=1) 

    #get batch shifts
    feat_ = F.reshape(feat_y, (-1,))
    shift = F.repeat(
        F.arange(0, batch_size), k * hidden_size, -1
    ) * length * hidden_size + F.cat(
        [F.arange(0, hidden_size)] * batch_size * k, -1
    )
    
    shift = F.copy_to(shift, F.context(feat))
    topk_indices_ = F.reshape(topk_indices, (-1,)) * hidden_size + shift
    #trainable params gather
    out_y = F.reshape(F.gather_row(feat_, topk_indices_), (batch_size*k, -1))
    out_y = F.replace_inf_with_zero(out_y)
    #nodes features gather
    out_xi = F.reshape(F.gather_row(xi, topk_indices_), (batch_size*k, -1))
    out_xi = F.replace_inf_with_zero(out_xi)
    return out_y, out_xi, topk_indices_


class TopK_Pool(torch.nn.Module):
    """
    https://arxiv.org/pdf/1905.05178.pdf
    Project Node Features to 1D for topK pooling using trainable weights
    Only type '0' features coded, would need to mix features between degrees, avoiding this for now
    """

    def __init__(self, fiber_in: Fiber, k=5):
        super().__init__()
        self.k = k
        fiber_out = Fiber({0: 1}) #convert to 1D of nodes
        self.weights = torch.nn.ParameterDict({
            str(degree_out): torch.nn.Parameter(
                torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
            for degree_out, channels_out in fiber_out
        })
        


    def forward(self, features: Dict[str, Tensor], graph: DGLGraph) -> Dict[str, Tensor]:
        #add topK selection, sigmoid, return nodes
        yi = {
            degree: torch.div(self.weights[degree] @ features[degree], self.weights[degree].norm())
            for degree, weight in self.weights.items()
        }
        y_selected, feats_selected, topk_indices_batched = topK_se3(graph, yi['0'], features['0'], self.k)
        return torch.sigmoid(y_selected)*feats_selected, topk_indices_batched
    
class Unpool(torch.nn.Module):
    """
    Place features into torch.zeros array
    """

    def __init__(self):
        super().__init__()

    def forward(self, features: Dict[str, Tensor], graph: DGLGraph, idx: Tensor, u_features: Dict[str, Tensor]):
        idx_count = 0
        out_feats = {}
        for key,val in features.items():
            new_h = val.new_zeros([graph.num_nodes(), val.shape[1], 1])
            out_feats[key] = F.scatter_row(new_h,idx[idx_count],val)
            idx_count +=1
        return out_feats
    
class Latent_Unpool(torch.nn.Module):
    """
    Duplicate Latent onto Graph
    """

    def __init__(self):
        super().__init__()

    def forward(self, features: Dict[str, Tensor], graph: DGLGraph, u_features: Dict[str, Tensor]):
        out_feats = {}
        for key,val in features.items():
            new_h = val.repeat_interleave(int(graph.num_nodes()/val.shape[0]),0)
            out_feats[key] = torch.add(new_h.unsqueeze(-1),u_features[key])
        return out_feats
    
    
    

In [29]:
class GraphUnet(nn.Module):

    def __init__(self, ks, in_dim, out_dim, dim, act, drop_p):
        super(GraphUnet, self).__init__()
        self.ks = ks
        self.bottom_gcn = GCN(dim, dim, act, drop_p)
        self.down_gcns = nn.ModuleList()
        self.up_gcns = nn.ModuleList()
        self.pools = nn.ModuleList()
        self.unpools = nn.ModuleList()
        self.l_n = len(ks)
        for i in range(self.l_n):
            self.down_gcns.append(GCN(dim, dim, act, drop_p))
            self.up_gcns.append(GCN(dim, dim, act, drop_p))
            self.pools.append(Pool(ks[i], dim, drop_p))
            self.unpools.append(Unpool(dim, dim, drop_p))

    def forward(self, g, h):
        adj_ms = []
        indices_list = []
        down_outs = []
        hs = []
        org_h = h
        for i in range(self.l_n):
            h = self.down_gcns[i](g, h)
            adj_ms.append(g)
            down_outs.append(h)
            g, h, idx = self.pools[i](g, h)
            indices_list.append(idx)
        h = self.bottom_gcn(g, h)
        for i in range(self.l_n):
            up_idx = self.l_n - i - 1
            g, idx = adj_ms[up_idx], indices_list[up_idx]
            g, h = self.unpools[i](g, h, down_outs[up_idx], idx)
            h = self.up_gcns[i](g, h)
            h = h.add(down_outs[up_idx])
            hs.append(h)
        h = h.add(org_h)
        hs.append(h)
        return hs

In [46]:
import torch.nn as nn

def prep_for_gcn(graph, xyz_pos, edge_feats_in, idx, max_degree=3):
    
    src, dst = graph.edges()
    
    new_pos = F.gather_row(xyz_pos, idx)
    rel_pos = F.gather_row(new_pos,dst) - F.gather_row(new_pos,src) 
    
    basis_out = get_basis(rel_pos, max_degree=max_degree,
                                   compute_gradients=True,
                                   use_pad_trick=False)
    basis_out = update_basis_with_fused(basis_out, max_degree, use_pad_trick=False,
                                            fully_fused=False)
    edge_feats_out = get_populated_edge_features(rel_pos, edge_feats_in)
    return basis_out, edge_feats_out

class Sequential(torch.nn.Sequential):
    """ Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """

    def forward(self, input, *args, **kwargs):
        for module in self:
            input = module(input, *args, **kwargs)
        return input

class GraphUNet(torch.nn.Module):
    def __init__(self,ks = [5],
                 batch_size = 8,
                 in_dim=12,
                 ndf_mult=12,
                 max_degree=3,
                 num_heads = 8,
                 channels_div=2,
                 batchsize=8,
                 max_nodes = 65):
        super().__init__()
        self.edge_feature_dim = 1
        
        
        self.ks = ks
        
        self.down_gcns = nn.ModuleList()
        self.up_gcns = nn.ModuleList()
        self.pools = nn.ModuleList()
        self.unpools = nn.ModuleList()
        
        self.l_n = len(ks)
        
        out_dim = in_dim*ndf_mult
        
        for i in range(self.l_n):
            self.down_gcns.append(AttentionBlockSE3( fiber_in= Fiber({0: in_dim}),
                                                     fiber_out  = Fiber({0: out_dim}),
                                                     fiber_edge = Fiber({0: self.edge_feature_dim}),
                                                     num_heads=num_heads,
                                                     channels_div=channels_div,
                                                     use_layer_norm=True,
                                                     max_degree=max_degree,
                                                     fuse_level=ConvSE3FuseLevel.NONE,
                                                     low_memory='True'))
        
            self.pools.append(TopK_Pool(Fiber({0: out_dim}), k=ks[i]))
                                  
            in_dim = out_dim
            out_dim = in_dim*ndf_mult
                                  
        self.bottom_gcn = AttentionBlockSE3( fiber_in= Fiber({0: in_dim}),
                                                     fiber_out  = Fiber({0: out_dim}),
                                                     fiber_edge = Fiber({0: self.edge_feature_dim}),
                                                     num_heads=num_heads,
                                                     channels_div=channels_div,
                                                     use_layer_norm=True,
                                                     max_degree=max_degree,
                                                     fuse_level=ConvSE3FuseLevel.NONE,
                                                     low_memory='True')
        
        self.global_pool = GPooling(pool='avg', feat_type=0)
        self.latent_unpool = Latent_Unpool()
        
        in_dim = out_dim
        out_dim = out_dim/ndf_mult
                                          
        for i in range(self.l_n,0,-1):
            self.up_gcns.append(AttentionBlockSE3( fiber_in= Fiber({0: in_dim}),
                                                     fiber_out  = Fiber({0: out_dim}),
                                                     fiber_edge = Fiber({0: self.edge_feature_dim}),
                                                     num_heads=num_heads,
                                                     channels_div=channels_div,
                                                     use_layer_norm=True,
                                                     max_degree=max_degree,
                                                     fuse_level=ConvSE3FuseLevel.NONE,
                                                     low_memory='True'))
        
            self.unpools.append(Unpool())
            
            in_dim = out_dim
            out_dim = out_dim/ndf_mult
            
        self.top_gcn = AttentionBlockSE3( fiber_in= Fiber({0: in_dim}),
                                                     fiber_out  = Fiber({0: out_dim}),
                                                     fiber_edge = Fiber({0: self.edge_feature_dim}),
                                                     num_heads=num_heads,
                                                     channels_div=channels_div,
                                                     use_layer_norm=True,
                                                     max_degree=max_degree,
                                                     fuse_level=ConvSE3FuseLevel.NONE,
                                                     low_memory='True')
        
        self.graph_list = [define_UGraph(max_nodes, batch_size, cast_type=torch.float32 )]
        for i in range(self.l_n):
            max_nodes = ks[i]
            self.graph_list.append(define_UGraph(max_nodes, batch_size, cast_type=torch.float32 ))
            

            
#     def forward(self, node_feats, rel_pos, ed):
        
#         indices_list = []
#         down_outs = []
        
#         for i in range(self.l_n):
#             feats
        
#         def forward(self, g, h):
#         adj_ms = []
#         indices_list = []
#         down_outs = []
#         hs = []
#         org_h = h
#         for i in range(self.l_n):
            
#             basis = get_basis(rel_pos, max_degree=max_degree,
#                                    compute_gradients=True,
#                                    use_pad_trick=False)
            
#             h = self.down_gcns[i](self.graph_list[i], h)
            
            
#             adj_ms.append(g)
#             down_outs.append(h)
#             g, h, idx = self.pools[i](g, h)
#             indices_list.append(idx)
#         h = self.bottom_gcn(g, h)
#         for i in range(self.l_n):
#             up_idx = self.l_n - i - 1
#             g, idx = adj_ms[up_idx], indices_list[up_idx]
#             g, h = self.unpools[i](g, h, down_outs[up_idx], idx)
#             h = self.up_gcns[i](g, h)
#             h = h.add(down_outs[up_idx])
#             hs.append(h)
#         h = h.add(org_h)
#         hs.append(h)
#         return hs
                                
        
        
            
        

In [33]:
gu = GraphUNet()

In [34]:
for i, inp in enumerate(dm.gds):
    batched_graph, node_feats, edge_feats = inp
    break

In [37]:
pos = batched_graph.ndata['pos']

Graph(num_nodes=520, num_edges=33280,
      ndata_schemes={'pos': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={'con': Scheme(shape=(1,), dtype=torch.float32)})

In [45]:
basis_out, edge_feats_out = prep_for_gcn(gu.graph_list[0], pos, edge_feats, gu.graph_list[0].nodes())

In [47]:
out = gu.down_gcns[0].forward(node_feats, edge_feats_out,graph=gu.graph_list[0],basis=basis_out)

  assert input.numel() == input.storage().size(), (


In [None]:
for i, inp in enumerate(dm.gds):
    batched_graph, node_feats, edge_feats = inp
    break
    
bg = to_cuda(batched_graph)
edge_feats = to_cuda(edge_feats)
node_feats = to_cuda(node_feats)

basis = get_basis(bg.edata['rel_pos'], max_degree=max_degree,
                                   compute_gradients=True,
                                   use_pad_trick=False)

#need to add basis fused here?

basis = update_basis_with_fused(basis, max_degree, use_pad_trick=False,
                                        fully_fused=False)

#concatenate on the distances of the edge based on 'rel pos'
edge_feats_cat = get_populated_edge_features(bg.edata['rel_pos'], edge_feats)

In [13]:
dm = H4_DataModule(ca_coords)

In [15]:
n_nodes = 65
NODE_FEATURE_DIM = 12
EDGE_FEATURE_DIM = 1 # probably expand to [2] one hot primary connect 
num_degrees = 4 # how many levels of spherical harmonics to use
num_channels = 8 # how many
num_heads = 4
channels_div = 2
max_degree = 4

use_layer_norm = True

fuse_level = ConvSE3FuseLevel.NONE

fiber_in=Fiber({0: NODE_FEATURE_DIM})
fiber_hidden=Fiber({0: num_degrees * num_channels})
fiber_edge=Fiber({0: EDGE_FEATURE_DIM})
fiber_out = Fiber({0: num_degrees * num_channels}) # can this be arbitrary, or projected

In [16]:
ablock = AttentionBlockSE3(fiber_in=fiber_in,
               fiber_out=fiber_hidden,
               fiber_edge=fiber_edge,
               num_heads=num_heads,
               channels_div=channels_div,
               use_layer_norm=use_layer_norm,
               max_degree=max_degree,
               fuse_level=fuse_level,
               low_memory='True')
acuda = ablock.to('cuda')

In [17]:
tk = TopK_Pool(fiber_hidden)
tk_cuda = tk.to('cuda')
# tblock = [ablock,tk]
# model = Sequential(*tblock)

In [18]:
for i, inp in enumerate(dm.gds):
    batched_graph, node_feats, edge_feats = inp
    break
    
bg = to_cuda(batched_graph)
edge_feats = to_cuda(edge_feats)
node_feats = to_cuda(node_feats)

basis = get_basis(bg.edata['rel_pos'], max_degree=max_degree,
                                   compute_gradients=True,
                                   use_pad_trick=False)

#need to add basis fused here?

basis = update_basis_with_fused(basis, max_degree, use_pad_trick=False,
                                        fully_fused=False)

#concatenate on the distances of the edge based on 'rel pos'
edge_feats_cat = get_populated_edge_features(bg.edata['rel_pos'], edge_feats)

In [27]:
out = acuda.forward(node_feats, edge_feats_cat,graph=bg,basis=basis)

  assert input.numel() == input.storage().size(), (


In [28]:
node_feat_1, inde = tk.forward(out, bg)
ndf1 = {}
ndf1['0'] = node_feat_1.unsqueeze(-1)

new_pos = F.gather_row(bg.ndata['pos'], inde)

In [29]:
#define subgraph 1, and rel pos indices 

bg_pool1 = define_UGraph(tk.k,batch_size=8) 
src, dst = bg_pool1.edges()
src_pool1 = to_cuda(src)
dst_pool1 = to_cuda(dst)


In [30]:

#
edge_feats_1 = {'0': bg_pool1.edata['con'][:, :1, None]}

edge_feats_1 = to_cuda(edge_feats_1)


rel_pos_pool1 = F.gather_row(new_pos,dst_pool1) - F.gather_row(new_pos,src_pool1) 
#rel_pos_pool1 = 

edge_feats_1_cat = get_populated_edge_features(rel_pos_pool1, edge_feats_1)
basis_1 = get_basis(rel_pos_pool1, max_degree=max_degree,
                                   compute_gradients=True,
                                   use_pad_trick=False)
basis_1 = update_basis_with_fused(basis_1, max_degree, use_pad_trick=False,
                                        fully_fused=False)



In [31]:
NODE_FEATURE_DIM = 32
EDGE_FEATURE_DIM = 1 # 
num_degrees = 4 # how many levels of spherical harmonics to use
num_channels = 16 # how many
num_heads = 8
channels_div = 2
max_degree = 4

use_layer_norm = True
fuse_level = ConvSE3FuseLevel.NONE


fiber_in2=Fiber({0: NODE_FEATURE_DIM})
fiber_hidden2=Fiber({0: num_degrees * num_channels*4})
fiber_edge2=Fiber({0: EDGE_FEATURE_DIM})
#fiber_out2 = Fiber({0: num_degrees * num_channels * 4}) # can this be arbitrary, or projected
ablock2 = AttentionBlockSE3(fiber_in=fiber_in2,
               fiber_out=fiber_hidden2,
               fiber_edge=fiber_edge2,
               num_heads=num_heads,
               channels_div=channels_div,
               use_layer_norm=use_layer_norm,
               max_degree=max_degree,
               fuse_level=fuse_level,
               low_memory='True')
acuda2 = ablock2.to('cuda')

In [42]:
out2 = acuda2.forward(ndf1, edge_feats_1_cat, graph=to_cuda(bg_pool1), basis=basis_1)

In [43]:
#i think i need to build a new graph after the pool, ugh, ugh ,ugh
#unpooling should be easier just add onto old.

In [44]:
out2['0'].shape

torch.Size([40, 256, 1])

In [45]:
global_pooling_module = GPooling(pool='max', feat_type=0)

In [46]:
latent_pool = global_pooling_module(out2, to_cuda(bg_pool1))

In [53]:
lp = Latent_Unpool()

In [56]:
lp.forward(latent_pool, bg_pool1,out2)

TypeError: forward() missing 1 required positional argument: 'u_features'

In [90]:
global_pooling_module = GPooling(pool='max', feat_type=0)
latent_pool = global_pooling_module(out2, to_cuda(bg_pool1))
node_feat_up1 = latent_pool.repeat_interleave(5,0) #copy pool to all new nodes
node_feat_up1  = torch.add(node_feat_up1.unsqueeze(-1),out2['0']) #unet add 

In [91]:
node_feat_up1 = latent_pool.repeat_interleave(5,0) #copy pool to all new nodes
node_feat_up1  = torch.add(node_feat_up1.unsqueeze(-1),out2['0']) #unet add 


In [92]:
NODE_FEATURE_DIM = 256
EDGE_FEATURE_DIM = 1 # 
num_degrees = 4 # how many levels of spherical harmonics to use
num_channels = 4 # how many
num_heads = 4
channels_div = 1
max_degree = 4

use_layer_norm = True
fuse_level = ConvSE3FuseLevel.NONE


fiber_in3=Fiber({0: NODE_FEATURE_DIM})
fiber_hidden3=Fiber({0: num_degrees * num_channels*2})
fiber_edge3=Fiber({0: EDGE_FEATURE_DIM})
#fiber_out2 = Fiber({0: num_degrees * num_channels * 4}) # can this be arbitrary, or projected
ablock3 = AttentionBlockSE3(fiber_in=fiber_in3,
               fiber_out=fiber_hidden3,
               fiber_edge=fiber_edge3,
               num_heads=num_heads,
               channels_div=channels_div,
               use_layer_norm=use_layer_norm,
               max_degree=max_degree,
               fuse_level=fuse_level,
               low_memory='True')
acuda3 = ablock3.to('cuda')

In [93]:
out3 = acuda3.forward({'0':node_feat_up1}, edge_feats_1_cat, graph=to_cuda(bg_pool1), basis=basis_1)

In [106]:
out3['0'].shape

torch.Size([40, 32, 1])

In [107]:
#unpool 

node_feat_up2 = torch.zeros((bg.num_nodes(), out3['0'].shape[1],1 )).to('cuda')
dd=F.scatter_row(node_feat_up2,inde,out3['0'])

#add U-net
pad= dd.shape[1]-node_feats['0'].shape[1]
unet_add = torch.cat((node_feats['0'], torch.zeros(node_feats['0'].shape[0],pad ,1).to('cuda')), 1)

torch.add(unet_add,dd).shape

torch.Size([520, 32, 1])

In [121]:
upool = Unpool()
inde_list = [inde]
upool(out3,bg,inde_list,node_feats)['0'].shape

torch.Size([520, 32, 1])

In [117]:
indelist = [inde]
idx_count = 0
out_feats = {}
for key,val in out3.items():
    new_h = val.new_zeros([bg.num_nodes(), val.shape[1],1])
    pad = val.new_zeros([bg.num_nodes(), new_h.shape[1]-node_feats[key].shape[1],1])
    print(new_h.shape,pad.shape)
    F.scatter_row(new_h,indelist[idx_count],val)
    break

torch.Size([520, 32, 1]) torch.Size([520, 20, 1])


In [116]:
inde.shape

torch.Size([40])

In [119]:
class Unpool(torch.nn.Module):
    """
    Place features into torch.zeros array
    """

    def __init__(self):
        super().__init__()

    def forward(self, features: Dict[str, Tensor], graph: DGLGraph, idx: Tensor, u_features: Dict[str, Tensor]):
        idx_count = 0
        out_feats = {}
        for key,val in features.items():
            new_h = val.new_zeros([graph.num_nodes(), val.shape[1], 1])
            pad = val.new_zeros([graph.num_nodes(), new_h.shape[1]-u_features[key].shape[1],1])
            out_feats[key] = torch.add(F.scatter_row(new_h,idx[idx_count],val),torch.cat((u_features[key],pad),1))
            idx_count +=1
        return out_feats

In [89]:
unet_add = torch.cat((node_feats['0'], torch.zeros(node_feats['0'].shape[0],dd.shape[1]-node_feats['0'].shape[1],1).to('cuda')), 1).shape

In [73]:
inde

tensor([ 45,  46,  47,  48,  49, 110, 111, 112, 113, 114, 175, 176, 177, 178,
        179, 240, 241, 242, 243, 244, 305, 306, 307, 308, 309, 370, 371, 372,
        373, 374, 435, 436, 437, 438, 439, 500, 501, 502, 503, 504],
       device='cuda:0')

In [51]:
?torch.scatter

In [None]:
F.scatter_row

In [41]:
out3['0'].shape

torch.Size([40, 32, 1])

In [51]:
out2['0'].shape

torch.Size([40, 256, 1])

In [52]:
node_feat_up1.shape

torch.Size([40, 256])

In [48]:
latent_pool.repeat_interleave(5,0).shape

torch.Size([40, 256])

In [None]:
#unpooling

#need torch.zeros size of previous graph
torch.zeros((bg_pool1.num_nodes,))


In [31]:
bg_pool1.num_nodes

Graph(num_nodes=40, num_edges=160,
      ndata_schemes={'pe': Scheme(shape=(12,), dtype=torch.float32), 'pos': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={'con': Scheme(shape=(1,), dtype=torch.float32)})

In [None]:
def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor:
        pooled = self.pool(graph, features[str(self.feat_type)])
        return pooled.squeeze(dim=-1)

In [None]:
#remake without pulling from graph?

READOUT_ON_ATTRS = {
    "nodes": ("ndata", "batch_num_nodes", "number_of_nodes"),
    "edges": ("edata", "batch_num_edges", "number_of_edges"),
}

def _topk_on(graph, typestr, feat, k, descending, sortby, ntype_or_etype):
    """Internal function to take graph-wise top-k node/edge features of
    field :attr:`feat` in :attr:`graph` ranked by keys at given
    index :attr:`sortby`. If :attr:`descending` is set to False, return the
    k smallest elements instead.

    Parameters
    ---------
    graph : DGLGraph
        The graph
    typestr : str
        'nodes' or 'edges'
    feat : str
        The feature field name.
    k : int
        The :math:`k` in "top-:math`k`".
    descending : bool
        Controls whether to return the largest or smallest elements,
         defaults to True.
    sortby : int
        The key index we sort :attr:`feat` on, if set to None, we sort
        the whole :attr:`feat`.
    ntype_or_etype : str, tuple of str
        Node/edge type.

    Returns
    -------
    sorted_feat : Tensor
        A tensor with shape :math:`(B, K, D)`, where
        :math:`B` is the batch size of the input graph.
    sorted_idx : Tensor
        A tensor with shape :math:`(B, K)`(:math:`(B, K, D)` if sortby
        is set to None), where
        :math:`B` is the batch size of the input graph, :math:`D`
        is the feature size.


    Notes
    -----
    If an example has :math:`n` nodes/edges and :math:`n<k`, in the first
    returned tensor the :math:`n+1` to :math:`k`th rows would be padded
    with all zero; in the second returned tensor, the behavior of :math:`n+1`
    to :math:`k`th elements is not defined.
    """
    _, batch_num_objs_attr, _ = READOUT_ON_ATTRS[typestr]
    data = getattr(graph, typestr)[ntype_or_etype].data
    if F.ndim(data[feat]) > 2:
        raise DGLError(
            "Only support {} feature `{}` with dimension less than or"
            " equal to 2".format(typestr, feat)
        )
    feat = data[feat]
    hidden_size = F.shape(feat)[-1]
    batch_num_objs = getattr(graph, batch_num_objs_attr)(ntype_or_etype)
    batch_size = len(batch_num_objs)
    length = max(max(F.asnumpy(batch_num_objs)), k)
    fill_val = -float("inf") if descending else float("inf")
    feat_ = F.pad_packed_tensor(
        feat, batch_num_objs, fill_val, l_min=k
    )  # (batch_size, l, d)

    if F.backend_name == "pytorch" and sortby is not None:
        # PyTorch's implementation of top-K
        keys = feat_[..., sortby]  # (batch_size, l)
        return _topk_torch(keys, k, descending, feat_)
    else:
        # Fallback to framework-agnostic implementation of top-K
        if sortby is not None:
            keys = F.squeeze(F.slice_axis(feat_, -1, sortby, sortby + 1), -1)
            order = F.argsort(keys, -1, descending=descending)
        else:
            order = F.argsort(feat_, 1, descending=descending)
        topk_indices = F.slice_axis(order, 1, 0, k)

        if sortby is not None:
            feat_ = F.reshape(feat_, (batch_size * length, -1))
            shift = F.repeat(F.arange(0, batch_size) * length, k, -1)
            shift = F.copy_to(shift, F.context(feat))
            topk_indices_ = F.reshape(topk_indices, (-1,)) + shift
        else:
            feat_ = F.reshape(feat_, (-1,))
            shift = F.repeat(
                F.arange(0, batch_size), k * hidden_size, -1
            ) * length * hidden_size + F.cat(
                [F.arange(0, hidden_size)] * batch_size * k, -1
            )
            shift = F.copy_to(shift, F.context(feat))
            topk_indices_ = F.reshape(topk_indices, (-1,)) * hidden_size + shift
        out = F.reshape(F.gather_row(feat_, topk_indices_), (batch_size, k, -1))
        out = F.replace_inf_with_zero(out)
        return out, topk_indices

In [73]:
for x in range(8):
    for y in range(30):
        if len(np.unique(sort_ind[x][:,y].cpu().numpy())) != 30:
            print(False)

In [56]:
sort_feat[0]

tensor([[ 1.2111, -0.0900,  0.7318,  0.2344, -0.0669,  0.0927,  0.5824,  0.4390,
          0.0690,  0.5615,  0.7202,  0.1591,  0.5839,  0.1566,  0.3246,  0.1450,
          0.8620,  0.5528,  1.3103,  0.5718,  0.4452, -0.5404,  0.4007,  0.3480,
          0.4214,  0.7397,  0.1821,  0.0995,  0.1473,  0.8472,  0.4826,  0.7203],
        [ 1.2044, -0.0912,  0.7244,  0.2296, -0.1206,  0.0902,  0.5798,  0.4309,
          0.0607,  0.5256,  0.7131,  0.1099,  0.5822,  0.1490,  0.3212,  0.1406,
          0.8589,  0.5515,  1.3042,  0.5710,  0.4410, -0.5436,  0.3697,  0.3449,
          0.4145,  0.7275,  0.1739,  0.0936,  0.1362,  0.8349,  0.4770,  0.7097],
        [ 1.2012, -0.0977,  0.7028,  0.2276, -0.1264,  0.0872,  0.5759,  0.4240,
          0.0556,  0.5194,  0.6933,  0.0538,  0.5749,  0.1320,  0.3177,  0.1326,
          0.8483,  0.5320,  1.2934,  0.5513,  0.4327, -0.5531,  0.3374,  0.3360,
          0.4096,  0.7245,  0.1727,  0.0884,  0.1138,  0.8332,  0.4558,  0.6915],
        [ 1.1841, -0.0988

In [54]:
sort_ind[0].shape

torch.Size([10, 32])

In [None]:
model = SE3TransformerPooled(
        fiber_in=Fiber({0: dm.NODE_FEATURE_DIM}),
        fiber_out=Fiber({0: num_degrees * num_channels}),
        fiber_edge=Fiber({0: dm.EDGE_FEATURE_DIM}),
        output_dim=1,
        tensor_cores=using_tensor_cores(False),
        num_degrees=num_degrees,
        num_channels=num_channels,
        **kwargs
    )

In [41]:
?dgl.nn.pytorch.

In [None]:
# get dataset from pdb using npose utils
# import util.npose_util as nu
# import os
# import numpy as np

# model_direc = '/mnt/c/Users/nwood/OneDrive/Desktop/hTest/HelixGen_master/data/4H_dataset/models/'

# fL = os.listdir(model_direc)
# coords = np.zeros((len(fL),65*5,4)) #65 aa, 5 atoms per aa
# for i,file in enumerate(fL):
#     coords[i] = nu.npose_from_file(f'{model_direc}/{file}')

# coords_out = coords.reshape((27894,65,5,4))[...,:3]
# ca_coords = coords_out.reshape((27894,65,5,3))[:,:,1,:]

# np.savez_compressed('../gudiff/data/h4_coords.npz',coords_out)
# np.savez_compressed('../gudiff/data/h4_ca_coords.npz',ca_coords)

In [10]:




def normalize_pc(points):
    """Center at Zero Divide furtherst points"""
    centroid = np.mean(points, axis=0)
    points -= centroid
    #since the points are centered zero, the furthest points is the abs value di
    furthest_distance = np.max(np.sqrt(np.sum(abs(points)**2,axis=-1)))
    points /= furthest_distance

    return points, furthest_distance
    
def make_pe_encoding(i_pos=8, embed_dim = 8, scale = 10, cast_type=torch.float32):
    #positional encoding of node
    i_array = np.arange(1,(embed_dim/2)+1)
    wk = (1/(scale**(i_array*2/embed_dim)))
    t_array = np.arange(i_pos)
    si = torch.tensor(np.sin(wk*t_array.reshape((-1,1))))
    ci = torch.tensor(np.cos(wk*t_array.reshape((-1,1))))
    pe = torch.stack((si,ci),axis=2).reshape(t_array.shape[0],embed_dim).type(cast_type)
    return pe


def make_graph_struct(batch_size=32, n_nodes = 8):
    # make a fake graph to be filled with generator outputs
    
    v1 = np.arange(n_nodes-1) #vertex 1 of edges in chronological order
    v2 = np.arange(1,n_nodes) #vertex 2 of edges in chronological order

    ss = np.zeros(len(v1),dtype=np.int32)
    ss[np.arange(ss.shape[0])%2==0]=1  #alternate 0,1 for helix, loop, helix, etc
    ss = ss[:,None] #unsqueeze
    
    pe = make_pe_encoding(i_pos=8, embed_dim = 8, scale = 10, cast_type=torch.float32)

    graphList = []
    for i in range(batch_size):
        g = dgl.graph((v1,v2))
        g.edata['ss'] = torch.tensor(ss,dtype=torch.float32)
        g.ndata['pe'] = pe

        graphList.append(g)

    batched_graph = dgl.batch(graphList)

    return batched_graph


class GraphDataset(Dataset):
    def __init__(self, ep_file : pathlib.Path, limit=1000):
        self.data_path = ep_file
        rr = np.load(self.data_path)
        ep = [rr[f] for f in rr.files][0][:1000]
        
        #need to save furthest distance to regen later
        #maybe consider small change for next steps
        ep, self.furthest_distance = normalize_pc(ep.reshape((-1,3)))
        self.ep = ep.reshape((-1,8,3))
        
        
        v1 = np.arange(self.ep.shape[1]-1) #vertex 1 of edges in chronological order
        v2 = np.arange(1,self.ep.shape[1]) #vertex 2 of edges in chronological order

        ss = np.zeros(len(v1))
        ss[np.arange(ss.shape[0])%2==0]=1  #alternate 0,1 for helix, loop, helix, etc
        ss = ss[:,None] #unsqueeze

        #positional encoding of node
        pe = make_pe_encoding(i_pos=8, embed_dim = 8, scale = 10, cast_type=torch.float32)

        graphList = []

        for i,c in enumerate(self.ep):

            g = dgl.graph((v1,v2))
            g.ndata['pos'] = torch.tensor(c,dtype=torch.float32)
            g.edata['ss'] = torch.tensor(ss,dtype=torch.float32)
            g.ndata['pe'] = pe

            graphList.append(g)
        
        self.graphList = graphList


    def __len__(self):
        return len(self.graphList)

    def __getitem__(self, idx):
        return self.graphList[idx]

    
class HGenDataModule():
    """
    Datamodule wrapping hGen data set. 8 Helical endpoints defining a four helix protein.
    """
    #8 long positional encoding
    NODE_FEATURE_DIM = 8
    EDGE_FEATURE_DIM = 1 # 0 or 1 helix or loop

    def __init__(self,
                 data_dir: pathlib.Path, batch_size=32):
        
        self.data_dir = data_dir 
        self.GraphDatasetObj = GraphDataset(self.data_dir)
        self.gds = DataLoader(self.GraphDatasetObj,batch_size=batch_size, shuffle=True, drop_last=True,
                              collate_fn=self._collate)
        
    
        
    def _collate(self, graphs):
        batched_graph = dgl.batch(graphs)
        #reshape that batched graph to redivide into the individual graphs
        edge_feats = {'0': batched_graph.edata['ss'][:, :self.EDGE_FEATURE_DIM, None]}
        batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
        # get node features
        node_feats = {'0': batched_graph.ndata['pe'][:, :self.NODE_FEATURE_DIM, None]}
        
        return (batched_graph, node_feats, edge_feats)
    
def eval_gen(batch_size=8,z=12):
    
    in_z = torch.randn((batch_size,z), device='cuda',dtype = torch.float32)
    out = hg(in_z)*31
    out = out.reshape((-1,8,3)).detach().cpu().numpy()
    
    return eval_endpoints(out)
    
    

def eval_endpoints(ep_in): 
    
    ep = ep_in.reshape((-1,8,3))

    v1 = np.arange(ep.shape[1]-1) #vertex 1 of edges in chronological order
    v2 = np.arange(1,ep.shape[1]) #vertex 2 of edges in chronological order

    hLL = np.linalg.norm(ep[:,v1]-ep[:,v2],axis=2)

    hLoc = np.array([0,2,4,6])
    lLoc = np.array([1,3,5])

    return np.mean(hLL[:,hLoc]), np.mean(hLL[:,lLoc])
        

NameError: name 'Dataset' is not defined

In [123]:
class UnpoolDep(torch.nn.Module):
    """
    Place features into torch.zeros array
    """

    def __init__(self):
        super().__init__()

    def forward(self, features: Dict[str, Tensor], graph: DGLGraph, idx: Tensor, u_features: Dict[str, Tensor]):
        idx_count = 0
        out_feats = {}
        for key,val in features.items():
            new_h = val.new_zeros([graph.num_nodes(), val.shape[1], 1])
            pad = val.new_zeros([graph.num_nodes(), new_h.shape[1]-u_features[key].shape[1],1])
            out_feats[key] = torch.add(F.scatter_row(new_h,idx[idx_count],val),torch.cat((u_features[key],pad),1))
            idx_count +=1
        return out_feats