# Attribute-Guided Sampler

In [34]:
#https://github.com/pyg-team/pytorch_geometric/issues/1961
#https://github.com/rusty1s/pytorch_sparse/blob/99735df9ac54b1ae8d46ddf5360da58fd0eeef0c/torch_sparse/sample.py

In [35]:
#importance pooling: https://arxiv.org/pdf/1806.01973.pdf
#graphsaint sampling: https://arxiv.org/pdf/1907.04931.pdf

In [36]:
#as it turned out interactive shell (like Jupyter cannot handle CPU multiprocessing well so check which medium the code is runing)
#we will write code in Jupyter for understanding purposes but final execuation will be in shell
from ipynb.fs.full.Utils import isnotebook
from ipynb.fs.full.Dataset import get_data
from torch_geometric.utils import degree

In [37]:
import os.path as osp
from typing import Optional

import torch
from torch_sparse import SparseTensor
from tqdm import tqdm
import math
import time
import numpy as np
import getpy as gp

import random
random.seed(12345)
import numpy as np
np.random.seed(12345)

In [38]:
from collections.abc import Sequence
from typing import Any, Callable, Dict, List, Optional, Union

from torch import Tensor

from torch_geometric.data import Data, HeteroData
from torch_geometric.loader.base import BaseDataLoader
from torch_geometric.loader.utils import (edge_type_to_str, filter_data,
                                          filter_hetero_data, to_csc,
                                          to_hetero_csc)
from torch_geometric.typing import EdgeType, InputNodes

NumNeighbors = Union[List[int], Dict[EdgeType, List[int]]]

In [39]:
from ipynb.fs.full.SubmodularWeights import SubModularWeightFacilityFaster
from ipynb.fs.full.KNNWeights import KNNWeight
#from KNNWeights import KNNWeight
#from ipynb.fs.full.PretrainedLink import LinkPred, LinkNN, LinkSub
from ipynb.fs.full.RandomSparse import RandomSparse
from ipynb.fs.full.PretrainedLinkFast import get_link_weight,  LinkNN, LinkSub

In [40]:
from collections import defaultdict
from random import choices
import random
import torch
import numpy as np

def sample_with_weights(colptr, row, input_node, num_neighbors, replace, directed, weights):
    
    samples = input_node
    
    to_local_node = defaultdict(lambda: len(samples))
    for i, v in enumerate(input_node):
        to_local_node[v] = i

    rows, cols, edges, sampled_weights = [], [], [], []

    begin, end = 0, len(samples)
    for ell, num_samples in enumerate(num_neighbors):
        for i in range(begin, end):
            w = samples[i]
            col_start = colptr[w]
            col_end = colptr[w + 1]
            col_count = col_end - col_start

            if col_count == 0:
                continue

            if (num_samples < 0) or (not replace and num_samples >= col_count):
                sampled_indices = range(col_start,col_end)                
            elif replace:
                probs = weights[col_start:col_end] / np.sum(weights[col_start:col_end])                
                sampled_indices = choices(np.arange(col_start, col_end), weights=probs, k=num_samples) #with replacement
            else:                
                probs = weights[col_start:col_end] / np.sum(weights[col_start:col_end])
                #probs[-1] = 1 - np.sum(probs[0:-1])                
                sampled_indices = np.random.choice(np.arange(col_start, col_end), size=num_samples, replace=replace, p=probs)
                #sampled_indices = choices(np.arange(col_start, col_end), weights=probs, k=num_samples) #with replacement

            for offset in sampled_indices:
                v = row[offset]
                res = to_local_node[v]
                if res == len(samples):
                    samples.append(v)
                if directed:
                    cols.append(i)
                    rows.append(res)
                    edges.append(offset)
                sampled_weights.append(weights[offset])

        begin, end = end, len(samples)

    if not directed:
        local_node_indices = {v: i for i, v in enumerate(samples)}
        for i, w in enumerate(samples):
            col_start = colptr[w]
            col_end = colptr[w + 1]
            for offset in range(col_start, col_end):
                v = row[offset]
                res = local_node_indices.get(v)
                if res is not None:
                    rows.append(res)
                    cols.append(i)
                    edges.append(offset)
                    sampled_weights.append(weights[offset])

    return (
        torch.tensor(samples, dtype=torch.int64),
        torch.tensor(rows, dtype=torch.int64),
        torch.tensor(cols, dtype=torch.int64),
        torch.tensor(edges, dtype=torch.int64),
        torch.tensor(sampled_weights, dtype=torch.float32)
    )

