In [1]:
import sys
import pickle as pkl
import numpy as np
import scipy.sparse as sp
import networkx as nx

In [2]:
def parse_index_file(filename):
    """Parse index file."""
    index = []
    for line in open(filename):
        index.append(int(line.strip()))
    return index


def sample_mask(idx, l):
    """Create mask."""
    mask = np.zeros(l)
    mask[idx] = 1
    return np.array(mask, dtype=np.bool)

In [4]:
dataset_str = 'cora'

names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
objects = []
for i in range(len(names)):
    with open("gcn/data/ind.{}.{}".format(dataset_str, names[i]), 'rb') as f:
        if sys.version_info > (3, 0):
            objects.append(pkl.load(f, encoding='latin1'))
        else:
            objects.append(pkl.load(f))

x, y, tx, ty, allx, ally, graph = tuple(objects)
test_idx_reorder = parse_index_file(
    "gcn/data/ind.{}.test.index".format(dataset_str))
test_idx_range = np.sort(test_idx_reorder)

if dataset_str == 'citeseer':
    # Fix citeseer dataset (there are some isolated nodes in the graph)
    # Find isolated nodes, add them as zero-vecs into the right position
    test_idx_range_full = range(
        min(test_idx_reorder), max(test_idx_reorder)+1)
    tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
    tx_extended[test_idx_range-min(test_idx_range), :] = tx
    tx = tx_extended
    ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
    ty_extended[test_idx_range-min(test_idx_range), :] = ty
    ty = ty_extended

features = sp.vstack((allx, tx)).tolil()
features[test_idx_reorder, :] = features[test_idx_range, :]
adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))

labels = np.vstack((ally, ty))
labels[test_idx_reorder, :] = labels[test_idx_range, :]

idx_test = test_idx_range.tolist()
idx_train = range(len(y))
idx_val = range(len(y), len(y)+500)

train_mask = sample_mask(idx_train, labels.shape[0])
val_mask = sample_mask(idx_val, labels.shape[0])
test_mask = sample_mask(idx_test, labels.shape[0])

y_train = np.zeros(labels.shape)
y_val = np.zeros(labels.shape)
y_test = np.zeros(labels.shape)
y_train[train_mask, :] = labels[train_mask, :]
y_val[val_mask, :] = labels[val_mask, :]
y_test[test_mask, :] = labels[test_mask, :]

In [6]:
train_mask

array([ True,  True,  True, ..., False, False, False])

In [8]:
y_train[0]

array([0., 0., 0., 1., 0., 0., 0.])

In [9]:
y_val[0]

array([0., 0., 0., 0., 0., 0., 0.])

In [4]:
# x
print(objects[0])

  (0, 19)	1.0
  (0, 81)	1.0
  (0, 146)	1.0
  (0, 315)	1.0
  (0, 774)	1.0
  (0, 877)	1.0
  (0, 1194)	1.0
  (0, 1247)	1.0
  (0, 1274)	1.0
  (1, 19)	1.0
  (1, 88)	1.0
  (1, 149)	1.0
  (1, 212)	1.0
  (1, 233)	1.0
  (1, 332)	1.0
  (1, 336)	1.0
  (1, 359)	1.0
  (1, 472)	1.0
  (1, 507)	1.0
  (1, 548)	1.0
  (1, 687)	1.0
  (1, 763)	1.0
  (1, 808)	1.0
  (1, 889)	1.0
  (1, 1058)	1.0
  :	:
  (138, 1263)	1.0
  (138, 1274)	1.0
  (138, 1290)	1.0
  (138, 1307)	1.0
  (138, 1406)	1.0
  (139, 1)	1.0
  (139, 41)	1.0
  (139, 187)	1.0
  (139, 212)	1.0
  (139, 357)	1.0
  (139, 404)	1.0
  (139, 464)	1.0
  (139, 505)	1.0
  (139, 507)	1.0
  (139, 581)	1.0
  (139, 635)	1.0
  (139, 874)	1.0
  (139, 988)	1.0
  (139, 1071)	1.0
  (139, 1230)	1.0
  (139, 1231)	1.0
  (139, 1258)	1.0
  (139, 1263)	1.0
  (139, 1274)	1.0
  (139, 1393)	1.0


