In [1]:
import torch
import numpy as np
import util.npose_util as nu
import os
import pathlib
import dgl
from dgl import backend as F
import torch_geometric
from torch.utils.data import random_split, DataLoader, Dataset
from typing import Dict
from torch import Tensor
from dgl import DGLGraph
from torch import nn
# from chemical import cos_ideal_NCAC #from RoseTTAFold2
from torch import einsum
import time
torch.cuda.is_available()

True

In [2]:
# from data_rigid_diffuser import so3_diffuser
# from data_rigid_diffuser import r3_diffuser
# from scipy.spatial.transform import Rotation
# from data_rigid_diffuser import rigid_utils as ru
# import yaml
from data_rigid_diffuser.diffuser import FrameDiffNoise

In [3]:
from se3_transformer.model.basis import get_basis, update_basis_with_fused
from se3_transformer.model.transformer import Sequential, SE3Transformer
from se3_transformer.model.transformer_topk import SE3Transformer_topK
from se3_transformer.model.FAPE_Loss import FAPE_loss, Qs2Rs, normQ
from se3_transformer.model.layers.attentiontopK import AttentionBlockSE3
from se3_transformer.model.layers.linear import LinearSE3
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from se3_transformer.model.layers.norm import NormSE3
from se3_transformer.model.layers.pooling import GPooling, Latent_Unpool, Unpool_Layer
from se3_transformer.runtime.utils import str2bool, to_cuda
from se3_transformer.model.fiber import Fiber
from se3_transformer.model.transformer import get_populated_edge_features

In [4]:
#npose indexing
# Useful numbers
# N [-1.45837285,  0 , 0]
# CA [0., 0., 0.]
# C [0.55221403, 1.41890368, 0.        ]
# CB [ 0.52892494, -0.77445692, -1.19923854]

N_CA_dist = torch.tensor(1.458/10.0).to('cuda')
C_CA_dist = torch.tensor(1.523/10.0).to('cuda')

if ( hasattr(os, 'ATOM_NAMES') ):
    assert( hasattr(os, 'PDB_ORDER') )

    ATOM_NAMES = os.ATOM_NAMES
    PDB_ORDER = os.PDB_ORDER
else:
    ATOM_NAMES=['N', 'CA', 'CB', 'C', 'O']
    PDB_ORDER = ['N', 'CA', 'C', 'O', 'CB']

_byte_atom_names = []
_atom_names = []
for i, atom_name in enumerate(ATOM_NAMES):
    long_name = " " + atom_name + "       "
    _atom_names.append(long_name[:4])
    _byte_atom_names.append(atom_name.encode())

    globals()[atom_name] = i

R = len(ATOM_NAMES)

if ( "N" not in globals() ):
    N = -1
if ( "C" not in globals() ):
    C = -1
if ( "CB" not in globals() ):
    CB = -1


_pdb_order = []
for name in PDB_ORDER:
    _pdb_order.append( ATOM_NAMES.index(name) )

In [5]:
# data_path_str  = 'data/h4_ca_coords.npz'
# test_limit = 1028
# rr = np.load(data_path_str)
# ca_coords = [rr[f] for f in rr.files][0][:test_limit,:,:3]
# ca_coords.shape

# getting N-Ca, Ca-C vectors to add as typeI features
#apa = apart helices for val/train split
#tog = together helices for val/train split
apa_path_str  = 'data_npose/h4_apa_coords.npz'
tog_path_str  = 'data_npose/h4_tog_coords.npz'

#grab the first 3 atoms which are N,CA,C
test_limit = 5048
rr = np.load(apa_path_str)
coords_apa = [rr[f] for f in rr.files][0][:test_limit,:]

rr = np.load(tog_path_str)
coords_tog = [rr[f] for f in rr.files][0][:test_limit,:]

In [6]:
def build_npose_from_coords(coords_in):
    """Use N, CA, C coordinates to generate O an CB atoms"""
    rot_mat_cat = np.ones(sum((coords_in.shape[:-1], (1,)), ()))
    
    coords = np.concatenate((coords_in,rot_mat_cat),axis=-1)
    
    npose = np.ones((coords_in.shape[0]*5,4)) #5 is atoms per res

    by_res = npose.reshape(-1, 5, 4)
    
    if ( "N" in ATOM_NAMES ):
        by_res[:,N,:3] = coords_in[:,0,:3]
    if ( "CA" in ATOM_NAMES ):
        by_res[:,CA,:3] = coords_in[:,1,:3]
    if ( "C" in ATOM_NAMES ):
        by_res[:,C,:3] = coords_in[:,2,:3]
    if ( "O" in ATOM_NAMES ):
        by_res[:,O,:3] = nu.build_O(npose)
    if ( "CB" in ATOM_NAMES ):
        tpose = nu.tpose_from_npose(npose)
        by_res[:,CB,:] = nu.build_CB(tpose)

    return npose

def dump_coord_pdb(coords_in, fileOut='fileOut.pdb'):
    
    npose =  build_npose_from_coords(coords_in)
    nu.dump_npdb(npose,fileOut)

In [7]:
#goal define edges of
#connected backbone 1, 
#unconnected atoms 0,


def get_midpoint(ep_in):
    """Get midpoint, of each batched set of points"""
    
    #calculate midpoint
    midpoint = ep_in.sum(axis=1)/np.repeat(ep_in.shape[1], ep_in.shape[2])
    
    return midpoint

# def normQ(Q):
#     """normalize a quaternions
#     """
#     return Q / torch.linalg.norm(Q, keepdim=True, dim=-1)

# def Rs2Qs(Rs):
#     Qs = torch.zeros((*Rs.shape[:-2],4), device=Rs.device)

#     Qs[...,0] = 1.0 + Rs[...,0,0] + Rs[...,1,1] + Rs[...,2,2]
#     Qs[...,1] = 1.0 + Rs[...,0,0] - Rs[...,1,1] - Rs[...,2,2]
#     Qs[...,2] = 1.0 - Rs[...,0,0] + Rs[...,1,1] - Rs[...,2,2]
#     Qs[...,3] = 1.0 - Rs[...,0,0] - Rs[...,1,1] + Rs[...,2,2]
#     Qs[Qs<0.0] = 0.0
#     Qs = torch.sqrt(Qs) / 2.0
#     Qs[...,1] *= torch.sign( Rs[...,2,1] - Rs[...,1,2] )
#     Qs[...,2] *= torch.sign( Rs[...,0,2] - Rs[...,2,0] )
#     Qs[...,3] *= torch.sign( Rs[...,1,0] - Rs[...,0,1] )

#     return Qs

# def Qs2Rs(Qs):
#     Rs = torch.zeros((*Qs.shape[:-1],3,3), device=Qs.device)

#     Rs[...,0,0] = Qs[...,0]*Qs[...,0]+Qs[...,1]*Qs[...,1]-Qs[...,2]*Qs[...,2]-Qs[...,3]*Qs[...,3]
#     Rs[...,0,1] = 2*Qs[...,1]*Qs[...,2] - 2*Qs[...,0]*Qs[...,3]
#     Rs[...,0,2] = 2*Qs[...,1]*Qs[...,3] + 2*Qs[...,0]*Qs[...,2]
#     Rs[...,1,0] = 2*Qs[...,1]*Qs[...,2] + 2*Qs[...,0]*Qs[...,3]
#     Rs[...,1,1] = Qs[...,0]*Qs[...,0]-Qs[...,1]*Qs[...,1]+Qs[...,2]*Qs[...,2]-Qs[...,3]*Qs[...,3]
#     Rs[...,1,2] = 2*Qs[...,2]*Qs[...,3] - 2*Qs[...,0]*Qs[...,1]
#     Rs[...,2,0] = 2*Qs[...,1]*Qs[...,3] - 2*Qs[...,0]*Qs[...,2]
#     Rs[...,2,1] = 2*Qs[...,2]*Qs[...,3] + 2*Qs[...,0]*Qs[...,1]
#     Rs[...,2,2] = Qs[...,0]*Qs[...,0]-Qs[...,1]*Qs[...,1]-Qs[...,2]*Qs[...,2]+Qs[...,3]*Qs[...,3]

#     return Rs


def normalize_points(input_xyz, print_dist=False):
    
    #broadcast to distance matrix [Batch, M, R3] to [Batch,M,1, R3] to [Batch,1,M, R3] to [Batch, M,M, R3] 
    vec_diff = input_xyz[...,None,:]-input_xyz[...,None,:,:]
    dist = np.sqrt(np.sum(np.square(vec_diff),axis=len(input_xyz.shape)))
    furthest_dist = np.max(dist)
    centroid  = get_midpoint(input_xyz)
    if print_dist:
        print(f'largest distance {furthest_dist:0.1f}')
    
    xyz_mean_zero = input_xyz - centroid[:,None,:]
    return xyz_mean_zero/furthest_dist