In [41]:
# import DeviceDir
# DIR, RESULTS_DIR = DeviceDir.get_directory()
# device, NUM_PROCESSORS = DeviceDir.get_device()

# n=7
# x = torch.Tensor([[1,0],[1,0],[1,0],[0,1],[0,1],[0,1],[0,1]])
# y = torch.LongTensor([0,0,0, 1, 1, 1, 1])
# edge_index = torch.LongTensor([[1,2],[1,4],[1,5],[2,1],[3,6],[3,7],[4,5],[4,1],[4,6],[4,7],[5,1],
#                                [5,4],[5,6],[6,3],[6,4],[6,5],[6,7],[7,3],[7,4],[7,6]]).T
# edge_index = edge_index-1
# mask = torch.zeros(n, dtype=torch.bool)
# mask[[0,1,4,5]] = True
# data = Data(x = x, y = y, edge_index = edge_index, train_mask = mask, test_mask = ~mask, val_mask = ~mask)    
# print(data)

# #data, dataset = get_data('Cora', DIR=DIR, log = False) 

# (row, col) = data.edge_index
# size = data.size()
# perm = (col * size[0]).add_(row).argsort()
# colptr = torch.ops.torch_sparse.ind2ptr(col[perm], size[1])
# row = row[perm]

# weights = (1. / degree(col, data.num_nodes)[col]) # Norm by in-degree.
# weights = weights[perm]
# index = torch.LongTensor([0,1])
# num_neighbors= [20, 10]
# # sample_with_weights_getpy(colptr,row,index,num_neighbors,False, True, weights)

In [42]:
import sys
#import torch
#sys.path.append("/home/sferdou/CPPSamplerNew/build/src")

#sys.path.append("/home/das90/GNNcodes/CVE2020/GNN-NC/Graph-Sparsification/CPPsamplerPy/build/src")

#import sampling_module

In [43]:
# result = sampling_module.sample(
#     colptr,
#     row,
#     index,
#     num_neighbors
# )
# result

In [44]:
import os

def save_weight(method,save_dir,weights):
    filename= save_dir+method+".pt"
    
    directory = osp.dirname(filename)    
    if not osp.exists(directory):
        os.makedirs(directory)
    
    torch.save(weights, filename)
    
def load_weight(method, save_dir):
    filename= save_dir+method+".pt"
    if not osp.exists(filename):
        return None
    else:
        return torch.load(filename)
    
def is_compute(kwargs, method):
    compute=False
    w = None
    if kwargs['recompute']==True:                         
        compute=True
    else:
        w = load_weight(method, kwargs['save_dir'])
        if w is None:
            compute=True            
    return compute, w

In [45]:
class CustomNeighborSampler:
    def __init__(
        self,
        data: Union[Data, HeteroData],
        num_neighbors: NumNeighbors,
        replace: bool = False,
        directed: bool = True,
        input_node_type: Optional[str] = None,
        **kwargs,
                
    ):
        self.data_cls = data.__class__
        self.num_neighbors = num_neighbors
        self.replace = replace
        self.directed = directed
         
        ##addded 
        self.N = N = data.num_nodes
        self.E = E = data.num_edges
        self.data = data
        
        self.adj = SparseTensor(
            row=data.edge_index[0], col=data.edge_index[1],
            value=torch.arange(E, device=data.edge_index.device),
            sparse_sizes=(N, N))
        
        
        self.data=data
        #log = False
        log = kwargs['log']
        self.log = log
        
        weight_funcs = kwargs['weight_func']
        params = kwargs['params']                
        self.weight_funcs = weight_funcs
        
