In [1]:
#import logging
import pathlib
from typing import List
import dgl
from sklearn.preprocessing import MinMaxScaler
import time

import numpy as np
import torch

import torch.distributed as dist
import torch.nn as nn
import collections.abc as container_abcs
#from apex.optimizers import FusedAdam, FusedLAMB
from torch.nn.modules.loss import _Loss
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm

from se3_transformer.data_loading import QM9DataModule
from se3_transformer.model import SE3TransformerPooled
from se3_transformer.model.fiber import Fiber
from se3_transformer.runtime import gpu_affinity
from se3_transformer.runtime.arguments import PARSER
from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \
    PerformanceCallback
from se3_transformer.runtime.inference import evaluate
from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger
from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \
    using_tensor_cores, increase_l2_fetch_granularity
import helix.helix_bb as hh

In [2]:
from torch.utils.data import random_split, DataLoader, Dataset

In [3]:
def get_midpoint(ep_in):
    """Get midpoints of input endpoints, 2 points per helix"""
    
    #calculate midpoint
    midpoint = ep_in.sum(axis=1)/np.repeat(ind_ep.shape[1], ind_ep.shape[2])
    
    return midpoint



def normalize_pc(points):
    """Center at Zero Divide furtherst points"""
    centroid = np.mean(points, axis=0)
    points -= centroid
    furthest_distance = np.max(np.sqrt(np.sum(abs(points)**2,axis=-1)))
    points /= furthest_distance

    return points, furthest_distance
    
def make_pe_encoding(i_pos=8, embed_dim = 8, scale = 10, cast_type=torch.float32):
    #positional encoding of node
    i_array = np.arange(1,(embed_dim/2)+1)
    wk = (1/(scale**(i_array*2/embed_dim)))
    t_array = np.arange(i_pos)
    si = torch.tensor(np.sin(wk*t_array.reshape((-1,1))))
    ci = torch.tensor(np.cos(wk*t_array.reshape((-1,1))))
    pe = torch.stack((si,ci),axis=2).reshape(t_array.shape[0],embed_dim).type(cast_type)
    return pe


def make_graph_struct(batch_size=32, n_nodes = 8):
    # make a fake graph to be filled with generator outputs
    
    v1 = np.arange(n_nodes-1) #vertex 1 of edges in chronological order
    v2 = np.arange(1,n_nodes) #vertex 2 of edges in chronological order

    ss = np.zeros(len(v1),dtype=np.int32)
    ss[np.arange(ss.shape[0])%2==0]=1  #alternate 0,1 for helix, loop, helix, etc
    ss = ss[:,None] #unsqueeze
    
    pe = make_pe_encoding(i_pos=8, embed_dim = 8, scale = 10, cast_type=torch.float32)

    graphList = []
    for i in range(batch_size):
        g = dgl.graph((v1,v2))
        g.edata['ss'] = torch.tensor(ss,dtype=torch.float32)
        g.ndata['pe'] = pe

        graphList.append(g)

    batched_graph = dgl.batch(graphList)

    return batched_graph


class GraphDataset(Dataset):
    def __init__(self, ep_file : pathlib.Path, limit=1000):
        self.data_path = ep_file
        rr = np.load(self.data_path)
        ep = [rr[f] for f in rr.files][0][:1000]
        
        #need to save furthest distance to regen later
        #maybe consider small change for next steps
        ep, self.furthest_distance = normalize_pc(ep.reshape((-1,3)))
        self.ep = ep.reshape((-1,8,3))
        
        
        v1 = np.arange(self.ep.shape[1]-1) #vertex 1 of edges in chronological order
        v2 = np.arange(1,self.ep.shape[1]) #vertex 2 of edges in chronological order

        ss = np.zeros(len(v1))
        ss[np.arange(ss.shape[0])%2==0]=1  #alternate 0,1 for helix, loop, helix, etc
        ss = ss[:,None] #unsqueeze

        #positional encoding of node
        pe = make_pe_encoding(i_pos=8, embed_dim = 8, scale = 10, cast_type=torch.float32)

        graphList = []

        for i,c in enumerate(self.ep):

            g = dgl.graph((v1,v2))
            g.ndata['pos'] = torch.tensor(c,dtype=torch.float32)
            g.edata['ss'] = torch.tensor(ss,dtype=torch.float32)
            g.ndata['pe'] = pe

            graphList.append(g)
        
        self.graphList = graphList


    def __len__(self):
        return len(self.graphList)

    def __getitem__(self, idx):
        return self.graphList[idx]

    