def define_graph_edges(n_nodes):
    #connected backbone

    con_v1 = np.arange(n_nodes-1) #vertex 1 of edges in chronological order
    con_v2 = np.arange(1,n_nodes) #vertex 2 of edges in chronological order

    ind = con_v1*(n_nodes-1)+con_v2-1 #account for removed self connections (-1)


    #unconnected backbone

    nodes = np.arange(n_nodes)
    v1 = np.repeat(nodes,n_nodes-1) #starting vertices, same number repeated for each edge

    start_v2 = np.repeat(np.arange(n_nodes)[None,:],n_nodes,axis=0)
    diag_ind = np.diag_indices(n_nodes)
    start_v2[diag_ind] = -1 #diagonal of matrix is self connections which we remove (self connections are managed by SE3 Conv channels)
    v2 = start_v2[start_v2>-0.5] #remove diagonal and flatten

    edge_data = torch.zeros(len(v2))
    edge_data[ind] = 1
    
    return v1,v2,edge_data, ind



def make_pe_encoding(n_nodes=65, embed_dim = 12, scale = 1000, cast_type=torch.float32, print_out=False):
    #positional encoding of node
    i_array = np.arange(1,(embed_dim/2)+1)
    wk = (1/(scale**(i_array*2/embed_dim)))
    t_array = np.arange(n_nodes)
    si = torch.tensor(np.sin(wk*t_array.reshape((-1,1))))
    ci = torch.tensor(np.cos(wk*t_array.reshape((-1,1))))
    pe = torch.stack((si,ci),axis=2).reshape(t_array.shape[0],embed_dim).type(cast_type)
    
    if print_out == True:
        for x in range(int(n_nodes/12)):
            print(np.round(pe[x],1))
    
    return pe
    
    
#v1,v2,edge_data, ind = define_graph_edges(n_nodes)
#norm_p = normalize_points(ca_coords,print_dist=True)
pe = make_pe_encoding(n_nodes=65, embed_dim = 12, scale = 10, print_out=True)

tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
tensor([0.6000, 0.8000, 0.4000, 0.9000, 0.3000, 1.0000, 0.2000, 1.0000, 0.1000,
        1.0000, 0.1000, 1.0000])
tensor([1.0000, 0.2000, 0.8000, 0.6000, 0.6000, 0.8000, 0.4000, 0.9000, 0.3000,
        1.0000, 0.2000, 1.0000])
tensor([ 0.9000, -0.5000,  1.0000,  0.2000,  0.8000,  0.6000,  0.6000,  0.8000,
         0.4000,  0.9000,  0.3000,  1.0000])
tensor([ 0.4000, -0.9000,  1.0000, -0.3000,  1.0000,  0.3000,  0.8000,  0.7000,
         0.6000,  0.8000,  0.4000,  0.9000])


In [8]:
#?dgl.nn.pytorch.KNNGraph, nearest neighbor graph maker
def define_graph(batch_size=8,n_nodes=65):
    
    v1,v2,edge_data, ind = define_graph_edges(n_nodes)
    pe = make_pe_encoding(n_nodes=n_nodes)
    
    graphList = []
    
    for i in range(batch_size):
        
        g = dgl.graph((v1,v2))
        g.edata['con'] = edge_data
        g.ndata['pe'] = pe

        graphList.append(g)
        
    batched_graph = dgl.batch(graphList)

    return batched_graph


In [9]:
def torch_normalize(v, eps=1e-6):
    """Normalize vector in last axis"""
    norm = torch.linalg.vector_norm(v, dim=len(v.shape)-1)+eps
    return v / norm[...,None]

def normalize(v):
    """Normalize vector in last axis"""
    norm = np.linalg.norm(v,axis=len(v.shape)-1)
    norm[norm == 0] = 1
    return v / norm[...,None]

def get_CN_vector(coords_in):
    N_CA_vec = normalize(coords_in[...,N,:3]-coords_in[...,CA,:3])
    C_CA_vec = normalize(coords_in[...,C,:3]-coords_in[...,CA,:3])
    return N_CA_vec, C_CA_vec





In [10]:
def _get_relative_pos(graph_in: dgl.DGLGraph) -> torch.Tensor:
    x = graph_in.ndata['pos']
    src, dst = graph_in.edges()
    rel_pos = x[dst] - x[src]
    return rel_pos

class Helix4_Dataset(Dataset):
    def __init__(self, coordinates: np.array, cast_type=torch.float32):
        #prots,#length_prot in aa, #residues/aa, #xyz per atom
           
        #alphaFold reduce by 10
        coord_div = 10
        
        coordinates = coordinates/coord_div
        self.ca_coords = torch.tensor(coordinates[:,:,CA,:], dtype=cast_type)
        #unsqueeze to stack together later
        self.N_CA_vec = torch.tensor(coordinates[:,:,N,:] - coordinates[:,:,CA,:], dtype=cast_type)
        self.C_CA_vec = torch.tensor(coordinates[:,:,C,:] - coordinates[:,:,CA,:], dtype=cast_type)
        
        self.N_CA_vec = torch_normalize(self.N_CA_vec).unsqueeze(2)
        self.C_CA_vec = torch_normalize(self.C_CA_vec).unsqueeze(2)
        
    def __len__(self):
        return len(self.ca_coords)

    def __getitem__(self, idx):
        return {'CA':self.ca_coords[idx], 'N_CA':self.N_CA_vec[idx], 'C_CA':self.C_CA_vec[idx]}
    
    

    
    
class Make_KNN_MP_Graphs():
    
    #8 long positional encoding
    NODE_FEATURE_DIM_0 = 12
    EDGE_FEATURE_DIM = 1 # 0 or 1 primary seq connection or not
    NODE_FEATURE_DIM_1 = 2
    
    def __init__(self, mp_stride=4, n_nodes=65, radius=15, coord_div=10, cast_type=torch.float32, channels_start=32,
                       ndf1=6, ndf0=32,cuda=True):
        
        self.KNN = 30
        self.n_nodes = n_nodes
        self.pe = make_pe_encoding(n_nodes=n_nodes)
        self.mp_stride = mp_stride
        self.cast_type = cast_type
        self.channels_start = channels_start
        
        self.cuda = cuda
        self.ndf1 = ndf1 #awkard adding of nodes features to mpGraph
        self.ndf0 = ndf0
        
    def create_and_batch(self, bb_dict):
        
        graphList = []
        mpGraphList = []
        mpRevGraphList = []
        mpSelfGraphList = []
        
        for j, caXYZ in enumerate(bb_dict['CA']):
            graph = dgl.knn_graph(caXYZ, self.KNN)
            graph.ndata['pe'] = pe
            graph.ndata['pos'] = caXYZ
            graph.ndata['bb_ori'] = torch.cat((bb_dict['N_CA'][j],  bb_dict['C_CA'][j]),axis=1)
            
            #define covalent connections
            esrc, edst = graph.edges()
            graph.edata['con'] = (torch.abs(esrc-edst)==1).type(self.cast_type).reshape((-1,1))
            
            mp_list = torch.zeros((len(list(range(0,self.n_nodes, self.mp_stride))),caXYZ.shape[1]))
            
            new_src = torch.tensor([],dtype=torch.int)
            new_dst = torch.tensor([],dtype=torch.int)
            
            new_src_rev = torch.tensor([], dtype=torch.int)
            new_dst_rev = torch.tensor([], dtype=torch.int)
           
            i=0#mp list counter
            for x in range(0,self.n_nodes, self.mp_stride):
                src, dst = graph.in_edges(x) #dst repeats x
                n_tot = torch.cat((torch.tensor(x).unsqueeze(0),src)) #add x to node list
                mp_list[i] = caXYZ[n_tot].sum(axis=0)/n_tot.shape[0]
                mp_node = i + graph.num_nodes() #add midpoints nodes at end of graph
                #define edges between midpoint nodes and nodes defining midpoint for midpointGraph
                
                new_src = torch.cat((new_src,n_tot))
                new_dst = torch.cat((new_dst,
                                     (torch.tensor(mp_node).unsqueeze(0).repeat(n_tot.shape[0]))))
                #and reverse graph for coming off
                new_src_rev = torch.cat((new_src_rev,
                                         (torch.tensor(mp_node).unsqueeze(0).repeat(n_tot.shape[0]))))
                new_dst_rev = torch.cat((new_dst_rev,n_tot))
                
                i+=1
                
            mpGraph = dgl.graph((new_src,new_dst))
            mpGraph.ndata['pos'] = torch.cat((caXYZ,mp_list),axis=0).type(self.cast_type)
            mp_node_indx = torch.arange(0,self.n_nodes, self.mp_stride).type(torch.int)
            #match output shape of first transformer
            pe_mp = torch.cat((pe,torch.zeros((pe.shape[0],self.channels_start-pe.shape[1]))),axis=1)
            mpGraph.ndata['pe'] = torch.cat((pe_mp,pe_mp[mp_node_indx]))
            mpGraph.edata['con'] = torch.zeros((mpGraph.num_edges(),1))
            
            mpGraph_rev = dgl.graph((new_src_rev,new_dst_rev))
            mpGraph_rev.ndata['pos'] = torch.cat((caXYZ,mp_list),axis=0).type(self.cast_type)
            mpGraph_rev.ndata['pe'] = torch.cat((pe_mp,pe_mp[mp_node_indx]))
            mpGraph_rev.edata['con'] = torch.zeros((mpGraph_rev.num_edges(),1))
            
            #make graph for self interaction of midpoints
            v1,v2,edge_data, ind = define_graph_edges(len(mp_list))
            mpSelfGraph = dgl.graph((v1,v2))
            mpSelfGraph.edata['con'] = edge_data.reshape((-1,1))
            mpSelfGraph.ndata['pe'] = pe[mp_node_indx] #not really needed
            mpSelfGraph.ndata['pos'] = mp_list.type(self.cast_type)
            
            
            mpSelfGraphList.append(mpSelfGraph) 
            mpGraphList.append(mpGraph)
            mpRevGraphList.append(mpGraph_rev)
            graphList.append(graph)
        
        return dgl.batch(graphList), dgl.batch(mpGraphList), dgl.batch(mpSelfGraphList), dgl.batch(mpRevGraphList)
    
    def prep_for_network(self, bb_dict, cuda=True):
    
        batched_graph, batched_mpgraph, batched_mpself_graph, batched_mpRevgraph =  self.create_and_batch(bb_dict)
        
        edge_feats        =    {'0':   batched_graph.edata['con'][:, :self.EDGE_FEATURE_DIM, None]}
        edge_feats_mp     = {'0': batched_mpgraph.edata['con'][:, :self.EDGE_FEATURE_DIM, None]} #def all zero now
        edge_feats_mpself = {'0': batched_mpself_graph.edata['con'][:, :self.EDGE_FEATURE_DIM, None]}