#         ######## delete this alter
#         weights = []
#         row, col = data.edge_index
#         weights.append(1. / degree(col, data.num_nodes)[col]) # Norm by in-degree.        
#         weights.append(1. / degree(col, data.num_nodes)[col])         
#         self.weights = weights
#         data.weights = self.weights
#         #######
        
        if 'weights' not in data:
            
#             weight_funcs = kwargs['weight_func']
#             params = kwargs['params']            
#             self.weight_funcs = weight_funcs
            
            weights = []
            
            #print('Weight not given, computing edge weights....')            
            
            if len(weight_funcs) > 0:
                for i,method in enumerate(weight_funcs):
                    if method == 'knn':       
                        
                        m_name = method+params[method]['metric']                        
                        compute, w = is_compute(kwargs, m_name)
                        
                        if compute:
                            knn = KNNWeight(data, metric=params[method]['metric'], log=log)                
                            w = knn.compute_weights()
                            weights.append(w)
                            
                            if log:
                                print("saving weights ",m_name)
                            save_weight(m_name, kwargs['save_dir'], w)
                        else:
                            if log:
                                print("Loading weights ",m_name)
                            weights.append(w)

                    elif method == 'submodular':
                        
                        m_name = method+params[method]['metric']                        
                        compute, w = is_compute(kwargs, m_name)                    
                        
                        if compute:
                            sub = SubModularWeightFacilityFaster(data, metric=params[method]['metric'], log=log)
                            w = sub.compute_weights()
                            weights.append(w)
                            
                            if log:
                                print("saving weights ",m_name)
                            save_weight(m_name, kwargs['save_dir'], w)
                        else:
                            if log:
                                print("Loading weights ",method)
                            weights.append(w)
                        
                    elif method == 'fastlink':    
                        compute, w = is_compute(kwargs, method)
                        if compute:                            
                            w = get_link_weight(data, selfloop = True, log = log, worker=kwargs['num_workers'])
                            weights.append(w)
                            if log:
                                print("saving weights ",method)
                            save_weight(method, kwargs['save_dir'], w)
                        else:
                            if log:
                                print("Loading weights ",method)
                            weights.append(w)
                    
                    elif method == 'link-nn':  
                        m_name = method+params[method]['value']
                        
                        compute, w = is_compute(kwargs, m_name)
                        if compute:                            
                            nn_weight = LinkNN(data, value=params[method]['value'], log=log) #min favor similar ones, max disimilar
                            w = nn_weight.compute_weights()                            
                            weights.append(w)                            
                            if log:
                                print("saving weights ",m_name)
                            save_weight(m_name, kwargs['save_dir'], w)
                        else:
                            if log:
                                print("Loading weights ",m_name)
                            weights.append(w)
                    
                    elif method == 'link-sub': 
                        m_name = method+params[method]['value']
                        compute, w = is_compute(kwargs, m_name)
                        if compute:
                            #default value = 'max'                        
                            linksub = LinkSub(data, value=params[method]['value'], selfloop = True, log=log) #min favor similar ones, max disimilar    
                            w = linksub.compute_weights()                        
                            weights.append(w)
                            
                            if log:
                                print("saving weights ",m_name)
                            save_weight(m_name, kwargs['save_dir'], w)
                        else:
                            if log:
                                print("Loading weights ",m_name)
                            weights.append(w)
                    
                    elif method == 'apricot':
                        
                        from ipynb.fs.full.SubmodularWeightsApricot import SubModularWeightApricot
                        
                        m_name = method+params[method]['sub_func']+params[method]['metric']
                        
                        compute, w = is_compute(kwargs, m_name)
                        if compute:
                            #default value = 'max'                        
                            sub = SubModularWeightApricot(data, metric=params[method]['metric'], sub_func= params[method]['sub_func'],log=log) #min favor similar ones, max disimilar    
                            w = sub.compute_weights()                        
                            weights.append(w)
                            
                            if log:
                                print("saving weights ",m_name)
                            save_weight(m_name, kwargs['save_dir'], w)
                        else:
                            if log:
                                print("Loading weights ",m_name)
                            weights.append(w)
                    
                    elif method == 'random':
                        row, col = data.edge_index
                        weights.append(1. / degree(col, data.num_nodes)[col]) # Norm by in-degree.
                    
                    else:
                        raise NotImplemented
                                                                
            else:
                row, col = data.edge_index
                data.weight = 1. / degree(col, data.num_nodes)[col]  # Norm by in-degree.                    
                data.weight.to(data.edge_index.device)      
                
            #print(weights)
            self.weights = torch.stack(weights).to(data.edge_index.device)
            data.weights = self.weights
        else:
            self.weights = data.weights
            #print(data.weights.shape)