class HGenDataModule():
    """
    Datamodule wrapping hGen data set. 8 Helical endpoints defining a four helix protein.
    """
    #8 long positional encoding
    NODE_FEATURE_DIM = 8
    EDGE_FEATURE_DIM = 1 # 0 or 1 helix or loop

    def __init__(self,
                 data_dir: pathlib.Path, batch_size=32):
        
        self.data_dir = data_dir 
        self.GraphDatasetObj = GraphDataset(self.data_dir)
        self.gds = DataLoader(self.GraphDatasetObj,batch_size=batch_size, shuffle=True, drop_last=True,
                              collate_fn=self._collate)
        
    
        
    def _collate(self, graphs):
        batched_graph = dgl.batch(graphs)
        #reshape that batched graph to redivide into the individual graphs
        edge_feats = {'0': batched_graph.edata['ss'][:, :self.EDGE_FEATURE_DIM, None]}
        batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
        # get node features
        node_feats = {'0': batched_graph.ndata['pe'][:, :self.NODE_FEATURE_DIM, None]}
        
        return (batched_graph, node_feats, edge_feats)
    
def eval_gen(batch_size=8,z=12):
    
    in_z = torch.randn((batch_size,z), device='cuda',dtype = torch.float32)
    out = hg(in_z)*31
    out = out.reshape((-1,8,3)).detach().cpu().numpy()
    
    return eval_endpoints(out)
    
    

def eval_endpoints(ep_in): 
    
    ep = ep_in.reshape((-1,8,3))

    v1 = np.arange(ep.shape[1]-1) #vertex 1 of edges in chronological order
    v2 = np.arange(1,ep.shape[1]) #vertex 2 of edges in chronological order

    hLL = np.linalg.norm(ep[:,v1]-ep[:,v2],axis=2)

    hLoc = np.array([0,2,4,6])
    lLoc = np.array([1,3,5])

    return np.mean(hLL[:,hLoc]), np.mean(hLL[:,lLoc])
        

In [4]:
torch.cuda.is_available()

False

In [61]:
#num channels relates to self interaction of features of the same degree on the same node (1x1 convolution)
#learnable skip connections, since nodes don't attend to themselves nedded


#--num_degrees: Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1] (default: 4)
#so num degrees is 3?
# Fiber

# A fiber can be viewed as a representation of a set of features of different types or degrees (positive integers), where each feature type transforms according to its rule.
# In this repository, a fiber can be seen as a dictionary with degrees as keys and numbers of channels as values.

#Edge feature dimension does not include rel_pos which is concatenated per forward pass

#I believe channels needs to equal degrees times heads to match expansion by heads
#with the self atention by channels 
def to_detach(x):
    """ Try to convert a Tensor, a collection of Tensors or a DGLGraph to CUDA """
    if isinstance(x, Tensor):
        return x.detach()
    elif isinstance(x, tuple):
        return (to_detach(v) for v in x)
    elif isinstance(x, list):
        return [to_detach(v) for v in x]
    elif isinstance(x, dict):
        return {k: to_detach(v) for k, v in x.items()}
    else:
        # DGLGraph or other objects
        return x
def _get_relative_pos(graph_in: dgl.DGLGraph) -> torch.Tensor:
    x = graph_in.ndata['pos']
    src, dst = graph_in.edges()
    rel_pos = x[dst] - x[src]
    return rel_pos

class helixGen(nn.Module):
    def __init__(self, input_z=12, hidden=64, output=24):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_z, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, output),
            nn.Tanh()
        )

    def forward(self, x):
        out_xyz = self.linear_relu_stack(x)
        return out_xyz.reshape((-1,3))
    
    


batch_size = 128
num_channels = 16
num_degrees = 4

kwargs = dict()
kwargs['pooling'] = 'max'
kwargs['num_layers'] = 4
kwargs['num_heads'] = 8
kwargs['channels_div'] =torch.tensor(2,dtype=torch.int32)

#channels dive = channels/num_heads


dm = HGenDataModule(pathlib.Path('data/ep_for_gmp.npz'),batch_size=batch_size)