#         edge_feats_mp     = {'0': batched_mpRevgraph.edata['con'][:, :self.EDGE_FEATURE_DIM, None]}
        batched_graph.edata['rel_pos']   = _get_relative_pos(batched_graph)
        batched_mpgraph.edata['rel_pos'] = _get_relative_pos(batched_mpgraph)
        batched_mpself_graph.edata['rel_pos'] = _get_relative_pos(batched_mpself_graph)
        batched_mpRevgraph.edata['rel_pos'] = _get_relative_pos(batched_mpRevgraph)
        # get node features
        node_feats =         {'0': batched_graph.ndata['pe'][:, :self.NODE_FEATURE_DIM_0, None],
                              '1': batched_graph.ndata['bb_ori'][:,:self.NODE_FEATURE_DIM_1, :3]}
        node_feats_mp =      {'0': batched_mpgraph.ndata['pe'][:, :self.ndf0, None],
                              '1': torch.ones((batched_mpgraph.num_nodes(),self.ndf1,3))}
        #unused
        node_feats_mpself =  {'0': batched_mpself_graph.ndata['pe'][:, :self.NODE_FEATURE_DIM_0, None]}
        
        if cuda:
            bg,nf,ef = to_cuda(batched_graph), to_cuda(node_feats), to_cuda(edge_feats)
            bg_mp, nf_mp, ef_mp = to_cuda(batched_mpgraph), to_cuda(node_feats_mp), to_cuda(edge_feats_mp)
            bg_mps, nf_mps, ef_mps = to_cuda(batched_mpself_graph), to_cuda(node_feats_mpself), to_cuda(edge_feats_mpself)
            bg_mpRev = to_cuda(batched_mpRevgraph)
            
            return bg,nf,ef, bg_mp, nf_mp, ef_mp, bg_mps, nf_mps, ef_mps, bg_mpRev
        
        else:
            bg,nf,ef = batched_graph, node_feats, edge_feats
            bg_mp, nf_mp, ef_mp = batched_mpgraph, node_feats_mp, edge_feats_mp
            bg_mps, nf_mps, ef_mps = batched_mpself_graph, node_feats_mpself, edge_feats_mpself
            bg_mpRev = batched_mpRevgraph
            
            return bg,nf,ef, bg_mp, nf_mp, ef_mp, bg_mps, nf_mps, ef_mps, bg_mpRev
        
            

def get_edge_features(graph,edge_feature_dim=1):
    return {'0': graph.edata['con'][:, :edge_feature_dim, None]}

def define_poolGraph(n_nodes, batch_size, cast_type=torch.float32, cuda_out=True ):
    
    v1,v2,edge_data, ind = define_graph_edges(n_nodes)
    #pe = make_pe_encoding(n_nodes=n_nodes)#pe e
    
    graphList = []
    
    for i in range(batch_size):
        
        g = dgl.graph((v1,v2))
        g.edata['con'] = edge_data.type(cast_type).reshape((-1,1))
        g.ndata['pos'] = torch.zeros((n_nodes,3),dtype=torch.float32)

        graphList.append(g)
        
    batched_graph = dgl.batch(graphList)
    
    if cuda_out:
        return to_cuda(batched_graph)
    else:
        return batched_graph            
        

In [11]:
def pull_edge_features(graph, edge_feat_dim=1):
    return {'0': graph.edata['con'][:, :edge_feat_dim, None]}

def prep_for_gcn(graph, xyz_pos, edge_feats_input, idx, max_degree=3, comp_grad=True):
    
    src, dst = graph.edges()
    
    new_pos = F.gather_row(xyz_pos, idx)
    rel_pos = F.gather_row(new_pos,dst) - F.gather_row(new_pos,src) 
    
    basis_out = get_basis(rel_pos, max_degree=max_degree,
                                   compute_gradients=comp_grad,
                                   use_pad_trick=False)
    basis_out = update_basis_with_fused(basis_out, max_degree, use_pad_trick=False,
                                            fully_fused=False)
    edge_feats_out = get_populated_edge_features(rel_pos, edge_feats_input)
    return edge_feats_out, basis_out, new_pos    