#             print(data.weights)
        
    
        if isinstance(data, Data):
            # Convert the graph data into a suitable format for sampling.
            #self.colptr, self.row, self.perm = to_csc(data, device='cpu')
            self.colptr, self.row, self.perm = to_csc(data, device='cpu')
            self.colptr_npy =  self.colptr.numpy()
            self.row_npy = self.row.numpy()
            
            assert isinstance(num_neighbors, (list, tuple))
        else:
            raise TypeError(f'NeighborLoader found invalid type: {type(data)}')
                
        self.sample_fn = torch.ops.torch_sparse.neighbor_sample    
        
        self.weights_npy=[]
        for i in range(len(self.weights)):
            self.weights[i] = self.weights[i][self.perm]
            self.weights_npy.append(self.weights[i].numpy())
    
    def weighted_sample(self, index: Union[List[int], Tensor], weight_index):
        
        if not isinstance(index, torch.LongTensor):
            index = torch.LongTensor(index)
        
        nodes=[]
        rows=[]
        cols=[]
        edges=[]
        
        #print("Start: ", index)
        
        u_src=index
        nodes.append(u_src) ##to have main nodes at first
        
        for k in self.num_neighbors: 
            
            n_u_src=[]
            
            for u in u_src:
                
                col, row, edge = self.adj[u.item(),:].coo()  
                
#                 print("-*-"*50)
#                 print(row, col, edge)

                if k==-1 or k>=len(row):
                    n_u_src.extend(row)
                    nodes.append(row)
                    rows.append(self.data.edge_index[1][edge])
                    cols.append(self.data.edge_index[0][edge])
                    edges.append(edge)
                                        
                else:
                    edge_weight = self.weights[weight_index][edge].numpy() 
                
                    #new_src = random.choices(row, weights=edge_weight, k=k) #with replacement
                    
                    edge_weight = edge_weight/sum(edge_weight)
                    edge_weight[-1] = 1 - np.sum(edge_weight[0:-1])
                    
                    #print(edge_weight)
                    
                    new_src = np.random.choice(len(row), k, replace=False, p=edge_weight)    
                        
                    n_u = row[new_src]
                    
                    n_u_src.extend(n_u)
                    nodes.append(n_u)
                    rows.append(n_u)
                    cols.append(u.repeat(len(new_src)))
                    edges.append(edge[new_src])
                    
                    #print(n_u_src)
                
            u_src = n_u_src
            
        #print("Nodes",nodes)
        
        node=torch.cat(nodes)
        row=torch.cat(rows)
        col=torch.cat(cols)
        edge=torch.cat(edges)
        
#         print("final:-----------")
#         print(node, row, col, edge)
        
        node_list = node.tolist()
        node_dict={}
        number=0
        for u in node_list:
            if u not in node_dict:
                node_dict[u]=number
                number+=1
        
        node_unique = torch.LongTensor(list(node_dict.keys()))
                
        row=torch.LongTensor([node_dict[i.item()] for i in row])
        col=torch.LongTensor([node_dict[i.item()] for i in col])

        
