In [1]:
import torch
import numpy as np
import util.npose_util as nu
import pathlib
import dgl
from dgl import backend as F
import torch_geometric
from torch.utils.data import random_split, DataLoader, Dataset
from typing import Dict
from torch import Tensor
from dgl import DGLGraph
from torch import nn
from chemical import cos_ideal_NCAC #from RoseTTAFold2
from torch import einsum
torch.cuda.is_available()

True

In [2]:
from se3_transformer.model.basis import get_basis, update_basis_with_fused
from se3_transformer.model.transformer_reset import Sequential, SE3Transformer
from se3_transformer.model.layers.attentiontopK import AttentionBlockSE3
from se3_transformer.model.layers.linear import LinearSE3
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from se3_transformer.model.layers.norm import NormSE3
from se3_transformer.model.layers.pooling import GPooling
from se3_transformer.runtime.utils import str2bool, to_cuda
from se3_transformer.model.fiber import Fiber
from se3_transformer.model.transformer import get_populated_edge_features

In [3]:
import os
# Useful numbers
# N [-1.45837285,  0 , 0]
# CA [0., 0., 0.]
# C [0.55221403, 1.41890368, 0.        ]
# CB [ 0.52892494, -0.77445692, -1.19923854]

if ( hasattr(os, 'ATOM_NAMES') ):
    assert( hasattr(os, 'PDB_ORDER') )

    ATOM_NAMES = os.ATOM_NAMES
    PDB_ORDER = os.PDB_ORDER
else:
    ATOM_NAMES=['N', 'CA', 'CB', 'C', 'O']
    PDB_ORDER = ['N', 'CA', 'C', 'O', 'CB']

_byte_atom_names = []
_atom_names = []
for i, atom_name in enumerate(ATOM_NAMES):
    long_name = " " + atom_name + "       "
    _atom_names.append(long_name[:4])
    _byte_atom_names.append(atom_name.encode())

    globals()[atom_name] = i

R = len(ATOM_NAMES)

if ( "N" not in globals() ):
    N = -1
if ( "C" not in globals() ):
    C = -1
if ( "CB" not in globals() ):
    CB = -1


_pdb_order = []
for name in PDB_ORDER:
    _pdb_order.append( ATOM_NAMES.index(name) )

In [4]:
# data_path_str  = 'data/h4_ca_coords.npz'
# test_limit = 1028
# rr = np.load(data_path_str)
# ca_coords = [rr[f] for f in rr.files][0][:test_limit,:,:3]
# ca_coords.shape

# getting N-Ca, Ca-C vectors to add as typeI features
#apa = apart helices for test/train split
#tog = together helices for test/train split
apa_path_str  = 'data/h4_apa_coords.npz'
tog_path_str  = 'data/h4_tog_coords.npz'

#grab the first 3 atoms which are N,CA,C
test_limit = 1028
rr = np.load(apa_path_str)
coords_apa = [rr[f] for f in rr.files][0][:test_limit,:]

rr = np.load(tog_path_str)
coords_tog = [rr[f] for f in rr.files][0][:test_limit,:]

In [5]:
def build_npose_from_coords(coords_in):
    """Use N, CA, C coordinates to generate O an CB atoms"""
    rot_mat_cat = np.ones(sum((coords_in.shape[:-1], (1,)), ()))
    
    coords = np.concatenate((coords_in,rot_mat_cat),axis=-1)
    
    npose = np.ones((coords_in.shape[0]*5,4)) #5 is atoms per res

    by_res = npose.reshape(-1, 5, 4)
    
    if ( "N" in ATOM_NAMES ):
        by_res[:,N,:3] = coords_in[:,0,:3]
    if ( "CA" in ATOM_NAMES ):
        by_res[:,CA,:3] = coords_in[:,1,:3]
    if ( "C" in ATOM_NAMES ):
        by_res[:,C,:3] = coords_in[:,2,:3]
    if ( "O" in ATOM_NAMES ):
        by_res[:,O,:3] = nu.build_O(npose)
    if ( "CB" in ATOM_NAMES ):
        tpose = nu.tpose_from_npose(npose)
        by_res[:,CB,:] = nu.build_CB(tpose)

    return npose

