In [2]:
import pickle
import os
import torch
import numpy as np
import pdb
import torch_geometric
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.model_selection import KFold
            
from torch_geometric.datasets import Planetoid, Coauthor, Amazon, CoraFull, CitationFull
import torch_geometric.transforms as T

In [3]:
from collections import namedtuple

A = namedtuple('a', 'dataset')
args = A('CORA')
args.dataset

'CORA'

In [34]:
data = Planetoid('../data', args.dataset, pre_transform=T.Compose([T.NormalizeFeatures()])) 
        
features = data.data.x
edge_index = data.data.edge_index
labels = data.data.y
train_index = torch.where(data.data.train_mask)[0].tolist()
train_label = labels[train_index]
valid_index = torch.where(data.data.val_mask)[0].tolist()
valid_label = labels[valid_index]
test_index = torch.where(data.data.test_mask)[0].tolist()
test_label = labels[test_index]
num_classes = data.num_classes
bn = False

# Return values
# features, edge_index, train_index, train_label, valid_index, valid_label, test_index, test_label, num_classes



# CORA Dataset explanation
https://medium.com/mlearning-ai/ultimate-guide-to-graph-neural-networks-1-cora-dataset-37338c04fe6f

In [12]:
#
features

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

# features
features[0] : Node  
features[1] : Feature

In [17]:
features.shape, edge_index.shape

(torch.Size([2708, 1433]), torch.Size([2, 10556]))

# edge_index
Full Edges

In [6]:
#
edge_index

tensor([[   0,    0,    0,  ..., 2707, 2707, 2707],
        [ 633, 1862, 2582,  ...,  598, 1473, 2706]])

In [31]:
from scipy.sparse import csr_matrix

row = np.array([0, 2])
col = np.array([1, 2])
data = np.ones(row.shape[0])

csr_matrix((data, (row, col)), shape=(3, 3)).toarray()

array([[0., 1., 0.],
       [0., 0., 0.],
       [0., 0., 1.]])

In [7]:
features[edge_index[0][0]][edge_index[1][0]]

tensor(0.)

In [8]:
#
labels

tensor([3, 4, 4,  ..., 3, 3, 3])

In [9]:
#
train_index

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 131,
 132,
 133,
 134,
 135,
 136,
 137,
 138,
 139]

In [37]:
data.data.train_mask.shape

torch.Size([2708])

In [38]:
torch.where(data.data.train_mask)



(tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
          14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
          28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
          42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
          56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
          70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
          84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
          98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
         112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
         126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139]),)

In [10]:
#
train_label

tensor([3, 4, 4, 0, 3, 2, 0, 3, 3, 2, 0, 0, 4, 3, 3, 3, 2, 3, 1, 3, 5, 3, 4, 6,
        3, 3, 6, 3, 2, 4, 3, 6, 0, 4, 2, 0, 1, 5, 4, 4, 3, 6, 6, 4, 3, 3, 2, 5,
        3, 4, 5, 3, 0, 2, 1, 4, 6, 3, 2, 2, 0, 0, 0, 4, 2, 0, 4, 5, 2, 6, 5, 2,
        2, 2, 0, 4, 5, 6, 4, 0, 0, 0, 4, 2, 4, 1, 4, 6, 0, 4, 2, 4, 6, 6, 0, 0,
        6, 5, 0, 6, 0, 2, 1, 1, 1, 2, 6, 5, 6, 1, 2, 2, 1, 5, 5, 5, 6, 5, 6, 5,
        5, 1, 6, 6, 1, 5, 1, 6, 5, 5, 5, 1, 5, 1, 1, 1, 1, 1, 1, 1])