#         node_unique, inverse_indices = torch.unique(node, sorted=False, return_inverse=True)
#         node_dict = dict(zip(node_unique.tolist(), range(len(node_unique))))    
#         row=torch.LongTensor([node_dict[i] for i in row.tolist()])
#         col=torch.LongTensor([node_dict[i] for i in col.tolist()])
        
        return node_unique, row, col, edge, index.numel()
    
    
    def call__original(self, index: Union[List[int], Tensor]):
        if not isinstance(index, torch.LongTensor):
            index = torch.LongTensor(index)

        if issubclass(self.data_cls, Data):
            
            sample_fn = torch.ops.torch_sparse.neighbor_sample
            node, row, col, edge = sample_fn(
                self.colptr,
                self.row,
                index,
                self.num_neighbors,
                self.replace,
                self.directed,
            )
            return (node, row, col, edge, index.numel())

        else:
            raise TypeError(f'NeighborLoader found invalid type: {type(data)}')
            
    
    def call_weighted_sample(self, index: Union[List[int], Tensor], weight_index):
        if not isinstance(index, torch.LongTensor):
            index = torch.LongTensor(index)

        if issubclass(self.data_cls, Data):        
            node, row, col, edge, sampled_weight = sample_with_weights(
                self.colptr_npy,
                self.row_npy,
                index.tolist(),
                self.num_neighbors,
                self.replace,
                self.directed,
                self.weights_npy[weight_index]
            )
            return (node, row, col, edge, index.numel())

        else:
            raise TypeError(f'NeighborLoader found invalid type: {type(data)}')
            
    def call_weighted_sample_cpp(self, index: Union[List[int], Tensor], weight_index):
        
        #print("Here....")
        
            
        if not isinstance(index, torch.LongTensor):
            index = torch.LongTensor(index)
        
        if issubclass(self.data_cls, Data):        
            node, row, col, edge = sampling_module.weighted_sample(
                self.colptr,
                self.row,
                index,
                self.num_neighbors,
                self.weights[weight_index],
                self.replace,
                self.directed,
            )
            return (node, row, col, edge, index.numel())

        else:
            raise TypeError(f'NeighborLoader found invalid type: {type(data)}')

    
    def __call__(self, index: Union[List[int], Tensor]):
        
        output = []
        a = 1
        b = 1
        
        for i,method in enumerate(self.weight_funcs):
            if method == 'random':
#                 start = time.time()
                output.append(self.call__original(index))            
#                 end = time.time()
#                 a = end-start
#                 print("Random sample:", end-start)
            else:
#                 start = time.time()
                output.append(self.weighted_sample(index, i)) ##my sparse tensor based implementation
                #output.append(self.call_weighted_sample(index, i)) ## c inspired implementation
                #output.append(self.call_weighted_sample_cpp(index, i)) ## modified c installation
#                 end = time.time()
#                 b = end-start
#                 print("Biased sample:", end-start)
        
#         print("Scale:", b/a)
        
        return output            

In [46]:
# input_node_type = get_input_node_type(data.train_mask)
# replace=False
# directed=True
# sampler = CustomNeighborSampler(data, [1,1],replace, directed,input_node_type)
# print(sampler.call__original([1,0]))

# print("-"*100)

# print(sampler.__call__([1,0]))

In [47]:
class WeightedNeighborLoader(BaseDataLoader):
    def __init__(
        self,
        data: Union[Data, HeteroData],
        num_neighbors: NumNeighbors,
        input_nodes: InputNodes = None,
        replace: bool = False,
        directed: bool = True,
        transform: Callable = None,
        neighbor_sampler: Optional[CustomNeighborSampler] = None,
        **kwargs,
    ):
        if 'dataset' in kwargs:
            del kwargs['dataset']
        if 'collate_fn' in kwargs:
            del kwargs['collate_fn']   
            
        if 'save_dir' not in kwargs:
            kwargs['save_dir'] = 'weights/'
        if 'recompute' not in kwargs:
            kwargs['recompute'] =False

        # Save for PyTorch Lightning:
        self.data = data
        self.num_neighbors = num_neighbors
        self.input_nodes = input_nodes
        self.replace = replace
        self.directed = directed
        self.transform = transform
        self.neighbor_sampler = neighbor_sampler
        self.log = kwargs['log']
        
        self.weight_funcs = kwargs['weight_func']

        if neighbor_sampler is None:
            input_node_type = get_input_node_type(input_nodes)
            self.neighbor_sampler = CustomNeighborSampler(data, num_neighbors,
                                                    replace, directed,
                                                    input_node_type,**kwargs)

            self.weights = self.neighbor_sampler.weights
            
        if 'weight_func' in kwargs:
            del kwargs['weight_func']
        if 'params' in kwargs:
            del kwargs['params']
        if 'log' in kwargs:
            del kwargs['log']
        if 'save_dir' in kwargs:
            del kwargs['save_dir']
        if 'recompute' in kwargs:
            del kwargs['recompute']
        

        return super().__init__(get_input_node_indices(self.data, input_nodes),
                                collate_fn=self.neighbor_sampler, **kwargs)

    def transform_fn(self, out: Any) -> Union[Data, HeteroData]:
        
        batch_data = []
        
        if isinstance(self.data, Data):            
            for i, (node, row, col, edge, batch_size) in enumerate(out):
                
                #node, row, col, edge, batch_size = out     
                