def dump_coord_pdb(coords_in, fileOut='fileOut.pdb'):
    
    npose =  build_npose_from_coords(coords_in)
    nu.dump_npdb(npose,fileOut)

In [6]:
#goal define edges of
#connected backbone 1, 
#unconnected atoms 0,


def get_midpoint(ep_in):
    """Get midpoint, of each batched set of points"""
    
    #calculate midpoint
    midpoint = ep_in.sum(axis=1)/np.repeat(ep_in.shape[1], ep_in.shape[2])
    
    return midpoint


def normalize_points(input_xyz, print_dist=False):
    
    #broadcast to distance matrix [Batch, M, R3] to [Batch,M,1, R3] to [Batch,1,M, R3] to [Batch, M,M, R3] 
    vec_diff = input_xyz[...,None,:]-input_xyz[...,None,:,:]
    dist = np.sqrt(np.sum(np.square(vec_diff),axis=len(input_xyz.shape)))
    furthest_dist = np.max(dist)
    centroid  = get_midpoint(input_xyz)
    if print_dist:
        print(f'largest distance {furthest_dist:0.1f}')
    
    xyz_mean_zero = input_xyz - centroid[:,None,:]
    return xyz_mean_zero/furthest_dist

def normalize(v):
    norm = np.linalg.norm(v)
    if norm == 0: 
        return v
    return v / norm

def define_graph_edges(n_nodes):
    #connected backbone

    con_v1 = np.arange(n_nodes-1) #vertex 1 of edges in chronological order
    con_v2 = np.arange(1,n_nodes) #vertex 2 of edges in chronological order

    ind = con_v1*(n_nodes-1)+con_v2-1 #account for removed self connections (-1)


    #unconnected backbone

    nodes = np.arange(n_nodes)
    v1 = np.repeat(nodes,n_nodes-1) #starting vertices, same number repeated for each edge

    start_v2 = np.repeat(np.arange(n_nodes)[None,:],n_nodes,axis=0)
    diag_ind = np.diag_indices(n_nodes)
    start_v2[diag_ind] = -1 #diagonal of matrix is self connections which we remove (self connections are managed by SE3 Conv channels)
    v2 = start_v2[start_v2>-0.5] #remove diagonal and flatten

    edge_data = torch.zeros(len(v2))
    edge_data[ind] = 1
    
    return v1,v2,edge_data, ind

def make_pe_encoding(n_nodes=65, embed_dim = 12, scale = 1000, cast_type=torch.float32, print_out=False):
    #positional encoding of node
    i_array = np.arange(1,(embed_dim/2)+1)
    wk = (1/(scale**(i_array*2/embed_dim)))
    t_array = np.arange(n_nodes)
    si = torch.tensor(np.sin(wk*t_array.reshape((-1,1))))
    ci = torch.tensor(np.cos(wk*t_array.reshape((-1,1))))
    pe = torch.stack((si,ci),axis=2).reshape(t_array.shape[0],embed_dim).type(cast_type)
    
    if print_out == True:
        for x in range(int(n_nodes/12)):
            print(np.round(pe[x],1))
    
    return pe
    
    
#v1,v2,edge_data, ind = define_graph_edges(n_nodes)
#norm_p = normalize_points(ca_coords,print_dist=True)
pe = make_pe_encoding(n_nodes=65, embed_dim = 12, scale = 10, print_out=True)

tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
tensor([0.6000, 0.8000, 0.4000, 0.9000, 0.3000, 1.0000, 0.2000, 1.0000, 0.1000,
        1.0000, 0.1000, 1.0000])
tensor([1.0000, 0.2000, 0.8000, 0.6000, 0.6000, 0.8000, 0.4000, 0.9000, 0.3000,
        1.0000, 0.2000, 1.0000])
tensor([ 0.9000, -0.5000,  1.0000,  0.2000,  0.8000,  0.6000,  0.6000,  0.8000,
         0.4000,  0.9000,  0.3000,  1.0000])