In [12]:
class GraphUNet(torch.nn.Module):
    def __init__(self, 
                 fiber_start = Fiber({0:12, 1:2}),
                 fiber_out = Fiber({1:2}),
                 k=4,
                 batch_size = 8,
                 stride=4,
                 max_degree=3,
                 channels=32,
                 num_heads = 8,
                 channels_div=4,
                 num_layers = 1,
                 num_layers_ca = 1,
                 edge_feature_dim=1,
                 latent_pool_type = 'avg',
                 t_size = 1,
                 cuda=True):
        super(GraphUNet, self).__init__()
        
        
        self.comp_basis_grad = True
        self.cuda = cuda
        
        if cuda:
            self.device='cuda:0'
        else:
            self.device='cpu'
        
        self.max_degree=max_degree
        self.B = batch_size
        self.k = k
        self.ts = t_size
        
        self.num_layers = 1
        self.num_layers_ca = num_layers_ca
        self.channels = 32
        self.feat0 = 32
        self.feat1 = 6
        self.channels_div = 4
        self.num_heads = 8
        self.mult = int(stride/2)
        self.fiber_edge=Fiber({0:edge_feature_dim})
        self.edge_feat_dim = edge_feature_dim
        
        self.pool_type = latent_pool_type
        
        self.channels_down_ca = channels
        #down c_alpha interactions by radius
        self.fiber_start =  fiber_start
        self.fiber_hidden_down_ca = Fiber.create(self.max_degree, self.channels_down_ca)
        self.fiber_out_down_ca =Fiber({0: self.feat0, 1: self.feat1})
        
        #concat_t, plus one on input fiber, run concat_t method on forward
        self.down_ca = SE3Transformer(num_layers = self.num_layers_ca,
                        fiber_in=self.fiber_start+self.ts,
                        fiber_hidden= self.fiber_hidden_down_ca, 
                        fiber_out=self.fiber_out_down_ca,
                        num_heads = self.num_heads,
                        channels_div = self.channels_div,
                        fiber_edge=self.fiber_edge,
                        low_memory=True,
                        tensor_cores=False)
        
        self.channels_down_ca2mp = self.channels_down_ca*self.mult
        
        #pool from c_alpha onto midpoints
        self.fiber_in_down_ca2mp     = self.fiber_out_down_ca
        self.fiber_hidden_down_ca2mp = Fiber.create(max_degree, self.channels_down_ca2mp)
        self.fiber_out_down_ca2mp    = Fiber({0: self.feat0*self.mult, 1: self.feat1*self.mult})
        
        #concat_t, plus one on input fiber, run concat_t method on forward
        self.down_ca2mp = SE3Transformer(num_layers = self.num_layers,
                            fiber_in     = self.fiber_in_down_ca2mp+self.ts,
                            fiber_hidden = self.fiber_hidden_down_ca2mp, 
                            fiber_out    = self.fiber_out_down_ca2mp,
                            num_heads =    self.num_heads,
                            channels_div = self.channels_div,
                            fiber_edge=self. fiber_edge,
                            low_memory=True,
                            tensor_cores=False)
        
        self.fiber_in_mptopk =  self.fiber_out_down_ca2mp
        self.fiber_hidden_down_mp  =self.fiber_hidden_down_ca2mp
        self.fiber_out_down_mp_out =self.fiber_out_down_ca2mp
        self.fiber_out_topkpool=Fiber({0: self.feat0*self.mult*self.mult})
        
        #concat_t, plus one on input fiber, run concat_t method on forward
        self.mp_topk = SE3Transformer_topK(num_layers      = self.num_layers,
                                        fiber_in      = self.fiber_in_mptopk+self.ts,
                                        fiber_hidden  = self.fiber_hidden_down_mp, 
                                        fiber_out     = self.fiber_out_down_mp_out ,
                                        fiber_out_topk= self.fiber_out_topkpool,
                                        k             = self.k,
                                        num_heads     = self.num_heads,
                                        channels_div  = self.channels_div,
                                        fiber_edge    =  self.fiber_edge,
                                        low_memory=True,
                                        tensor_cores=False)
        
        self.gsmall = define_poolGraph(self.k, self.B, cast_type=torch.float32, cuda_out=self.cuda)
        self.ef_small = pull_edge_features(self.gsmall, edge_feat_dim=self.edge_feat_dim)
        
        #change to doing convolutions instead of points
        self.fiber_in_down_gcn   =  self.fiber_out_topkpool
        self.fiber_out_down_gcn  = Fiber({0: self.feat0*self.mult*self.mult, 1: self.feat1*self.mult})

        self.down_gcn = ConvSE3(fiber_in  = self.fiber_in_down_gcn,
                           fiber_out = self.fiber_out_down_gcn,
                           fiber_edge= self.fiber_edge,
                             self_interaction=True,
                             use_layer_norm=True,
                             max_degree=self.max_degree,
                             fuse_level= ConvSE3FuseLevel.NONE,
                             low_memory= True)
        
        
        self.fiber_in_down_gcnp = self.fiber_out_down_gcn
        #probably rename latent
        self.latent_size = self.feat0*self.mult*self.mult
        self.fiber_latent = Fiber({0: self.latent_size})

        self.down_gcn2pool = ConvSE3(fiber_in=self.fiber_in_down_gcnp,
                                     fiber_out=self.fiber_latent,
                                     fiber_edge=self.fiber_edge,
                                     self_interaction=True,
                                     use_layer_norm=True,
                                     max_degree=self.max_degree,
                                     fuse_level= ConvSE3FuseLevel.NONE,
                                     low_memory= True)
        
        self.global_pool = GPooling(pool=self.pool_type, feat_type=0)

        self.latent_unpool_layer = Latent_Unpool(fiber_in = self.fiber_latent, fiber_add = self.fiber_out_down_gcn, 
                                            knodes = self.k)

        self.up_gcn = ConvSE3(fiber_in=self.fiber_out_down_gcn,
                             fiber_out=self.fiber_out_down_gcn,
                             fiber_edge=self.fiber_edge,
                             self_interaction=True,
                             use_layer_norm=True,
                             max_degree=self.max_degree,
                             fuse_level= ConvSE3FuseLevel.NONE,
                             low_memory= True)
        
        self.unpool_layer = Unpool_Layer(fiber_in=self.fiber_out_down_gcn, fiber_add=self.fiber_out_down_ca)
        
        self.fiber_in_up_gcn_mp = self.unpool_layer.fiber_out
        self.fiber_hidden_up_mp= self.fiber_hidden_down_ca2mp
        self.fiber_out_up_gcn_mp = self.fiber_out_down_mp_out

        self.up_gcn_mp = SE3Transformer(num_layers = num_layers,
                        fiber_in=self.fiber_in_up_gcn_mp,
                        fiber_hidden= self.fiber_hidden_up_mp, 
                        fiber_out=self.fiber_out_up_gcn_mp,
                        num_heads = self.num_heads,
                        channels_div = self.channels_div,
                        fiber_edge=self.fiber_edge,
                        low_memory=True,
                        tensor_cores=False)
        
        self.unpool_layer_off_mp = Unpool_Layer(fiber_in=self.fiber_out_down_mp_out, fiber_add=self.fiber_out_down_mp_out)

        self.fiber_in_up_off_mp = self.fiber_out_up_gcn_mp
        self.fiber_hidden_up_off_mp = self.fiber_hidden_up_mp
        self.fiber_out_up_off_mp = self.fiber_out_down_ca 
        
        #uses reverse graph to move mp off 
        
        self.up_off_mp = SE3Transformer(num_layers = self.num_layers,
                        fiber_in=self.fiber_in_up_off_mp,
                        fiber_hidden= self.fiber_hidden_up_off_mp, 
                        fiber_out=self.fiber_out_up_off_mp,
                        num_heads = self.num_heads,
                        channels_div = self.channels_div,
                        fiber_edge=self.fiber_edge,
                        low_memory=True,
                        tensor_cores=False)
        
        self.pre_linear = Fiber({1:36})
        
        #concat_t, plus one on input fiber, run concat_t method on forward
        
        self.up_ca = SE3Transformer(num_layers = self.num_layers_ca,
                                    fiber_in=self.fiber_out_down_ca+self.ts,
                                    fiber_hidden= self.fiber_hidden_down_ca, 
                                    fiber_out=self.pre_linear,
                                    num_heads = self.num_heads,
                                    channels_div = self.channels_div,
                                    fiber_edge= self.fiber_edge,
                                    low_memory=True,
                                    tensor_cores=False)
        
        self.fiber_out = fiber_out
        
        self.linear = LinearSE3(fiber_in=self.pre_linear,
                                fiber_out=fiber_out)
        
        self.zero_linear()
        
    def zero_linear(self):
        nn.init.zeros_(self.linear.weights['1'])
        
    def concat_mp_feats(self, ca_feats_in, mp_feats):

        nf0_c = ca_feats_in['0'].shape[-2]
        nf1_c = ca_feats_in['1'].shape[-2]

        out0_cat_shape = (B,self.ca_nodes,-1,1)
        mp0_cat_shape  = (B,self.mp_nodes,-1,1)
        out1_cat_shape = (B,self.ca_nodes,-1,3)
        mp1_cat_shape  = (B,self.mp_nodes,-1,3)

        nf_c = {} #nf_cat
        nf_c['0'] = torch.cat((ca_feats_in['0'].reshape(out0_cat_shape), 
                                 mp_feats['0'].reshape(mp0_cat_shape)[:,-(self.mp_nodes-self.ca_nodes):,:,:]),
                              axis=1).reshape((-1,nf0_c,1))

        nf_c['1'] = torch.cat((ca_feats_in['1'].reshape(out1_cat_shape), 
                                 mp_feats['1'].reshape(mp1_cat_shape)[:,-(self.mp_nodes-self.ca_nodes):,:,:]),
                              axis=1).reshape((-1,nf1_c,3))

        return nf_c
        
    def pull_out_mp_feats(self, ca_mp_feats):

        nf0_c = ca_mp_feats['0'].shape[1]
        nf1_c = ca_mp_feats['1'].shape[1]

        nf_mp_ = {}
        #select just mp nodes to move on, the other nodes don't connect but mainting self connections
        nf_mp_['0'] = ca_mp_feats['0'].reshape(B,self.mp_nodes,
                                               nf0_c,1)[:,-(self.mp_nodes-self.ca_nodes):,...].reshape((-1,nf0_c,1))
        nf_mp_['1'] = ca_mp_feats['1'].reshape(B,self.mp_nodes,
                                               nf1_c,3)[:,-(self.mp_nodes-self.ca_nodes):,...].reshape((-1,nf1_c,3))

        return nf_mp_
    
    def concat_t(self, feats_in, t_vec):
        """Concatenate T to first position of each tensor. Pad Zeros left for degree 1."""
        feats_out = {}
        key = next(iter(feats_in.keys()))
        shape_tuple = (self.B,-1)+feats_in[key].shape[1:]
        batch_shape = feats_in[key].reshape(shape_tuple).shape
        L = batch_shape[1] #can be ca, ca+mp, mp, k nodes long

        if '0' in feats_in.keys():
            feats_out['0'] = torch.concat((t_vec[...,None,None,None].repeat(1,L,1,1), 
                                           feats_in['0'].reshape((self.B,L,-1,1))),
                                          axis=2).reshape((self.B*L,-1,1))
        if '1' in feats_in.keys():
            pshape = t_vec[...,None,None,None].repeat(1,L,1,1)
            p1d = (2,0)
            out = torch.nn.functional.pad(pshape, p1d, "constant", 0)
            feats_out['1'] = torch.concat((out, feats_in['1'].reshape((self.B,L,-1,3)))
                                          , dim=2).reshape((self.B*L,-1,3))

        return feats_out
        
    def forward(self, input_tuple, batched_t):

        b_graph, nf, ef, b_graph_mp, nf_mp, ef_mp, b_graph_mps, nf_mps, ef_mps, b_graph_mpRev = input_tuple
        #assumes equal node numbers in g raphs
        self.ca_nodes = int(b_graph.batch_num_nodes()[0])
        self.mp_nodes = int(b_graph_mp.batch_num_nodes()[0]) #ca+mp nodes number

        #SE3 Attention Transformer, c_alpha 
        t_nf = self.concat_t(nf, batched_t) #concat_t on
        nf_ca_down_out = self.down_ca(b_graph, t_nf, ef)

        #concatenate on midpoints feats
        
        nf_down_cat_mp = self.concat_mp_feats(nf_ca_down_out, nf_mp)

        #pool from ca onto selected midpoints via SE3 Attention transformer
        #edges from ca to mp only (ca nodes zero after this)
        t_nf_down_cat_mp = self.concat_t(nf_down_cat_mp, batched_t) #concat_t on
        nf_down_ca2mp_out = self.down_ca2mp(b_graph_mp, t_nf_down_cat_mp, ef_mp)

        #remove ca node feats from tensor 
        nf_mp_out = self.pull_out_mp_feats(nf_down_ca2mp_out)

        t_nf_mp_out = self.concat_t(nf_mp_out, batched_t) #concat_t on
        node_feats_tk, topk_feats, topk_indx = self.mp_topk(b_graph_mps, t_nf_mp_out, ef_mps)

        #make new basis for small graph of k selected midpoints
        edge_feats_out, basis_out, new_pos = prep_for_gcn(self.gsmall, b_graph_mps.ndata['pos'], self.ef_small,
                                                          topk_indx,
                                                          max_degree=self.max_degree, comp_grad=True)

        down_gcn_out = self.down_gcn(topk_feats, edge_feats_out, self.gsmall,  basis_out)

        down_gcnpool_out = self.down_gcn2pool(down_gcn_out, edge_feats_out, self.gsmall,  basis_out)

        pooled_tensor = self.global_pool(down_gcnpool_out,self.gsmall)
        pooled = {'0':pooled_tensor}
        #----------------------------------------- end of down section
        lat_unp = self.latent_unpool_layer(pooled,down_gcn_out)

        up_gcn_out = self.up_gcn(lat_unp, edge_feats_out, self.gsmall,  basis_out)

        k_to_mp  = self.unpool_layer(up_gcn_out,node_feats_tk,topk_indx)

        up_mp_gcn_out = self.up_gcn_mp(b_graph_mps, k_to_mp, ef_mps)
        
        off_mp_add = {}
        for k,v in up_mp_gcn_out.items():
            off_mp_add[k] = torch.add(up_mp_gcn_out[k],nf_mp_out[k])


        #####triple check from here
        #midpoints node indices for unpool layer
        mp_node_indx = torch.arange(self.ca_nodes,self.mp_nodes, device=self.device)
        mp_idx = mp_node_indx[None,...].repeat_interleave(self.B,0)
        mp_idx =((torch.arange(self.B,device=self.device)*(self.mp_nodes)).reshape((-1,1))+mp_idx).reshape((-1))
        
        #during unpool, keep mp=values and ca=zeros
        zeros_mp_ca = {}
        for k,v in nf_down_cat_mp.items():
            zeros_mp_ca[k] = torch.zeros_like(v, device=self.device)


        unpoff_out = self.unpool_layer_off_mp(off_mp_add, zeros_mp_ca, mp_idx)
        
        out_up_off_mp = self.up_off_mp(b_graph_mpRev, unpoff_out, ef_mp)
        
        #select just ca nodes, mp = zeros from last convolution
        inv_mp_idx= torch.arange(0,self.ca_nodes, device=self.device)
        inv_mp_idx = inv_mp_idx[None,...].repeat_interleave(self.B,0)
        inv_mp_idx =((torch.arange(self.B,device=self.device)*(self.mp_nodes)).reshape((-1,1))
                     +inv_mp_idx).reshape((-1))

        node_final_ca = {}
        for key in out_up_off_mp.keys():
            node_final_ca[key] = torch.add(out_up_off_mp[key][inv_mp_idx,...],nf_ca_down_out[key])

        #return updates 
        t_node_final_ca = self.concat_t(node_final_ca, batched_t) #concat_t on
        
        return self.linear(self.up_ca(b_graph, t_node_final_ca, ef))
                
        
        

