# GraphSAGE - Link Stealing Attack
> Steal links from GraphSAGE model

#### Imports

In [1]:
import argparse
import time
import numpy as np
import networkx as nx
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args
from dgl.data import citation_graph as citegrh
from dgl.nn.pytorch.conv import SAGEConv
import matplotlib.pyplot as plt

Using backend: pytorch


#### Random

In [2]:
import random
random.seed(1)

#### Disable Warnings

In [3]:
import warnings
warnings.filterwarnings('ignore')

#### GraphSAGE model

In [4]:
class GraphSAGE(nn.Module):
    
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout,
                 aggregator_type):
        super(GraphSAGE, self).__init__()
        self.layers = nn.ModuleList()
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

        # input layer
        self.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type))
        # hidden layers
        for i in range(n_layers - 1):
            self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type))
        # output layer
        self.layers.append(SAGEConv(n_hidden, n_classes, aggregator_type)) # activation None

        
    def forward(self, graph, inputs):
        h = self.dropout(inputs)
        for l, layer in enumerate(self.layers):
            h = layer(graph, h)
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
        return h

#### Classification

In [5]:
def evaluate(model, graph, features, labels, nid):
    model.eval()
    with torch.no_grad():
        logits = model(graph, features)
        logits = logits[nid]
        labels = labels[nid]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

In [6]:
def load_data(args):
    if args['dataset'] == 'cora':
        return citegrh.load_cora()
    elif args['dataset'] == 'citeseer':
        return citegrh.load_citeseer()
    elif args['dataset'] == 'pubmed':
        return citegrh.load_pubmed()
    elif args['dataset'] is not None and args['dataset'].startswith('reddit'):
        return RedditDataset(self_loop=('self-loop' in args.dataset))
    else:
        raise ValueError('Unknown dataset: {}'.format(args.dataset))

In [7]:
def main(args):
    # load and preprocess dataset
    print('---------Load Dataset----------\n')

    data = load_data(args)
    g = data[0]
    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    in_feats = features.shape[1]
    n_classes = data.num_classes
    n_edges = data.graph.number_of_edges()
    
    print("""\n\n--------Data statistics--------
    
    Edges   %d
    Classes %d
    
    Train samples %d
    Val samples   %d
    Test samples  %d""" %
          (n_edges, n_classes,
           train_mask.int().sum().item(),
           val_mask.int().sum().item(),
           test_mask.int().sum().item()))

    if args['gpu'] < 0:
        cuda = False
    else:
        cuda = True
        torch.cuda.set_device(args['gpu'])
        features = features.cuda()
        labels = labels.cuda()
        train_mask = train_mask.cuda()
        val_mask = val_mask.cuda()
        test_mask = test_mask.cuda()
        print("\n    Cuda in use", args['gpu'])

    train_nid = train_mask.nonzero().squeeze()
    val_nid = val_mask.nonzero().squeeze()
    test_nid = test_mask.nonzero().squeeze()

    # graph preprocess and calculate normalization factor
    g = dgl.remove_self_loop(g)
    n_edges = g.number_of_edges()
    if cuda:
        g = g.int().to(args['gpu'])

    # create GraphSAGE model
    model = GraphSAGE(in_feats,
                      args['n_hidden'],
                      n_classes,
                      args['n_layers'],
                      F.relu,
                      args['dropout'],
                      args['aggregator_type'])

    if cuda:
        model.cuda()

    # use optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])

    # initialize graph
    dur = []
    print("\n\n--------Training process--------\n")
    for epoch in range(args['n_epochs']):
        model.train()
        if epoch >= 3:
            t0 = time.time()
        # forward
        logits = model(g, features)
        loss = F.cross_entropy(logits[train_nid], labels[train_nid])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)

        acc = evaluate(model, g, features, labels, val_nid)
        print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
              "ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(),
                                            acc, n_edges / np.mean(dur) / 1000))

    print()
    acc = evaluate(model, g, features, labels, test_nid)
    print("Test Accuracy {:.4f}".format(acc))

In [8]:
# Model Parameter
para = {'aggregator_type': 'gcn', 
        'dataset': 'cora', 
        'dropout': 0.5, 
        'gpu': 0, 
        'lr': 0.01, 
        'n_epochs': 200, 
        'n_hidden': 16, 
        'n_layers': 2, 
        'weight_decay': 0.0005}

main(para)

---------Load Dataset----------

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


--------Data statistics--------
    
    Edges   10556
    Classes 7
    
    Train samples 140
    Val samples   500
    Test samples  1000

    Cuda in use 0


--------Training process--------

Epoch 00000 | Time(s) nan | Loss 1.9449 | Accuracy 0.2700 | ETputs(KTEPS) nan
Epoch 00001 | Time(s) nan | Loss 1.9417 | Accuracy 0.4160 | ETputs(KTEPS) nan
Epoch 00002 | Time(s) nan | Loss 1.9346 | Accuracy 0.4020 | ETputs(KTEPS) nan
Epoch 00003 | Time(s) 0.0037 | Loss 1.9234 | Accuracy 0.6400 | ETputs(KTEPS) 2816.48
Epoch 00004 | Time(s) 0.0036 | Loss 1.9180 | Accuracy 0.6880 | ETputs(KTEPS) 2902.33
Epoch 00005 | Time(s) 0.0036 | Loss 1.8980 | Accuracy 0.6100 | ETputs(KTEPS) 2937.18
Epoch 00006 | Time(s) 0.0036 | Loss 1.8871 | Accuracy 0.5800 | ETputs(KTEPS) 2951.67
Epoch 00007 |

Epoch 00128 | Time(s) 0.0031 | Loss 0.2729 | Accuracy 0.7860 | ETputs(KTEPS) 3432.08
Epoch 00129 | Time(s) 0.0031 | Loss 0.2326 | Accuracy 0.7720 | ETputs(KTEPS) 3430.46
Epoch 00130 | Time(s) 0.0031 | Loss 0.2249 | Accuracy 0.7620 | ETputs(KTEPS) 3422.12
Epoch 00131 | Time(s) 0.0031 | Loss 0.2134 | Accuracy 0.7740 | ETputs(KTEPS) 3415.44
Epoch 00132 | Time(s) 0.0031 | Loss 0.3129 | Accuracy 0.7880 | ETputs(KTEPS) 3415.43
Epoch 00133 | Time(s) 0.0031 | Loss 0.2777 | Accuracy 0.7960 | ETputs(KTEPS) 3415.93
Epoch 00134 | Time(s) 0.0031 | Loss 0.2555 | Accuracy 0.7900 | ETputs(KTEPS) 3414.49
Epoch 00135 | Time(s) 0.0031 | Loss 0.2439 | Accuracy 0.7920 | ETputs(KTEPS) 3414.90
Epoch 00136 | Time(s) 0.0031 | Loss 0.2102 | Accuracy 0.7880 | ETputs(KTEPS) 3415.34
Epoch 00137 | Time(s) 0.0031 | Loss 0.2361 | Accuracy 0.7940 | ETputs(KTEPS) 3415.76
Epoch 00138 | Time(s) 0.0031 | Loss 0.2446 | Accuracy 0.7960 | ETputs(KTEPS) 3416.14
Epoch 00139 | Time(s) 0.0031 | Loss 0.2325 | Accuracy 0.7900 | ET