tensor([ 0.4000, -0.9000,  1.0000, -0.3000,  1.0000,  0.3000,  0.8000,  0.7000,
         0.6000,  0.8000,  0.4000,  0.9000])


In [7]:
#v1,v2,edge_data, ind = define_graph_edges(4)

In [8]:
#?dgl.nn.pytorch.KNNGraph, nearest neighbor graph maker
def define_graph(batch_size=8,n_nodes=65):
    
    v1,v2,edge_data, ind = define_graph_edges(n_nodes)
    pe = make_pe_encoding(n_nodes=n_nodes)
    
    graphList = []
    
    for i in range(batch_size):
        
        g = dgl.graph((v1,v2))
        g.edata['con'] = edge_data
        g.ndata['pe'] = pe

        graphList.append(g)
        
    batched_graph = dgl.batch(graphList)

    return batched_graph


In [9]:
def define_UGraph(n_nodes, batch_size, cast_type=torch.float32, cuda_out=True ):
    
    v1,v2,edge_data, ind = define_graph_edges(n_nodes)
    #pe = make_pe_encoding(n_nodes=n_nodes)#pe e
    
    graphList = []
    
    for i in range(batch_size):
        
        g = dgl.graph((v1,v2))
        g.edata['con'] = edge_data.type(cast_type).reshape((-1,1))
        g.ndata['pos'] = torch.zeros((n_nodes,3),dtype=torch.float32)

        graphList.append(g)
        
    batched_graph = dgl.batch(graphList)
    
    if cuda_out:
        return to_cuda(batched_graph)
    else:
        return batched_graph

def get_edge_features(graph,edge_feature_dim=1):
    return {'0': graph.edata['con'][:, :edge_feature_dim, None]}

class Graph_4H_Dataset(Dataset):
    def __init__(self, coordinates, cast_type=torch.float32):
                                    #prots,#length_prot in aa, #residues/aa, #xyz per atom
            
        #alphaFold reduce by 10
        coordinates = coordinates/10
            
        self.ca_coords = coordinates[:,:,CA,:]
        #unsqueeze to stack together later
        self.N_CA_vec = torch.tensor(coordinates[:,:,N,:] - coordinates[:,:,CA,:], dtype=cast_type).unsqueeze(2)
        self.C_CA_vec = torch.tensor(coordinates[:,:,C,:] - coordinates[:,:,CA,:], dtype=cast_type).unsqueeze(2)
        
        #set mean to zero and max_distance between points to 1, is this necessary? since se3 transforms distances
        #self.norm_ca = normalize_points(self.ca_coords)
        
        
        n_nodes = self.ca_coords.shape[1] 
        
        v1,v2,edge_data, ind = define_graph_edges(n_nodes)
        pe = make_pe_encoding(n_nodes=n_nodes)

        graphList = []

        for i,c in enumerate(self.ca_coords):

            g = dgl.graph((v1,v2))
            g.edata['con'] = edge_data.type(cast_type).reshape((-1,1))
            g.ndata['pe'] = pe
            g.ndata['pos'] = torch.tensor(c,dtype=cast_type)
            g.ndata['bb_ori'] = torch.cat((self.N_CA_vec[i],self.C_CA_vec[i]),axis=1)
            graphList.append(g)
        
        self.graphList = graphList


    def __len__(self):
        return len(self.graphList)

    def __getitem__(self, idx):
        return self.graphList[idx]

def _get_relative_pos(graph_in: dgl.DGLGraph) -> torch.Tensor:
    x = graph_in.ndata['pos']
    src, dst = graph_in.edges()
    rel_pos = x[dst] - x[src]
    return rel_pos
    
