In [None]:
import torch
import numpy as np
from torch import nn
from functools import partial
from time import time
# from scipy.sparse import csr_matrix

# https://stackoverflow.com/questions/52299420/scipy-csr-matrix-understand-indptr

In [None]:
# class myGraphSAINT(nn.Module):
#     def __init__(self, hidden_sizes):
#         super(myGraphSAINT, self).__init__()
# #         self.layers = nn.ModuleList([nn.Linear(hidden_sizes[i], hidden_sizes[i+1], bias=False)
# #                                      for i in range(len(hidden_sizes)-1)])
#         self.weights = [
#             nn.Parameter(torch.nn.init.xavier_uniform_(torch.empty((hidden_sizes[i], hidden_sizes[i+1])), gain=1/np.sqrt(6.0)))
#             for i in range(len(hidden_sizes)-1)
#         ]
    
#     def forward(self, x, A):
#         for W in self.weights:
#             x = nn.functional.relu(A @ x @ W)
#         return x
    
#     def sampleGraph(self, graph):
#         pass
    
#     def train(self, )

In [None]:
# # print(nn.Parameter(torch.nn.init.xavier_uniform_(torch.empty((2, 3)), gain=1.0)))
# x = torch.rand((4,3))
# A = torch.eye(4)
# model = myGraphSAINT([3,10,3])
# print(model(x, A))

## Data loading and processing

In [None]:
adj_train = np.load("./ppi/adj_train.npz")
features = np.load("./ppi/feats.npy")

# The data for the (symmetric) adjacency matrix
data = adj_train["data"]
data = data.astype(int)
indices = adj_train["indices"]
indptr = adj_train["indptr"]
shape = adj_train["shape"]

# # This is a less memory-efficient method to get the sparse torch adjacency matrix, that also requires SciPy
# adj_matrix = csr_matrix((data, indices, indptr), shape=shape).toarray().astype(int)
# np.set_printoptions(threshold=sys.maxsize)
# adj_matrix = torch.from_numpy(adj_matrix).to_sparse()

# Change the SciPy csr format to torch format
torch_first_indices = []
for i in range(len(indptr)-1):
    torch_first_indices += [i for ind in indices[indptr[i]:indptr[i+1]]]
torch_first_indices = np.asarray(torch_first_indices)
torch_indices = np.stack((torch_first_indices, indices))

# The given shape for the ppi train set is much larger than the number of actual nodes in the set
num_nodes = len(np.unique(torch_indices))
shape_small = [num_nodes, num_nodes]

# Create the adjacency matrix
adj_matrix = torch.sparse_coo_tensor(indices=torch_indices, values=data, size=shape_small, dtype=torch.float)

# Calculate the node degrees
degree = [0.0 for i in range(shape_small[0])]
for i in torch_indices.T:
    degree[i[0]] += 1
    if i[1] == i[0]:
        degree[i[0]] += 1
inverse_degree = np.reciprocal(np.asarray(degree))

# Calculate the normalized adjacency matrix
norm_adj_data = inverse_degree[torch_indices[0]]*data.astype(float)
# norm_adj_matrix = torch.sparse_coo_tensor(indices=torch_indices, values=norm_adj_data.astype(float), size=shape_small,
#                                           dtype=torch.float64)

## Sampling

In [None]:
def sample_nodes(num_nodes, budget, p_nodes, p_edges, indices, data, features):
    # Sample
#     t5 = time()
    nodes_s = np.unique(np.random.choice(np.arange(num_nodes), size=budget, p=p_nodes, replace=False))
#     t6 = time()
#     print(f"Node samlping time: {t6-t5}")
    
    # Connect the sampled nodes
#     t3 = time()
    condition = np.all(np.in1d(indices.flatten(), nodes_s).reshape(indices.shape), axis=0)
    edges_s = indices[:,condition]
    data_s = data[condition]
    edge_indices = np.where(condition)[0]
    nodes_s = set(nodes_s)
#     edges_s = []
#     data_s = []
#     edge_indices = []
#     for i, edge in enumerate(indices.transpose()):
#         if edge[0] in nodes_s and edge[1] in nodes_s:
#             edges_s.append(edge)
#     #         edges_s.append([orig2sub[edge[0]], orig2sub[edge[1]]])
#             data_s.append(data[i])
#             edge_indices.append(i)
#     edges_s = np.asarray(edges_s).transpose()
#     data_s = np.asarray(data_s)
#     print(edges_s.dtype)
#     print(data_s.dtype)
#     t4 = time()
#     print(f"Loop time: {t4-t3}")

    # Remove unconnected nodes (not connected to other nodes or themselves)
