In [1]:
import numpy as np
import scipy.sparse as sp
import torch

In [4]:
def encode_onehot(labels):
    # The classes must be sorted before encoding to enable static class encoding.
    # In other words, make sure the first class always maps to index 0.
    classes = sorted(list(set(labels)))
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32)
    return labels_onehot

def normalize_adj(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    r_inv_sqrt = np.power(rowsum, -0.5).flatten()
    r_inv_sqrt[np.isinf(r_inv_sqrt)] = 0.
    r_mat_inv_sqrt = sp.diags(r_inv_sqrt)
    return mx.dot(r_mat_inv_sqrt).transpose().dot(r_mat_inv_sqrt)


def normalize_features(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx


def accuracy(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)

In [10]:
path="./data/cora/"
dataset="cora"
"""Load citation network dataset (cora only for now)"""
print('Loading {} dataset...'.format(dataset))
idx_features_labels = np.genfromtxt("{}{}.content".format(path, dataset), dtype=np.dtype(str))
features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)
labels = encode_onehot(idx_features_labels[:, -1])
# build graph

idx = np.array(idx_features_labels[:, 0], dtype=np.int32)
idx_map = {j: i for i, j in enumerate(idx)}
edges_unordered = np.genfromtxt("{}{}.cites".format(path, dataset), dtype=np.int32)
edges = np.array(list(map(idx_map.get, edges_unordered.flatten())), dtype=np.int32).reshape(edges_unordered.shape)
adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), shape=(labels.shape[0], labels.shape[0]), dtype=np.float32)
idx_train = range(140)
idx_val = range(200, 500)
idx_test = range(500, 1500)

print(adj)

# build symmetric adjacency matrix
adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
features = normalize_features(features)
print(features.shape)
adj = normalize_adj(adj + sp.eye(adj.shape[0]))
adj = torch.FloatTensor(np.array(adj.todense()))
features = torch.FloatTensor(np.array(features.todense()))
labels = torch.LongTensor(np.where(labels)[1])
idx_train = torch.LongTensor(idx_train)
idx_val = torch.LongTensor(idx_val)
idx_test = torch.LongTensor(idx_test)


Loading cora dataset...
  (163, 402)	1.0
  (163, 659)	1.0
  (163, 1696)	1.0
  (163, 2295)	1.0
  (163, 1274)	1.0
  (163, 1286)	1.0
  (163, 1544)	1.0
  (163, 2600)	1.0
  (163, 2363)	1.0
  (163, 1905)	1.0
  (163, 1611)	1.0
  (163, 141)	1.0
  (163, 1807)	1.0
  (163, 1110)	1.0
  (163, 174)	1.0
  (163, 2521)	1.0
  (163, 1792)	1.0
  (163, 1675)	1.0
  (163, 1334)	1.0
  (163, 813)	1.0
  (163, 1799)	1.0
  (163, 1943)	1.0
  (163, 2077)	1.0
  (163, 765)	1.0
  (163, 769)	1.0
  :	:
  (2228, 1093)	1.0
  (2228, 1094)	1.0
  (2228, 2068)	1.0
  (2228, 2085)	1.0
  (2694, 2331)	1.0
  (617, 226)	1.0
  (422, 1691)	1.0
  (2142, 2096)	1.0
  (1477, 1252)	1.0
  (1485, 1252)	1.0
  (2185, 2109)	1.0
  (2117, 2639)	1.0
  (1211, 1247)	1.0
  (1884, 745)	1.0
  (1884, 1886)	1.0
  (1884, 1902)	1.0
  (1885, 745)	1.0
  (1885, 1884)	1.0
  (1885, 1886)	1.0
  (1885, 1902)	1.0
  (1886, 745)	1.0
  (1886, 1902)	1.0
  (1887, 2258)	1.0
  (1902, 1887)	1.0
  (837, 1686)	1.0
(2708, 1433)


In [11]:
#adj, features, labels, idx_train, idx_val, idx_test
print(features)

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


In [8]:
print(labels)

tensor([2, 5, 4,  ..., 1, 0, 2])


In [9]:
print(idx_train)

tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139])