In [21]:
# y
print(objects[1])

[[0 0 0 1 0 0 0]
 [0 0 0 0 1 0 0]
 [0 0 0 0 1 0 0]
 [1 0 0 0 0 0 0]
 [0 0 0 1 0 0 0]
 [0 0 1 0 0 0 0]
 [1 0 0 0 0 0 0]
 [0 0 0 1 0 0 0]
 [0 0 0 1 0 0 0]
 [0 0 1 0 0 0 0]
 [1 0 0 0 0 0 0]
 [1 0 0 0 0 0 0]
 [0 0 0 0 1 0 0]
 [0 0 0 1 0 0 0]
 [0 0 0 1 0 0 0]
 [0 0 0 1 0 0 0]
 [0 0 1 0 0 0 0]
 [0 0 0 1 0 0 0]
 [0 1 0 0 0 0 0]
 [0 0 0 1 0 0 0]
 [0 0 0 0 0 1 0]
 [0 0 0 1 0 0 0]
 [0 0 0 0 1 0 0]
 [0 0 0 0 0 0 1]
 [0 0 0 1 0 0 0]
 [0 0 0 1 0 0 0]
 [0 0 0 0 0 0 1]
 [0 0 0 1 0 0 0]
 [0 0 1 0 0 0 0]
 [0 0 0 0 1 0 0]
 [0 0 0 1 0 0 0]
 [0 0 0 0 0 0 1]
 [1 0 0 0 0 0 0]
 [0 0 0 0 1 0 0]
 [0 0 1 0 0 0 0]
 [1 0 0 0 0 0 0]
 [0 1 0 0 0 0 0]
 [0 0 0 0 0 1 0]
 [0 0 0 0 1 0 0]
 [0 0 0 0 1 0 0]
 [0 0 0 1 0 0 0]
 [0 0 0 0 0 0 1]
 [0 0 0 0 0 0 1]
 [0 0 0 0 1 0 0]
 [0 0 0 1 0 0 0]
 [0 0 0 1 0 0 0]
 [0 0 1 0 0 0 0]
 [0 0 0 0 0 1 0]
 [0 0 0 1 0 0 0]
 [0 0 0 0 1 0 0]
 [0 0 0 0 0 1 0]
 [0 0 0 1 0 0 0]
 [1 0 0 0 0 0 0]
 [0 0 1 0 0 0 0]
 [0 1 0 0 0 0 0]
 [0 0 0 0 1 0 0]
 [0 0 0 0 0 0 1]
 [0 0 0 1 0 0 0]
 [0 0 1 0 0 0 

In [6]:
# tx
print(objects[2])

  (0, 311)	1.0
  (0, 314)	1.0
  (0, 353)	1.0
  (0, 505)	1.0
  (0, 510)	1.0
  (0, 621)	1.0
  (0, 1075)	1.0
  (0, 1132)	1.0
  (0, 1171)	1.0
  (0, 1226)	1.0
  (0, 1230)	1.0
  (0, 1301)	1.0
  (0, 1379)	1.0
  (0, 1389)	1.0
  (0, 1392)	1.0
  (1, 78)	1.0
  (1, 121)	1.0
  (1, 228)	1.0
  (1, 505)	1.0
  (1, 510)	1.0
  (1, 617)	1.0
  (1, 662)	1.0
  (1, 931)	1.0
  (1, 988)	1.0
  (1, 993)	1.0
  :	:
  (998, 1240)	1.0
  (998, 1258)	1.0
  (998, 1263)	1.0
  (998, 1306)	1.0
  (998, 1314)	1.0
  (999, 30)	1.0
  (999, 65)	1.0
  (999, 432)	1.0
  (999, 548)	1.0
  (999, 570)	1.0
  (999, 610)	1.0
  (999, 690)	1.0
  (999, 720)	1.0
  (999, 724)	1.0
  (999, 749)	1.0
  (999, 763)	1.0
  (999, 993)	1.0
  (999, 1058)	1.0
  (999, 1143)	1.0
  (999, 1150)	1.0
  (999, 1170)	1.0
  (999, 1177)	1.0
  (999, 1205)	1.0
  (999, 1274)	1.0
  (999, 1392)	1.0


In [16]:
# ty
print(objects[3].shape)

(1000, 7)


In [23]:
# allx
print(objects[4])

  (0, 19)	1.0
  (0, 81)	1.0
  (0, 146)	1.0
  (0, 315)	1.0
  (0, 774)	1.0
  (0, 877)	1.0
  (0, 1194)	1.0
  (0, 1247)	1.0
  (0, 1274)	1.0
  (1, 19)	1.0
  (1, 88)	1.0
  (1, 149)	1.0
  (1, 212)	1.0
  (1, 233)	1.0
  (1, 332)	1.0
  (1, 336)	1.0
  (1, 359)	1.0
  (1, 472)	1.0
  (1, 507)	1.0
  (1, 548)	1.0
  (1, 687)	1.0
  (1, 763)	1.0
  (1, 808)	1.0
  (1, 889)	1.0
  (1, 1058)	1.0
  :	:
  (1706, 1236)	1.0
  (1706, 1242)	1.0
  (1706, 1320)	1.0
  (1706, 1337)	1.0
  (1707, 4)	1.0
  (1707, 118)	1.0
  (1707, 153)	1.0
  (1707, 180)	1.0
  (1707, 228)	1.0
  (1707, 699)	1.0
  (1707, 701)	1.0
  (1707, 719)	1.0
  (1707, 750)	1.0
  (1707, 758)	1.0
  (1707, 810)	1.0
  (1707, 911)	1.0
  (1707, 1177)	1.0
  (1707, 1233)	1.0
  (1707, 1251)	1.0
  (1707, 1257)	1.0
  (1707, 1262)	1.0
  (1707, 1299)	1.0
  (1707, 1325)	1.0
  (1707, 1386)	1.0
  (1707, 1397)	1.0


In [22]:
# ally
print(objects[5])

[[0 0 0 ... 0 0 0]
 [0 0 0 ... 1 0 0]
 [0 0 0 ... 1 0 0]
 ...
 [0 1 0 ... 0 0 0]
 [0 0 1 ... 0 0 0]
 [0 0 0 ... 0 1 0]]


In [10]:
print(objects[6])

defaultdict(<class 'list'>, {0: [633, 1862, 2582], 1: [2, 652, 654], 2: [1986, 332, 1666, 1, 1454], 3: [2544], 4: [2176, 1016, 2176, 1761, 1256, 2175], 5: [1629, 2546, 1659, 1659], 6: [1416, 1602, 1042, 373], 7: [208], 8: [281, 1996, 269], 9: [2614, 723, 723], 10: [476, 2545], 11: [1655, 1839], 12: [2661, 1001, 1318, 2662], 13: [1810, 1701], 14: [2034, 2075, 158, 2077, 2668], 15: [2367, 1093, 1090, 1271, 1093], 16: [2444, 1632, 970, 2642], 17: [24, 2140, 1316, 1315, 927], 18: [2082, 139, 1786, 1560, 2145], 19: [1939], 20: [1072, 2374, 2375, 2269, 2270], 21: [1043, 2310], 22: [1703, 1702, 2238, 39, 1234], 23: [2159], 24: [1701, 2139, 1636, 17, 2141, 598, 201], 25: [1344, 2011, 1301, 2317], 26: [2454, 2455, 123, 99, 122], 27: [1810, 606, 2360, 2578], 28: [1687], 29: [963, 2645], 30: [1358, 1416, 2162, 697, 2343, 738], 31: [1594], 32: [1973, 279, 518, 1850], 33: [2119, 911, 588, 1051, 2120, 286, 2040, 2121, 698], 34: [1358], 35: [1296, 1913, 895], 36: [1146, 1640, 1505, 2106, 1781, 2094, 

In [17]:
print(objects[6].keys())

dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219,