#needs to be done
class H4_DataModule():
    """
    Datamodule wrapping hGen data set. 8 Helical endpoints defining a four helix protein.
    """
    #8 long positional encoding
    NODE_FEATURE_DIM_0 = 12
    EDGE_FEATURE_DIM = 1 # 0 or 1 helix or loop
    NODE_FEATURE_DIM_1 = 2
    

    def __init__(self,
                 coords: np.array, batch_size=8):
        
        self.GraphDatasetObj = Graph_4H_Dataset(coords)
        self.gds = DataLoader(self.GraphDatasetObj, batch_size=batch_size, shuffle=True, drop_last=True,
                              collate_fn=self._collate)
        
    
        
    def _collate(self, graphs):
        batched_graph = dgl.batch(graphs)
        #reshape that batched graph to redivide into the individual graphs
        edge_feats = {'0': batched_graph.edata['con'][:, :self.EDGE_FEATURE_DIM, None]}
        batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
        # get node features
        node_feats = {'0': batched_graph.ndata['pe'][:, :self.NODE_FEATURE_DIM_0, None],
                      '1': batched_graph.ndata['bb_ori'][:,:self.NODE_FEATURE_DIM_1, :3]}
        
        return (batched_graph, node_feats, edge_feats)
    
class GaussianNoise(nn.Module):
    """Gaussian noise regularizer.

    Args:
        sigma (float, optional): relative standard deviation used to generate the
            noise. Relative means that it will be multiplied by the magnitude of
            the value your are adding the noise to. This means that sigma can be
            the same regardless of the scale of the vector.
        is_relative_detach (bool, optional): whether to detach the variable before
            computing the scale of the noise. If `False` then the scale of the noise
            won't be seen as a constant but something to optimize: this will bias the
            network to generate vectors with smaller values.
    """

    def __init__(self, sigma=0.1):
        super().__init__()
        self.sigma = sigma
        self.noise = torch.tensor(0,dtype=torch.float)

    def forward(self,x,scale=1.0):
        if self.sigma != 0:
            #without modifer mult, mean=0, std_dev=1
            sampled_noise = (self.noise.repeat(*x.size()).normal_() * self.sigma*scale).to(x.device)
            x = x + sampled_noise
        return x, sampled_noise
    

def rigid_from_3_points(N, Ca, C, non_ideal=False, eps=1e-8):
    #N, Ca, C - [B,L, 3]
    #R - [B,L, 3, 3], det(R)=1, inv(R) = R.T, R is a rotation matrix
    B,L = N.shape[:2]
    
    v1 = C-Ca
    v2 = N-Ca
    e1 = v1/(torch.norm(v1, dim=-1, keepdim=True)+eps)
    u2 = v2-(torch.einsum('bli, bli -> bl', e1, v2)[...,None]*e1)
    e2 = u2/(torch.norm(u2, dim=-1, keepdim=True)+eps)
    e3 = torch.cross(e1, e2, dim=-1)
    R = torch.cat([e1[...,None], e2[...,None], e3[...,None]], axis=-1) #[B,L,3,3] - rotation matrix
    
    if non_ideal:
        v2 = v2/(torch.norm(v2, dim=-1, keepdim=True)+eps)
        cosref = torch.clamp( torch.sum(e1*v2, dim=-1), min=-1.0, max=1.0) # cosine of current N-CA-C bond angle
        costgt = cos_ideal_NCAC.item()
        cos2del = torch.clamp( cosref*costgt + torch.sqrt((1-cosref*cosref)*(1-costgt*costgt)+eps), min=-1.0, max=1.0 )
        cosdel = torch.sqrt(0.5*(1+cos2del)+eps)
        sindel = torch.sign(costgt-cosref) * torch.sqrt(1-0.5*(1+cos2del)+eps)
        Rp = torch.eye(3, device=N.device).repeat(B,L,1,1)
        Rp[:,:,0,0] = cosdel
        Rp[:,:,0,1] = -sindel
        Rp[:,:,1,0] = sindel
        Rp[:,:,1,1] = cosdel
    
        R = torch.einsum('blij,bljk->blik', R,Rp)

    return R, Ca

def get_t(N, Ca, C, non_ideal=False, eps=1e-5):
    I,B,L=N.shape[:3]
    Rs,Ts = rigid_from_3_points(N.view(I*B,L,3), Ca.view(I*B,L,3), C.view(I*B,L,3), non_ideal=non_ideal, eps=eps)
    Rs = Rs.view(I,B,L,3,3)
    Ts = Ts.view(I,B,L,3)
    t = Ts[:,:,None] - Ts[:,:,:,None] # t[0,1] = residue 0 -> residue 1 vector
    return einsum('iblkj, iblmk -> iblmj', Rs, t) # (I,B,L,L,3)
