In [1]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

1.12.1


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch_geometric.nn.pool import radius
import torch
from torch import nn

import os
import time

V = lambda x: x.detach().cpu().numpy()
from numpy import newaxis, mean, savetxt
import networkx as nx

from scipy import spatial
from scipy.spatial import cKDTree

#GCN -----------
from torch_geometric.nn import GCNConv

#GAT ----------
from torch_geometric.nn import GATConv

#GN -----------
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_scatter import scatter_mean

class NodeModel(torch.nn.Module):
    def __init__(self, in_feat, out_feat):
        super().__init__()

        self.in_feat = in_feat
        self.out_feat = out_feat

        self.node_mlp_1 = Seq(Lin(self.in_feat, self.in_feat, bias = False), ReLU(), Lin(self.in_feat, self.in_feat, bias = False))
        self.node_mlp_2 = Seq(Lin(2*self.in_feat, self.in_feat, bias = False), ReLU(), Lin(self.in_feat, self.out_feat, bias = False))

    def forward(self, x, edge_index):
        # x: [N, F_x], where N is the number of nodes.
        # edge_index: [2, E] with max entry N - 1.
        # edge_attr: [E, F_e]
        # u: [B, F_u]
        # batch: [N] with max entry B - 1.
        row, col = edge_index
        out = x[row]
        out = self.node_mlp_1(out)
        out = scatter_mean(out, col, dim=0, dim_size=x.size(0))
        out = torch.cat([x, out], dim=1)
        
        return self.node_mlp_2(out)
    

#Loss function -----------

MIN_REPUL_DIST = 1e-3
radius = .4
magnitude = 10

def Repulsion(X):
    # assuming X: (n,d)
    dX = X[newaxis] - X[:,newaxis] # (n,n,d)
    # have to remove the diagonal (self-loops)
    # dX shouldn't be calc'd for self. 
    # mask = torch.ones(len(X), len(X))- torch.eye(len(X))
    #return torch.sum(1/ (MIN_REPUL_DIST + torch.norm(dX, dim=(-1))))

    r = torch.sum( dX**2, dim = -1)
    return magnitude*torch.sum(torch.exp( -r/4/(radius**2) ))



def Elastic(X, A):
    D = torch.diag(torch.sum(A,dim = 1))
    # Laplacian
    L = D - A 
    return torch.trace(X.t() @ L @ X)

def Loss(X,A,c=1):
    return Elastic(X,A) + c * Repulsion(X)


def Elastic_edgelist(X, edg):
    # if edg.shape[-1] == 2:
    return torch.sum((X[edg[:,0]]-X[edg[:,1]])**2)/2
    
def Loss_edgelist(X,edg,c=1):
    return Elastic_edgelist(X,edg) + c * Repulsion(X)


#Alternative GCN models ----------- 
class GCN(nn.Module):
    def __init__(self, in_feat, out_feat, A=None, edgelist=None, N=None):
        """A: Adjacency matrix 
        if A is gven, edgelist ignored. 
        If edglelist given, max index in edgelist is assumed to be number of nodes N, unless N is given. 
        """
        super(GCN,self).__init__()
        if A!=None:
            self.A = torch.as_tensor(A)
        elif len(edgelist): 
            raise
            
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.W = nn.Parameter(torch.empty(in_feat, out_feat))
        nn.init.kaiming_normal_(self.W)
        
    def forward(self, x):
        return self.A @ x @ self.W


class GCN_Lowrank(nn.Module):
    def __init__(self, in_feat, out_feat, A=None, pc_num_frac = 0.1, 
                 edgelist=None, N=None):
        """A: Adjacency matrix 
        if A is gven, edgelist ignored. 
        If edglelist given, max index in edgelist is assumed to be number of nodes N, unless N is given. 
        """
        super(GCN_Lowrank,self).__init__()
        if A!=None:
            self.A = torch.as_tensor(A)
            self._prep_A_lowrank(A, pc_num_frac)
        elif len(edgelist): 
            raise
            
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.W = nn.Parameter(torch.empty(in_feat, out_feat))
        nn.init.kaiming_normal_(self.W)
        
    def _prep_A_lowrank(self, A, pc_num_frac):
        A = torch.as_tensor(A)
        A = (A+A.t())/2
        v,p = torch.linalg.eigh(A)
        k = int(pc_num_frac * len(A))
        idx = torch.argsort(v, descending=True)[:k]
        self.pval, self.pvec = v[:k], p[:,:k]
        self.pvec_tv = (self.pvec * self.pval).t()
        
    def agg(self,x):
        x = self.pvec_tv @ x
        return self.pvec @ x
            
    def forward(self, x):
        # return self.A @ x @ self.W
        return self.agg(x) @ self.W

    
