In [None]:
import torch
import numpy as np
from torch import nn
from scipy.sparse import csr_matrix

# https://stackoverflow.com/questions/52299420/scipy-csr-matrix-understand-indptr

In [None]:
class myGraphSAINT(nn.Module):
    def __init__(self):
        super(myGraphSAINT, self).__init__()
    
    def forward(self, x):
        return x
    
    def sampleGraph(self, graph):
        pass

In [None]:
x = torch.ones(3)
model = myGraphSAINT()
print(model(x))

## Data loading and processing

In [None]:
adj_train = np.load("./ppi/adj_train.npz")

# The data for the (symmetric) adjacency matrix
data = adj_train["data"]
data = data.astype(int)
indices = adj_train["indices"]
indptr = adj_train["indptr"]
shape = adj_train["shape"]

# # This is a less memory-efficient method to get the sparse torch adjacency matrix, that also requires SciPy
# adj_matrix = csr_matrix((data, indices, indptr), shape=shape).toarray().astype(int)
# np.set_printoptions(threshold=sys.maxsize)
# adj_matrix = torch.from_numpy(adj_matrix).to_sparse()

# Change the SciPy csr format to torch format
torch_first_indices = []
for i in range(len(indptr)-1):
    torch_first_indices += [i for ind in indices[indptr[i]:indptr[i+1]]]
torch_first_indices = np.asarray(torch_first_indices)
torch_indices = np.stack((torch_first_indices, indices))

# The given shape for the ppi train set is much larger than the number of actual nodes in the set
num_nodes = len(np.unique(torch_indices))
shape_small = [num_nodes, num_nodes]

# Create the adjacency matrix
adj_matrix = torch.sparse_coo_tensor(indices=torch_indices, values=data, size=shape_small, dtype=torch.float64)

# Calculate the node degrees
degree = [0.0 for i in range(shape_small[0])]
for i in torch_indices.T:
    degree[i[0]] += 1
    if i[1] == i[0]:
        degree[i[0]] += 1
inverse_degree = np.reciprocal(np.asarray(degree))

# Calculate the normalized adjacency matrix
norm_adj_data = inverse_degree[torch_indices[0]]*data.astype(float)
# norm_adj_matrix = torch.sparse_coo_tensor(indices=torch_indices, values=norm_adj_data.astype(float), size=shape_small,
#                                           dtype=torch.float64)

## Sampling

In [None]:
budget = 18

# Calculate the sampling probablities
p_nodes = [0.0 for i in range(shape_small[0])]
for i, ind in enumerate(torch_indices[1]):
    p_nodes[ind] += norm_adj_data[i]**2
p_nodes = np.asarray(p_nodes)
p_nodes = p_nodes/p_nodes.sum()

# Sample 
nodes_s = set(np.random.choice(np.arange(shape_small[0]), size=budget, p=p_nodes, replace=False))

# Connect the sampled nodes
edges_s = []
data_s = []
for i, edge in enumerate(torch_indices.transpose()):
    if edge[0] in nodes_s and edge[1] in nodes_s:
        edges_s.append(edge)
#         edges_s.append([orig2sub[edge[0]], orig2sub[edge[1]]])
        data_s.append(data[i])
edges_s = np.asarray(edges_s).transpose()
data_s = np.asarray(data_s)

# Remove unconnected nodes (not connected to other nodes or themselves)
to_remove = set()
for node in nodes_s:
    if node not in edges_s:
        to_remove.add(node)
        budget -= 1
nodes_s = nodes_s-to_remove
orig2sub = {ind : i for i, ind in enumerate(nodes_s)}
nodes_s_sub = {orig2sub[node] for node in nodes_s}
edges_s_sub = np.vectorize(orig2sub.get)(edges_s)

# Create the adjacency matrix of the sampled graph
adj_matrix_s = adj_matrix = torch.sparse_coo_tensor(indices=edges_s_sub, values=data_s, size=[budget, budget],
                                                    dtype=torch.float64)
print(adj_matrix_s.to_dense())