model = SE3TransformerPooled(
        fiber_in=Fiber({0: dm.NODE_FEATURE_DIM}),
        fiber_out=Fiber({0: num_degrees * num_channels}),
        fiber_edge=Fiber({0: dm.EDGE_FEATURE_DIM}),
        output_dim=1,
        tensor_cores=using_tensor_cores(False),
        num_degrees=num_degrees,
        num_channels=num_channels,
        **kwargs
    )

device=torch.cuda.current_device()

hg = helixGen()
hg.to(torch.float32)
hg.to(device)

model.to(device)
model.to(torch.float32)
loss_fn = nn.BCEWithLogitsLoss().to(device)
#loss_fn = nn.MSELoss()
d_opt = torch.optim.SGD(model.parameters(), lr=0.001, weight_decay=0.001)
g_opt = torch.optim.SGD(hg.parameters(), lr=0.001, weight_decay=0.001)
batched_graph = make_graph_struct(batch_size=batch_size)
batched_graph = to_cuda(batched_graph)

In [62]:
def train_step(input_z, input_real):
    
    batched_graph_real, node_feats_real, edge_feats_real = input_real
    batched_graph_real = to_cuda(batched_graph_real)
    node_feats_real = to_cuda(node_feats_real)
    edge_feats_real = to_cuda(edge_feats_real)
    
    model.zero_grad()
    hg.zero_grad()
    #----------------compute the generators loss---------------------------
    outpos = hg(input_z)

    batched_graph.ndata['pos'] = outpos
    batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
    #these node and edges feats are static, the s03 transformer is order independent
    #right now they just denote position
    g_fake_out = model(batched_graph, node_feats_real, edge_feats_real, compute_gradients=True) 
    real_fake_targets = torch.ones(g_fake_out.shape[0],dtype=torch.float32, device='cuda')
    g_loss = loss_fn(g_fake_out, real_fake_targets)
    
    #retain_graph=True
    g_loss.backward()
    g_opt.step()
    
    
    
    
    #----------compute the discrimanators loss
    model.zero_grad()
    d_real_out = model(batched_graph_real, node_feats_real, edge_feats_real)
    real_targets = torch.ones(d_real_out.shape[0],dtype=torch.float32, device='cuda')
    d_loss_real = loss_fn(d_real_out, real_targets)
    
    batched_graph.edata['rel_pos'] = batched_graph.edata['rel_pos'].detach()
    
    d_fake_out = model(batched_graph, node_feats_real, edge_feats_real)
    # I believe we can detach the input batch_graph position here, and not use retain_graph equals true
    #also fix repeated used of node_feats_real, edge_feats_real
    fake_targets = torch.zeros(d_fake_out.shape[0],dtype=torch.float32, device='cuda')
    d_loss_fake = loss_fn(d_fake_out, fake_targets)
    
    d_loss = d_loss_real + d_loss_fake
    
    d_loss.backward()
    d_opt.step()
    
    
    #save probs here
    #d_loss.detach(), g_loss.detach(),
    
    return  g_loss.detach(), d_loss.detach(), g_fake_out.detach(), d_fake_out.detach(), d_real_out.detach(), 
    
    
    

In [63]:
all_losses = []
all_d_vals = []
epochs = 200
start_time = time.time()
for epoch in range(epochs):

    epoch_d_vals, epoch_losses = [],[]

    for i, input_real in enumerate(dm.gds):
        in_z = torch.randn((batch_size,12), device='cuda',dtype = torch.float32)

        gloss, dloss, gfo, dfo, dfr = train_step(in_z, input_real)
        d_probs_real = torch.mean(torch.sigmoid(dfr))
        d_probs_fake = torch.mean(torch.sigmoid(dfo))
        g_probs = torch.mean(torch.sigmoid(gfo))

        epoch_losses.append((gloss.cpu().numpy(), dloss.cpu().numpy()))
        epoch_d_vals.append((d_probs_real.cpu().numpy(), d_probs_fake.cpu().numpy()))
        
    all_losses.append(epoch_losses)
    all_d_vals.append(epoch_d_vals)

    track = f'Epoch {epoch:03d} |  ET {(time.time()-start_time)/60:.2f} min AvgLosses >> G/D '
    track = f'{track}{(np.mean(all_losses[-1][0],axis=0)):.3f}/{(np.mean(all_losses[-1][1],axis=0)):.3f}'
    track = f'{track} D Real :{(np.mean(all_d_vals[-1][0],axis=0)):.3f}'
    track = f'{track} D Fake :{(np.mean(all_d_vals[-1][1],axis=0)):.3f}'
    track = f'{track} Length EvaL {eval_gen(batch_size=batch_size)}'

    print(track)
    
    

    
    

