In [23]:
import DeviceDir

# DIR, RESULTS_DIR = DeviceDir.get_directory()
# device, NUM_PROCESSORS = DeviceDir.get_device()

In [24]:
import networkx as nx
import matplotlib.pyplot as plt
import torch_geometric
from  torch_geometric.utils import remove_self_loops, add_remaining_self_loops, add_self_loops
from ipynb.fs.full.Dataset import get_data
from scipy.sparse import csr_matrix
from scipy.sparse.linalg import eigs
import numpy as np
from scipy import linalg
from sklearn.neighbors import NearestNeighbors
import os
import pickle
import torch.nn as nn
import copy

In [25]:
def get_pos(edge_index = None, y = None, n = None):
    G = nx.Graph()
    nodes = range(n)
    G.add_nodes_from(nodes)
    edges = edge_index.t().tolist()
    G.add_edges_from(edges)
    pos = nx.spring_layout(G)
    
    return pos

def draw_graph(edge_index = None, y = None, n = None, pos = None, G = None):
    # Create a random graph
    #edge_index = data.edge_index
    plt.figure(figsize=(5, 5))
    
    if G is None:
        # Convert the edge index to a NetworkX graph
        G = nx.Graph()
        nodes = range(n)
        G.add_nodes_from(nodes)
        edges = edge_index.t().tolist()
        G.add_edges_from(edges)

    # Plot the graph
    if pos is None:
        pos = nx.spring_layout(G)
        #pos = nx.random_layout(G)
        #pos = nx.spectral_layout(G)
    nx.draw(G, pos, with_labels=True, node_color=y.tolist(), cmap=plt.cm.get_cmap('cool', max(y)+1))
    plt.show()

# draw_graph(data.edge_index, data.y, data.num_nodes)

In [26]:
import torch
from torch_geometric.data import Data
from tqdm import tqdm
import torch.nn.functional as F
from torch_geometric.utils.convert import to_networkx, from_networkx

In [27]:
# x = torch.Tensor([[1,0],[1,0],[1,0],[0,1],[0,1],[0,1],[0,1]])
# y = torch.LongTensor([0,0,0, 1, 1, 1, 1])
# edge_index = torch.LongTensor([[1,2],[1,4],[1,5],[2,1],[3,6],[3,7],[4,5],[4,1],[4,6],[4,7],[5,1],[5,4],[5,6],[6,3],[6,4],[6,5],[6,7],[7,3],[7,4],[7,6]]).T
# edge_index = edge_index-1
# data = Data(x=x, y=y, edge_index = edge_index)

In [28]:
# data, dataset = get_data('karate')

In [29]:
class EdgeLoader(torch.utils.data.DataLoader):

    def __init__(self, data, batch_size: int, metric='cosine', log: bool = True, **kwargs):

        assert not data.edge_index.is_cuda
        self.data = data
        self.__batch_size__ = batch_size
        self.log = log
        self.E = data.num_edges
        self.metric = metric
        super().__init__(self, batch_size=batch_size, **kwargs)
    
    def __len__(self):
        return self.E

    def __getitem__(self, idx):
        edge_idx = self.data.edge_index[:,idx]
        x = self.data.x[edge_idx[0]].unsqueeze(0)
        y = self.data.x[edge_idx[1]].unsqueeze(0)
        
        if self.metric=='euclidean':            
            return nn.PairwiseDistance(p=2)(x, y)
        
        return F.cosine_similarity(x, y, dim=-1)

    
def compute_edge_weight(data, metric='cosine', log=True):

    num_workers = 8  if data.num_edges>1000000 else 0
    loader = EdgeLoader(data, batch_size=100000, metric = metric, log=log, num_workers=num_workers, shuffle=False, drop_last=False)

    edge_weight=[]

    if log:
        pbar = tqdm(total=data.num_edges)
        pbar.set_description(f'Edges')

    for i, batch in enumerate(loader):
        edge_weight.append(batch.view(-1))        
        if log:pbar.update(batch.shape[0])

    if log:pbar.close()
    
    edge_weight = torch.cat(edge_weight)
    
    return edge_weight

