https://github.com/williamleif/graphsage-simple

In [1]:
import torch
import torch.nn as nn
from torch.nn import init
from torch.autograd import Variable

import numpy as np
import time
import random
from sklearn.metrics import f1_score
from collections import defaultdict

from graphsage.encoders import Encoder
from graphsage.aggregators import MeanAggregator

The Cora dataset consists of Machine Learning papers. These papers are classified into one of the following seven classes:
		
        Case_Based
		
        Genetic_Algorithms
		
        Neural_Networks
		
        Probabilistic_Methods
		
        Reinforcement_Learning
		
        Rule_Learning
		
        Theory
        
The .content file contains descriptions of the papers in the following format:

		<paper_id> <word_attributes>+ <class_label>

The first entry in each line contains the unique string ID of the paper followed by binary values indicating whether each word in the vocabulary is present (indicated by 1) or absent (indicated by 0) in the paper. Finally, the last entry in the line contains the class label of the paper.

The .cites file contains the citation graph of the corpus. Each line describes a link in the following format:

		<ID of cited paper> <ID of citing paper>

Each line contains two paper IDs. The first entry is the ID of the paper being cited and the second ID stands for the paper which contains the citation. The direction of the link is from right to left. If a line is represented by "paper1 paper2" then the link is "paper2->paper1".

In [2]:
"""
Simple supervised GraphSAGE model as well as examples running the model
on the Cora and Pubmed datasets.
"""

class SupervisedGraphSage(nn.Module):

    def __init__(self, num_classes, enc):
        super(SupervisedGraphSage, self).__init__()
        self.enc = enc
        self.xent = nn.CrossEntropyLoss()

        self.weight = nn.Parameter(torch.FloatTensor(num_classes, enc.embed_dim))
        init.xavier_uniform(self.weight)

    def forward(self, nodes):
        embeds = self.enc(nodes)
        scores = self.weight.mm(embeds)
        return scores.t()

    def loss(self, nodes, labels):
        scores = self.forward(nodes)
        return self.xent(scores, labels.squeeze())

def load_cora():
    num_nodes = 2708
    num_feats = 1433 # уникальных слов бинарного вектора каждой вершины
    feat_data = np.zeros((num_nodes, num_feats))
    labels = np.empty((num_nodes,1), dtype=np.int64)
    node_map = {}
    label_map = {}
    with open("cora/cora.content") as fp:
        for i,line in enumerate(fp):
            info = line.strip().split()
            feat_data[i,:] = list(map(float, info[1:-1]))
            node_map[info[0]] = i
            if not info[-1] in label_map:
                label_map[info[-1]] = len(label_map)
            labels[i] = label_map[info[-1]]

    adj_lists = defaultdict(set)
    with open("cora/cora.cites") as fp:
        for i,line in enumerate(fp):
            info = line.strip().split()
            paper1 = node_map[info[0]]
            paper2 = node_map[info[1]]
            adj_lists[paper1].add(paper2)
            adj_lists[paper2].add(paper1)
    return feat_data, labels, adj_lists

In [3]:
num_nodes = 2708
num_feats = 1433 # уникальных слов бинарного вектора каждой вершины
feat_data = np.zeros((num_nodes, num_feats))
labels = np.empty((num_nodes,1), dtype=np.int64)
node_map = {}
label_map = {} # словарь с кодировкой 7 классов
with open("cora/cora.content") as fp:
    for i, line in enumerate(fp):
        info = line.strip().split()
        feat_data[i,:] = list(map(float, info[1:-1]))
        node_map[info[0]] = i
        if not info[-1] in label_map:
            label_map[info[-1]] = len(label_map)
            labels[i] = label_map[info[-1]]

adj_lists = defaultdict(set)
with open("cora/cora.cites") as fp:
    for i,line in enumerate(fp):
        info = line.strip().split()
        paper1 = node_map[info[0]]
        paper2 = node_map[info[1]]
        adj_lists[paper1].add(paper2)
        adj_lists[paper2].add(paper1)

In [4]:
adj_lists