def time_it(f):
    def dt(*args, **kw):
        t0 = time.time()
        a = f(*args, **kw)
        print('%s time: %.3g s'%(f.__name__,time.time()-t0))
        # return a
    return dt

#NeuLay model -----------

class ResGCN(nn.Module):
    def __init__(self, feat_dims = [1,1], A=None, edge_index = None, edgelist=None, normalize_A = True,
                activation = None, device = None, lr = 1e-2, GCN_class = GCN):
        super(ResGCN,self).__init__()
        if device==None:
            device = ('cuda' if torch.cuda.is_available() else 'cpu')
            print(device)
        self.device = torch.device(device)
        
        if edgelist!=None:
            self.edgelist = torch.tensor(edgelist, dtype = torch.long).to(self.device) 
            A = torch.zeros(n,n, dtype = torch.float32).to(self.device)
            for i,j in self.edgelist:
                A[i,j] = 1
                # A[j,i] = 1 
            self.Loss = lambda x: Loss_edgelist(x,self.edgelist)
            self.A = torch.as_tensor(A).to(self.device)
        
        elif A!=None:
            self.A = torch.as_tensor(A).to(self.device)
            self.Loss = lambda x: Loss(x, self.A)
            
        self.edge_index = edge_index.to(self.device)
        
        if normalize_A:
            degs = self.A.sum(dim=1) 
            self.DA = torch.diag(1/(1e-1+degs)) @ self.A 
        else:
            self.DA = self.A
            
        n = len(self.A)
        # =====  Module Parameters ======
        self.latent = nn.Parameter(torch.empty(n,feat_dims[0]).to(self.device))
        r = n**(1./feat_dims[0])
        print("latent radius {:.3g}".format(r))
        nn.init.normal_(self.latent,std = r)
        
        #self.gcn_list = nn.ModuleList([GCN_class(feat_dims[i], feat_dims[i+1], A=self.DA).to(self.device)
         #                              for i in range(len(feat_dims)-2)])

        if GCN_class == 'GCNConv':
          self.gcn_list = nn.ModuleList([GCNConv(feat_dims[i], feat_dims[i+1], improved= True, bias = False ).to(self.device)
                                        for i in range(len(feat_dims)-2)])
        if GCN_class == 'GATConv':
          self.gcn_list = nn.ModuleList([GATConv(feat_dims[i], feat_dims[i+1], bias = False).to(self.device)
                                         for i in range(len(feat_dims)-2)])

        if GCN_class == 'GraphNet':
          self.gcn_list = nn.ModuleList([NodeModel(feat_dims[i], feat_dims[i+1]).to(self.device)
                                         for i in range(len(feat_dims)-2)])
        
        
        self.projection_layer = (nn.Linear(sum(feat_dims[:-1]), feat_dims[-1]).to(self.device)
                                 if len(feat_dims)>1 else nn.Identity())
        self.loss_history = []
        self.optim = torch.optim.Adam(self.parameters(), lr = lr/2)
        
        self.fine_pos = nn.Parameter(torch.empty(n,feat_dims[-1]).to(self.device))
        nn.init.normal_(self.fine_pos,std = r)
        self.optim_fine = torch.optim.Adam([self.fine_pos], lr = lr)
        
    def forward(self,):
        out = [self.latent]
        for g in self.gcn_list:
            out += [g(out[-1], self.edge_index)]
        out = torch.concat(out,dim = 1)
        out = self.projection_layer(out)
        return out
    
    @time_it
    def train(self,gcn_steps=200, fdl_steps=2000, early_stop_check_steps = 100, 
              min_steps=100, #stop_delta_ratio = 5e-3, 
              gcn_stop_threshold = 2e-2,
              fdl_stop_threshold = 5e-3,
              **stop_kws):
        """train(self,gcn_steps=200, fdl_steps=2000, early_stop_check_steps = 100, 
              min_steps=100, #stop_delta_ratio = 5e-3, 
              gcn_stop_threshold = 2e-2,
              fdl_stop_threshold = 5e-3,
              **stop_kws):
        """
        try:
            if gcn_steps > 0 and len(self.gcn_list) > 0:
                self.train_gcn(steps=gcn_steps, early_stop_check=early_stop_check_steps, 
                               min_steps=min_steps, 
                               stop_delta_ratio=gcn_stop_threshold, **stop_kws)
                self.fine_pos.data = self()
            print(f'\nFDL training {fdl_steps} steps')
            self.train_fine(steps=fdl_steps, early_stop_check=early_stop_check_steps, 
                            min_steps=min_steps, 
                            stop_delta_ratio=fdl_stop_threshold, **stop_kws)
        except KeyboardInterrupt:
            print('\nTraining interrupted')
            return
        
    
    @time_it
    def train_gcn(self,steps = 100, early_stop_check = 100, min_steps=100, stop_delta_ratio = 5e-3, **stop_kws):
        for i in range(steps):
            self.optim.zero_grad()
            # loss = Loss(self(), self.A)
            loss = self.Loss(self())
            loss.backward() 
            self.loss_history += [loss.item()+0.]
            self.optim.step()
            if i> min_steps and i % early_stop_check == 1:
                # print('checking', i)
                if early_stopping(self.loss_history,stop_delta_ratio=stop_delta_ratio, **stop_kws): 
                    print('\n===========\nstopping at step ',i)
                    break
    
    @time_it    
    def train_fine(self,steps = 100, early_stop_check = 100, min_steps=100, stop_delta_ratio = 5e-3, **stop_kws):
        for i in range(steps):
            self.optim_fine.zero_grad()
            # loss = Loss(self.fine_pos, self.A)
            loss = self.Loss(self.fine_pos)
            loss.backward() 
            self.loss_history += [loss.item()+0.]
            self.optim_fine.step()
            if i> min_steps and i % early_stop_check == 1:
                # print('checking', i)
                if early_stopping(self.loss_history,stop_delta_ratio=stop_delta_ratio, **stop_kws): 
                    print('\n===========\nstopping at step ',i)
                    break
                    
    def get_node_pos(self):
        return self.fine_pos
    
    def save_layout(self,save_dir='./', save_name = 'nodes', delimiter=','):
        """saves a nodes.csv """
        pos = V(self.get_node_pos())
        os.makedirs(save_dir, exist_ok=True)
        savetxt(os.path.join(save_dir, save_name+'.csv'), pos, delimiter=delimiter)
        
        # TBA: create edgelist if not made
        # savetxt(os.path.join(save_dir, 'edges.csv'), self.edgelist, delimiter=delimiter)
            
            