Epoch 000 |  ET 0.03 min AvgLosses >> G/D 0.953/0.952 D Real :0.575 D Fake :0.573 Length EvaL (8.25626, 7.3440986)
Epoch 001 |  ET 0.05 min AvgLosses >> G/D 0.947/0.945 D Real :0.569 D Fake :0.568 Length EvaL (7.8896675, 7.0250325)
Epoch 002 |  ET 0.07 min AvgLosses >> G/D 0.942/0.944 D Real :0.566 D Fake :0.564 Length EvaL (7.863228, 6.9552994)
Epoch 003 |  ET 0.09 min AvgLosses >> G/D 0.939/0.940 D Real :0.563 D Fake :0.561 Length EvaL (8.103188, 7.189209)
Epoch 004 |  ET 0.11 min AvgLosses >> G/D 0.940/0.938 D Real :0.557 D Fake :0.556 Length EvaL (8.279875, 7.4276423)


KeyboardInterrupt: 

In [104]:
eval_gen(batch_size=batch_size)

(24.367468, 10.302194)

In [54]:
896*8

7168

In [105]:
dm.GraphDatasetObj.furthest_distance

31.228954134944612

In [106]:
in_z = torch.randn((batch_size,12), device='cuda',dtype = torch.float32)
out = hg(in_z)*31.228954134944612
out = out.reshape((-1,8,3)).detach().cpu().numpy()

In [85]:
data_path = 'data/ep_for_gmp.npz'
rr = np.load(data_path)
ep = [rr[f] for f in rr.files][0]

In [86]:
gds= GraphDataset(data_path)

In [87]:
gds.furthest_distance

31.228954134944612

In [88]:
ep2 = gds.ep*gds.furthest_distance

In [107]:
hep = hh.EP_Recon(out)

In [108]:
npl = hep.to_npose()

In [109]:
for x in range(len(npl)):
    hh.nu.dump_npdb(npl[x],f'output/test_{x}.pdb')

In [30]:
out = (batched_graph_real.ndata['pos']*36).numpy().reshape((-1,8,3))

In [10]:
for i, input_real in enumerate(dm.gds):
    batched_graph_real, node_feats_real, edge_feats_real = input_real
    h,l = eval_endpoints(batched_graph_real.ndata['pos']*36)
    print(h,l)
    

12.077777 9.12193
11.456908 8.926224
11.475114 8.901618
11.425886 9.046969
11.976165 9.069525
11.843756 9.040355
11.139473 8.855804
11.722401 9.133283
11.81967 8.958594
11.975039 9.1037855
11.305781 8.875032
11.437936 8.918839
11.171131 8.927791
11.85101 8.967086
11.807449 9.046146
11.386745 8.943198
11.5919485 8.931197
11.43671 9.010324
12.039566 8.899229
11.381808 9.015571
11.148035 9.03046
11.83581 9.007893
11.907744 8.951764
11.17659 8.871845
11.521595 8.8635
11.803186 8.921136
10.994955 8.978087
11.404093 8.949065
11.466954 8.976123
11.342091 9.141507
11.6009865 8.899449
11.377945 9.0340805
11.199827 9.032464
11.858041 9.061273
11.615349 8.943565
11.5568495 9.001395
11.211485 9.044753
11.289133 8.870537
11.711969 8.875341
11.706545 8.941649
11.017423 8.992263
10.645308 8.903129
11.671707 8.967496
11.30115 9.053582
12.039323 9.041419
11.751722 8.876072
11.331128 8.96942
11.78825 9.005442
11.689248 8.999727
11.587079 9.010528
11.073433 8.976633
11.86665 9.0105915
11.370591 8.971354


In [None]:
def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]):
    """ Saves model, optimizer and epoch states to path (only once per node) """
    if get_local_rank() == 0:
        state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
        checkpoint = {
            'state_dict': state_dict,
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch
        }
        for callback in callbacks:
            callback.on_checkpoint_save(checkpoint)

        torch.save(checkpoint, str(path))
        logging.info(f'Saved checkpoint to {str(path)}')




