In [2]:
#import logging
import pathlib
from typing import List
import dgl
from sklearn.preprocessing import MinMaxScaler
import time

import numpy as np
import torch

import torch.distributed as dist
import torch.nn as nn
import collections.abc as container_abcs
#from apex.optimizers import FusedAdam, FusedLAMB
from torch.nn.modules.loss import _Loss
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm

from se3_transformer.data_loading import QM9DataModule
from se3_transformer.model import SE3TransformerPooled
from se3_transformer.model.fiber import Fiber
#from se3_transformer.runtime import gpu_affinity
#from se3_transformer.runtime.arguments import PARSER
#from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \
    #PerformanceCallback
#from se3_transformer.runtime.inference import evaluate
#from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger
from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \
    using_tensor_cores, increase_l2_fetch_granularity
#import helix.helix_bb as hh

In [3]:
from torch.utils.data import random_split, DataLoader, Dataset

In [4]:
def get_midpoint(ep_in):
    """Get midpoints of input endpoints, 2 points per helix"""
    
    #calculate midpoint
    midpoint = ep_in.sum(axis=1)/np.repeat(ind_ep.shape[1], ind_ep.shape[2])
    
    return midpoint



def normalize_pc(points):
    """Center at Zero Divide furtherst points"""
    centroid = np.mean(points, axis=0)
    points -= centroid
    furthest_distance = np.max(np.sqrt(np.sum(abs(points)**2,axis=-1)))
    points /= furthest_distance

    return points, furthest_distance
    
def make_pe_encoding(i_pos=8, embed_dim = 8, scale = 10, cast_type=torch.float32):
    #positional encoding of node
    i_array = np.arange(1,(embed_dim/2)+1)
    wk = (1/(scale**(i_array*2/embed_dim)))
    t_array = np.arange(i_pos)
    si = torch.tensor(np.sin(wk*t_array.reshape((-1,1))))
    ci = torch.tensor(np.cos(wk*t_array.reshape((-1,1))))
    pe = torch.stack((si,ci),axis=2).reshape(t_array.shape[0],embed_dim).type(cast_type)
    return pe


def make_graph_struct(batch_size=32, n_nodes = 8):
    # make a fake graph to be filled with generator outputs
    
    v1 = np.arange(n_nodes-1) #vertex 1 of edges in chronological order
    v2 = np.arange(1,n_nodes) #vertex 2 of edges in chronological order

    ss = np.zeros(len(v1),dtype=np.int32)
    ss[np.arange(ss.shape[0])%2==0]=1  #alternate 0,1 for helix, loop, helix, etc
    ss = ss[:,None] #unsqueeze
    
    pe = make_pe_encoding(i_pos=8, embed_dim = 8, scale = 10, cast_type=torch.float32)

    graphList = []
    for i in range(batch_size):
        g = dgl.graph((v1,v2))
        g.edata['ss'] = torch.tensor(ss,dtype=torch.float32)
        g.ndata['pe'] = pe

        graphList.append(g)

    batched_graph = dgl.batch(graphList)

    return batched_graph


class GraphDataset(Dataset):
    def __init__(self, ep_file : pathlib.Path, limit=1000):
        self.data_path = ep_file
        rr = np.load(self.data_path)
        ep = [rr[f] for f in rr.files][0][:1000]
        
        #need to save furthest distance to regen later
        #maybe consider small change for next steps
        ep, self.furthest_distance = normalize_pc(ep.reshape((-1,3)))
        self.ep = ep.reshape((-1,8,3))
        
        
        v1 = np.arange(self.ep.shape[1]-1) #vertex 1 of edges in chronological order
        v2 = np.arange(1,self.ep.shape[1]) #vertex 2 of edges in chronological order

        ss = np.zeros(len(v1))
        ss[np.arange(ss.shape[0])%2==0]=1  #alternate 0,1 for helix, loop, helix, etc
        ss = ss[:,None] #unsqueeze

        #positional encoding of node
        pe = make_pe_encoding(i_pos=8, embed_dim = 8, scale = 10, cast_type=torch.float32)

        graphList = []

        for i,c in enumerate(self.ep):

            g = dgl.graph((v1,v2))
            g.ndata['pos'] = torch.tensor(c,dtype=torch.float32)
            g.edata['ss'] = torch.tensor(ss,dtype=torch.float32)
            g.ndata['pe'] = pe

            graphList.append(g)
        
        self.graphList = graphList


    def __len__(self):
        return len(self.graphList)

    def __getitem__(self, idx):
        return self.graphList[idx]

    