defaultdict(set,
            {163: {22,
              42,
              55,
              129,
              141,
              145,
              174,
              188,
              189,
              191,
              219,
              237,
              266,
              290,
              309,
              346,
              380,
              390,
              395,
              402,
              415,
              422,
              448,
              523,
              530,
              546,
              563,
              602,
              606,
              624,
              658,
              659,
              689,
              714,
              717,
              727,
              743,
              744,
              757,
              765,
              769,
              781,
              793,
              800,
              813,
              856,
              910,
              935,
              940,
              942,
              961,
            

In [5]:
label_map

{'Case_Based': 6,
 'Genetic_Algorithms': 5,
 'Neural_Networks': 0,
 'Probabilistic_Methods': 3,
 'Reinforcement_Learning': 2,
 'Rule_Learning': 1,
 'Theory': 4}

In [6]:
labels

array([[                  0],
       [                  1],
       [                  2],
       ..., 
       [         4670016448],
       [3418795067127627777],
       [         4368448728]])

In [11]:
def run_cora():
    np.random.seed(1)
    random.seed(1)
    num_nodes = 2708
    feat_data, labels, adj_lists = load_cora()
    features = nn.Embedding(2708, 1433)
    features.weight = nn.Parameter(torch.FloatTensor(feat_data), requires_grad=False)
   # features.cuda()

    agg1 = MeanAggregator(features, cuda=True)
    enc1 = Encoder(features, 1433, 128, adj_lists, agg1, gcn=True, cuda=False)
    agg2 = MeanAggregator(lambda nodes : enc1(nodes).t(), cuda=False)
    enc2 = Encoder(lambda nodes : enc1(nodes).t(), enc1.embed_dim, 128, adj_lists, agg2,
            base_model=enc1, gcn=True, cuda=False)
    enc1.num_samples = 5
    enc2.num_samples = 5

    graphsage = SupervisedGraphSage(7, enc2)
#    graphsage.cuda()
    rand_indices = np.random.permutation(num_nodes)
    test = rand_indices[:1000]
    val = rand_indices[1000:1500]
    train = list(rand_indices[1500:])

    optimizer = torch.optim.SGD(filter(lambda p : p.requires_grad, graphsage.parameters()), lr=0.7)
    times = []
    for batch in range(100):
        batch_nodes = train[:256]
        random.shuffle(train)
        start_time = time.time()
        optimizer.zero_grad()
        loss = graphsage.loss(batch_nodes, 
                Variable(torch.LongTensor(labels[np.array(batch_nodes)])))
        loss.backward()
        optimizer.step()
        end_time = time.time()
        times.append(end_time-start_time)
        print (batch, loss.data[0])

    val_output = graphsage.forward(val) 
    print ("Validation F1-macro:", f1_score(labels[val], val_output.data.numpy().argmax(axis=1), average="micro"))
    print ("Average batch time:", np.mean(times))

In [10]:
if __name__ == "__main__":
    run_cora()

0 1.9442723989486694
1 1.9286892414093018
2 1.9014158248901367
3 1.8826987743377686
4 1.848921537399292
5 1.8117573261260986
6 1.7734934091567993
7 1.7524316310882568
8 1.669343113899231
9 1.6175191402435303
10 1.540442943572998
11 1.5086159706115723
12 1.409396767616272
13 1.398180365562439
14 1.3389586210250854
15 1.2618186473846436
16 1.195272445678711
17 1.1109561920166016
18 1.038014531135559
19 1.071476936340332
20 0.9232332110404968
21 1.0266202688217163
22 0.8229751586914062
23 0.884732186794281
24 0.7611747980117798
25 0.7287988066673279
26 0.7191015481948853
27 0.8186793923377991
28 0.8818615674972534
29 1.010355830192566
30 0.6069413423538208
31 0.5770217776298523
32 0.5530521273612976
33 0.5401475429534912
34 0.4588921368122101
35 0.48101067543029785
36 0.4982425272464752
37 0.558771550655365
38 0.44576871395111084
39 0.48572468757629395
40 0.4226260185241699
41 0.46171438694000244
42 0.5400210618972778
43 0.4604809284210205
44 0.4884299337863922
45 0.36357200145721436
46 0

In [12]:
if __name__ == "__main__":
    run_cora()

0 1.9444255828857422
1 1.9184678792953491
2 1.8983098268508911
3 1.868328332901001
4 1.825634241104126
5 1.7801891565322876
6 1.7467312812805176
7 1.7252392768859863
8 1.6468331813812256
9 1.583451271057129
10 1.5362558364868164
11 1.507093906402588
12 1.4090080261230469
13 1.4146698713302612
14 1.3599122762680054
15 1.2870018482208252
16 1.234735131263733
17 1.1642166376113892
18 1.0972449779510498
19 1.121261715888977
20 0.9730217456817627
21 1.0703723430633545
22 0.8670238256454468
23 0.9438769221305847
24 0.8158208727836609
25 0.7570282220840454
26 0.7087162733078003
27 0.7047733068466187
28 0.7624689340591431
29 0.8713706731796265
30 1.2225092649459839
31 0.7565930485725403
32 0.6912572383880615
33 0.5860294103622437
34 0.5076517462730408
35 0.5146880149841309
36 0.5156162977218628
37 0.5403442978858948
38 0.42390772700309753
39 0.45227888226509094
40 0.3855976462364197
41 0.396552175283432
42 0.4403824210166931
43 0.3938487768173218
44 0.4035552740097046
45 0.3519592881202698
46 