def FAPE_loss(pred, true,  d_clamp=10.0, d_clamp_inter=30.0, A=10.0, gamma=1.0, eps=1e-6):
    '''
    Calculate Backbone FAPE loss from RosettaTTAFold
    https://github.com/uw-ipd/RoseTTAFold2/blob/main/network/loss.py
    Input:
        - pred: predicted coordinates (I, B, L, n_atom, 3)
        - true: true coordinates (B, L, n_atom, 3)
    Output: str loss
    '''
    I = pred.shape[0]
    true = true.unsqueeze(0)
    t_tilde_ij = get_t(true[:,:,:,0], true[:,:,:,1], true[:,:,:,2])
    t_ij = get_t(pred[:,:,:,0], pred[:,:,:,1], pred[:,:,:,2])

    difference = torch.sqrt(torch.square(t_tilde_ij-t_ij).sum(dim=-1) + eps)
    eij_label = difference[-1].clone().detach()

    clamp = torch.zeros_like(difference)

    # intra vs inter#me coded
    clamp[:,True] = d_clamp

    difference = torch.clamp(difference, max=clamp)
    loss = difference / A # (I, B, L, L)

    # calculate masked loss (ignore missing regions when calculate loss)
    loss = (loss[:,True]).sum(dim=-1) / (torch.ones_like(loss).sum()+eps) # (I)

    # weighting loss
    w_loss = torch.pow(torch.full((I,), gamma, device=pred.device), torch.arange(I, device=pred.device))
    w_loss = torch.flip(w_loss, (0,))
    w_loss = w_loss / w_loss.sum()

    tot_loss = (w_loss * loss).sum()
    
    return tot_loss, loss.detach()

In [10]:

def init_lecun_normal(module, scale=1.0):
    def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
        normal = torch.distributions.normal.Normal(0, 1)

        alpha = (a - mu) / sigma
        beta = (b - mu) / sigma

        alpha_normal_cdf = normal.cdf(torch.tensor(alpha))
        p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform

        v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8)
        x = mu + sigma * np.sqrt(2) * torch.erfinv(v)
        x = torch.clamp(x, a, b)

        return x

    def sample_truncated_normal(shape, scale=1.0):
        stddev = np.sqrt(scale/shape[-1])/.87962566103423978  # shape[-1] = fan_in
        return stddev * truncated_normal(torch.rand(shape))

    module.weight = torch.nn.Parameter( (sample_truncated_normal(module.weight.shape)) )
    return module

def init_lecun_normal_param(weight, scale=1.0):
    def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
        normal = torch.distributions.normal.Normal(0, 1)

        alpha = (a - mu) / sigma
        beta = (b - mu) / sigma

        alpha_normal_cdf = normal.cdf(torch.tensor(alpha))
        p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform

        v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8)
        x = mu + sigma * np.sqrt(2) * torch.erfinv(v)
        x = torch.clamp(x, a, b)

        return x

    def sample_truncated_normal(shape, scale=1.0):
        stddev = np.sqrt(scale/shape[-1])/.87962566103423978  # shape[-1] = fan_in
        return stddev * truncated_normal(torch.rand(shape))

    weight = torch.nn.Parameter( (sample_truncated_normal(weight.shape)) )
    return weight