#                 if self.weight_funcs[i] == 'random':        
#                     b_data = filter_data(self.data, node, row, col, edge, self.neighbor_sampler.perm)                
#                 else:
#                     b_data = filter_data(self.data, node, row, col, edge, None)
                    
                b_data = filter_data(self.data, node, row, col, edge, self.neighbor_sampler.perm)
                b_data.weight = self.weights[i][self.neighbor_sampler.perm[edge]]
            
#                 print('-'*50)
#                 print(node, row, col, edge, batch_size)
#                 print(b_data)
#                 print('-'*50)
                
                b_data.batch_size = batch_size                
                #b_data.weight = self.weights[i][edge]                
                batch_data.append(b_data)
                
#         print(batch_data)

        if len(batch_data)==1:
            batch_data=batch_data[0]

        return batch_data if self.transform is None else self.transform(data)

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}()'

In [48]:
def get_input_node_type(input_nodes: InputNodes) -> Optional[str]:
    if isinstance(input_nodes, str):
        return input_nodes
    if isinstance(input_nodes, (list, tuple)):
        assert isinstance(input_nodes[0], str)
        return input_nodes[0]
    return None


def get_input_node_indices(data: Union[Data, HeteroData],
                           input_nodes: InputNodes) -> Sequence:
    if isinstance(data, Data) and input_nodes is None:
        return range(data.num_nodes)
   
    if isinstance(input_nodes, Tensor):
        if input_nodes.dtype == torch.bool:
            input_nodes = input_nodes.nonzero(as_tuple=False).view(-1)
        input_nodes = input_nodes.tolist()

    assert isinstance(input_nodes, Sequence)
    return input_nodes

In [49]:
# data.edge_index
# #data.weight = torch.Tensor(list(range(data.edge_index.shape[1])))+100
# data.weight = torch.ones(data.edge_index.shape[1])
# data.weight.to(data.edge_index.device)

In [50]:
# data.edge_index

In [51]:
# #loader = WeightedNeighborLoader(data, batch_size=2, num_neighbors=[-1], input_nodes=data.train_mask, save_dir=dataset.processed_dir, num_workers=0, shuffle=False)
# loader = WeightedNeighborLoader(data, batch_size=2, num_neighbors=[-1,-1], input_nodes=data.train_mask, num_workers=0, shuffle=False)
# batch  = next(iter(loader))
# print(batch.weight)
# batch

In [52]:
# for batch_data in loader:
    
#     print("*"*50)
#     print(batch_data.edge_index)
#     print("*"*50)

# Main

In [53]:
if __name__ == '__main__':  
    
    from ipynb.fs.full.Dataset import get_data
    
    data, dataset = get_data('karate', log=False, h_score=True) 
    
#     n=7
#     x = torch.Tensor([[1,0],[1,0],[1,0],[0,1],[0,1],[0,1],[0,1]])
#     y = torch.LongTensor([0,0,0, 1, 1, 1, 1])
#     edge_index = torch.LongTensor([[1,2],[1,4],[1,5],[2,1],[3,6],[3,7],[4,5],[4,1],[4,6],[4,7],[5,1],[5,4],[5,6],[6,3],[6,4],[6,5],[6,7],[7,3],[7,4],[7,6]]).T
#     edge_index = edge_index-1
    
#     mask = torch.zeros(n, dtype=torch.bool)
#     mask[[1,3]] = True
    
