# Lab 6: GNNs (Graph Neural Networks)

In this lab, we will develop Graph Neural Networks (GNN).

Reference:
- [Graph Neural Networks: A review of methods and applications](arxiv.org/ftp/arxiv/papers/1812/1812.08434.pdf)
- [A Gentle Introduction to Graph Neural Networks (Basics, DeepWalk, and GraphSage)](https://towardsdatascience.com/a-gentle-introduction-to-graph-neural-network-basics-deepwalk-and-graphsage-db5d540d50b3)
- [Hands-on Graph Neural Networks with PyTorch & PyTorch Geometric](towardsdatascience.com/hands-on-graph-neural-networks-with-pytorch-pytorch-geometric-359487e221a8)
- [Graph Neural Network (GNN): What It Is and How to Use It](https://builtin.com/data-science/gnn)
- [An Introduction to Graph Neural Network(GNN) For Analysing Structured Data](https://towardsdatascience.com/an-introduction-to-graph-neural-network-gnn-for-analysing-structured-data-afce79f4cfdc)

## What is GNNs?

Graph Neural Network(GNN) is the first introduce in 2009, [cite](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.1015.7227&rep=rep1&type=pdf). GNN is a neural model that can be applied directly to graphs without prior knowledge of each component. GNN provides a convenient method for performing node, edge, and graph level prediction tasks.

GNN recently has received a lot of attention because its ability to analyze graph structural data which models a set of objects and their relationships.

In order to incorporate graph structured information in the data processing step, the underlying graph structured data is encoded using the topological relationships between the nodes of the graph. This class of techniques includes recursive neural networks and Markov chains, which are commonly used to solve graph and node-focused problems.

## Graph Theory

A graph is a data structure made up of two parts: **vertices ($V$)** and **edges ($E$)**. It is a mathematical structure that is used to examine the pair-wise relationship between objects and entities.

A graph ($G$) is defined as
$$G=(V,E)$$
where
- $V$ is a set of nodes or vertices. These two terms are interchangeable.
- $E$ is the edges between them. Edges can be either directed or undirected, depending on whether there exist directional dependencies between vertices.

<img src="img/GNNSimpleGraph.webp" title="A simple Graph" style="width: 150px;" />

A graph is often represented by an Adjacency matrix ($A$). If a graph has $N$ nodes, then $A$ has a dimension of ($N\times N$). Mathematically, the graph’s adjacency matrix has a value of 1 only when there is an edge; otherwise it’s zero.

<img src="img/AdjacencyMatrix.PNG" title="Adjacency matrix" style="width: 600px;" />

Another **feature matrix** is sometimes provided to describe the nodes in the graph. If each node has $F$ features, then the dimension of the feature matrix $X$ is ($N \times F$).

<img src="img/FeatureMatrix.PNG" title="Feature matrix" style="width: 600px;" />

We can perform $A\times X$. This turns the matrix multiplication, $H$ into the summation of nodes connected to the reference node.

<img src="img/multiMatrix.PNG" title="Summation matrix" style="width: 600px;" />

## Graph concept, why so difficult?

1. A graph does not exist in Euclidean space, so it cannot be represented by any of the coordinate systems we are familiar with. This makes graph data much more difficult to interpret when compared to other types of data such as waves, images, or time-series signals, all of which can be mapped to a 2-D or 3-D space.
2. Graphs do not have a fixed shape. Consider the following example. Graphs A and B have completely different structures and appear completely different from one another, but when converted to adjacency matrix representation, the two graphs have the same adjacency matrix (if we ignore the weight of the edges).

<table><tr>
<td> <img src="img/GraphA.webp" title="A" style="width: 150px;" /> </td>
<td> <img src="img/GraphB.webp" title="B" style="width: 150px;" /> </td>
</tr></table>

3. In general, graphs are difficult to visualize for human interpretation. This is not about small graphs, but about massive graphs with hundreds or thousands of nodes. Humans struggle to understand the graph when the dimension is very high and the nodes are densely grouped. As a result, training a machine for this task is difficult.

<img src="img/circuitnetlist.webp" title="B" style="width: 300px;" />

## Advantages of Graphs

The graphs have been used for a variety of reasons as:

1. Graphs are a better way to represent abstract concepts such as relationships and interactions. They also provide a visually intuitive approach to thinking about these concepts. Graphs are also a natural starting point for analyzing relationships in a social context.
2. Graphs can solve more complex problems by reducing them to simpler representations or by transforming them into representations from various perspectives.
3. Graph theories and concepts are used to investigate and model Social Networks, Fraud patterns, Power consumption patterns, Virality, and Influence in Social Media. The most well-known application of Graph Theory for Data Science is probably Social Network Analysis (SNA).

## Traditional graph analysis methods

1. Searching algorithms, e.g. Breadth First Search (BFS), Dept First Search (DFS)
2. Shortest path algorithms, e.g. Dijkstra’s algorithm, Nearest Neighbour
3. Spanning-tree algorithms, e.g. Prim’s algorithm
4. clustering methods, e.g. Highly Connected Components, k-mean

The limitation of such algorithms is that we must first gain confidence in the graph before we can apply the algorithm. It does not allow us to investigate the graph itself. Most importantly, graph level classification is not possible.

## Graph Neural Network

A Graph Neural Network is a type of Neural Network that operates directly on the Graph structure. Node classification is a common application of GNN. Essentially, each node in the graph has a label, and we want to predict the labels of the nodes without using ground truth.

There are mainly three types of graph neural networks in the literature:

1. Recurrent Graph Neural Network
2. Spatial Convolutional Network
3. Spectral Convolutional Network

### GNN function

Each node $v$ in the node classification problem is identified by its feature $x_v$ and associated with a ground-truth label $t_v$. The goal is to use the labeled nodes in a partially labeled graph $G$ to predict the labels of the unlabeled nodes. It learns to represent each node with a d-dimensional vector (state) $h_v$ containing information about its surroundings. Specifically,

$$h_v=f(x_v,x_{co[v]},h_{ne[v]},x_{ne[v]})$$

where $x_{co[v]}$ refers to the features of the edges connecting with $v$, $h_{ne[v]}$ refers to the embedding of $v$'s neighboring nodes, and $x_{ne[v]}$ refers to the features of $v$'s neighboring nodes. The transition function $f$ is responsible for projecting these inputs onto a d-dimensional space.

We can use the Banach fixed point theorem to rewrite the above equation as an iteratively updated process because we are looking for a unique solution for $h_v$. This operation is also known as message passing or neighborhood aggregation.

$$H^{t+1}=F(H^t,X)$$

$H$ and $X$ denote the concatenation of all the $h$ and $x$, respectively.

### Output

The output of the GNN is computed by passing the state $h_v$ as well as the feature $x_v$ to an output function $g$.

$$o_v=g(h_v,x_v)$$

Both $f$ and $g$ here can be interpreted as feed-forward fully-connected Neural Networks. 

### Loss function

The L1 loss can be straightforwardly formulated as the following:

$$\mathcal{L}_1=\sum_{i=1}^p(t_i-o_i)$$

which can be optimized via gradient descent.

### Original GNN limitation
There are three main limitations:

1. If the assumption of "fixed point" is relaxed, Multi-layer Perceptron can be used to learn a more stable representation while eliminating the iterative update process. This is because different iterations of the original proposal use the same parameters of the transition function f, whereas different parameters in different layers of MLP allow for hierarchical feature extraction.
2. It cannot process edge information (for example, different edges in a knowledge graph may indicate different relationships between nodes).
3. Fixed point can discourage node distribution diversification and thus may be unsuitable for learning to represent nodes.

Several GNN variants have been proposed to address the aforementioned issue. However, they are not covered as they are not the focus in this post.

### DeepWalk

Reference: http://www.perozzi.net/publications/14_kdd_deepwalk.pdf

DeepWalk is the first algorithm to propose unsupervised node embedding learning. In terms of training, it is very similar to word embedding. The motivation is that the distribution of nodes in a graph and words in a corpus follows a power law, as illustrated in the figure below:

<img src="img/deepwalk.webp" title="http://www.perozzi.net/publications/14_kdd_deepwalk.pdf" style="width: 500px;" />

The algorithm contains two steps:

1. Perform random walks on nodes in a graph to generate node sequences
2. Run skip-gram to learn the embedding of each node based on the node sequences generated in step 1

At each time step of the random walk, the next node is sampled uniformly from the neighbor of the previous node. Each sequence is then truncated into sub-sequences of length $2|w| + 1$, where w denotes the window size in skip-gram.

You can learn more about skip-gram in the [link](https://towardsdatascience.com/word-embedding-with-word2vec-and-fasttext-a209c1d3e12c)

Hierarchical softmax is applied to address the costly computation of softmax due to the huge number of nodes. To compute the softmax value of each of the individual output element, we must compute all the $e^{x_k}$ for all the element $k$.

$$softmax(x)_i = \frac{e^{x_i}}{\sum_{k=1}^K e^{x_k}}$$

Therefore, the computation time is $O(|V|)$ for the original softmax, where $V$ denotes the set of vertices in the graph.

To solve the problem, hierarchical softmax employs a binary tree. All of the leaves $(v1, v2, \dots)$ are vertices in this binary tree. Each inner node contains a binary classifier that determines which path to take. To compute the probability of a given vertex $v_k$, simply add the probabilities of each sub-path from the root node to the leave $v_k$. Because the probability of each node's children sums to one, the property that the sum of the probability of all vertices equals one remains valid in the hierarchical softmax. As the longest path for an element, the computation time is now reduced to $O(\log|V|)$.

<img src="img/HierarchicalSoftmax.webp" title="http://www.perozzi.net/publications/14_kdd_deepwalk.pdf" style="width: 500px;" />

After training a DeepWalk GNN, the model has learned a good representation of each node, as shown in the figure below. In the input graph, different colors represent different labels. We can see in the output graph (embedding with two dimensions) that nodes with the same labels are clustered together, whereas most nodes with different labels are properly separated.

<img src="img/Deepwalkresult.webp" title="http://www.perozzi.net/publications/14_kdd_deepwalk.pdf" style="width: 500px;" />

The main problem with DeepWalk is that it lacks generalization ability. When a new node is added, the model must be re-trained to represent this node (transductive). As a result, such GNN is unsuitable for dynamic graphs with constantly changing nodes.

## GraphSage

Reference: https://www-cs-faculty.stanford.edu/people/jure/pubs/graphsage-nips17.pdf

GraphSage provides a solution to the aforementioned problem by inductively learning the embedding for each node. In particular, each node is represented by the sum of its neighbors. As a result, even if a previously unseen node appears in the graph during training, it can still be properly represented by its neighbors.

<img src="img/GraphSageAlgo.webp" title="http://www.perozzi.net/publications/14_kdd_deepwalk.pdf" style="width: 500px;" />

The number of update iterations is indicated by the outer loop, and $h^k_v$ is the latent vector of node $v$ at update iteration $k$. $h^k_v$ is updated at each update iteration using an aggregation function, the latent vectors of $v$ and $v$'s neighborhood from the previous iteration, and a weight matrix $W^k$. There are three aggregation functions as:

1. **Mean aggregator**: takes the average of the latent vectors of a node and all its neighborhood.

$$h_v^k\leftarrow \sigma(W\cdot \text{MEAN}(\{h_v^{k-1}\}\cup \{h_u^{k-1},\forall u \in \mathcal{N}(v)\}))$$

Compared with the original equation, it removes the concatenation operation at line 5 in the above pseudo code. This operation can be viewed as a “skip-connection”, which later in the paper proved to largely improve the performance of the model.

2. **LSTM aggregator**: Since the nodes in the graph don’t have any order, they assign the order randomly by permuting these nodes.

3. **Pooling aggregator**: This operator performs an element-wise pooling function on the neighboring set. Below shows an example of max-pooling:

$$\text{AGGREGATE}_k^{pool}= \max(\{\sigma(W_{pool}h_{u_i}^k+b),\forall u_i \in \mathcal{N}(v)\})$$

which can be substituted for mean-pooling or any other symmetric pooling function. It is stated that the pooling aggregator performs the best, while the mean-pooling and max-pooling aggregators perform similarly. The paper's default aggregation function is max-pooling.

### Loss function

The loss function is defined as the following:

$$J_{\mathcal{G}}(z_u)=-\log(\sigma(z_u^\top))-Q\cdot \mathbb{E}_{v_n\sim P_n(v)}\log(\sigma(-z_u^\top z_{v_n}))$$

where $u$ and $v$ co-occur in a fixed-length random walk, while $v_n$ are the negative samples that don’t co-occur with $u$. Such loss function encourages nodes closer to have similar embedding, while those far apart to be separated in the projected space. Via this approach, the nodes will gain more and more information about their neighborhoods.

By aggregating its nearby nodes, GraphSage generates representable embedding for unseen nodes. It enables the application of node embedding to domains involving dynamic graphs, where the structure of the graph is constantly changing.

## RECURRENT GRAPH NEURAL NETWORK

RecGNN is built with an assumption of Banach Fixed-Point Theorem. Banach Fixed-Point Theorem states:

***Let $(X,d)$ be a complete metric space and let $(T:X\rightarrow X)$ be a contraction mapping. Then $T$ has a unique fixed point $(x*)$ and for any $x\in X$ the sequence $T_n(x)$ for $n \rightarrow \infty$ converges to $(x*)$.***

If apply the mapping $T$ on $x$ for $k$ times, $x^k$ should be almost equal to $x^{(k-1)}$.

$$x^k=T(x^{k-1}),k\in (1,n)$$

RecGNN defines a parameterized function $f_w$:

$$x_n=f_w(l_n,l_{co[n]},x_{ne[n]},l_{ne[n]})$$

$l_n$, $l_{co}$, $x_{ne}$, $l_{ne}$ represent the features of the current node $[n]$, the edges of the node $[n]$, the state of the neighboring nodes, and the features of the neighboring nodes.

<img src="img/7_gnn.png" title="The Graph Neural Network Model paper" style="width: 500px;" />

Finally, after $k$ iterations, the graph neural network model makes use of the final node state to produce an output in order to make a decision about each node. The output function is defined as:

$$o_n=g_w(x_n,l_n)$$

## Spatial Convolution network

The spatial convolution network is similar to the convolution neural networks (CNN) that dominate the image classification and segmentation literature. Convolution on an image, in a nutshell, is the sum of neighboring pixels around a center pixel specified by a filter with parameterized size and learnable weight. The same concept is used in a spatial convolutional network, which aggregates the features of neighboring nodes into the center node.

<img src="img/9_gnn.png" title="Left: Convolution on a regular graph such as an image. Right: Convolution on the arbitrary graph structure. | Image: A Comprehensive Survey on Graph Neural Networks" style="width: 500px;" />

##  Spectral Convolutional Networks

Spectral convolution networks have a solid mathematical foundation when compared to other types of GNN. The spectral convolutional network is based on graph signal processing theory as well as graph convolution simplification and approximation. Graph convolution can be reduced to the following:

$$g_{\theta'}\ast x \approx \sum_{k=0}^K \theta_kT_k(\Lambda)$$

After further simplification Kipf and Welling suggest a two-layered neural network structure, described as:

$$Z=f(X,A)=\mathit{softmax}(\hat{A}\mathit{Relu}(\hat{A}XW^{(0)})W^{(1)})$$

$A_{head}$ is the pre-processed Laplacian of the original graph adjacency matrix $A$. This formula looks very familiar if you have some experience in machine learning because it’s nothing but two fully connected layer structures that programmers commonly use. Nevertheless, it serves as graph convolution in this case.

### Spectral convolutional networks vs Spectral convolutional networks

Despite having different starting points, spectral and spatial convolutional networks follow the same propagation rule. The format of all currently available convolutional graph neural networks is the same. They are all attempting to learn a function to pass node information around and update node state via this message-passing process. A message-passing neural network with 
 - message-passing function: $M_t:m^{l+1}=M_t(H^l,A)$
 - node update function: $U_t:H^{l+1}=U_t(H^l,m^{l+1})$
 - readout function: $R_t: y=R(H^l)$

can be expressed as any graph neural network. 

## What GNN can do?

GNN can solve in tasks:

- Node classification
- Link prediction
- Graph classification

### Node classification

The goal of node classification is to predict the node embedding for each node in a graph. This type of problem is typically trained semi-supervised, with only a portion of the graph labeled. Citation networks, Reddit posts, YouTube videos, and Facebook friendships are examples of common node classification applications.

### Link prediction

The task of link prediction is to understand the relationship between entities in graphs and predict whether two entities are connected. A recommender system, for example, can be modeled as a link prediction problem. When we feed the model a collection of user reviews of various products, the task is to predict the users' preferences and tune the recommender system to push more relevant products based on the users' interests.

### Graph classification

The goal of graph classification is to divide the entire graph into different categories. It's similar to image classification, but the target is a graph. There are numerous industrial problems where graph classification can be used; for example, in chemistry, biomedicine, or physics, we can provide the model with a molecular structure and ask it to classify the target into meaningful categories. The model then speeds up the analysis of atoms, molecules, and other structured data.

## GNN Applications

### Natural Language Processing (NLP)

GNN is frequently used in natural language processing (NLP), which is where GNN got its start. If you've worked with natural language processing (NLP), you're probably thinking that text is a type of sequential or temporal data that we can describe with a recurrent neural network (RNN) or a long short-term memory (LTSM). GNN, on the other hand, approaches the problem from an entirely different perspective. GNN predicts categories by utilizing the inner relationships of words or documents. A citation network, for example, attempts to predict each paper's label in a network based on the paper citation relationship and the words cited in other papers. GNN can also construct a syntactic model by examining different parts of sentences rather than only working sequentially as RNN or LTSM do.

### Computer Vision

Many GNN-based methods have achieved cutting-edge performance in image object detection, but we still don't know the relationships between the objects. Using graphs to model relationships between objects detected by a CNN-based detector is one successful application of GNN in computer vision (CV). After detecting objects in images, they are fed into a GNN inference for relationship prediction. The GNN inference produces a generated graph that models the relationships between various objects.

<img src="img/16_gnn.png" title="" style="width: 500px;" />

Image generation from graph descriptions is another intriguing CV application. This is almost the inverse of the preceding application. Text-to-image generation using generative adversarial network (GAN) or autoencoder is the traditional method of image generation. Rather than using text to describe images, graph-to-image generation provides more information about the semantic structures of the images.

<img src="img/17_gnn.png" title="" style="width: 500px;" />

The most intriguing application is zero-shot learning (ZSL), which learns to classify an object with no target class training samples. If no training samples are provided, the model must think in order to recognize a target. Assume we're given three images and told to find an okapi among them. We've never seen an okapi before, but if we're also told that an okapi is a deer-faced animal with four legs and zebra-striped skin, it's not difficult to figure out which one is an okapi. The detected features are typically converted into text to simulate this thought process. Text encodings, on the other hand, are independent of one another. The relationships between the text descriptions are difficult to model. Graph representations, on the other hand, accurately model these relationships and assist the machine in thinking more like a human.

<img src="img/18_gnn.png" title="" style="width: 500px;" />

### Other applications

Human behavior detection, traffic control, molecular structure study, recommender systems, program verification, logical reasoning, social influence prediction, and adversarial attack prevention are some of the more practical applications of GNN. Through social network analysis, for example, GNN can be used to cluster people into different community groups.

## GNN in code

Now we are implementing GNN using PyTorch and PyTorch Geometric (PyG), a Graph Neural Network framework built on top of PyTorch that runs blazingly fast. It is several times faster than the most well-known GNN framework, DGL.

### Pytorch Geometric Basics (PyG basic)

#### Data
The *torch_geometric.data* module contains a Data class that allows you to create graphs from your data very easily. You only need to specify:

1. the attributes/features associated with each node
2. the connectivity/adjacency of each node (edge index)

Let’s use the following graph to demonstrate how to create a Data object

<img src="img/GNNEx01.PNG" title="" style="width: 500px;" />

There are 4 nodes in the graph, $A \dots D$, each of which is associated with a 2-dimensional feature vector, and a label y indicating its class. These two can be represented as FloatTensors:

#### Install PyG

Please follow the link:[https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html)

In [None]:
!pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.13.0+cpu.html

In [1]:
!pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu102.html

[0mLooking in links: https://data.pyg.org/whl/torch-1.10.0+cu102.html
[0m

In [2]:
import torch
from torch_geometric.data import Data

x = torch.tensor([[2,1], [5,6], [3,7], [12,0]], dtype=torch.float)
y = torch.tensor([0, 1, 0, 1], dtype=torch.float)

The graph connectivity (edge index) should be confined with the COO format, i.e. the first list contains the index of the source nodes, while the index of target nodes is specified in the second list.

In [3]:
edge_index = torch.tensor([[0, 1, 2, 0, 3],
                           [1, 0, 1, 3, 2]], dtype=torch.long)

Note that the order of the edge index is irrelevant to the Data object you create since such information is only for computing the adjacency matrix. Therefore, the above edge_index express the same information as the following one.

In [4]:
edge_index = torch.tensor([[0, 2, 1, 0, 3],
                           [3, 1, 0, 1, 2]], dtype=torch.long)

In [5]:
data = Data(x=x, y=y, edge_index=edge_index)

### Dataset
The dataset creation procedure is not very straightforward, but it may seem familiar to those who’ve used torchvision, as PyG is following its convention. PyG provides two different types of dataset classes, InMemoryDataset and Dataset. As they indicate literally, the former one is for data that fit in your RAM, while the second one is for much larger data. Since their implementations are quite similar, I will only cover InMemoryDataset.

To create an InMemoryDataset object, there are 4 functions you need to implement:

- *raw_file_names()*: It returns a list that shows a list of raw, unprocessed file names. If you only have a file then the returned list should only contain 1 element. In fact, you can simply return an empty list and specify your file later in *process()*.
- *processed_file_names()*: It also returns a list containing the file names of all the processed data. After *process()* is called, Usually, the returned list should only have one element, storing the only processed data file name.
- *download()*: This function should download the data you are working on to the directory as specified in self.raw_dir. If you don’t need to download data, simply drop in <code>pass</code> in the function.
- *process()*: This is the most important method of Dataset. You need to gather your data into a list of Data objects. Then, call *self.collate()* to compute the slices that will be used by the DataLoader object.

In [None]:
import torch
from torch_geometric.data import InMemoryDataset


class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(MyOwnDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # Download to `self.raw_dir`.

    def process(self):
        # Read data into huge `Data` list.
        data_list = [...]

        if self.pre_filter is not None:
            data_list [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

### DataLoader

The DataLoader class allows you to feed data by batch into the model effortlessly. To create a DataLoader object, you simply specify the Dataset and the batch size you want.

In [None]:
loader = DataLoader(dataset, batch_size=512, shuffle=True)

Every iteration of a DataLoader object yields a Batch object, which is very much like a Data object but with an attribute, “batch”. It indicates which graph each node is associated with. Since a DataLoader aggregates x, y, and edge_index from different samples/ graphs into Batches, the GNN model needs this “batch” information to know which nodes belong to the same graph within a batch to perform computation.

In [None]:
for batch in loader:
    print(batch)

### MessagePassing

Message passing is the essence of GNN which describes how node embeddings are learned. I have talked about in my last post, so I will just briefly run through this with terms that conform to the PyG documentation.

$$x_i^{(k)}= \gamma^{(k)}(x_i^{(k-1)}, \square_{j \in \mathcal{N}(i)}\phi^{(k)}(x_i^{(k-1)},x_j^{(k-1)},e_{i,j}))$$

$x$ denotes the node embeddings, $e$ denotes the edge features, $\phi$ denotes the message function, $\square$ denotes the aggregation function, $\gamma$ denotes the update function. If the edges in the graph have no feature other than connectivity, $e$ is essentially the edge index of the graph. The superscript represents the index of the layer. When $k=1$, $x$ represents the input feature of each node. Below I will illustrate how each function works:

- <code>propagate(edge_index, size=None, **kwargs)</code>: It takes in edge index and other optional information, such as node features (embedding). Calling this function will consequently call message and update.
- <code>message(**kwargs)</code>: You specify how you construct “message” for each of the node pair $(x_i, x_j)$. Since it follows the calls of propagate, it can take any argument passing to propagate. One thing to note is that you can define the mapping from arguments to the specific nodes with “_i” and “_j”. Therefore, you must be very careful when naming the argument of this function.
- <code>update(aggr_out, **kwargs)</code>: It takes in the aggregated message and other arguments passed into propagate, assigning a new embedding value for each node.

### SageConv

Now, we implement [SageConv](https://arxiv.org/abs/1706.02216). The message passing formula of SageConv is defined as:

$$h_{\mathcal{N}(v)}^k \leftarrow \text{AGGREGATE}_k(\{h_u^{k-1},\forall u \in \mathcal{N}(v)\})$$
$$h_v^k \leftarrow \sigma(W^k \cdot \text{CONCAT}(h_v^{k-1},h_{\mathcal{N}(v)}^k))$$

Max pooling is used as the aggregation method. Therefore, the right-hand side of the first line can be written as:
$$\max(\{\sigma(W_{pool}h_{u_i}^k+b), \forall u_i \in \mathcal{N}(v)\}$$

which illustrates how the “message” is constructed. Each neighboring node embedding is multiplied by a weight matrix, added a bias and passed through an activation function. 

As for the update part, the aggregated message and the current node embedding is aggregated. Then, it is multiplied by another weight matrix and applied another activation function.

the SageConv layer class is:

In [None]:
import torch
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops
class SAGEConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(SAGEConv, self).__init__(aggr='max') #  "Max" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)
        self.act = torch.nn.ReLU()
        self.update_lin = torch.nn.Linear(in_channels + out_channels, in_channels, bias=False)
        self.update_act = torch.nn.ReLU()
        
    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        
        
        edge_index, _ = remove_self_loops(edge_index)
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        
        
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_j):
        # x_j has shape [E, in_channels]

        x_j = self.lin(x_j)
        x_j = self.act(x_j)
        
        return x_j

    def update(self, aggr_out, x):
        # aggr_out has shape [N, out_channels]


        new_embedding = torch.cat([aggr_out, x], dim=1)
        
        new_embedding = self.update_lin(new_embedding)
        new_embedding = self.update_act(new_embedding)
        
        return new_embedding

### Example RecSys Challenge 2015

The RecSys Challenge 2015 is challenging data scientists to build a session-based recommender system. Participants in this challenge are asked to solve two tasks:

1. Predict whether there will be a buy event followed by a sequence of clicks
2. Predict which item will be bought

We can download the data from the official website of [RecSys Challenge 2015 and construct a Dataset in Kaggle](https://www.kaggle.com/datasets/chadgostopp/recsys-challenge-2015).

The challenge provides two main sets of data, yoochoose-clicks.dat, and yoochoose-buys.dat, containing click events and buy events, respectively.

#### Preprocessing

After downloading the data, we preprocess it so that it can be fed to our model. item_ids are categorically encoded to ensure the encoded item_ids, which will later be mapped to an embedding matrix, starts at 0.

In [6]:
import numpy as np
import pandas as pd
import pickle
import csv
import os
import torch
from torch_geometric.data import Data
from tqdm import tqdm

In [7]:
df = pd.read_csv('./archive/yoochoose-clicks.dat', header=None)
df.columns=['session_id','timestamp','item_id','category']

df.head()

  has_raised = await self.run_ast_nodes(code_ast.body, cell_name,


Unnamed: 0,session_id,timestamp,item_id,category
0,1,2014-04-07T10:51:09.277Z,214536502,0
1,1,2014-04-07T10:54:09.868Z,214536500,0
2,1,2014-04-07T10:54:46.998Z,214536506,0
3,1,2014-04-07T10:57:00.306Z,214577561,0
4,2,2014-04-07T13:56:37.614Z,214662742,0


In [8]:
buy_df = pd.read_csv('./archive/yoochoose-buys.dat', header=None)
buy_df.columns=['session_id','timestamp','item_id','price','quantity']

buy_df.head()

Unnamed: 0,session_id,timestamp,item_id,price,quantity
0,420374,2014-04-06T18:44:58.314Z,214537888,12462,1
1,420374,2014-04-06T18:44:58.325Z,214537850,10471,1
2,281626,2014-04-06T09:40:13.032Z,214535653,1883,1
3,420368,2014-04-04T06:13:28.848Z,214530572,6073,1
4,420368,2014-04-04T06:13:28.858Z,214835025,2617,1


In [9]:
# filter out item session with length < 2
df['valid_session'] = df.session_id.map(df.groupby('session_id')['item_id'].size() > 2)
df = df.loc[df.valid_session].drop('valid_session',axis=1)
df.nunique()

session_id     4431931
timestamp     24590089
item_id          48255
category           331
dtype: int64

Since the data is quite large, we subsample it for easier demonstration.

In [10]:
#randomly sample a couple of them
sampled_session_id = np.random.choice(df.session_id.unique(), 1000000, replace=False)
df = df.loc[df.session_id.isin(sampled_session_id)]
df.nunique()

session_id    1000000
timestamp     5557628
item_id         37216
category          257
dtype: int64

In [11]:
# average length of session 
df.groupby('session_id')['item_id'].size().mean()

5.559688

In [12]:
from sklearn.preprocessing import LabelEncoder

item_encoder = LabelEncoder()
category_encoder = LabelEncoder()
df['item_id'] = item_encoder.fit_transform(df.item_id )
df['category']= category_encoder.fit_transform(df.category.apply(str))
df.head()

Unnamed: 0,session_id,timestamp,item_id,category
0,1,2014-04-07T10:51:09.277Z,1612,0
1,1,2014-04-07T10:54:09.868Z,1611,0
2,1,2014-04-07T10:54:46.998Z,1613,0
3,1,2014-04-07T10:57:00.306Z,7173,0
49,19,2014-04-01T20:52:12.357Z,5759,0


In [13]:
buy_df = buy_df.loc[buy_df.session_id.isin(df.session_id)]
buy_df['item_id'] = item_encoder.transform(buy_df.item_id)
buy_df.head()

Unnamed: 0,session_id,timestamp,item_id,price,quantity
0,420374,2014-04-06T18:44:58.314Z,1901,12462,1
1,420374,2014-04-06T18:44:58.325Z,1893,10471,1
3,420368,2014-04-04T06:13:28.848Z,776,6073,1
4,420368,2014-04-04T06:13:28.858Z,29347,2617,1
26,70427,2014-04-02T15:54:07.144Z,26493,3769,1


In [14]:
buy_item_dict = dict(buy_df.groupby('session_id')['item_id'].apply(list))
buy_item_dict

{87: [12996, 31056, 22447, 27395, 26550, 27422, 26275],
 189: [10178],
 197: [24949],
 277: [30016, 30017],
 319: [9294, 9294],
 484: [16786, 16786],
 507: [21558, 6043],
 593: [26726, 7935],
 612: [23235],
 651: [21394, 21394, 30613],
 708: [8081, 16369],
 873: [26492],
 899: [19622],
 966: [2924, 21418],
 1112: [26267],
 1372: [27326, 27329, 27327, 27328],
 1457: [26722],
 1562: [21599],
 1713: [27455, 27429],
 1834: [30766, 30996],
 1962: [27458, 27455, 27429, 27467, 29279, 26734, 27457],
 2071: [27378, 27322, 27441, 27330],
 2101: [26402],
 2103: [27422, 21517, 27406, 27408],
 2256: [30766, 30994, 30996, 30995],
 2377: [30766],
 2409: [18174, 26483, 26496, 26482],
 2426: [27464, 27469],
 2553: [26720, 26768, 6522],
 2666: [19404, 18795],
 2834: [21687],
 2859: [20098],
 2998: [7935],
 3074: [28911, 21030, 17360, 17360],
 3191: [16785, 31206],
 3264: [20099],
 3456: [26351, 27416],
 3474: [19635, 19635],
 3517: [20943, 24949],
 3624: [18528],
 3811: [26722, 29279],
 3841: [28911, 28

In [15]:
df.head()

Unnamed: 0,session_id,timestamp,item_id,category
0,1,2014-04-07T10:51:09.277Z,1612,0
1,1,2014-04-07T10:54:09.868Z,1611,0
2,1,2014-04-07T10:54:46.998Z,1613,0
3,1,2014-04-07T10:57:00.306Z,7173,0
49,19,2014-04-01T20:52:12.357Z,5759,0


To determine the ground truth, i.e. whether there is any buy event for a given session, we simply check if a session_id in yoochoose-clicks.dat presents in yoochoose-buys.dat as well.

In [16]:
df['label'] = df.session_id.isin(buy_df.session_id)
df.head()

Unnamed: 0,session_id,timestamp,item_id,category,label
0,1,2014-04-07T10:51:09.277Z,1612,0,False
1,1,2014-04-07T10:54:09.868Z,1611,0,False
2,1,2014-04-07T10:54:46.998Z,1613,0,False
3,1,2014-04-07T10:57:00.306Z,7173,0,False
49,19,2014-04-01T20:52:12.357Z,5759,0,False


#### Dataset Construction

The data is ready to be transformed into a Dataset object after the preprocessing step. Here, we treat each item in a session as a node, and therefore all items in the same session form a graph. To build the dataset, we group the preprocessed data by session_id and iterate over these groups. In each iteration, the item_id in each group are categorically encoded again since for each graph, the node index should count from 0. Thus, we have the following:

In [17]:
import torch
from torch_geometric.data import InMemoryDataset
from tqdm import tqdm

class YooChooseDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(YooChooseDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return []
    @property
    def processed_file_names(self):
        return ['../archive/yoochoose_click_binary_100000_sess.dataset']

    def download(self):
        pass
    
    def process(self):
        
        data_list = []

        # process by session_id
        grouped = df.groupby('session_id')
        for session_id, group in tqdm(grouped):
            le = LabelEncoder()
            sess_item_id = le.fit_transform(group.item_id)
            group = group.reset_index(drop=True)
            group['sess_item_id'] = sess_item_id
            node_features = group.loc[group.session_id==session_id,['sess_item_id','item_id','category']].sort_values('sess_item_id')[['item_id','category']].drop_duplicates().values

            node_features = torch.LongTensor(node_features).unsqueeze(1)
            target_nodes = group.sess_item_id.values[1:]
            source_nodes = group.sess_item_id.values[:-1]

            edge_index = torch.tensor([source_nodes,
                                   target_nodes], dtype=torch.long)
            x = node_features

            if session_id in buy_item_dict:
                positive_indices = le.transform(buy_item_dict[session_id])
                label = np.zeros(len(node_features))
                label[positive_indices] = 1
            else:
                label = [0] * len(node_features)


            y = torch.FloatTensor(label)

            data = Data(x=x, edge_index=edge_index, y=y)

            data_list.append(data)
        
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
        

In [18]:
dataset = YooChooseDataset('./')

In [19]:
dataset = dataset.shuffle()
one_tenth_length = int(len(dataset) * 0.1)
train_dataset = dataset[:one_tenth_length * 8]
val_dataset = dataset[one_tenth_length*8:one_tenth_length * 9]
test_dataset = dataset[one_tenth_length*9:]
len(train_dataset), len(val_dataset), len(test_dataset)

(800000, 100000, 100000)

In [20]:
from torch_geometric.data import DataLoader
batch_size= 512
train_loader = DataLoader(train_dataset, batch_size=batch_size)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)



In [21]:
num_items = df.item_id.max() +1
num_categories = df.category.max()+1
num_items , num_categories

(37216, 256)

#### Build a Graph Neural Network

In [22]:
embed_dim = 128
from torch_geometric.nn import GraphConv, TopKPooling, GatedGraphConv, SAGEConv, SGConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
import torch.nn.functional as F
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = GraphConv(embed_dim * 2, 128)
        self.pool1 = TopKPooling(128, ratio=0.9)
        self.conv2 = GraphConv(128, 128)
        self.pool2 = TopKPooling(128, ratio=0.9)
        self.conv3 = GraphConv(128, 128)
        self.pool3 = TopKPooling(128, ratio=0.9)
        self.item_embedding = torch.nn.Embedding(num_embeddings=num_items, embedding_dim=embed_dim)
        self.category_embedding = torch.nn.Embedding(num_embeddings=num_categories, embedding_dim=embed_dim)        
        self.lin1 = torch.nn.Linear(256, 256)
        self.lin2 = torch.nn.Linear(256, 128)
        self.bn1 = torch.nn.BatchNorm1d(128)
        self.bn2 = torch.nn.BatchNorm1d(64)
        self.act1 = torch.nn.ReLU()
        self.act2 = torch.nn.ReLU()        
  
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        item_id = x[:,:,0]
        category = x[:,:,1]
        

        emb_item = self.item_embedding(item_id).squeeze(1)
        emb_category = self.category_embedding(category).squeeze(1)
        
#         emb_item = emb_item.squeeze(1)
#         emb_cat
        x = torch.cat([emb_item, emb_category], dim=1)  
#         print(x.shape)
        x = F.relu(self.conv1(x, edge_index))
#                 print(x.shape)
        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = F.relu(self.conv2(x, edge_index))
     
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = F.relu(self.conv3(x, edge_index))

        x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = x1 + x2 + x3

        x = self.lin1(x)
        x = self.act1(x)
        x = self.lin2(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.act2(x)      
        
        outputs = []
        for i in range(x.size(0)):
            output = torch.matmul(emb_item[data.batch == i], x[i,:])

            outputs.append(output)
              
        x = torch.cat(outputs, dim=0)
        x = torch.sigmoid(x)
        
        return x

In [29]:
device = torch.device('cuda:1')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
crit = torch.nn.BCELoss()

#### Training

Training our custom GNN is very easy, we simply iterate the DataLoader constructed from the training set and back-propagate the loss function. Here, we use Adam as the optimizer with the learning rate set to 0.005 and Binary Cross Entropy as the loss function.

In [24]:
def train():
    model.train()

    loss_all = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)

        label = data.y.to(device)
        loss = crit(output, label)
        loss.backward()
        loss_all += data.num_graphs * loss.item()
        optimizer.step()
    return loss_all / len(train_dataset)

#### Validation

This label is highly unbalanced with an overwhelming amount of negative labels since most of the sessions are not followed by any buy event. In other words, a dumb model guessing all negatives would give you above 90% accuracy. Therefore, instead of accuracy, Area Under Curve (AUC) is a better metric for this task as it only cares if the positive examples are scored higher than the negative examples. We use the off-the-shelf AUC calculation function from Sklearn.

In [25]:
from sklearn.metrics import roc_auc_score
def evaluate(loader):
    model.eval()

    predictions = []
    labels = []

    with torch.no_grad():
        for data in loader:

            data = data.to(device)
            pred = model(data).detach().cpu().numpy()

            label = data.y.detach().cpu().numpy()
            predictions.append(pred)
            labels.append(label)

    predictions = np.hstack(predictions)
    labels = np.hstack(labels)
    
    return roc_auc_score(labels, predictions)

#### Result
Trained the model for 2 epoch, and measure the training, validation, and testing AUC scores:

In [30]:
for epoch in range(1, 2):
    loss = train()
    train_acc = evaluate(train_loader)
    val_acc = evaluate(val_loader)    
    test_acc = evaluate(test_loader)
    print('Epoch: {:03d}, Loss: {:.5f}, Train Auc: {:.5f}, Val Auc: {:.5f}, Test Auc: {:.5f}'.
          format(epoch, loss, train_acc, val_acc, test_acc))

ValueError: Encountered invalid 'dim_size' (got '2372' but expected >= '4294967297')

## Exercise

Dr. Matt, please decide :-)

Maybe we can try this: [pdf](https://arxiv.org/pdf/2206.00272.pdf) [Vision GNN](https://github.com/huawei-noah/Efficient-AI-Backbones) and [model](https://gitee.com/mindspore/models)