In [30]:
# data.edge_weight = compute_edge_weight(data, metric = 'euclidean')
# print(data.edge_weight)
# data

In [31]:
# pos = get_pos(data.edge_index, data.y, data.num_nodes)
# draw_graph(data.edge_index, data.y, data.num_nodes, pos=pos)

In [32]:
def sparse_graph(data, K=2, minimum=False, draw=False, log=True):
    
    if draw: 
        if 'pos' not in locals(): pos = get_pos(data.edge_index, data.y, data.num_nodes)
#         pos = get_pos(data.edge_index, data.y, data.num_nodes)
    
    OG = to_networkx(data, edge_attrs=['e_weight'], to_undirected=True)
    G = OG
    graph_collections=[]
    graph_weights = []

    for i in range(K):    
        
        if log: print("MST: ",i, end=' ... ')
        
        if minimum == True:
            mst_G = nx.minimum_spanning_tree(G)            
        elif minimum == False:
            mst_G = nx.maximum_spanning_tree(G)
        else:
            try:
                mst_G = nx.random_spanning_tree(G, weight='e_weight', multiplicative=True)
            except:
                print('No connected spanning tree')
                break
        
        
        graph_collections.append(mst_G)
        graph_weights.append(mst_G.size(weight='e_weight'))        
        
        if log: print("DONE: ",i)
        
        if i==K:
            break
        
        dG = nx.difference(G, mst_G)   
        for u, v in dG.edges():dG[u][v]['e_weight'] = G[u][v]['e_weight']
        G = dG
    
    if draw:
        for i,g in enumerate(graph_collections):
            print("MST: ",i)
            draw_graph(data.edge_index, data.y, data.num_nodes, pos=pos, G=g)
    
    return graph_collections, graph_weights


def get_random_sparse(data, sel_K = 2, max_K = 5, metric = 'cosine', minimum=False, draw=False, log=True, dataset_name='default', recompute=True):
    graph_filename = dataset_name+"_graphs"+str(max_K)+metric+str(minimum)+".pkl"
    weight_filename = dataset_name+"_weights"+str(max_K)+metric+str(minimum)+".pkl"
    
    if os.path.exists(graph_filename)==False or recompute==True:    
        
        if log:print('Computing: ',graph_filename)        
        
        data.e_weight = compute_edge_weight(data, metric = metric, log=log)        
        
        graph_collections, graph_weights = sparse_graph(data, K=max_K, minimum=minimum, draw=draw, log=log)        
        graph_weights = [max((len(graph_weights)-i)*1e-3,w) for i,w in enumerate(graph_weights)]
        
        if log: print("Subgraph weights: ", graph_weights)
        
        file_handle_graph = open(graph_filename, 'wb')
        pickle.dump(graph_collections, file_handle_graph)
        file_handle_graph.close()
        
        file_handle_weight = open(weight_filename, 'wb')
        pickle.dump(graph_weights, file_handle_weight)
        file_handle_weight.close()
        
    else:        
        if log: print('Loading..')
        
        file_handle_graph = open(graph_filename, 'rb')
        graph_collections = pickle.load(file_handle_graph)
        file_handle_graph.close()
        
        file_handle_weight = open(weight_filename, 'rb')
        graph_weights = pickle.load(file_handle_weight)
        file_handle_weight.close()
    
    
    #graph_collections, graph_weights = sparse_graph(data, K=max_K, minimum=minimum, draw=draw, log=log)    
    if minimum == False:
        sel_mst_index = np.random.choice(len(graph_weights), size=min(sel_K,len(graph_weights)), replace=False, p=graph_weights/np.sum(graph_weights))
    else:
        graph_weights = [(len(graph_weights)-i)*1e-3 for i,w in enumerate(graph_weights)]
        sel_mst_index = np.random.choice(len(graph_weights), size=min(sel_K,len(graph_weights)), replace=False, p=graph_weights/np.sum(graph_weights))
    
    
    if log: 
        print("MST weights: ", graph_weights)
        print("Selected MSTs: ", sel_mst_index, " Weights: ",[graph_weights[i] for i in sel_mst_index])
    
    
    
    if len(sel_mst_index)==1:        
        #data_updated = from_networkx(graph_collections[sel_mst_index[0]], group_edge_attrs=['e_weight'])        
        data_updated = from_networkx(graph_collections[sel_mst_index[0]])        
    else:
        col_G = nx.compose_all([graph_collections[i] for i in sel_mst_index])
        #data_updated = from_networkx(col_G, group_edge_attrs=['e_weight'])
        data_updated = from_networkx(col_G)

    if draw:
        print("Composite:")
        if 'pos' not in locals(): pos = get_pos(data.edge_index, data.y, data.num_nodes)
        draw_graph(data.edge_index, data.y, data.num_nodes, pos=pos, G=col_G)

    #data_updated.edge_attr = data_updated.edge_attr.view(-1) 
    data_updated.edge_attr = None
    
    edge_index, edge_attr = add_remaining_self_loops(edge_index = data_updated.edge_index, edge_attr = data_updated.edge_attr, num_nodes = data.num_nodes)
    
    #print(edge_index, edge_attr)    
    
    data_updated.e_weight = edge_attr 
    data_updated.edge_index = edge_index
    
    return data_updated, sel_mst_index