#     data = Data(x = x, y = y, edge_index = edge_index, train_mask = mask, test_mask = mask, val_mask = mask)    
#     print(data)

    #weight_func=['knn', 'submodular'],
    #weight_func=['knn','submodular','random', 'link-nn', 'link-sub'],
    #weight_func=['knn','submodular'],
    
    loader = WeightedNeighborLoader(data, batch_size=16, num_neighbors=[2,2], 
                                    input_nodes=data.train_mask, 
                                    log = True,
                                    num_workers=0, shuffle=False, 
                                    weight_func = ['random','knn','submodular','fastlink','link-sub','link-nn','apricot'],                                    
                                    #weight_func = ['apricot'],                                    
                                    params={
                                        'knn':{'metric':'cosine'},
                                        'submodular':{'metric':'cosine'},
                                        'link-nn':{'value':'min'},
                                        'link-sub':{'value':'max'},
                                        'apricot':{'sub_func':'facility','metric':'cosine'}
                                    },
                                    replace = False,
                                    directed = True,
                                    save_dir = 'Results/',
                                    recompute = False
                                   )
    
    print(data)
    
    batch  = next(iter(loader))
    #print(batch.weight)
    print(batch)
    
#     from torch_geometric.loader import NeighborSampler, NeighborLoader
#     loader = NeighborLoader(data, batch_size=1, num_neighbors=[-1], input_nodes=data.train_mask, num_workers=0, shuffle=False)
    
    batch  = next(iter(loader))
#     print(batch.weight)
    print(batch)
    
    for batch_data in loader:
        print("*"*50)
        print(batch_data)
        #print(batch_data.edge_index)
        #print(batch_data.node_id)
        #print(batch_data.weight)
        print("*"*50)
    
    None

N  34  E  156  d  4.588235294117647 0.8020520210266113 0.7564102411270142 0.6170591711997986 -0.4756128787994385 Loading weights  knncosine
Loading weights  submodular
Loading weights  fastlink
Loading weights  link-submax
Loading weights  link-nnmin


[3 1 5 0 2 4 6]
[488.  48.  27.   1.   1.   1.   1.]
2.338205575942993
81 & 80 & 77 & 56 & 32 & 17 & 0\\
80 & 81 & 80 & 65 & 45 & 32 & 17\\
77 & 80 & 81 & 72 & 56 & 45 & 32\\
56 & 65 & 72 & 81 & 77 & 72 & 65\\
32 & 45 & 56 & 77 & 81 & 80 & 77\\
17 & 32 & 45 & 72 & 80 & 81 & 80\\
0 & 17 & 32 & 65 & 77 & 80 & 81\\
0 & 0\\
1 & 0\\
2 & 0\\
5 & 0\\
7 & 0\\
8 & 0\\
9 & 0\\
cosine
Pool Size:  40


Nodes:   0%|          | 0/34 [00:00<?, ?it/s]

KeyboardInterrupt: 

# Scratch 

In [None]:
# import DeviceDir

# DIR, RESULTS_DIR = DeviceDir.get_directory()
# device, NUM_PROCESSORS = DeviceDir.get_device()

# data, dataset = get_data('Reddit', DIR=DIR+'RedditPyg204', log = False) 

# (row, col) = data.edge_index
# data.edge_index = torch.stack((torch.cat((row, col),dim=0),torch.cat((col, row),dim=0)),dim=0)

In [None]:
# weight_func=['random', 'knn']

# params={
#     'knn':{'metric':'cosine'},
#     'submodular':{'metric':'cosine'}
# }

# loader = WeightedNeighborLoader(data, batch_size=1024, num_neighbors=[4, 4], input_nodes=data.train_mask, log = True,
#                                 num_workers=0, shuffle=False, weight_func=weight_func,params=params,
#                                replace=False, directed=False)

In [None]:
# # print(loader.weights)
# batch  = next(iter(loader))
# for b in batch:
#     print(b)

In [None]:
# for i, batch in enumerate(loader):
#     print("-"*50)
#     for b in batch:
#         #print(b)
#         None
#     if i>100:
#         break