class HGenDataModule():
    """
    Datamodule wrapping hGen data set. 8 Helical endpoints defining a four helix protein.
    """
    #8 long positional encoding
    NODE_FEATURE_DIM = 8
    EDGE_FEATURE_DIM = 1 # 0 or 1 helix or loop

    def __init__(self,
                 data_dir: pathlib.Path, batch_size=32):
        
        self.data_dir = data_dir 
        self.GraphDatasetObj = GraphDataset(self.data_dir)
        self.gds = DataLoader(self.GraphDatasetObj,batch_size=batch_size, shuffle=True, drop_last=True,
                              collate_fn=self._collate)
        
    
        
    def _collate(self, graphs):
        batched_graph = dgl.batch(graphs)
        #reshape that batched graph to redivide into the individual graphs
        edge_feats = {'0': batched_graph.edata['ss'][:, :self.EDGE_FEATURE_DIM, None]}
        batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
        # get node features
        node_feats = {'0': batched_graph.ndata['pe'][:, :self.NODE_FEATURE_DIM, None]}
        
        return (batched_graph, node_feats, edge_feats)
    
def eval_gen(batch_size=8,z=12):
    
    in_z = torch.randn((batch_size,z), device='cuda',dtype = torch.float32)
    out = hg(in_z)*31
    out = out.reshape((-1,8,3)).detach().cpu().numpy()
    
    return eval_endpoints(out)
    
    

def eval_endpoints(ep_in): 
    
    ep = ep_in.reshape((-1,8,3))

    v1 = np.arange(ep.shape[1]-1) #vertex 1 of edges in chronological order
    v2 = np.arange(1,ep.shape[1]) #vertex 2 of edges in chronological order

    hLL = np.linalg.norm(ep[:,v1]-ep[:,v2],axis=2)

    hLoc = np.array([0,2,4,6])
    lLoc = np.array([1,3,5])

    return np.mean(hLL[:,hLoc]), np.mean(hLL[:,lLoc])
        

In [5]:
torch.cuda.is_available()

True

In [6]:
#num channels relates to self interaction of features of the same degree on the same node (1x1 convolution)
#learnable skip connections, since nodes don't attend to themselves nedded


#--num_degrees: Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1] (default: 4)
#so num degrees is 3?
# Fiber

# A fiber can be viewed as a representation of a set of features of different types or degrees (positive integers), where each feature type transforms according to its rule.
# In this repository, a fiber can be seen as a dictionary with degrees as keys and numbers of channels as values.

#Edge feature dimension does not include rel_pos which is concatenated per forward pass

#I believe channels needs to equal degrees times heads to match expansion by heads
#with the self atention by channels 
def to_detach(x):
    """ Try to convert a Tensor, a collection of Tensors or a DGLGraph to CUDA """
    if isinstance(x, Tensor):
        return x.detach()
    elif isinstance(x, tuple):
        return (to_detach(v) for v in x)
    elif isinstance(x, list):
        return [to_detach(v) for v in x]
    elif isinstance(x, dict):
        return {k: to_detach(v) for k, v in x.items()}
    else:
        # DGLGraph or other objects
        return x
def _get_relative_pos(graph_in: dgl.DGLGraph) -> torch.Tensor:
    x = graph_in.ndata['pos']
    src, dst = graph_in.edges()
    rel_pos = x[dst] - x[src]
    return rel_pos

class helixGen(nn.Module):
    def __init__(self, input_z=12, hidden=64, output=24):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_z, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, output),
            nn.Tanh()
        )

    def forward(self, x):
        out_xyz = self.linear_relu_stack(x)
        return out_xyz.reshape((-1,3))
    
    


batch_size = 128
num_channels = 16
num_degrees = 4

kwargs = dict()
kwargs['pooling'] = 'max'
kwargs['num_layers'] = 4
kwargs['num_heads'] = 8
kwargs['channels_div'] =torch.tensor(2,dtype=torch.int32)

#channels dive = channels/num_heads


dm = HGenDataModule(pathlib.Path('data/ep_for_gmp.npz'),batch_size=batch_size)


model = SE3TransformerPooled(
        fiber_in=Fiber({0: dm.NODE_FEATURE_DIM}),
        fiber_out=Fiber({0: num_degrees * num_channels}),
        fiber_edge=Fiber({0: dm.EDGE_FEATURE_DIM}),
        output_dim=1,
        tensor_cores=False,
        num_degrees=num_degrees,
        num_channels=num_channels,
        **kwargs
    )

device=torch.cuda.current_device()

hg = helixGen()
hg.to(torch.float32)
hg.to(device)