class SE3TransformerWrapper(nn.Module):
    """SE(3) equivariant GCN with attention"""
    def __init__(self, batch_size=16, num_channels=8, num_degrees=3):
        super().__init__()     
        

        kwargs = dict()
        kwargs['pooling'] = None
        kwargs['num_layers'] = 4
        kwargs['num_heads'] = 8
        kwargs['channels_div'] =torch.tensor(1,dtype=torch.int32)
        self.se3 = SE3Transformer(
                                    fiber_in=Fiber({0:12,1:2}),
                                    fiber_hidden=Fiber.create(num_degrees,num_channels),
                                    fiber_out=Fiber({1:3}),
                                    fiber_edge=Fiber({0: 1}),
                                    tensor_cores=False,
                                    num_degrees=num_degrees,
                                    num_channels=num_channels,
                                    use_layer_norm=True, **kwargs)

        self.reset_parameter()

    def reset_parameter(self):

        # make sure linear layer before ReLu are initialized with kaiming_normal_
        for n, p in self.se3.named_parameters():
            if "bias" in n:
                nn.init.zeros_(p)
            elif len(p.shape) == 1:
                continue
            else:
                if "radial_func" not in n:
                    p = init_lecun_normal_param(p) 
                else:
                    if "net.6" in n:
                        nn.init.zeros_(p)
                    else:
                        nn.init.kaiming_normal_(p, nonlinearity='relu')

        ## make last layers to be zero-initialized
        #nn.init.zeros_(self.se3.graph_modules[-1].project.weights['0'])
        #if self.l1_out > 0:
        #    nn.init.zeros_(self.se3.graph_modules[-1].project.weights['1'])
        # make last layers to be zero-initialized
        #nn.init.zeros_(self.se3.graph_modules[-1].weights['0'])
        #if self.l1_out > 0:
        nn.init.zeros_(self.se3.graph_modules[-1].weights['1'])

    def forward(self, batched_graph, node_features, edge_features):
        return self.se3(batched_graph, node_features, edge_features)

In [11]:
def get_noise_pred_true(batched_graph, node_feats, gauss_noise, edge_feats, trans, B, L=65):
    
    true_pos = batched_graph.ndata['pos'].clone()

    #add vectors for N and C atomts
    CA_t  = true_pos.reshape(B, L, 3)
    NC_t = CA_t + node_feats['1'][:,0,:].reshape(B, L, 3)
    CC_t = CA_t + node_feats['1'][:,1,:].reshape(B, L, 3)
    true =  torch.cat((NC_t,CA_t,CC_t),dim=2).reshape(B,L,3,3)

    
    batched_graph.ndata['pos'], noise = gauss_noise.forward(batched_graph.ndata['pos'])
    
    CA_n = batched_graph.ndata['pos'].clone().reshape(B, L, 3)
    NC_n = CA_n + node_feats['1'][:,0,:].reshape(B, L, 3)
    CC_n = CA_n + node_feats['1'][:,1,:].reshape(B, L, 3)
    noise_xyz =  torch.cat((NC_n,CA_n,CC_n),dim=2).reshape(B,L,3,3)


    shift= trans.forward(batched_graph, node_feats,edge_feats)
    #offset = shift['1'].reshape(B, L, 3)
    #pred = torch.add(Ts.reshape(B, L, 3), batched_graph.ndata['pos'].reshape(B, L, 3))
    
    CA_p = shift['1'][:,0,:].reshape(B, L, 3)+batched_graph.ndata['pos'].reshape(B, L, 3)
    NC_p = CA_p + node_feats['1'][:,0,:].reshape(B, L, 3)
    CC_p = CA_p + node_feats['1'][:,1,:].reshape(B, L, 3)
    pred = torch.cat((NC_p,CA_p,CC_p),dim=2).reshape(B,L,3,3)
    
    return true.to('cpu').numpy()*10, noise_xyz.to('cpu').numpy()*10, pred.detach().to('cpu').numpy()*10

In [12]:
def train_step(batched_graph, node_feats, gauss_noise, edge_feats, trans):
    
    true_pos = batched_graph.ndata['pos'].clone()

    #add vectors for 
    CA_t  = true_pos.reshape(B, L, 3)
    NC_t = CA_t + node_feats['1'][:,0,:].reshape(B, L, 3)
    CC_t = CA_t + node_feats['1'][:,1,:].reshape(B, L, 3)
    true =  torch.cat((NC_t,CA_t,CC_t),dim=2).reshape(B,L,3,3)

    
    batched_graph.ndata['pos'], noise = gauss_noise.forward(batched_graph.ndata['pos'])


    shift= trans.forward(batched_graph, node_feats,edge_feats)
    #offset = shift['1'].reshape(B, L, 3)
    #pred = torch.add(Ts.reshape(B, L, 3), batched_graph.ndata['pos'].reshape(B, L, 3))
    
    CA_p = shift['1'][:,0,:].reshape(B, L, 3)+batched_graph.ndata['pos'].reshape(B, L, 3)
    NC_p = CA_p + shift['1'][:,1,:].reshape(B, L, 3)+node_feats['1'][:,0,:].reshape(B, L, 3)
    CC_p = CA_p + shift['1'][:,2,:].reshape(B, L, 3)+node_feats['1'][:,1,:].reshape(B, L, 3)
    pred = torch.cat((NC_p,CA_p,CC_p),dim=2).reshape(B,L,3,3)

    tloss, loss = FAPE_loss(pred.unsqueeze(0), true)
    
    opti.zero_grad()
    tloss.backward()
    opti.step()
    
    return tloss.detach()