def early_stopping(metric_list,
            small_window = 32,
            big_window = 1000,
            stop_delta_ratio = 1e-3, verbose=True):
    if len(metric_list) < 2*small_window:
        return False
    # check if chenges within big window and small window are smaller then the ratio
    big_window = max(big_window, 2*small_window)
    last = mean(metric_list[-small_window:])
    dl_small =  abs(last - mean(metric_list[-2*small_window:-small_window]))
    idx = max(0,len(metric_list)-big_window)
    dl_big = abs(last - mean(metric_list[idx:idx+small_window]))
    ratio = dl_small / dl_big
    if verbose: 
        print(f'step: {len(metric_list)}, Loss change ratio: {ratio:.3g}', end='\r')
        # print(f'Loss change ratio: {ratio:.3g}', end='\r')
    return ratio < stop_delta_ratio 

def plot_layout(res, dims=[0,1], edges=True, node_kws={}, edg_kws=dict(c='k',lw=.5)):
    x = res.get_node_pos()
    # subplot(aspect='equal')
    scatter(*V(x).T[dims], zorder = 1000, **node_kws)
    if edges: 
        for i,j in  zip(*torch.where(res.A)):
            plot(*V(x[[i.item(),j.item()]]).T[dims], **edg_kws)

In [3]:
def read_pkl_file(filename):
    return nx.read_gpickle(filename)