In [33]:
# data_updated, sel_mst_index = get_random_sparse(copy.copy(data), sel_K = 2, max_K = 3, metric = 'cosine', minimum='rand', draw=True, log=True, dataset_name='default', recompute=True)
# data_updated

In [34]:
def fun():
    DIR, RESULTS_DIR = DeviceDir.get_directory()
    dataset_name = 'karate'
    data, dataset = get_data(dataset_name, log=False)    
    
    data_updated, sel_mst_index = get_random_sparse(copy.deepcopy(data), sel_K = 2, max_K = 4, metric = 'euclidean', 
                                                    minimum=False, draw=False, log=True, 
                                                    dataset_name=DIR+dataset_name, recompute=False)
    
    print(data_updated, sel_mst_index)
    
    return data_updated

if __name__ == '__main__':    

    for i in range(10):
        fun()

Computing:  /scratch/gilbreth/das90/Dataset/karate_graphs4euclideanFalse.pkl


Edges: 100%|██████████| 156/156 [00:00<00:00, 23800.93it/s]

MST:  0 ... DONE:  0
MST:  1 ... DONE:  1
MST:  2 ... DONE:  2
MST:  3 ... DONE:  3
Subgraph weights:  [46.66904675960541, 42.426406145095825, 15.55634891986847, 5.656854152679443]
MST weights:  [46.66904675960541, 42.426406145095825, 15.55634891986847, 5.656854152679443]
Selected MSTs:  [0 1]  Weights:  [46.66904675960541, 42.426406145095825]
Data(edge_index=[2, 160], num_nodes=34) [0 1]
Loading..
MST weights:  [46.66904675960541, 42.426406145095825, 15.55634891986847, 5.656854152679443]
Selected MSTs:  [1 0]  Weights:  [42.426406145095825, 46.66904675960541]
Data(edge_index=[2, 160], num_nodes=34) [1 0]
Loading..
MST weights:  [46.66904675960541, 42.426406145095825, 15.55634891986847, 5.656854152679443]
Selected MSTs:  [0 1]  Weights:  [46.66904675960541, 42.426406145095825]
Data(edge_index=[2, 160], num_nodes=34) [0 1]
Loading..
MST weights:  [46.66904675960541, 42.426406145095825, 15.55634891986847, 5.656854152679443]
Selected MSTs:  [0 1]  Weights:  [46.66904675960541, 42.42640614