In [13]:
device=torch.cuda.current_device()

batch_size = 16
num_channels = 8
num_degrees = 3

kwargs = dict()
kwargs['pooling'] = None
kwargs['num_layers'] = 4
kwargs['num_heads'] = 8
kwargs['channels_div'] =torch.tensor(1,dtype=torch.int32)

#channels dive = channels/num_heads

# model = SE3Transformer(
#         fiber_in=Fiber({0:12,1:2}),
#         fiber_hidden=Fiber.create(num_degrees,num_channels),
#         fiber_out=Fiber({1:1}),
#         fiber_edge=Fiber({0: 1}),
#         tensor_cores=False,
#         num_degrees=num_degrees,
#         num_channels=num_channels,
#         **kwargs
#     ).to(device)


model = SE3TransformerWrapper().to(device)
opti = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=5e-5)

In [16]:
B=batch_size
L=65
dm = H4_DataModule(coords_tog,batch_size=B)
gn = GaussianNoise(sigma=0.1).to('cuda')
for e in range(300):
    lsum = 0
    for i, inp in enumerate(dm.gds):
        batched_graph, node_feats, edge_feats = inp
        bg = to_cuda(batched_graph)
        nf = to_cuda(node_feats)
        ef = to_cuda(edge_feats)

        out = train_step(bg, nf, gn, ef,model)
        lsum += out
        
    
    for x, inp in enumerate(dm.gds):
        batched_graph, node_feats, edge_feats = inp
        bg = to_cuda(batched_graph)
        nf = to_cuda(node_feats)
        ef = to_cuda(edge_feats)
        true, noise, pred = get_noise_pred_true(bg, nf, gn, ef, model, B, L=65)
        dump_coord_pdb(true[0],fileOut=f'output/true_{e}.pdb')
        dump_coord_pdb(noise[0],fileOut=f'output/noise_{e}.pdb')
        dump_coord_pdb(pred[0],fileOut=f'output/pred_{e}.pdb')
        break

    print(lsum/i)

tensor(0.0226, device='cuda:0')
tensor(0.0226, device='cuda:0')
tensor(0.0226, device='cuda:0')
tensor(0.0225, device='cuda:0')
tensor(0.0225, device='cuda:0')
tensor(0.0226, device='cuda:0')


KeyboardInterrupt: 

In [34]:
for i, inp in enumerate(dm.gds):
    batched_graph, node_feats, edge_feats = inp
    bg = to_cuda(batched_graph)
    nf = to_cuda(node_feats)
    ef = to_cuda(edge_feats)
    true, noise, pred = get_noise_pred_true(bg, nf, gn, ef, model, B, L=65)
    break

In [35]:
dump_coord_pdb(true[0],fileOut='output/true.pdb')
dump_coord_pdb(noise[0],fileOut='output/noise.pdb')
dump_coord_pdb(pred[0],fileOut='output/pred.pdb')

In [22]:
nf['1'][:,0,:].reshape(B, L, 3).shape

torch.Size([4, 65, 3])

In [19]:
B

4

In [20]:
260/65

4.0

In [None]:
batched_graph.ndata['pos'] = gauss_noise.forward(batched_graph.ndata['pos'])

In [27]:
out = model(batched_graph, node_feats, edge_feats)

In [28]:
out['1'].shape

torch.Size([260, 1, 3])