In [13]:
def model_step(backbone_dict, noised_dict, batched_t, scores_scales, graph_maker, graph_unet, train=True):
    
    CA_t  = backbone_dict['CA'].reshape(B, L, 3).to('cuda')
    NC_t = CA_t + backbone_dict['N_CA'].reshape(B, L, 3).to('cuda')
    CC_t = CA_t + backbone_dict['C_CA'].reshape(B, L, 3).to('cuda')
    true =  torch.cat((NC_t,CA_t,CC_t),dim=2).reshape(B,L,3,3)
    
    CA_n  = noised_dict['CA'].reshape(B, L, 3).to('cuda')
    NC_n = CA_n + noised_dict['N_CA'].reshape(B, L, 3).to('cuda')
    CC_n = CA_n + noised_dict['C_CA'].reshape(B, L, 3).to('cuda')
    noise_xyz =  torch.cat((NC_n,CA_n,CC_n),dim=2).reshape(B,L,3,3)
    
    x = graph_maker.prep_for_network(noised_dict)
    out = graph_unet(x, batched_t)
    CA_p = out['1'][:,0,:].reshape(B, L, 3)+CA_n #translation of Calpha
    Qs = out['1'][:,1,:] # rotation
    Qs = Qs.unsqueeze(1).repeat((1,2,1))
    Qs = torch.cat((torch.ones((B*L,2,1),device=Qs.device),Qs),dim=-1).reshape(B,L,2,4)
    Qs = normQ(Qs)
    Rs = Qs2Rs(Qs)
    N_C_to_Rot = torch.cat((noised_dict['N_CA'].reshape(B, L, 3).to('cuda'),
                            noised_dict['C_CA'].reshape(B, L, 3).to('cuda')),dim=2).reshape(B,L,2,1,3)

    
    
    rot_vecs = einsum('bnkij,bnkhj->bnki',Rs, N_C_to_Rot)
    NC_p = CA_p + rot_vecs[:,:,0,:].to('cuda')*N_CA_dist
    CC_p = CA_p + rot_vecs[:,:,1,:].reshape(B, L, 3).to('cuda')*C_CA_dist

    pred = torch.cat((NC_p,CA_p,CC_p),dim=2).reshape(B,L,3,3)
    
    tloss, loss = FAPE_loss(pred.unsqueeze(0), true, scores_scales)
    
    return tloss

In [14]:
def get_noise_pred_true(backbone_dict, noised_dict, batched_t, graph_maker, graph_unet):
    
    CA_t  = bb_dict['CA'].reshape(B, L, 3).to('cuda')
    NC_t = CA_t + bb_dict['N_CA'].reshape(B, L, 3).to('cuda')*N_CA_dist
    CC_t = CA_t + bb_dict['C_CA'].reshape(B, L, 3).to('cuda')*C_CA_dist
    true =  torch.cat((NC_t,CA_t,CC_t),dim=2).reshape(B,L,3,3)
    
    CA_n  = noised_dict['CA'].reshape(B, L, 3).to('cuda')
    NC_n = CA_n + noised_dict['N_CA'].reshape(B, L, 3).to('cuda')*N_CA_dist
    CC_n = CA_n + noised_dict['C_CA'].reshape(B, L, 3).to('cuda')*C_CA_dist
    noise_xyz =  torch.cat((NC_n,CA_n,CC_n),dim=2).reshape(B,L,3,3)
    
    x = graph_maker.prep_for_network(noised_dict)
    out = graph_unet(x, batched_t)
    CA_p = out['1'][:,0,:].reshape(B, L, 3)+CA_n #translation of Calpha
    Qs = out['1'][:,1,:] # rotation
    Qs = Qs.unsqueeze(1).repeat((1,2,1))
    Qs = torch.cat((torch.ones((B*L,2,1),device=Qs.device),Qs),dim=-1).reshape(B,L,2,4)
    Qs = normQ(Qs)
    Rs = Qs2Rs(Qs)
    N_C_to_Rot = torch.cat((noised_dict['N_CA'].reshape(B, L, 3).to('cuda'),
                            noised_dict['C_CA'].reshape(B, L, 3).to('cuda')),dim=2).reshape(B,L,2,1,3)
    
    
    rot_vecs = einsum('bnkij,bnkhj->bnki',Rs, N_C_to_Rot)
    NC_p = CA_p + rot_vecs[:,:,0,:].to('cuda')*N_CA_dist
    CC_p = CA_p + rot_vecs[:,:,1,:].reshape(B, L, 3).to('cuda')*C_CA_dist

    pred = torch.cat((NC_p,CA_p,CC_p),dim=2).reshape(B,L,3,3)
    
    return true.to('cpu').numpy()*10, noise_xyz.to('cpu').numpy()*10, pred.detach().to('cpu').numpy()*10