In [116]:
def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path):
    """ Loads model, optimizer and epoch states from path """
    checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:0'})
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    logging.info(f'Loaded checkpoint from {str(path)}')
    return checkpoint['epoch']

In [111]:
def save_state(model_save: nn.Module, optimizer_save: Optimizer, epoch: int, path: pathlib.Path):
    """ Saves model, optimizer and epoch states to path (only once per node) """

    state_dict = model_save.state_dict()
    checkpoint = {
        'state_dict': state_dict,
        'optimizer_state_dict': optimizer_save.state_dict(),
        'epoch': epoch
    }

    torch.save(checkpoint, str(path))

In [118]:
save_state(model, d_opt, 250, f'results/test1/d_check')
save_state(hg, g_opt, 250, f'results/test1/g_check')

In [119]:
load_state(model, d_opt, f'results/test1/d_check')
load_state(hg, g_opt, f'results/test1/g_check')

250

In [None]:
class HGenTest():
    """
    Datamodule wrapping hGen data set. 8 Helical endpoints defining a four helix protein.
    """
    #8 long positional encoding
    NODE_FEATURE_DIM = 8
    EDGE_FEATURE_DIM = 1 # 0 or 1 helix or loop

    def __init__(self,
                 data_dir: pathlib.Path, batch_size=32):
        
        self.data_dir = data_dir 
        self.gds_fake = DataLoader(Fake_GraphDataset(self.data_dir),batch_size=32, shuffle=True, collate_fn=self._collate)

        
    def _collate(self, samples):
        graphs_real, graphs_fake = map(list, zip(*samples))    

        batched_graph_real = dgl.batch(graphs_real)
        #reshape that batched graph to redivide into the individual graphs
        edge_feats_real = {'0': batched_graph_real.edata['ss'][:, :self.EDGE_FEATURE_DIM, None]}
        batched_graph_real.edata['rel_pos'] = _get_relative_pos(batched_graph_real)
        # get node features
        node_feats_real = {'0': batched_graph_real.ndata['pe'][:, :self.NODE_FEATURE_DIM, None]}
        
        batched_graph_fake = dgl.batch(graphs_fake)
        #reshape that batched graph to redivide into the individual graphs
        edge_feats_fake = {'0': batched_graph_fake.edata['ss'][:, :self.EDGE_FEATURE_DIM, None]}
        batched_graph_fake.edata['rel_pos'] = _get_relative_pos(batched_graph_fake)
        # get node features
        node_feats_fake = {'0': batched_graph_fake.ndata['pe'][:, :self.NODE_FEATURE_DIM, None]}
        
        
        
        
        return (batched_graph_real, node_feats_real, edge_feats_real, batched_graph_fake, node_feats_fake, edge_feats_fake)

In [None]:
class Fake_GraphDataset(Dataset):
    def __init__(self, ep_file : pathlib.Path, limit=1000):
        self.data_path = ep_file
        rr = np.load(data_path)
        ep = [rr[f] for f in rr.files][0]
        
        ep = ep[:limit]
        
        v1 = np.arange(ep.shape[1]-1) #vertex 1 of edges in chronological order
        v2 = np.arange(1,ep.shape[1]) #vertex 2 of edges in chronological order

        ss = np.zeros(len(v1),dtype=np.int32)
        ss[np.arange(ss.shape[0])%2==0]=1  #alternate 0,1 for helix, loop, helix, etc
        ss = ss[:,None] #unsqueeze

        #positional encoding of node
        embed_dim = 8
        scale = 10
        i_array = np.arange(1,(embed_dim/2)+1)
        wk = (1/(scale**(i_array*2/embed_dim)))
        t_array = np.arange(ep.shape[1])
        si = torch.tensor(np.sin(wk*t_array.reshape((-1,1))))
        ci = torch.tensor(np.cos(wk*t_array.reshape((-1,1))))
        pe = torch.stack((si,ci),axis=2).reshape(ep.shape[1],embed_dim).type(torch.float32)

        graphList_real = []
        graphList_fake = []
        
        #randomize
        ep_f = ep.copy()
        rng = np.random.default_rng()
        rng.shuffle(ep_f,axis=1)
        

        for i,c in enumerate(ep):

            g = dgl.graph((v1,v2))
            g.ndata['pos'] = torch.tensor(c,dtype=torch.float32)
            g.edata['ss'] = torch.tensor(ss,dtype=torch.float32)
            g.ndata['pe'] = pe

            graphList_real.append(g)
            
        for i,c in enumerate(ep_f):

            g = dgl.graph((v1,v2))
            g.ndata['pos'] = torch.tensor(c,dtype=torch.float32)
            g.edata['ss'] = torch.tensor(ss,dtype=torch.float32)
            g.ndata['pe'] = pe

            graphList_fake.append(g)
        
        self.graphList_real = graphList_real
        self.graphList_fake = graphList_fake


    def __len__(self):
        return len(self.graphList_real)

    def __getitem__(self, idx):
        return self.graphList_real[idx], self.graphList_fake[idx]