model.to(device)
model.to(torch.float32)
loss_fn = nn.BCEWithLogitsLoss().to(device)
#loss_fn = nn.MSELoss()
d_opt = torch.optim.SGD(model.parameters(), lr=0.001, weight_decay=0.001)
g_opt = torch.optim.SGD(hg.parameters(), lr=0.001, weight_decay=0.001)
batched_graph = make_graph_struct(batch_size=batch_size)
batched_graph = to_cuda(batched_graph)

In [7]:
def train_step(input_z, input_real):
    
    batched_graph_real, node_feats_real, edge_feats_real = input_real
    batched_graph_real = to_cuda(batched_graph_real)
    node_feats_real = to_cuda(node_feats_real)
    edge_feats_real = to_cuda(edge_feats_real)
    
    model.zero_grad()
    hg.zero_grad()
    #----------------compute the generators loss---------------------------
    outpos = hg(input_z)

    batched_graph.ndata['pos'] = outpos
    batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
    #these node and edges feats are static, the s03 transformer is order independent
    #right now they just denote position
    g_fake_out = model(batched_graph, node_feats_real, edge_feats_real, compute_gradients=True) 
    real_fake_targets = torch.ones(g_fake_out.shape[0],dtype=torch.float32, device='cuda')
    g_loss = loss_fn(g_fake_out, real_fake_targets)
    
    #retain_graph=True
    g_loss.backward()
    g_opt.step()
    
    
    
    
    #----------compute the discrimanators loss
    model.zero_grad()
    d_real_out = model(batched_graph_real, node_feats_real, edge_feats_real)
    real_targets = torch.ones(d_real_out.shape[0],dtype=torch.float32, device='cuda')
    d_loss_real = loss_fn(d_real_out, real_targets)
    
    batched_graph.edata['rel_pos'] = batched_graph.edata['rel_pos'].detach()
    
    d_fake_out = model(batched_graph, node_feats_real, edge_feats_real)
    # I believe we can detach the input batch_graph position here, and not use retain_graph equals true
    #also fix repeated used of node_feats_real, edge_feats_real
    fake_targets = torch.zeros(d_fake_out.shape[0],dtype=torch.float32, device='cuda')
    d_loss_fake = loss_fn(d_fake_out, fake_targets)
    
    d_loss = d_loss_real + d_loss_fake
    
    d_loss.backward()
    d_opt.step()
    
    
    #save probs here
    #d_loss.detach(), g_loss.detach(),
    
    return  g_loss.detach(), d_loss.detach(), g_fake_out.detach(), d_fake_out.detach(), d_real_out.detach(), 
    
    
    

In [8]:
dm.GraphDatasetObj.furthest_distance

29.20556601787868

In [24]:
dm.GraphDatasetObj.ep

In [22]:
print(eval_endpoints(dm.GraphDatasetObj.ep))