def dump_tnp(true, noise, pred, t_val, e=0, numOut=1,outdir='output/'):
    
    if numOut>true.shape[0]:
        numOut = true.shape[0]
    
    for x in range(numOut):
        dump_coord_pdb(true[x], fileOut=f'{outdir}/true_{t_val[x]*100:.0f}_e{e}_{x}.pdb')
        dump_coord_pdb(noise[x], fileOut=f'{outdir}/noise_{t_val[x]*100:.0f}_e{e}_{x}.pdb')
        dump_coord_pdb(pred[x], fileOut=f'{outdir}/pred_{t_val[x]*100:.0f}_e{e}_{x}.pdb')
        
def visualize_model(bb_dict, noised_bb, batched_t, epoch, numOut=1, outdir='output/'):
    true, noise, pred = get_noise_pred_true(bb_dict, noised_bb, batched_t, gm, gu)
    dump_tnp(true,noise,pred, batched_t, e=epoch, numOut=numOut, outdir=f'{outdir}/models/')

In [15]:
def make_save_folder(name=''):
    base_folder = time.strftime(f'log/%y%b%d_%I%M%p_{name}/', time.localtime())
    if not os.path.exists(base_folder):
        os.makedirs(base_folder)
    subfolders = ['models']
    for subfolder in subfolders:
        if not os.path.exists(base_folder + subfolder):
            os.makedirs(base_folder + subfolder)
            
    return base_folder
        
def save_chkpt(model_path, model, optimizer, epoch, batch, val_losses, train_losses):
    """Save a training checkpoint
    Args:
        model_path (str): the path to save the model to
        model (nn.Module): the model to save
        optimizer (torch.optim.Optimizer): the optimizer to save
        epoch (int): the current epoch
        batch (int): the current batch in the epoch
        loss_domain (list of int): a list of the shared domain for val and training 
            losses
        val_losses (list of float): a list containing the validation losses
        train_losses (list of float): a list containing the training losses
    """
    os.makedirs(os.path.dirname(model_path), exist_ok=True)
    state_dict = dict()
    state_dict.update({'model':model.state_dict(),
                       'optimizer':optimizer.state_dict(),
                       'epoch':epoch,
                       'batch':batch,
                       'train_losses':train_losses,
                       'val_losses':val_losses
                       })
    torch.save(state_dict, f'{model_path}model_e{epoch}')


In [17]:
B = 8
L=65
limit = 5048
h4_trainData = Helix4_Dataset(coords_tog[:limit])
h4_valData = Helix4_Dataset(coords_apa[:limit])
train_dL = DataLoader(h4_trainData, batch_size=B, shuffle=True, drop_last=True)
val_dL   = DataLoader(h4_valData, batch_size=B, shuffle=True, drop_last=True)


In [18]:
gu = GraphUNet(batch_size = B, num_layers_ca = 2).to('cuda')
opti = torch.optim.Adam(gu.parameters(), lr=0.001, weight_decay=5e-5)
gm = Make_KNN_MP_Graphs() #consider precalculating graphs for training
fdn= FrameDiffNoise()
#visualize_T
vis_t = np.array([0.01,0.05,0.1,0.2,0.3,0.5,0.8,1.0])
vis_t = vis_t[None,...].repeat(int(np.ceil(B/len(vis_t))),axis=0).flatten()[:B]

In [19]:
#t=0.1
#t_vec = np.ones((B,))*t
#print(t_vec)
model_path = make_save_folder(name=f'full_diff')
num_epochs = 300
save_per=10
avg_vloss=0

