In [144]:
import torch
import numpy as np
import pathlib
import dgl

In [145]:
data_path_str  = 'data/bcov_ca.npz'
test_limit = 128
rr = np.load(data_path_str)
ca_coords = [rr[f] for f in rr.files][0][:test_limit,:,:3]
ca_coords.shape

(128, 65, 3)

In [146]:
def get_midpoint(ep_in):
    """Get midpoint, of each batched set of points"""
    
    #calculate midpoint
    midpoint = ep_in.sum(axis=1)/np.repeat(ep_in.shape[1], ep_in.shape[2])
    
    return midpoint


def normalize_points(input_xyz, print_dist=False):
    
    #broadcast to distance matrix [Batch, M, R3] to [Batch,M,1, R3] to [Batch,1,M, R3] to [Batch, M,M, R3] 
    vec_diff = input_xyz[...,None,:]-input_xyz[...,None,:,:]
    dist = np.sqrt(np.sum(np.square(vec_diff),axis=len(input_xyz.shape)))
    furthest_dist = np.max(dist)
    centroid  = get_midpoint(input_xyz)
    if print_dist:
        print(f'largest distance {furthest_dist:0.1f}')
    
    xyz_mean_zero = input_xyz - centroid[:,None,:]
    return xyz_mean_zero/furthest_dist
    

In [147]:
norm_p = normalize_points(ca_coords,print_dist=True)

largest distance 32.8


In [148]:
#goal define edges of
#connected backbone 1, 
#unconnected atoms 0,

def define_graph_edges(n_nodes):
    #connected backbone

    n_nodes = norm_p[0].shape[0]

    con_v1 = np.arange(n_nodes-1) #vertex 1 of edges in chronological order
    con_v2 = np.arange(1,n_nodes) #vertex 2 of edges in chronological order

    ind = con_v1*(n_nodes-1)+con_v2-1 #account for removed self connections (-1)


    #unconnected backbone

    nodes = np.arange(n_nodes)
    v1 = np.repeat(nodes,n_nodes-1) #starting vertices, same number repeated for each edge

    start_v2 = np.repeat(np.arange(n_nodes)[None,:],n_nodes,axis=0)
    diag_ind = np.diag_indices(n_nodes)
    start_v2[diag_ind] = -1 #diagonal of matrix is self connections which we remove (self connections are managed elsewhere)
    v2 = start_v2[start_v2>-0.5] #remove diagonal and flatten

    edge_data = torch.zeros(len(u_v2))
    edge_data[ind] = 1
    
    return v1,v2,edge_data, ind
    
    


In [149]:
v1,v2,edge_data, ind = define_graph_edges(n_nodes)

In [150]:
def make_pe_encoding(n_nodes=65, embed_dim = 12, scale = 1000, cast_type=torch.float32, print_out=False):
    #positional encoding of node
    i_array = np.arange(1,(embed_dim/2)+1)
    wk = (1/(scale**(i_array*2/embed_dim)))
    t_array = np.arange(n_nodes)
    si = torch.tensor(np.sin(wk*t_array.reshape((-1,1))))
    ci = torch.tensor(np.cos(wk*t_array.reshape((-1,1))))
    pe = torch.stack((si,ci),axis=2).reshape(t_array.shape[0],embed_dim).type(cast_type)
    
    if print_out == True:
        for x in range(int(n_nodes/12)):
            print(np.round(pe[x],1))
    
    return pe

In [151]:
pe = make_pe_encoding(n_nodes=65,print_out=True)

tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
tensor([0.3000, 1.0000, 0.1000, 1.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0000,
        1.0000, 0.0000, 1.0000])
tensor([0.6000, 0.8000, 0.2000, 1.0000, 0.1000, 1.0000, 0.0000, 1.0000, 0.0000,
        1.0000, 0.0000, 1.0000])
tensor([0.8000, 0.6000, 0.3000, 1.0000, 0.1000, 1.0000, 0.0000, 1.0000, 0.0000,
        1.0000, 0.0000, 1.0000])
tensor([1.0000, 0.3000, 0.4000, 0.9000, 0.1000, 1.0000, 0.0000, 1.0000, 0.0000,
        1.0000, 0.0000, 1.0000])


In [154]:
def define_graph(batch_size=8,n_nodes=65):
    
    v1,v2,edge_data, ind = define_graph_edges(n_nodes)
    pe = make_pe_encoding(n_nodes=65)
    
    graphList = []
    
    for i in range(batch_size):
        
        g = dgl.graph((v1,v2))
        g.edata['con'] = edge_data
        g.ndata['pe'] = pe

        graphList.append(g)
        
    batched_graph = dgl.batch(graphList)

    return batched_graph

In [155]:
bg = define_graph(batch_size=8)

In [None]:




def normalize_pc(points):
    """Center at Zero Divide furtherst points"""
    centroid = np.mean(points, axis=0)
    points -= centroid
    #since the points are centered zero, the furthest points is the abs value di
    furthest_distance = np.max(np.sqrt(np.sum(abs(points)**2,axis=-1)))
    points /= furthest_distance

    return points, furthest_distance
    