In [4]:
# Space dimensions
DIMENSIONS = 3

# choose GCN small steps
MAX_GCN_STEPS = int(2e3)

# fine-tuning FDL steps. make large
MAX_FDL_STEPS = int(2e4)

# early stopping (lower = runs longer)
GCN_STOP_THRESHOLD = 5e-3
FDL_STOP_THRESHOLD = 2e-3


In [5]:
data_path = "graphs/erdos_renyi/"
files = os.listdir(data_path)
files = files[:10]


e = 3

In [6]:
import shutil
save_dir = "layouts/"
if os.path.exists(save_dir):
    shutil.rmtree(save_dir)
os.makedirs(save_dir)
results = []

for f in files:
  g = read_pkl_file(os.path.join(data_path, f))

  n = len(g.nodes())
  # adjacency matrix
  a = torch.zeros(n,n)
  identity = torch.eye(n)
  # edgelist  
  edg = []
  for k,v in g.adj.items():
      for i in v:
          a[k,i] = 1
          edg += [[k,i]]

  edge_index = torch.tensor(edg).T

  for run in range(e):
    ###
    #layout = ResGCN(feat_dims=[100,100,DIMENSIONS], A= a, edge_index=edge_index, lr=1e-1, GCN_class='GCNConv')
    
    #start_time = time.time()
    #layout.train(gcn_steps=0, fdl_steps=MAX_FDL_STEPS, 
    #      early_stop_check_steps = 100, 
    #      min_steps=200, 
    #      gcn_stop_threshold = GCN_STOP_THRESHOLD,
    #      fdl_stop_threshold = FDL_STOP_THRESHOLD,
    #     )
    #results += [[ 'FDL', time.time() - start_time, layout.loss_history[-1] ]] 

    #layout.save_layout(save_dir='/content/drive/MyDrive/GCN_reparametrization/data', save_name = 'FDL_short_'+f+str(run) )
    
    ###
    layout = ResGCN(feat_dims=[100,100,DIMENSIONS], A= a, edge_index=edge_index, lr=1e-1, GCN_class='GraphNet')

    start_time = time.time()  
    layout.train(gcn_steps=MAX_GCN_STEPS, fdl_steps=MAX_FDL_STEPS, 
          early_stop_check_steps = 100, 
          min_steps=200, 
          gcn_stop_threshold = GCN_STOP_THRESHOLD,
          fdl_stop_threshold = FDL_STOP_THRESHOLD,
         )
    
    results += [[ 'GraphNet', time.time() - start_time, layout.loss_history[-1] ]] 
    
    layout.save_layout(save_dir=save_dir, save_name = 'GraphNet_short_'+f+str(run) )



cpu
latent radius 1.07


Traceback (most recent call last):
  File "_pydevd_bundle/pydevd_cython.pyx", line 1078, in _pydevd_bundle.pydevd_cython.PyDBFrame.trace_dispatch
  File "_pydevd_bundle/pydevd_cython.pyx", line 297, in _pydevd_bundle.pydevd_cython.PyDBFrame.do_wait_suspend
  File "d:\anaconda3\envs\pdm\lib\site-packages\debugpy\_vendored\pydevd\pydevd.py", line 1976, in do_wait_suspend
    keep_suspended = self._do_wait_suspend(thread, frame, event, arg, suspend_type, from_this_thread, frames_tracker)
  File "d:\anaconda3\envs\pdm\lib\site-packages\debugpy\_vendored\pydevd\pydevd.py", line 2011, in _do_wait_suspend
    time.sleep(0.01)
KeyboardInterrupt


KeyboardInterrupt: 

In [7]:
a = [1,3,2]
for i in range(5):
    print(i)

0
1
2
3
4