for e in range(num_epochs):
    
    running_tloss = 0 
    start = time.time()
    for i, bb_dict in enumerate(train_dL):
        noised_bb, tv, ss = fdn(bb_dict,t_vec=None)
        tv = tv.to('cuda')
        ss = ss.to('cuda')
        train_loss = model_step(bb_dict, noised_bb, tv, ss, gm, gu)
        opti.zero_grad()
        train_loss.backward()
        opti.step()

        running_tloss += train_loss.detach().cpu()
    
    end = time.time()
    avg_tloss = running_tloss/(i+1)
    print(f'Average Train Loss Epoch {e}: {avg_tloss};   Epoch time: {end-start:.0f}')

    if e %save_per==save_per-1:
        with torch.no_grad():
            running_vloss = 0
            for i, bb_dictv in enumerate(val_dL):
                noised_bb, tv, ss = fdn(bb_dictv,t_vec=None)
                tv = tv.to('cuda')
                ss = ss.to('cuda')
                valid_loss = model_step(bb_dictv, noised_bb, tv, ss, gm, gu)
                running_vloss += valid_loss
                
        avg_vloss = running_vloss/(i+1)
        print(f'Average Valid Loss Epoch {e}: {avg_vloss}')
                
        noised_bb, tv, ss = fdn(bb_dict, t_vec=vis_t)
        tv = tv.to('cuda')
        visualize_model(bb_dict, noised_bb, tv, e, numOut=8,outdir=model_path)
        save_chkpt(model_path, gu, opti, e, B, avg_vloss, avg_tloss)


  assert input.numel() == input.storage().size(), (


Average Train Loss Epoch 0: 0.26205548644065857;   Epoch time: 287
Average Train Loss Epoch 1: 0.1957971602678299;   Epoch time: 282
Average Train Loss Epoch 2: 0.17563512921333313;   Epoch time: 283
Average Train Loss Epoch 3: 0.18008320033550262;   Epoch time: 283
Average Train Loss Epoch 4: 0.17358267307281494;   Epoch time: 287
Average Train Loss Epoch 5: 0.174149289727211;   Epoch time: 282
Average Train Loss Epoch 6: 0.1762097030878067;   Epoch time: 282
Average Train Loss Epoch 7: 0.17160162329673767;   Epoch time: 280
Average Train Loss Epoch 8: 0.19456170499324799;   Epoch time: 280
Average Train Loss Epoch 9: 0.18396130204200745;   Epoch time: 280
Average Valid Loss Epoch 9: 0.17739665508270264
Average Train Loss Epoch 10: 0.19999010860919952;   Epoch time: 278
Average Train Loss Epoch 11: 0.1806856095790863;   Epoch time: 283
Average Train Loss Epoch 12: 0.18081332743167877;   Epoch time: 284
Average Train Loss Epoch 13: 0.17731228470802307;   Epoch time: 281
Average Train L

Average Train Loss Epoch 113: 0.15757107734680176;   Epoch time: 280
Average Train Loss Epoch 114: 0.1559603363275528;   Epoch time: 280
Average Train Loss Epoch 115: 0.15274885296821594;   Epoch time: 280
Average Train Loss Epoch 116: 0.15463188290596008;   Epoch time: 279
Average Train Loss Epoch 117: 0.15500536561012268;   Epoch time: 281
Average Train Loss Epoch 118: 0.15408453345298767;   Epoch time: 281
Average Train Loss Epoch 119: 0.15469998121261597;   Epoch time: 280
Average Valid Loss Epoch 119: 0.15367326140403748
Average Train Loss Epoch 120: 0.15501688420772552;   Epoch time: 280
Average Train Loss Epoch 121: 0.1543092131614685;   Epoch time: 280
Average Train Loss Epoch 122: 0.15332169830799103;   Epoch time: 280
Average Train Loss Epoch 123: 0.15771473944187164;   Epoch time: 280
Average Train Loss Epoch 124: 0.162766695022583;   Epoch time: 280
Average Train Loss Epoch 125: 0.15332628786563873;   Epoch time: 280
Average Train Loss Epoch 126: 0.15671999752521515;   Epoc

Average Train Loss Epoch 225: 0.15389785170555115;   Epoch time: 281
Average Train Loss Epoch 226: 0.15311820805072784;   Epoch time: 281
Average Train Loss Epoch 227: 0.16113580763339996;   Epoch time: 281
Average Train Loss Epoch 228: 0.15181714296340942;   Epoch time: 281
Average Train Loss Epoch 229: 0.15009094774723053;   Epoch time: 281
Average Valid Loss Epoch 229: 0.15398027002811432
Average Train Loss Epoch 230: 0.1537669450044632;   Epoch time: 282
Average Train Loss Epoch 231: 0.15142640471458435;   Epoch time: 284
Average Train Loss Epoch 232: 0.15242484211921692;   Epoch time: 290


KeyboardInterrupt: 

In [54]:
B=10028

In [55]:
test_batch['CA'].shape

torch.Size([8, 65, 3])

In [56]:
test_iter = iter(train_dL)
test_batch = next(test_iter)

t=0.05
t_vec = np.ones((B,))*t
nd, tv, ss = fdn(test_batch, t_vec=None)

# tv = tv.to('cuda')
# train_loss = model_step(bb_dict, noised_bb, tv, gm, gu)

In [57]:
ss.sum()

tensor(14.7698)

In [None]:
def noise_test(backbone_dict, noised_dict):
    CA_t  = bb_dict['CA'].reshape(B, L, 3).to('cuda')
    NC_t = CA_t + bb_dict['N_CA'].reshape(B, L, 3).to('cuda')*N_CA_dist
    CC_t = CA_t + bb_dict['C_CA'].reshape(B, L, 3).to('cuda')*C_CA_dist
    true =  torch.cat((NC_t,CA_t,CC_t),dim=2).reshape(B,L,3,3)
    
    CA_n  = noised_dict['CA'].reshape(B, L, 3).to('cuda')
    NC_n = CA_n + noised_dict['N_CA'].reshape(B, L, 3).to('cuda')*N_CA_dist
    CC_n = CA_n + noised_dict['C_CA'].reshape(B, L, 3).to('cuda')*C_CA_dist
    noise_xyz =  torch.cat((NC_n,CA_n,CC_n),dim=2).reshape(B,L,3,3)
    return true.to('cpu').numpy()*10, noise_xyz.to('cpu').numpy()*10

In [None]:
from util.npose_util import makePointPDB
#gds = Graph_RadiusMP_4H_Dataset(coords_tog[:100], 10, mp_stride = 3)
def view_mp_graph(mps: DGLGraph, coords: np.array ):
    p = mps.ndata['pos']*10
    
    to = np.concatenate((coords, np.ones_like(coords)[:,:,0][...,None]),axis=2)
    
    makePointPDB(p,'test.pdb',outDirec='output')
    nu.dump_npdb(to,'output/test2.pdb')
#view_mp_graph(gds.mpSelfGraphList[0], coords_tog[0])

In [None]:
# class GraphUNet(torch.nn.Module):
#     def __init__(self, 
#                  fiber_start = Fiber({0:12, 1:2}),
#                  fiber_out = Fiber({1:2}),
#                  k=4,
#                  batch_size = 8,
#                  stride=4,
#                  max_degree=3,
#                  channels=32,
#                  num_heads = 8,
#                  channels_div=4,
#                  num_layers = 1,
#                  num_layers_ca = 1,
#                  edge_feature_dim=1,
#                  latent_pool_type = 'avg',
#                  cuda=True):
#         super(GraphUNet, self).__init__()
        
        
#         self.comp_basis_grad = True
#         self.cuda = cuda
        
#         if cuda:
#             self.device='cuda:0'
#         else:
#             self.device='cpu'
        
#         self.max_degree=max_degree
#         self.B = batch_size
#         self.k = k
        
#         self.num_layers = 1
#         self.num_layers_ca = num_layers_ca
#         self.channels = 32
#         self.feat0 = 32
#         self.feat1 = 6
#         self.channels_div = 4
#         self.num_heads = 8
#         self.mult = int(stride/2)
#         self.fiber_edge=Fiber({0:edge_feature_dim})
#         self.edge_feat_dim = edge_feature_dim
        
#         self.pool_type = latent_pool_type
        
#         self.channels_down_ca = channels
#         #down c_alpha interactions by radius
#         self.fiber_start =  fiber_start
#         self.fiber_hidden_down_ca = Fiber.create(self.max_degree, self.channels_down_ca)
#         self.fiber_out_down_ca =Fiber({0: self.feat0, 1: self.feat1})
        
#         self.down_ca = SE3Transformer(num_layers = self.num_layers,
#                         fiber_in=self.fiber_start,
#                         fiber_hidden= self.fiber_hidden_down_ca, 
#                         fiber_out=self.fiber_out_down_ca,
#                         num_heads = self.num_heads,
#                         channels_div = self.channels_div,
#                         fiber_edge=self.fiber_edge,
#                         low_memory=True,
#                         tensor_cores=False)
        
#         self.channels_down_ca2mp = self.channels_down_ca*self.mult
        
#         #pool from c_alpha onto midpoints
#         self.fiber_in_down_ca2mp     = self.fiber_out_down_ca
#         self.fiber_hidden_down_ca2mp = Fiber.create(max_degree, self.channels_down_ca2mp)
#         self.fiber_out_down_ca2mp    = Fiber({0: self.feat0*self.mult, 1: self.feat1*self.mult})

#         self.down_ca2mp = SE3Transformer(num_layers = self.num_layers_ca,
#                             fiber_in     = self.fiber_in_down_ca2mp,
#                             fiber_hidden = self.fiber_hidden_down_ca2mp, 
#                             fiber_out    = self.fiber_out_down_ca2mp,
#                             num_heads =    self.num_heads,
#                             channels_div = self.channels_div,
#                             fiber_edge=self. fiber_edge,
#                             low_memory=True,
#                             tensor_cores=False)
        
#         self.fiber_in_mptopk =  self.fiber_out_down_ca2mp
#         self.fiber_hidden_down_mp  =self.fiber_hidden_down_ca2mp
#         self.fiber_out_down_mp_out =self.fiber_out_down_ca2mp
#         self.fiber_out_topkpool=Fiber({0: self.feat0*self.mult*self.mult})

#         self.mp_topk = SE3Transformer_topK(num_layers      = self.num_layers,
#                                         fiber_in      = self.fiber_in_mptopk,
#                                         fiber_hidden  = self.fiber_hidden_down_mp, 
#                                         fiber_out     = self.fiber_out_down_mp_out ,
#                                         fiber_out_topk= self.fiber_out_topkpool,
#                                         k             = self.k,
#                                         num_heads     = self.num_heads,
#                                         channels_div  = self.channels_div,
#                                         fiber_edge    =  self.fiber_edge,
#                                         low_memory=True,
#                                         tensor_cores=False)
        
#         self.gsmall = define_poolGraph(self.k, self.B, cast_type=torch.float32, cuda_out=self.cuda)
#         self.ef_small = pull_edge_features(self.gsmall, edge_feat_dim=self.edge_feat_dim)
        
#         #change to doing convolutions instead of points
#         self.fiber_in_down_gcn   =  self.fiber_out_topkpool
#         self.fiber_out_down_gcn  = Fiber({0: self.feat0*self.mult*self.mult, 1: self.feat1*self.mult})

#         self.down_gcn = ConvSE3(fiber_in  = self.fiber_in_down_gcn,
#                            fiber_out = self.fiber_out_down_gcn,
#                            fiber_edge= self.fiber_edge,
#                              self_interaction=True,
#                              use_layer_norm=True,
#                              max_degree=self.max_degree,
#                              fuse_level= ConvSE3FuseLevel.NONE,
#                              low_memory= True)
        
        
#         self.fiber_in_down_gcnp = self.fiber_out_down_gcn
#         #probably rename latent
#         self.latent_size = self.feat0*self.mult*self.mult
#         self.fiber_latent = Fiber({0: self.latent_size})

#         self.down_gcn2pool = ConvSE3(fiber_in=self.fiber_in_down_gcnp,
#                                      fiber_out=self.fiber_latent,
#                                      fiber_edge=self.fiber_edge,
#                                      self_interaction=True,
#                                      use_layer_norm=True,
#                                      max_degree=self.max_degree,
#                                      fuse_level= ConvSE3FuseLevel.NONE,
#                                      low_memory= True)
        
#         self.global_pool = GPooling(pool=self.pool_type, feat_type=0)

#         self.latent_unpool_layer = Latent_Unpool(fiber_in = self.fiber_latent, fiber_add = self.fiber_out_down_gcn, 
#                                             knodes = self.k)

#         self.up_gcn = ConvSE3(fiber_in=self.fiber_out_down_gcn,
#                              fiber_out=self.fiber_out_down_gcn,
#                              fiber_edge=self.fiber_edge,
#                              self_interaction=True,
#                              use_layer_norm=True,
#                              max_degree=self.max_degree,
#                              fuse_level= ConvSE3FuseLevel.NONE,
#                              low_memory= True)
        
#         self.unpool_layer = Unpool_Layer(fiber_in=self.fiber_out_down_gcn, fiber_add=self.fiber_out_down_ca)
        
#         self.fiber_in_up_gcn_mp = self.unpool_layer.fiber_out
#         self.fiber_hidden_up_mp= self.fiber_hidden_down_ca2mp
#         self.fiber_out_up_gcn_mp = self.fiber_out_down_mp_out

#         self.up_gcn_mp = SE3Transformer(num_layers = num_layers,
#                         fiber_in=self.fiber_in_up_gcn_mp,
#                         fiber_hidden= self.fiber_hidden_up_mp, 
#                         fiber_out=self.fiber_out_up_gcn_mp,
#                         num_heads = self.num_heads,
#                         channels_div = self.channels_div,
#                         fiber_edge=self.fiber_edge,
#                         low_memory=True,
#                         tensor_cores=False)
        
#         self.unpool_layer_off_mp = Unpool_Layer(fiber_in=self.fiber_out_down_mp_out, fiber_add=self.fiber_out_down_mp_out)

#         self.fiber_in_up_off_mp = self.fiber_out_up_gcn_mp
#         self.fiber_hidden_up_off_mp = self.fiber_hidden_up_mp
#         self.fiber_out_up_off_mp = self.fiber_out_down_ca 
        
#         #uses reverse graph to move mp off 
        
#         self.up_off_mp = SE3Transformer(num_layers = self.num_layers,
#                         fiber_in=self.fiber_in_up_off_mp,
#                         fiber_hidden= self.fiber_hidden_up_off_mp, 
#                         fiber_out=self.fiber_out_up_off_mp,
#                         num_heads = self.num_heads,
#                         channels_div = self.channels_div,
#                         fiber_edge=self.fiber_edge,
#                         low_memory=True,
#                         tensor_cores=False)
        
#         self.fiber_out = fiber_out
        
#         self.up_ca = SE3Transformer(num_layers = self.num_layers_ca,
#                                     fiber_in=self.fiber_out_down_ca,
#                                     fiber_hidden= self.fiber_hidden_down_ca, 
#                                     fiber_out=self.fiber_out,
#                                     num_heads = self.num_heads,
#                                     channels_div = self.channels_div,
#                                     fiber_edge= self.fiber_edge,
#                                     low_memory=True,
#                                     tensor_cores=False)
        
#         self.linear = LinearSE3(fiber_in=fiber_out,
#                                 fiber_out=fiber_out)
        
#         self.zero_linear()
        
#     def zero_linear(self):
#         nn.init.zeros_(self.linear.weights['1'])
        
#     def concat_mp_feats(self, ca_feats_in, mp_feats):

#         nf0_c = ca_feats_in['0'].shape[-2]
#         nf1_c = ca_feats_in['1'].shape[-2]

#         out0_cat_shape = (B,self.ca_nodes,-1,1)
#         mp0_cat_shape  = (B,self.mp_nodes,-1,1)
#         out1_cat_shape = (B,self.ca_nodes,-1,3)
#         mp1_cat_shape  = (B,self.mp_nodes,-1,3)

#         nf_c = {} #nf_cat
#         nf_c['0'] = torch.cat((ca_feats_in['0'].reshape(out0_cat_shape), 
#                                  mp_feats['0'].reshape(mp0_cat_shape)[:,-(self.mp_nodes-self.ca_nodes):,:,:]),
#                               axis=1).reshape((-1,nf0_c,1))

#         nf_c['1'] = torch.cat((ca_feats_in['1'].reshape(out1_cat_shape), 
#                                  mp_feats['1'].reshape(mp1_cat_shape)[:,-(self.mp_nodes-self.ca_nodes):,:,:]),
#                               axis=1).reshape((-1,nf1_c,3))

#         return nf_c
        
#     def pull_out_mp_feats(self, ca_mp_feats):

#         nf0_c = ca_mp_feats['0'].shape[1]
#         nf1_c = ca_mp_feats['1'].shape[1]

#         nf_mp_ = {}
#         #select just mp nodes to move on, the other nodes don't connect but mainting self connections
#         nf_mp_['0'] = ca_mp_feats['0'].reshape(B,self.mp_nodes,
#                                                nf0_c,1)[:,-(self.mp_nodes-self.ca_nodes):,...].reshape((-1,nf0_c,1))
#         nf_mp_['1'] = ca_mp_feats['1'].reshape(B,self.mp_nodes,
#                                                nf1_c,3)[:,-(self.mp_nodes-self.ca_nodes):,...].reshape((-1,nf1_c,3))

#         return nf_mp_
        
        
#     def forward(self, input_tuple):

#         b_graph, nf, ef, b_graph_mp, nf_mp, ef_mp, b_graph_mps, nf_mps, ef_mps, b_graph_mpRev = input_tuple
#         #assumes equal node numbers in g raphs
#         self.ca_nodes = int(b_graph.batch_num_nodes()[0])
#         self.mp_nodes = int(b_graph_mp.batch_num_nodes()[0]) #ca+mp nodes number

#         #SE3 Attention Transformer, c_alpha 
#         nf_ca_down_out = self.down_ca(b_graph, nf, ef)

#         #concatenate on midpoints feats
#         nf_down_cat_mp = self.concat_mp_feats(nf_ca_down_out, nf_mp)

#         #pool from ca onto selected midpoints via SE3 Attention transformer
#         #edges from ca to mp only (ca nodes zero after this)
#         nf_down_ca2mp_out = self.down_ca2mp(b_graph_mp, nf_down_cat_mp, ef_mp)

#         #remove ca node feats from tensor 
#         nf_mp_out = self.pull_out_mp_feats(nf_down_ca2mp_out)

#         node_feats_tk, topk_feats, topk_indx = self.mp_topk(b_graph_mps, nf_mp_out, ef_mps)

#         #make new basis for small graph of k selected midpoints
#         edge_feats_out, basis_out, new_pos = prep_for_gcn(self.gsmall, b_graph_mps.ndata['pos'], self.ef_small,
#                                                           topk_indx,
#                                                           max_degree=self.max_degree, comp_grad=True)

#         down_gcn_out = self.down_gcn(topk_feats, edge_feats_out, self.gsmall,  basis_out)

#         down_gcnpool_out = self.down_gcn2pool(down_gcn_out, edge_feats_out, self.gsmall,  basis_out)

#         pooled_tensor = self.global_pool(down_gcnpool_out,self.gsmall)
#         pooled = {'0':pooled_tensor}
#         #----------------------------------------- end of down section
#         lat_unp = self.latent_unpool_layer(pooled,down_gcn_out)

#         up_gcn_out = self.up_gcn(lat_unp, edge_feats_out, self.gsmall,  basis_out)

#         k_to_mp  = self.unpool_layer(up_gcn_out,node_feats_tk,topk_indx)

#         up_mp_gcn_out = self.up_gcn_mp(b_graph_mps, k_to_mp, ef_mps)
        
#         off_mp_add = {}
#         for k,v in up_mp_gcn_out.items():
#             off_mp_add[k] = torch.add(up_mp_gcn_out[k],nf_mp_out[k])


#         #####triple check from here
#         #midpoints node indices for unpool layer
#         mp_node_indx = torch.arange(self.ca_nodes,self.mp_nodes, device=self.device)
#         mp_idx = mp_node_indx[None,...].repeat_interleave(self.B,0)
#         mp_idx =((torch.arange(self.B,device=self.device)*(self.mp_nodes)).reshape((-1,1))+mp_idx).reshape((-1))
        
#         #during unpool, keep mp=values and ca=zeros
#         zeros_mp_ca = {}
#         for k,v in nf_down_cat_mp.items():
#             zeros_mp_ca[k] = torch.zeros_like(v, device=self.device)


#         unpoff_out = self.unpool_layer_off_mp(off_mp_add, zeros_mp_ca, mp_idx)
        
#         out_up_off_mp = self.up_off_mp(b_graph_mpRev, unpoff_out, ef_mp)
        
#         #select just ca nodes, mp = zeros from last convolution
#         inv_mp_idx= torch.arange(0,self.ca_nodes, device=self.device)
#         inv_mp_idx = inv_mp_idx[None,...].repeat_interleave(self.B,0)
#         inv_mp_idx =((torch.arange(self.B,device=self.device)*(self.mp_nodes)).reshape((-1,1))
#                      +inv_mp_idx).reshape((-1))

#         node_final_ca = {}
#         for key in out_up_off_mp.keys():
#             node_final_ca[key] = torch.add(out_up_off_mp[key][inv_mp_idx,...],nf_ca_down_out[key])

#         #return updates 
#         return self.linear(self.up_ca(b_graph, node_final_ca, ef))
                
        
        