def make_pe_encoding(i_pos=8, embed_dim = 8, scale = 10, cast_type=torch.float32):
    #positional encoding of node
    i_array = np.arange(1,(embed_dim/2)+1)
    wk = (1/(scale**(i_array*2/embed_dim)))
    t_array = np.arange(i_pos)
    si = torch.tensor(np.sin(wk*t_array.reshape((-1,1))))
    ci = torch.tensor(np.cos(wk*t_array.reshape((-1,1))))
    pe = torch.stack((si,ci),axis=2).reshape(t_array.shape[0],embed_dim).type(cast_type)
    return pe


def make_graph_struct(batch_size=32, n_nodes = 8):
    # make a fake graph to be filled with generator outputs
    
    v1 = np.arange(n_nodes-1) #vertex 1 of edges in chronological order
    v2 = np.arange(1,n_nodes) #vertex 2 of edges in chronological order

    ss = np.zeros(len(v1),dtype=np.int32)
    ss[np.arange(ss.shape[0])%2==0]=1  #alternate 0,1 for helix, loop, helix, etc
    ss = ss[:,None] #unsqueeze
    
    pe = make_pe_encoding(i_pos=8, embed_dim = 8, scale = 10, cast_type=torch.float32)

    graphList = []
    for i in range(batch_size):
        g = dgl.graph((v1,v2))
        g.edata['ss'] = torch.tensor(ss,dtype=torch.float32)
        g.ndata['pe'] = pe

        graphList.append(g)

    batched_graph = dgl.batch(graphList)

    return batched_graph


class GraphDataset(Dataset):
    def __init__(self, ep_file : pathlib.Path, limit=1000):
        self.data_path = ep_file
        rr = np.load(self.data_path)
        ep = [rr[f] for f in rr.files][0][:1000]
        
        #need to save furthest distance to regen later
        #maybe consider small change for next steps
        ep, self.furthest_distance = normalize_pc(ep.reshape((-1,3)))
        self.ep = ep.reshape((-1,8,3))
        
        
        v1 = np.arange(self.ep.shape[1]-1) #vertex 1 of edges in chronological order
        v2 = np.arange(1,self.ep.shape[1]) #vertex 2 of edges in chronological order

        ss = np.zeros(len(v1))
        ss[np.arange(ss.shape[0])%2==0]=1  #alternate 0,1 for helix, loop, helix, etc
        ss = ss[:,None] #unsqueeze

        #positional encoding of node
        pe = make_pe_encoding(i_pos=8, embed_dim = 8, scale = 10, cast_type=torch.float32)

        graphList = []

        for i,c in enumerate(self.ep):

            g = dgl.graph((v1,v2))
            g.ndata['pos'] = torch.tensor(c,dtype=torch.float32)
            g.edata['ss'] = torch.tensor(ss,dtype=torch.float32)
            g.ndata['pe'] = pe

            graphList.append(g)
        
        self.graphList = graphList


    def __len__(self):
        return len(self.graphList)

    def __getitem__(self, idx):
        return self.graphList[idx]

    
class HGenDataModule():
    """
    Datamodule wrapping hGen data set. 8 Helical endpoints defining a four helix protein.
    """
    #8 long positional encoding
    NODE_FEATURE_DIM = 8
    EDGE_FEATURE_DIM = 1 # 0 or 1 helix or loop

    def __init__(self,
                 data_dir: pathlib.Path, batch_size=32):
        
        self.data_dir = data_dir 
        self.GraphDatasetObj = GraphDataset(self.data_dir)
        self.gds = DataLoader(self.GraphDatasetObj,batch_size=batch_size, shuffle=True, drop_last=True,
                              collate_fn=self._collate)
        
    
        
    def _collate(self, graphs):
        batched_graph = dgl.batch(graphs)
        #reshape that batched graph to redivide into the individual graphs
        edge_feats = {'0': batched_graph.edata['ss'][:, :self.EDGE_FEATURE_DIM, None]}
        batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
        # get node features
        node_feats = {'0': batched_graph.ndata['pe'][:, :self.NODE_FEATURE_DIM, None]}
        
        return (batched_graph, node_feats, edge_feats)
    
def eval_gen(batch_size=8,z=12):
    
    in_z = torch.randn((batch_size,z), device='cuda',dtype = torch.float32)
    out = hg(in_z)*31
    out = out.reshape((-1,8,3)).detach().cpu().numpy()
    
    return eval_endpoints(out)
    
    

def eval_endpoints(ep_in): 
    
    ep = ep_in.reshape((-1,8,3))

    v1 = np.arange(ep.shape[1]-1) #vertex 1 of edges in chronological order
    v2 = np.arange(1,ep.shape[1]) #vertex 2 of edges in chronological order

    hLL = np.linalg.norm(ep[:,v1]-ep[:,v2],axis=2)

    hLoc = np.array([0,2,4,6])
    lLoc = np.array([1,3,5])

    return np.mean(hLL[:,hLoc]), np.mean(hLL[:,lLoc])
        