# GCN for Network Anomaly Detection


This notebook introduces methods to use [**Graph Convolutional Network**](https://tkipf.github.io/graph-convolutional-networks/) (**GCN**) to detect anomalies in network flows.
Being able to detect such anomalies could help preventing/stopping network attacks.

## Introduction

### Introduction to GCN

The basic idea behind **GCN** is the aggregation of the features from the neihborhood of each nodes. By repeating this operation $k$ times for a given node $U$, its resulting features is a combination of the features of the $k$-hop neighborhood of $U$.

![GCN Animation Medium](https://miro.medium.com/max/600/1*wt31DAeKTVeDWmPqKprycw.gif)

This simple yet powerfull operation allows us to capture implicit topological information of the node as well as explicit features. The resulting features can be used to determine the role of a node within the sub-graph it belongs to.

> *Incrementing $k$ increases the diameter of the sub-graph from which $U$ is the center, however that doesn't make the model better since it can lead to [**over-smoothing**](https://arxiv.org/pdf/1812.08434.pdf), which results in having the same resulting features for all the nodes in the graph*

To make it short, GCN propagate nodes features through edges.

### GCN for network anomaly detection

#### Network capture to Graph
A network snapshot can be represented by a graph by considering IPs as nodes and packet exchanges between IPs as edges (*ex: IP_A send TCP request to IP_B*). 

Below is shown a visualization of a network snapshot of 60 seconds:

![Network graph representation](./data/network_image.png)

#### Useful graph geometry

The role of some nodes on the network can be guessed just by watching the graph representation.

> For example, the **DNS** servers are probably the nodes at the center of the big clusters.

We can hence infer **topological information are meaningful** in those graph and that **GCN** could be relevant to detect anomalies as detailed [here](https://github.com/harvardnlp/botnet-detection).

#### Needs for more features

However, **topological information are not sufficient** to detect more elabored attacks. For that purpose, **we need more features**, for example:

- IP geographic localization
- IP Reputation score (from websites as [UrlVoid](https://www.urlvoid.com/) for example)
- IP Service Provider (ISP)

However, those are information we don't necesseraly have access on the public datasets we can use to train our model. More generally, **we don't have access to any IP (graph node) related information for privacy reasons**.

#### Edge-based features

Fortunately, we can **interpolate features from packets**, within a certain time window:

- Quantity of bytes (min, mean, median, max, std)
- Number of packets
- Protocol (TCP, UDP, ICMP)
- Frequency
- Connection Probability (that we could generate using another model)
- Exchange history (encoded exchanges history, represented by a vector)
- ...

> That's true for almost all the available public datasets I've worked on: [CTU-13](https://www.stratosphereips.org/datasets-ctu13), [CSE-CIC-IDS2018 (2018)](https://www.unb.ca/cic/datasets/ids-2018.html), [UGR'16 (2016)](https://nesg.ugr.es/nesg-ugr16/index.php).

Unfortunately, **GCNs are not defined to handle features on edges** and require features for nodes that we can not provide.

#### Node features interpolation from edges

Even if we were to have access to private IPs (node) information and therefore would have features on the nodes, we cannot ignore information from IPs interactions (edges).

The issue is that there are, to the best of my knowledge, no researches on the subject.

In this notebook, I try to explore a solution that I've naively entitled **Edge2Node** that consists of **interpolating node features as a non-linear combination of the in/out edges**. 

### Requirements
At the time of writting this notebook, there are 3 major Deep Learning Graph libraries in python:

- [Pytorch Geometric](https://github.com/rusty1s/pytorch_geometric)
- [GraphNets](https://github.com/deepmind/graph_nets)
- [Deep Graph Library (DGL)](https://github.com/dmlc/dgl)

There aren't any concrete comparisons between the three yet, so I jut went with the one that attracted me the most. Since I'm more used to **PyTorch**, I've ignored **GraphNets**. **PyTorch-Geometric** implements a lot a GCN papers, however, it seemed a bit rough to me compared to **DGL**, which appeared to have a well thought pipeline and plans for the future.

Conclusion, I had no real reasons to went with **DGL**, that's just intuition.

In [15]:
!pip install --user torch==1.6.0 dgl==0.5.2 networkx==2.4 numpy==1.19.3 matplotlib==3.3.1 tqdm



In [22]:
import dgl
import torch
import networkx as nx
import numpy as np
from tqdm.notebook import tqdm

### Project decomposition

The notebook has required the following steps:

1. Dataset Preparation:
    - Truncate raw PCAP (drop useless data)
    - Slide time Window
    - Extract features for each interaction IP_A to IP_B within window (nb packets, nb bytes sent...)
    - Graph Generation
    
2. GCN Model Design:
    - Generate nodes features from in/out edges (**Edge2Node**)
    - Apply GCN on the graph (using predicted nodes features)
    - Generate edges embedding from new nodes features (**Node2Edge**)
    - Classify nodes and edges using the computed features
    
3. Training:
    - Loss to penalize errors on edge/node classification
    - Basic ML pytorch training loop
    - Model Evaluation
    
However, we will mainly focus on the step **2. GCN Model Design** here since the others steps are just a draft used for the proof of concept. 

## Edge2Node

**Edge2Node** is the method we'll use to interpolate nodes features from edges features.

> The idea is to describe a node as the influence it has on its neighbors and the influence they have on itself.

### Proof Of Concept

In this section, we will apply the **Edge2Node** idea to a toy example.

#### Problem definition

Let's suppose we have the following graph where edges model heat transfer between nodes:

![edge2node_image](./data/edge2node_demo.png)

How can we classify the nodes into the categories: **COLD**, **HOT** ?

##### Intuitive solution

An intuitive solution is to assert that a node is best described by its contribution to the local system stabilization, then the feature $F$ of a node $u$ would be: 

$$F(u) = \sum_{v \in \mathcal{N}_{in}(u)}{W(v, u)}  - \sum_{v \in \mathcal{N}_{out}(u)}{W(u, v)}$$

It can be interpreted as : *A node loses the energy that goes through its out-edges and wins the energy that comes from its in-edges*.

##### Solution application

This process is represented by the following animation:

![edge2node_demo](./data/edge2node.gif)

We calculate that the contribution of the node $A$ to its neighborhood is:

$$F(A) = W(C, A) - W(A, B) = -16 - 20 = -36$$

Hence, the node $A$ loses a lot of heat, there is a **high probability** that a **HOT** node.

By applying the same idea to the other nodes, we have:
- $B$ gains a **bit** of heat (+4), there is a **slightly higher probability** that it is a **COLD** node
- $C$ gains a **lot** of heat (+32), there is a **high probability** that it is a **COLD** node

#### DGL implementation

In [3]:
class DummyEdgeToNode(torch.nn.Module):
    def forward(self, graph: dgl.DGLGraph, h):
        """
        graph: DGLGraph on which we apply the operation
        h: tensor of all edges features
        """
        h_in  = h  # in-edges  -> energy gain
        h_out = -h # out-edges -> energy loses
        
        # operations on local_scope do not affect the input graph
        with graph.local_scope():
            # shallow copy of the computed `h_in` and `h_out` as edges features
            graph.edata['e_in'] = h_in
            graph.edata['e_out'] = h_out
            
            # 1. Handle IN-EDGES
            # build and apply message passing function to the graph
            graph.update_all(
                # copy in-edges features `e_in` in the mailboxes `n_in` of the nodes
                dgl.function.copy_e('e_in', 'n_in'), 
                # reduce the messages of the mailboxes `n_in` by summing them and store
                # the result into `n_in`
                dgl.function.sum('n_in', 'n_in')
            )
            
            # 2. Handle OUT-EDGES
            # At the time of writing, we can't apply message passing algorithm on out-edges
            # with DGL. The only way I've found is to re-apply the previous logic on the
            # reversed-graph, so the out-edges become in-edges.
            graph = graph.reverse(copy_ndata=True, copy_edata=True)
            graph.update_all(
                dgl.function.copy_e('e_out', 'n_out'), 
                dgl.function.sum('n_out', 'n_out')
            )
            
            return graph.ndata['n_in'] + graph.ndata['n_out']

In [4]:
dummy_e2n = DummyEdgeToNode()

# we create the toy example graph
uv = [0, 1, 2, 2], [1, 2, 1, 0] # directed edges ('A'<=>0, 'B'<=>1, 'C'<=>2)
G = dgl.graph(uv)

# we apply the dummy e2n
edge_features = torch.tensor([20., 12., -4., -16.])
node_features = dummy_e2n(graph=G, h=edge_features).numpy()

## tada!
print(f"A = {node_features[0]}")
print(f"B = {node_features[1]}")
print(f"C = {node_features[2]}")

A = -36.0
B = 4.0
C = 32.0


##### Going Further
This simple idea seems to be enough on a small graph like this one, but what if the graph contained thousands of nodes?

We would need to look further than the direct node neighborhood to truly understand its role on the graph. Now that we have features on the nodes, we can introduce **GCN** to better capture the role of the node within the graph.

### Formula upgrade

The difference between the previously seen **DummyEdgeToNode** and the real one, is that instead of having a fixed (hardcoded) combination, which in our case was:

$$F(u) = \sum_{v \in \mathcal{N}_{in}(u)}{W(v, u)} - \sum_{v \in \mathcal{N}_{out}(u)}{W(u, v)}$$


This formula handles different type of edges on different type of graphs the same way while this is not always relevant. There might be cases when the out-edges have less impact than the in-edge, or cases when a given feature of the feature tensor has more impact on out-edges while another one has a negative impact on in-edges.

To handle those various differences (which depend on the type of graphs, edges, features), make the formula a bit more sofisticated:

$$F(u) = \theta(\sum_{v \in \mathcal{N}_{in}(u)}{\sigma(\theta_{in}(W(v, u)))} + \sum_{v \in \mathcal{N}_{out}(u)}{\sigma(\theta_{out}(W(u, v)))})$$

Where:
- $\sigma$ is a non-linear function (ex: RElu, sigmoid, tanh...)
- $\theta, \theta_{in}, \theta_{out}$ are function with parameters to be optimized (ex: Linear torch layer)
- $W(u, v)$ returns the features of the edge $(u,v)$

> There are no restrictions on $\theta, \theta_{in}, \theta_{out}$ (e.g. it can be a deep CNN as it can be a simple linear layer)

It should allow **Edge2Node** to scale to different graphs.

In [5]:
class EdgeToNode(torch.nn.Module):
    def __init__(self, in_features, hid_features, out_features, non_linear=None):
        super().__init__()
        self.linear_in = torch.nn.Linear(in_features, hid_features) # theta_in
        self.linear_out = torch.nn.Linear(in_features, hid_features) # theta_out
        self.linear_final = torch.nn.Linear(hid_features, out_features) # theta
        self.non_linear = non_linear or torch.nn.Identity()
        
    def forward(self, graph: dgl.DGLGraph, h):
        # apply theta_in and theta_out to edges features
        h_in, h_out = self.linear_in(h), self.linear_out(h)
        # apply sigma (non_linear) to `h_in` and `h_out`
        h_in, h_out = self.non_linear(h_in), self.non_linear(h_out)
        
        with graph.local_scope(): 
            graph.edata['e_in'] = h_in
            graph.edata['e_out'] = h_out
            
            graph.update_all(
                dgl.function.copy_e('e_in', 'n_in'),
                dgl.function.sum('n_in', 'n_in')
            ) 
            
            graph = graph.reverse(copy_ndata=True, copy_edata=True) 
            graph.update_all(
                dgl.function.copy_e('e_out', 'n_out'),
                dgl.function.sum('n_out', 'n_out')
            ) 
            
            return self.linear_final(graph.ndata['n_in'] + graph.ndata['n_out'])

### Node2Edge

For edge classification, we need to do the process inverse of **Edge2Node**. In fact, after we've applied a GCN to have meaningful node embedding, we need to "send" those information on edges. 

This can be done the following way:

$$W(u, v) = \theta( \sigma(\theta_{src}(F(u))) + \sigma(\theta_{dst}(F(v))) )$$

The intuition behind is the same as **Edge2Node**

In [6]:
class NodeToEdge(torch.nn.Module):
    def __init__(self, in_features, hid_features, out_features, non_linear=None):
        super().__init__()
        self.linear_src = torch.nn.Linear(in_features, hid_features)
        self.linear_dst = torch.nn.Linear(in_features, hid_features)
        self.linear_final = torch.nn.Linear(hid_features, out_features)
        self.non_linear = non_linear or nn.Identity()
        
    def forward(self, graph, h):
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.edata['h'] = torch.zeros((graph.number_of_edges(), h.shape[-1]))
            
            graph.apply_edges(dgl.function.e_add_u('h', 'h', 'h_src'))
            graph.apply_edges(dgl.function.e_add_v('h', 'h', 'h_dst'))
            
            h_src = self.linear_src(graph.edata['h_src'])
            h_dst = self.linear_dst(graph.edata['h_dst'])
            
            h_src = self.non_linear(h_src)
            h_dst = self.non_linear(h_dst)
            
            return self.linear_final(h_src + h_dst)

### Demonstration of synthetic data

We will now try node classification on a **synthetic dataset** to demonstrate the validity of **Edge2Node**, for that, we:

- Generate a random graph
- For each node, we generate synthetic features that describe a class among a fixed number of classes
- Generate edge features from the synthetic node features
- Train a model to predict node classes solely based on the edge features

#### Data generation

In [85]:
from sklearn.datasets import make_gaussian_quantiles
def generate_graph(n2e, n_nodes, n_ft_nodes=10, n_ft_edges=None, n_classes=8, cov=3):
    n_ft_edges = n_ft_edges or 2*n_ft_nodes
    
    g = nx.barabasi_albert_graph(n_nodes, 1)
    g = dgl.from_networkx(g)
    
    x, y = make_gaussian_quantiles(
        cov=cov, n_classes=n_classes, 
        n_features=n_ft_nodes, n_samples=n_nodes
    )
    x, y = torch.from_numpy(x).float(), torch.from_numpy(y).long()
    g.ndata['ft'] = x
    g.ndata['y'] = y
    
    with torch.no_grad():
        g.edata['ft'] = n2e(g, g.ndata['ft'])
        
    return g

n_features_nodes = 8
n_features_edges = 32
n_classes = 10
n2e = NodeToEdge(
    in_features=n_features_nodes,
    hid_features=3*n_features_nodes,
    out_features=n_features_edges,
    non_linear=torch.nn.ReLU(),
)

In [86]:
from torch.utils.data import random_split
# generate 200 synthetic graphs
dt = [
    generate_graph(
        n2e, 
        n_nodes=np.random.randint(100, 4000),
        n_ft_nodes=n_features_nodes,
        n_ft_edges=n_features_edges, 
        n_classes=n_classes
    )  for _ in tqdm(range(200))
]

dt_train, dt_val, dt_test = random_split(
    dt, 
    (len(dt) * np.array([0.7, 0.2, 0.1])).astype(int)
)

HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))




#### Baseline

Before training and evaluating the **NodePredictor**, we need baselines for comparisons.

- **Random**: the accuracy of such a model is 1/n_classes, for n_classes=10, we accuracy score is 0.1 . This gives us the **minimal performance** our model needs to achieve.
- **ML Classifier** trained on the **original nodes features**. This gives us the **maximal performance** our model should be able to achieve.

> The reason why the **ML Classifier** should give the **maximal performance** possible for our **GCN-based Classifier** is that the class of a node depends solely on its own features. Hence, the **nodes relationships are meaningless** for the classification. Moreover, the ML Classifier will work directly on the **non-altered nodes features**, while the GCN Classifier has to interpolate them from the edges using **Edge2Node**.

In [65]:
def evaluate(model, forward, dt, criterion):
    losses = []
    corr, tot = 0, 0
    with torch.no_grad():
        model.eval()
        for g in tqdm(dt, desc='eval', leave=False):
            Y_true = g.ndata['y']
            
            Y_pred = forward(model, g) 
            
            loss = criterion(Y_pred, Y_true)
            losses.append(loss.item())
            Y_pred = Y_pred.max(dim=-1)[1]
            
            tot += len(Y_pred)
            corr += (Y_true == Y_pred).sum().item()
            
    return np.mean(losses), corr/tot

def train(model, criterion, forward, dt_train, dt_val, dt_test, nb_epochs=100, freq_show_loss=10):
    optimizer = torch.optim.Adam(model.parameters())
    
    loss_val, acc_val = evaluate(model, forward, dt_val, criterion)
    loss_test, acc_test = evaluate(model, forward, dt_test, criterion)
    print('Before:')
    print(f'\tloss_val: {loss_val:.3f} / acc_val: {acc_val:.3f}')
    print(f'\tloss_test: {loss_test:.3f} / acc_tes: {acc_test:.3f}')
    
    for epoch in tqdm(range(nb_epochs), desc='epoch'):
        losses = []
        corr, tot = 0, 0
        
        model.train()
        pb_training = tqdm(dt_train, desc='train', leave=False)
        for idx, g in enumerate(pb_training):
            Y_true = g.ndata['y']
            Y_pred = forward(model, g)
            
            loss = criterion(Y_pred, Y_true)
            losses.append(loss.item())
            
            Y_pred = Y_pred.max(dim=-1)[1]
            tot += len(Y_pred)
            corr += (Y_pred == Y_true).sum().item()
            
            if idx % freq_show_loss == 0:
                pb_training.set_description(f"l:{np.mean(losses):.3f}, acc:{corr/tot:.3f}")
            
            # weights optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        loss_val, acc_val = evaluate(model, forward, dt_val, criterion)
        print(f"epoch {epoch:03d}: loss_val:{loss_val:.3f}, acc: {acc_val:.3f}")
              
    loss_val, acc_val = evaluate(model, forward, dt_val, criterion)
    loss_test, acc_test = evaluate(model, forward, dt_test, criterion)
    print('After:')
    print(f'\tloss_val: {loss_val:.3f} / acc_val: {acc_val:.3f}')
    print(f'\tloss_test: {loss_test:.3f} / acc_tes: {acc_test:.3f}')

In [87]:
model = torch.nn.Sequential(
    torch.nn.Linear(n_features_nodes, n_features_nodes * 2),
    torch.nn.ReLU(),
    torch.nn.Linear(n_features_nodes * 2, n_features_nodes * 2),
    torch.nn.ReLU(),
    torch.nn.Linear(n_features_nodes * 2, n_classes),
)

criterion = torch.nn.CrossEntropyLoss()
train(model, criterion, lambda f,g: f(g.ndata['ft']), dt_train, dt_val, dt_test, nb_epochs=20)

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='eval', max=20.0, style=ProgressStyle(description_width='i…

Before:
	loss_val: 2.344 / acc_val: 0.104
	loss_test: 2.345 / acc_tes: 0.104


HBox(children=(FloatProgress(value=0.0, description='epoch', max=20.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 000: loss_val:2.205, acc: 0.127


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 001: loss_val:2.003, acc: 0.202


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 002: loss_val:1.770, acc: 0.343


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 003: loss_val:1.554, acc: 0.424


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 004: loss_val:1.391, acc: 0.470


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 005: loss_val:1.279, acc: 0.501


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 006: loss_val:1.202, acc: 0.519


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 007: loss_val:1.149, acc: 0.530


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 008: loss_val:1.110, acc: 0.537


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 009: loss_val:1.081, acc: 0.541


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 010: loss_val:1.059, acc: 0.547


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 011: loss_val:1.040, acc: 0.551


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 012: loss_val:1.026, acc: 0.554


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 013: loss_val:1.013, acc: 0.558


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 014: loss_val:1.002, acc: 0.561


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 015: loss_val:0.993, acc: 0.564


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 016: loss_val:0.984, acc: 0.566


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 017: loss_val:0.975, acc: 0.569


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 018: loss_val:0.967, acc: 0.573


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 019: loss_val:0.957, acc: 0.577



HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='eval', max=20.0, style=ProgressStyle(description_width='i…

After:
	loss_val: 0.957 / acc_val: 0.577
	loss_test: 0.962 / acc_tes: 0.576


#### GCN Model definition

In [66]:
import torch.nn.functional as F

# Simple SAGE model for GCN depth 2
class SAGE(torch.nn.Module):
    def __init__(self, in_features, hid_features, out_features):
        super().__init__()
        self.conv1 = dgl.nn.SAGEConv(
            in_feats=in_features, 
            out_feats=hid_features,
            aggregator_type='mean'
        )
        self.conv2 = dgl.nn.SAGEConv(
            in_feats=hid_features, 
            out_feats=out_features,
            aggregator_type='mean'
        )

    def forward(self, graph, h):    
        h = F.relu(self.conv1(graph, h))
        h = self.conv2(graph, h)
        return h
    
    
# Node class prediction
class NodePredictor(torch.nn.Module):
    def __init__(self, n_classes, in_features_e, out_features_e=16, hid_features_n=32, out_features_n=16):
        super().__init__()
        
        self.e2n = EdgeToNode(
            in_features=in_features_e, 
            hid_features=hid_features_n, 
            out_features=out_features_e, 
            non_linear=torch.nn.ReLU()
        )
        self.sage = SAGE(out_features_e, hid_features_n, out_features_n)
        self.n2y = torch.nn.Linear(out_features_n, n_classes)
        
    def forward(self, graph, h):    
        h = self.e2n(graph, h)
        h = torch.tanh(h)
        h = self.sage(graph, h)
        h = torch.relu(h)
        
        return self.n2y(h)

In [91]:
model = NodePredictor(n_classes, n_features_edges)
criterion = torch.nn.CrossEntropyLoss()
print(f"Un-trained model precision: {evaluate(model, lambda f, g: f(g, g.edata['ft']), dt_test, criterion)[1]:.3f}")
train(model, criterion, lambda f, g: f(g, g.edata['ft']), dt_train, dt_val, dt_test, nb_epochs=20)
print(f"Trained model precision: {evaluate(model, lambda f, g: f(g, g.edata['ft']), dt_test, criterion)[1]:.3f}")

HBox(children=(FloatProgress(value=0.0, description='eval', max=20.0, style=ProgressStyle(description_width='i…

Un-trained model precision: 0.105


HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='eval', max=20.0, style=ProgressStyle(description_width='i…

Before:
	loss_val: 2.420 / acc_val: 0.103
	loss_test: 2.419 / acc_tes: 0.105


HBox(children=(FloatProgress(value=0.0, description='epoch', max=20.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 000: loss_val:2.128, acc: 0.193


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 001: loss_val:1.601, acc: 0.343


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 002: loss_val:1.489, acc: 0.381


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 003: loss_val:1.415, acc: 0.408


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 004: loss_val:1.366, acc: 0.422


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 005: loss_val:1.331, acc: 0.435


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 006: loss_val:1.307, acc: 0.443


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 007: loss_val:1.297, acc: 0.448


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 008: loss_val:1.305, acc: 0.446


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 009: loss_val:1.292, acc: 0.451


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 010: loss_val:1.259, acc: 0.462


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 011: loss_val:1.245, acc: 0.466


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 012: loss_val:1.239, acc: 0.468


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 013: loss_val:1.236, acc: 0.471


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 014: loss_val:1.227, acc: 0.471


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 015: loss_val:1.220, acc: 0.474


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 016: loss_val:1.212, acc: 0.476


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 017: loss_val:1.216, acc: 0.475


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 018: loss_val:1.209, acc: 0.478


HBox(children=(FloatProgress(value=0.0, description='train', max=140.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

epoch 019: loss_val:1.210, acc: 0.479



HBox(children=(FloatProgress(value=0.0, description='eval', max=40.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='eval', max=20.0, style=ProgressStyle(description_width='i…

After:
	loss_val: 1.210 / acc_val: 0.479
	loss_test: 1.199 / acc_tes: 0.484


HBox(children=(FloatProgress(value=0.0, description='eval', max=20.0, style=ProgressStyle(description_width='i…

Trained model precision: 0.484


#### Observation

The **GCN-based** classifier performs better than random draws and is not too far from the ideal baseline classifier.

It can then be concluded that **Edge2Node** isn't totally irrelevant.