#     t7 = time()
    len_before = len(nodes_s)
    nodes_s = nodes_s.intersection(set(np.unique(edges_s)))
    len_after = len(nodes_s)
    budget -= len_before-len_after
#     t8 = time()
#     print(f"Remove time: {t8-t7}")
    # If no nodes are connected, retry sampling
    if len(nodes_s) == 0:
        return sample_nodes(num_nodes, budget, p_nodes, p_edges, indices, data, features)
    
    orig2sub = {ind : i for i, ind in enumerate(nodes_s)}
    nodes_s_sub = {orig2sub[node] for node in nodes_s}
    edges_s_sub = np.vectorize(orig2sub.get)(edges_s)

    # Create the adjacency matrix of the sampled graph
    adj_matrix_s = adj_matrix = torch.sparse_coo_tensor(indices=edges_s_sub, values=data_s, size=[budget, budget],
                                                        dtype=torch.float)
    
    # Calculate the normalizing constants
    p_nodes_s = np.take(p_nodes, list(edges_s[0]))
    p_edges_s = np.take(p_edges, list(edge_indices))
    alpha = p_edges_s/p_nodes_s
    alpha_matrix = torch.sparse_coo_tensor(indices=edges_s_sub, values=alpha, size=[budget, budget],
                                           dtype=torch.float)
    return adj_matrix_s * alpha_matrix, torch.from_numpy(features[list(nodes_s)]).to(torch.float)

In [None]:
budget = 6000 # 18

# Calculate the sampling probablities
p_nodes = [0.0 for i in range(shape_small[0])]
for i, ind in enumerate(torch_indices[1]):
    p_nodes[ind] += norm_adj_data[i]**2
p_nodes = np.asarray(p_nodes)
p_nodes = p_nodes/p_nodes.sum()

# Get the edge probabilities
# For the node sampler, the probability of an edge being sampled, is just to probability of both it's nodes being sampled
# Note that for an edge connecting a node to itself, the probability of sampling it is just the probability of sampling the node
self_loops = np.where(np.all(torch_indices == torch_indices[0,:], axis = 0)==True)
p_edges = np.take(p_nodes, torch_indices)
np.put(p_edges[1], self_loops, 1)
p_edges = p_edges.prod(0)

# p_edge_matrix = torch.sparse_coo_tensor(indices=torch_indices, values=p_edges, size=shape_small, dtype=torch.float64)

nodewise_sampler = partial(sample_nodes, num_nodes, budget, p_nodes, p_edges, torch_indices, norm_adj_data, features)

In [None]:
# def sample_weight_matrix(matrix, indices):
#     return matrix[sorted(indices)][:, sorted(indices)]

# W = torch.rand(shape_small)
# W_batch = sample_weight_matrix(W, nodes_s)

In [None]:
class myGraphSAINT(nn.Module):
    def __init__(self, hidden_sizes, lr, sampler):
        super(myGraphSAINT, self).__init__()
        self.weights = torch.nn.ParameterList(
            nn.Parameter(torch.nn.init.xavier_uniform_(torch.empty((hidden_sizes[i], hidden_sizes[i+1])), gain=1/np.sqrt(6.0)))
            for i in range(len(hidden_sizes)-1)
        )
        self.sampler = sampler
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
    
    def forward(self):
#         t1 = time()
        A, x = self.sampleGraph()
#         t2 = time()
#         print(f"Sampling time: {t2-t1}")
        for W in self.weights:
            x = nn.functional.relu(A @ x @ W)
        return x
    
    def sampleGraph(self):
        return self.sampler()
    
    def train_step(self):
        self.optimizer.zero_grad()
        y = self()
        temp_test = torch.ones_like(y)
        loss = torch.nn.functional.mse_loss(y, temp_test)
        loss.backward()
        self.optimizer.step()
        return loss.item()
    
    def fit(self, num_iterations):
        self.train()
        losses = []
        for i in range(num_iterations):
            losses.append(self.train_step())
        return losses

In [None]:
# nodes_samp = np.unique(np.random.choice(np.arange(num_nodes), size=budget, p=p_nodes, replace=False))
# ta = time()
# condition = np.all(np.in1d(torch_indices.flatten(), nodes_samp).reshape(torch_indices.shape), axis=0)
# torch_indices[:,condition]
# print(np.where(condition)[0])
# tb = time()
# print(tb-ta)

In [None]:
model = myGraphSAINT([50, 128], 0.01, nodewise_sampler)
t_start = time()
losses = model.fit(10)
t_end = time()
print(t_end-t_start)
print(losses)