In [101]:
def train_epoch_(data_module):
    loss_acc = torch.zeros((1,), device='cuda')
    for i, batch in enumerate(data_module.gds_fake):
        batched_graph_real, node_feats_real, edge_feats_real, batched_graph_fake, node_feats_fake, edge_feats_fake = to_cuda(batch)

        #print(node_feats_real['0'][0])
        #print(edge_feats_real['0'][0])
        #print(batched_graph_real.edata['rel_pos'])
        pred_real = model(batched_graph_real, node_feats_real, edge_feats_real)
        real_targets = torch.ones(pred_real.shape[0], dtype=torch.long, device='cuda')

        loss_real = loss_fn(pred_real,real_targets)
        
        loss_real.backward()
        optimizer.step()
        model.zero_grad()


        pred_fake = model(batched_graph_fake, node_feats_fake, edge_feats_fake)
        fake_targets = torch.zeros(pred_fake.shape[0],dtype=torch.long, device='cuda')

        loss_fake = loss_fn(pred_fake,fake_targets)

        loss_fake.backward()
        optimizer.step()
        model.zero_grad()
        
        loss_acc += loss_fake.detach()
        loss_acc += loss_real.detach()
    return loss_acc / (i+1)

    
    

In [None]:
# make a fake graph to be filled with generator outputs
n_nodes = 8
batch_size = 32

v1 = np.arange(n_nodes-1) #vertex 1 of edges in chronological order
v2 = np.arange(1,n_nodes) #vertex 2 of edges in chronological order

ss = np.zeros(len(v1),dtype=np.int32)
ss[np.arange(ss.shape[0])%2==0]=1  #alternate 0,1 for helix, loop, helix, etc
ss = ss[:,None] #unsqueeze

#positional encoding of node
embed_dim = 8
scale = 10
i_array = np.arange(1,(embed_dim/2)+1)
wk = (1/(scale**(i_array*2/embed_dim)))
t_array = np.arange(ep.shape[1])
si = torch.tensor(np.sin(wk*t_array.reshape((-1,1))))
ci = torch.tensor(np.cos(wk*t_array.reshape((-1,1))))
pe = torch.stack((si,ci),axis=2).reshape(n_nodes,embed_dim).type(torch.float32)

graphList = []
for i in range(batch_size):
    g = dgl.graph((v1,v2))
    g.edata['ss'] = torch.tensor(ss)
    g.ndata['pe'] = pe
    
    graphList.append(g)
    
batched_graph = dgl.batch(graphList)
#reshape that batched graph to redivide into the individual graphs
edge_feats = {'0': batched_graph.edata['ss'][:, :8, None]}

# get node features
node_feats = {'0': batched_graph.ndata['pe'][:,:1, None]}
  

src = torch.tensor(v1,dtype=torch.long)
dst = torch.tensor(v2,dtype=torch.long)

In [50]:
for i, input_real in enumerate(dm.gds):
    batched_graph_real, node_feats_real, edge_feats_real = input_real
    h,l = eval_endpoints(batched_graph_real.ndata['pos']*36)
    print(h,l)

25.981705 11.658938
26.098038 11.485578
26.076225 11.557309
26.028965 11.7545805
26.1562 11.585332
26.181648 11.615834
26.119846 11.527458


In [None]:
# num_epochs = 100
# for x in range(num_epochs):
#     loss = train_epoch(dm)
#     print(loss)

In [None]:
# tilBad = True
# counter = 0
# #shitty work around for div by zero
# while tilBad and counter<len(gg) and counter<max:
#     try:
#         eprec = hh.EP_Recon(gg[:counter])
#         arr = aa.to_npose()
#         counter += 1
#     except Exception:
#         tilBad = False