In [15]:
all_losses = []
all_d_vals = []
epochs = 200
start_time = time.time()
for epoch in range(epochs):

    epoch_d_vals, epoch_losses = [],[]

    for i, input_real in enumerate(dm.gds):
        in_z = torch.randn((batch_size,12), device='cuda',dtype = torch.float32)

        gloss, dloss, gfo, dfo, dfr = train_step(in_z, input_real)
        d_probs_real = torch.mean(torch.sigmoid(dfr))
        d_probs_fake = torch.mean(torch.sigmoid(dfo))
        g_probs = torch.mean(torch.sigmoid(gfo))

        epoch_losses.append((gloss.cpu().numpy(), dloss.cpu().numpy()))
        epoch_d_vals.append((d_probs_real.cpu().numpy(), d_probs_fake.cpu().numpy()))
        
    all_losses.append(epoch_losses)
    all_d_vals.append(epoch_d_vals)

    track = f'Epoch {epoch:03d} |  ET {(time.time()-start_time)/60:.2f} min AvgLosses >> G/D '
    track = f'{track}{(np.mean(all_losses[-1][0],axis=0)):.3f}/{(np.mean(all_losses[-1][1],axis=0)):.3f}'
    track = f'{track} D Real :{(np.mean(all_d_vals[-1][0],axis=0)):.3f}'
    track = f'{track} D Fake :{(np.mean(all_d_vals[-1][1],axis=0)):.3f}'
    track = f'{track} Length EvaL {eval_gen(batch_size=batch_size)}'

    print(track)
    
    

    
    

  assert input.numel() == input.storage().size(), (


Epoch 000 |  ET 0.28 min AvgLosses >> G/D 1.027/1.024 D Real :0.506 D Fake :0.507 Length EvaL (8.068716, 10.25965)
Epoch 001 |  ET 0.31 min AvgLosses >> G/D 1.021/1.018 D Real :0.506 D Fake :0.505 Length EvaL (8.63751, 10.347848)
Epoch 002 |  ET 0.34 min AvgLosses >> G/D 1.012/1.010 D Real :0.507 D Fake :0.510 Length EvaL (7.9941826, 10.26891)
Epoch 003 |  ET 0.36 min AvgLosses >> G/D 1.003/1.005 D Real :0.511 D Fake :0.509 Length EvaL (8.932562, 11.068898)
Epoch 004 |  ET 0.39 min AvgLosses >> G/D 1.000/0.997 D Real :0.509 D Fake :0.511 Length EvaL (8.0107355, 10.384983)
Epoch 005 |  ET 0.41 min AvgLosses >> G/D 0.994/0.995 D Real :0.511 D Fake :0.508 Length EvaL (8.7419195, 10.533969)
Epoch 006 |  ET 0.44 min AvgLosses >> G/D 0.991/0.989 D Real :0.510 D Fake :0.511 Length EvaL (8.523566, 10.109623)
Epoch 007 |  ET 0.46 min AvgLosses >> G/D 0.988/0.985 D Real :0.508 D Fake :0.510 Length EvaL (9.120443, 10.897237)
Epoch 008 |  ET 0.49 min AvgLosses >> G/D 0.986/0.986 D Real :0.510 D Fa

Epoch 071 |  ET 2.15 min AvgLosses >> G/D 1.038/0.975 D Real :0.470 D Fake :0.523 Length EvaL (25.90562, 11.868195)
Epoch 072 |  ET 2.17 min AvgLosses >> G/D 0.965/1.036 D Real :0.525 D Fake :0.468 Length EvaL (25.903322, 12.134791)
Epoch 073 |  ET 2.20 min AvgLosses >> G/D 1.124/0.907 D Real :0.417 D Fake :0.580 Length EvaL (25.9601, 12.28291)
Epoch 074 |  ET 2.23 min AvgLosses >> G/D 0.914/1.100 D Real :0.574 D Fake :0.429 Length EvaL (25.829723, 12.139056)
Epoch 075 |  ET 2.25 min AvgLosses >> G/D 1.067/0.928 D Real :0.450 D Fake :0.559 Length EvaL (25.704674, 11.964114)
Epoch 076 |  ET 2.28 min AvgLosses >> G/D 0.995/1.011 D Real :0.503 D Fake :0.489 Length EvaL (24.880882, 11.397389)
Epoch 077 |  ET 2.30 min AvgLosses >> G/D 1.081/0.958 D Real :0.442 D Fake :0.529 Length EvaL (25.462776, 11.527448)
Epoch 078 |  ET 2.33 min AvgLosses >> G/D 0.922/1.067 D Real :0.557 D Fake :0.451 Length EvaL (25.595926, 11.170032)
Epoch 079 |  ET 2.36 min AvgLosses >> G/D 1.017/0.984 D Real :0.486 

Epoch 142 |  ET 4.08 min AvgLosses >> G/D 1.282/0.930 D Real :0.395 D Fake :0.579 Length EvaL (22.762112, 10.680062)
Epoch 143 |  ET 4.10 min AvgLosses >> G/D 0.951/1.264 D Real :0.556 D Fake :0.402 Length EvaL (22.958635, 10.38737)
Epoch 144 |  ET 4.13 min AvgLosses >> G/D 1.282/0.856 D Real :0.386 D Fake :0.634 Length EvaL (22.765224, 10.519515)
Epoch 145 |  ET 4.16 min AvgLosses >> G/D 0.978/1.071 D Real :0.536 D Fake :0.480 Length EvaL (22.357494, 10.123593)
Epoch 146 |  ET 4.19 min AvgLosses >> G/D 1.259/0.920 D Real :0.402 D Fake :0.580 Length EvaL (22.402143, 9.833534)
Epoch 147 |  ET 4.22 min AvgLosses >> G/D 0.920/1.175 D Real :0.579 D Fake :0.437 Length EvaL (22.537233, 9.919255)
Epoch 148 |  ET 4.24 min AvgLosses >> G/D 1.139/0.944 D Real :0.444 D Fake :0.556 Length EvaL (22.396545, 9.899922)
Epoch 149 |  ET 4.27 min AvgLosses >> G/D 0.913/1.207 D Real :0.577 D Fake :0.420 Length EvaL (22.27145, 9.894733)
Epoch 150 |  ET 4.30 min AvgLosses >> G/D 0.940/1.192 D Real :0.566 D 

In [16]:
eval_gen(batch_size=batch_size)

(22.54369, 9.892837)

In [None]:
data_path = 'data/ep_for_gmp.npz'
rr = np.load(data_path)
ep = [rr[f] for f in rr.files][0]

In [11]:
eval_endpoints(ep)

(21.155681206729565, 9.422310760234602)