# Gratis

Gratis is a novel graph representation learning framework that produces a graph representation with a task-specific topology and multi-dimensional edge features to describe an arbitrary graph or non-graph data, including pre-defined graphs, texts, audio, images, etc.

In this tutorial, we first introduce the preliminary materials and core methods of our Gratis. We then describe the detailed implementations with respect to the graph and non-graph input of our Gratis in a step-by-step fashion. At last, for the purpose of reproducing our reported results, we end this tutorial by giving the detailed experimental settings on the used datasets.

**Table of contents**
1. [Background](#Background)
2. [Methods](#Methods)
3. [Step-by-Step Implementation on Graph](#Graph_implementation)
4. [Step-by-Step Implementation on Non-Graph](#Nongraph_implementation)
5. [Experiments](#Experiments)

The reader is encourged to skip the first **Background** part if one is already familar with the basic graph neural network.

## 1 Background <a name="Background"></a>

### 1.1 Graph Representation

A graph $\mathcal{G} = (\mathcal{V}, \mathcal{E})$ is made up of a set of vertices $\mathcal{V} \subseteq \{\mathbf{v}_i \in \mathbb{R}^{1 \times K} \}$, and edges $\mathcal{E} \subseteq \{ \mathbf{e}_{i,j} = \mathbf{e}(\mathbf{v}_i, \mathbf{v}_j) \mid \mathbf{v}_i, \mathbf{v}_j \in \mathcal{V},  i \neq j \}$, where $\mathbf{v}_i$ represents $K$ attributes of the $i_{th}$ object in the graph and $\mathbf{e}_{i,j}$ represents the edge feature that defines the relationship between vertices $\mathbf{v}_i$ and $\mathbf{v}_j$. Each pair of vertices can only be connected by at most one undirected edge or two directed edges. A standard way to describe such edges is through the adjacency matrix $\mathcal{A} \in \mathbb{R}^{|\mathcal{V}| \times |\mathcal{V}|}$, where all vertices in a graph are ordered so that each vertex indexes a specific row and column. As a result, the presence of each edge $\mathbf{e}_{i,j}$ can be described by a binary value $\mathcal{A}_{i,j} = 1$ if $\mathbf{v}_i$ and $\mathbf{v}_j$ are connected or $\mathcal{A}_{i,j} = 0$ otherwise. Specifically, the adjacent matrix is always symmetric if all edges are undirected, but can be non-symmetric if one or more directed edges exist.

![Graph Representation](Graph_Representation.png)

To this end, we have learned the basic data structure of a graph. But what happens in the graph neural network (GNN)?


### 1.2 Graph Neural Network (GNN)

Before going to GNN, let's first recap what's happening in the convolutional neural network?

![Convolution](Convolution.png)

With convolutional neural network (CNN), the so-called convolution is actually the element-wise multiplication and addition with the input feature patch and kernel. Using the given figure above as an example, we have the 5x5 feature map from layer $(l-1)$, the learned 3x3 convolutional kernel, then the resulting feature map can be computued with the the element-wise multiplication and addition, for example, the center element $765 = 255 \times 1 + 255 \times 0 + 255 \times -1 + 255 \times 2 + 255 \times 0 + 0 \times -2 + 255 \times 1 + 255 \times 0 + 0 \times -1$.

Slightly complexed than CNN, for a GNN $G$, its $l_{th}$ layer $G^l$ takes the graph $\mathcal{G}^{l-1} = (\mathcal{V}^{l-1}, \mathcal{E}^{l-1})$ that is produced by the ${(l-1)}_{th}$ layer as the input, and generates a new graph $\mathcal{G}^{l} = (\mathcal{V}^l, \mathcal{E}^l)$, which can be formulated as:
\begin{equation}
    \mathcal{G}^{l} = G^l(\mathcal{G}^{l-1})
\end{equation}
Specifically, the vertex feature $\mathbf{v}^l_i$ in $\mathcal{G}^{l}$ is computed based on:
1. its previous status $\mathbf{v}_i^{l-1}$ in $\mathcal{G}^{l-1}$;
2. the adjacent vertices $\mathbf{v}_j^{l-1}$ of the $\mathbf{v}_i^{l-1}$ in $\mathcal{G}^{l-1}$ (denoted as the $\mathbf{v}_j^{l-1} \subseteq \mathcal{N}(\mathbf{v}_i^{l-1})$, where $\mathcal{A}^{l-1}_{i,j} = 1$ and $\mathcal{A}^{l-1}$ is the adjacent matrix of the $\mathcal{G}^{l-1})$
3. the edge feature $\mathbf{e}_{i,j}^{l-1}$ that represents the relationship between $\mathbf{v}_i^{l-1}$ and $\mathbf{v}_j^{l-1}$ in $\mathcal{N}(\mathbf{v}_i^{l-1})$.

Here, the message $\mathbf{m}$ that the vertex $\mathbf{v}_i^{l-1}$  received from its adjacent vertices $\mathcal{N}(\mathbf{v}^{l-1}_i)$ can be denoted as:
\begin{equation}
\begin{split}
& \mathbf{m}_{\mathcal{N}(\mathbf{v}^{l-1}_i)} = M(\mathbin\Vert ^{N}_{j=1} f(\mathbf{v}_j^{l-1},\mathbf{e}_{i,j}^{l-1})) \\
& f(\mathbf{v}_j^{l-1},\mathbf{e}_{i,j}^{l-1}) = 0 \quad \text{subject to} \quad \mathcal{A}^{l-1}_{i,j} = 0
\end{split}
\label{eq:message-passing}
\end{equation}
where $M$ is a differentiable function that aggregates messages produced from the adjacent vertices; $N$ denotes the number of vertices in the graph $\mathcal{G}^{l-1}$; $f(\mathbf{v}_j^{l-1},\mathbf{e}_{i,j}^{l-1})$ is a differentiable function which defines the influence of an $\mathbf{v}_i^{l-1}$'s adjacent vertex $\mathbf{v}_j^{l-1}$ on $\mathbf{v}_i^{l-1}$ through their edge $\mathbf{e}_{i,j}^{l-1}$; and $\mathbin\Vert $ is the **aggregation** operator to combine messages of all adjacent vertices of the $\mathbf{v}_i^{l-1}$. As a result, the vertex feature $\mathbf{v}^l_i$ can be produced as:
\begin{equation}
\begin{split}
   \mathbf{v}^l_i = G_v^l(\mathbf{v}^{l-1}_i,\mathbf{m}_{\mathcal{N}(\mathbf{v}^{l-1}_i)})
\end{split}
\label{eq:vertex}
\end{equation}
where $G_v^l$ denotes a differentiable function of the $l_{th}$ GNN layer, which updates each vertex feature for the graph $\mathcal{G}^{l}$. 



Besides vertices, an edge feature $\mathbf{e}_{i,j}^l$ in the graph $\mathcal{G}^{l}$ can be either kept as the same to the its previous status $\mathbf{e}_{i,j}^{l-1}$ in the graph $\mathcal{G}^{l-1}$ (denoted as GNN type 1), or updated during the propagation of GNNs (denoted as GNN type 2). Specifically, the edge feature $\mathbf{e}_{i,j}^l$ in $\mathcal{G}^{l}$ is computed based on: (i) its previous status $\mathbf{e}^{l-1}_{i,j}$ in $\mathcal{G}^{l-1}$; and (ii) the corresponding vertex features $\mathbf{v}^{l-1}_i$ and $\mathbf{v}^{l-1}_j$ in $\mathcal{G}^{l-1}$. Mathmatically speaking, the $\mathbf{e}^{l}_{i,j}$ can be computed as:
\begin{equation}
\begin{split}
    \mathbf{e}_{i,j}^l = 
    \begin{cases}
    \mathbf{e}^{l-1}_{i,j} &  \text{GNN type 1} \\
    G_e^l(\mathbf{e}^{l-1}_{i,j},g(\mathbf{v}^{l-1}_i, \mathbf{v}^{l-1}_j)) & \text{GNN type 2}
    \end{cases}
\end{split}
\end{equation}
where $G_e^l$ is a differentiable function of the $l_{th}$ GNN layer, which updates each edge feature for the graph $\mathcal{G}^{l}$, and $g$ is also a differentiable function that models relationship cues between $\mathbf{v}^{l-1}_i$ and $\mathbf{v}^{l-1}_j$. In summary, during the propagation of a GNN, the updating of vertex features and edge features are mutually influenced.

##  2 Methods <a name="Methods"></a>


### 2.1 Motivation
A typical graph consists of a set of vertices and edges, where each vertex usually represents the mathematical abstraction of an object, and each edge describes the relationship between a pair of vertices. To encode a raw data sample as a graph, the majority of existing approaches manually define the vertices, the topology of the graph (i.e., edges' presence) as well as each edge's features based on pre-defined rules. Such hand-crafted strategies frequently assign the same topology for graph representations of all samples in the dataset. Subsequently, task-related connections may be ignored in the manually-defined graphs, and thus the performance of the graph analysis would be limited.

Edge features are also essential components of graphs, the majority of existing graphs only employ a single value as each edge's feature to describe their relationships, which usually ignore crucial relationship cues. To comprehensively utilize rich relationship cues between vertices for graph analysis tasks, several studies developed novel edge message passing methods that allow GNNs to process multi-dimensional edge feature-based graphs. However, these methods only focus on efficiently processing multi-dimensional edge features that are already contained in the input graph. Although some studies manually design multi-dimensional edge features to describe some specific relationship between vertices, these hand-crafted edge features still fail to learn task-specific relationship cues between vertices. In summary, there is lack of a generic graph representation learning framework that can automatically generate a graph representation that has a task-specific topology and multi-dimensional edge features, for any arbitrary input data (i.e., pre-defined graph or non-graph data such as image, video, audio, and text).


### 2.2 Problem Formulation


**Problem 1: manually defined task-agnostic graph topology.**

With a manually defined graph topology (represented as an adjacency matrix) $\mathcal{A} \in \mathbb{R}^{|\mathcal{V}| \times |\mathcal{V}|}$, a pair of vertices $\mathbf{v}^{l}_i$ and $\mathbf{v}^{l}_j$ are connected when their relationship meets the pre-defined criteria $\mathcal{R}$, while no edge presented between a pair of vertices whose relationships are not considered by $\mathcal{R}$. Assuming $\mathcal{A}_{i,j} \in \mathcal{A}$ describes the presence of the edge $\mathbf{e}_{i,j}$ between vertices $\mathbf{v}_i$ and $\mathbf{v}_j$ in $\mathcal{G}$, then it is represented as
\begin{equation}
\begin{split}
    \mathcal{A}_{i,j} = 
    \begin{cases}
    1 & \{\mathbf{v}_i, \mathbf{v}_j \} \in \mathcal{R} \\
    0 &  \text{Otherwise}
    \end{cases}
\end{split}
\label{eq:edge}
\end{equation}
As aformentioned, the message passing of vertex features in the graph $\mathcal{G}$ depends on its adjacency matrix $\mathcal{A}$. Consequently, a manually defined adjacency matrix may not provide the task-specific message passing mechanism for the graph. In other words, properly exploring a task-specific adjacency matrix $\mathcal{\hat{A}} \in \mathbb{R}^{|\mathcal{V}| \times |\mathcal{V}|}$ for $\mathcal{G}$ would allow vertex messages to be passed via task-specific paths, and result in better graph analysis performances.

**Problem 2: message passing via single-value edges.**

While edges are essential components for a graph and decide its message passing process, many existing approaches only use a single value as each edge's representation (denoted as $\mathbf{e}_{i,j} = [e_{i,j}(1)]$) to define either the edge's presence or the strength of association between a pair of vertices (i.e., $e_{i,j}(1) = \mathcal{A}_{i,j}$). Let's define each vertex in a single-value edge feature-based graph $\mathcal{G}$ as $\mathbf{v}_j = [v_j(1), v_j(2), \cdots, v_j(K)]$ ($j = 1, 2, \cdots, N$). Then, the function $f$ in Eqa. \ref{eq:message-passing} can be re-written as: 
\begin{equation}
\begin{split}
f(\mathbf{v}_j,\mathbf{e}_{i,j}) = f([v_j(1),
\cdots, v_j(K)], [e_{i,j}(1)])
\end{split}
\label{eq:single_edge}
\end{equation}
where the impact of the vertex $\mathbf{v}_j$ on its adjacent vertex $\mathbf{v}_i$ is only controlled by a single value $e_{i,j}(1)$, which may fail to include all crucial relationship cues between vertices $\mathbf{v}_i$ and $\mathbf{v}_j$. Consequently, the messages passed via such single-value edges may not be optimal. 

### 2.3 Gratis

On contrary, our Gratis can **generate a task-specific graph representation with a task-specific topology and multi-dimensional edge features from an arbitrary graph or non-graph data**. The pipeline for graph and non-graph input is systematically consistent and is given in the figure below:

![pipeline](pipeline.png)

Gratis takes an arbitrary input $\mathcal{D}^{\text{in}}$, and produces a task-specific graph representation $\mathcal{\hat{G}}(\mathcal{V}, \mathcal{\hat{E}})$ (or $\mathcal{\hat{G}}(\mathcal{\hat{V}}, \mathcal{\hat{E}})$ for non-graph data) which consists of $N$ vertices $\mathcal{V} \subseteq \{\mathbf{v}_i \in \mathbb{R}^{1 \times K} \}$ ($i = 1, 2, \cdots, N$), and a set of **directed** edges whose presences are defined by a binary adjacency matrix $\mathcal{\hat{A}} \in \mathbb{R}^{N \times N}$, where each of these presented **directed** edges is described by a task-specific multi-dimensional edge feature. These edges can be denoted as $\mathcal{\hat{E}} \subseteq \{ \mathbf{\hat{e}}_{i,j} = \mathbf{\hat{e}}(\mathbf{v}_i, \mathbf{v}_j) \mid \mathbf{v}_i, \mathbf{v}_j \in \mathcal{V} \quad \text{and} \quad \mathcal{\hat{A}}_{i,j} = 1 \}$.

The proposed framework consists of three modules:

1. the **Graph Definition (GD)** produces a basic graph representation $\mathcal{G}^{\text{B}}(\mathcal{V}, \mathcal{E})$ from the input data $\mathcal{D}^{\text{in}}$. The $\mathcal{G}^{\text{B}}(\mathcal{V}, \mathcal{E})$ is defined by a set of vertex features $\mathcal{V}$, a basic topology (adjacency matrix) $\mathcal{A}$, and a set of basic edge features $\mathcal{E} \subseteq \{ \mathbf{e}_{i,j} = \mathbf{e}(\mathbf{v}_i, \mathbf{v}_j) \mid \mathbf{v}_i, \mathbf{v}_j \in \mathcal{V} \quad \text{and} \quad \mathcal{A}_{i,j} = 1 \}$;
2. the **Task-specific Topology Prediction (TTP)** module produces a task-specific graph topology, i.e., replacing the basic graph topology defined by $\mathcal{A}$ with a task-specific adjacency matrix $\mathcal{\hat{A}} \in \mathbb{R}^{N \times N}$;
3. the **Multi-dimensional Edge Feature Generation (MEFG)**  specifically assigns a task-specific multi-dimensional edge feature $\mathbf{\hat{e}}_{i,j}$ to each presented edge ($\mathcal{\hat{A}}_{i,j} = 1$), describing multiple task-specific relationship cues between vertices $\mathbf{v}_i$ and $\mathbf{v}_j$ (i.e., replacing basic edge features $\mathcal{E}$ with task-specific multi-dimensional edge features $\mathcal{\hat{E}}$).

Since we train the GD, TTP, and MEFG with the backbone and GNN predictor in an end-to-end manner, both TTP and MEFG are learned to assign task-specific topology and edge features for the final produced graph representation. Moreover, the TTP and MEFG considers not only the relationship between corresponding vertex features but also the global contextual information of both vertices, to decide each edge's presence and feature. As a result, a task-specific graph representation $\mathcal{\hat{G}}(\mathcal{V}, \mathcal{\hat{E}})$ (or $\mathcal{\hat{G}}(\hat{\mathcal{V}}, \mathcal{\hat{E}})$ for non-graph data) whose topology is defined by $\mathcal{\hat{A}}$ can be obtained from any arbitrary input $\mathcal{D}^{\text{in}}$ (i.e., if the input non-graph data is represented by a set of vectors, we concatenate them as a single matrix $\mathcal{D}^{\text{in}}$).

## 3 Step-by-Step Implementation on Graph <a name="Graph_implementation"></a>

In this section, we will start with a random initialized graph and walk you through the core idea of our Gratis with this example graph.

With an input graph and its task-specific ground truth label, the proposed Gratis helps us

* find an optimal task-specific topology
* generate the multi-dimensional edge feature

In this example, we will leave out the ground-truth label for now and only focus on the learning of **task-specific topology** and **multi-dimensional edge feature**. The usage of ground-truth labels will be introduced for each specific task in later sections.

### 3.1 Graph Definition

* We first create a simple graph with four nodes and five edges.
* We then assign random values for the node feature denoted as 'feat'.
* We also assign random values for the edge feature denoted as 'feat'.

At last, we obtain a graph
```
Graph(num_nodes=4, num_edges=5,
      ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})
```
Note that this randomly initialized graph shall be replaced by enumerating the data loader when training the whole Gratis.

In [1]:
#Define a simple Graph.
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
import dgl

from layers.gated_gcn_layer import GatedGCNLayer
from layers.mlp_readout_layer import MLPReadout

from layers.cross_attention_layer import CrossTransformerEncoder
from layers.transformer_layer import ResidualAttentionBlock

u, v = torch.tensor([0, 0, 0 , 1, 1]), torch.tensor([1, 2, 3, 2, 3])
g = dgl.graph((u, v))
print(g)
print(g.all_edges())
print(g.num_nodes())
print(g.num_edges())
g.ndata['feat'] = torch.rand(g.num_nodes(), 3)
g.edata['feat'] = torch.rand(g.num_edges(), 1)
print(g.ndata['feat'][:])
print(g.edata['feat'])
print(g.all_edges())
print(g)

  from .autonotebook import tqdm as notebook_tqdm


Graph(num_nodes=4, num_edges=5,
      ndata_schemes={}
      edata_schemes={})
(tensor([0, 0, 0, 1, 1]), tensor([1, 2, 3, 2, 3]))
4
5
tensor([[0.1186, 0.0780, 0.7348],
        [0.7600, 0.2731, 0.1711],
        [0.2024, 0.3863, 0.6379],
        [0.6366, 0.6394, 0.2253]])
tensor([[0.7877],
        [0.7362],
        [0.5946],
        [0.8739],
        [0.8137]])
(tensor([0, 0, 0, 1, 1]), tensor([1, 2, 3, 2, 3]))
Graph(num_nodes=4, num_edges=5,
      ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})


>**Given the input data $\mathcal{D}^{\text{in}}$, we first categorize it as either a pre-defined graph $\mathcal{D}^{\text{in}} = \mathcal{G}^{\text{in}}(\mathcal{V}^{\text{in}}, \mathcal{E}^{\text{in}})$ or a non-graph data $\mathcal{D}^{\text{in}}$ (e.g., image, text, etc.), and then obtain the vertex features $\mathcal{V}$, basic topology $\mathcal{A}$ and basic edge features $\mathcal{E}$ to build its basic graph representation $\mathcal{G}^{\text{B}}(\mathcal{V}, \mathcal{E})$**



As we described in the paper, for a pre-defined single-value edge graph $\mathcal{G}^{\text{in}}(\mathcal{V}^{\text{in}}, \mathcal{E}^{\text{in}})$ that has an adjacency matrix $\mathcal{A}^{\text{in}}$, we directly employ its original vertices $\mathcal{V}^{\text{in}}$, topology $\mathcal{A}^{\text{in}}$ and edge features $\mathcal{E}^{\text{in}}$ to define vertex features, basic topology, and basic edge features for $\mathcal{G}^{\text{B}}$. This can be formulated as:
    \begin{equation}
        \mathcal{V} = \mathcal{V}^{\text{in}}, \quad
        \mathcal{A} =  \mathcal{A}^{\text{in}}, \quad
        \mathcal{E} = \mathcal{E}^{\text{in}} 
    \end{equation}
    
In other words, **this simple graph ```g ``` will be the direct input of Gratis.**

In [2]:
batch_graphs = g
batch_x = batch_graphs.ndata['feat']
batch_e = batch_graphs.edata['feat']

### 3.2 Task-Specific Topology Prediction (TTP)

The key to predicting a Task-Specific Graph is to learn a **task-specific adjacent matrix**, denoted as $\mathcal{\hat{A}}_{i,j}$ in the paper:
\begin{equation}
\mathcal{\hat{A}}_{i,j} = 
\begin{cases}
1 & \mathcal{\hat{A}}^{\text{prob}}_{i,j} \geqslant \theta \quad  \text{or} \quad  \mathcal{A}_{i,j} = 1 \\
0 &  \text{Otherwise}
\end{cases}
\label{eq:ttp_graph}
\end{equation}    
where $\theta \in [0,1]$ is a threshold, $\mathcal{\hat{A}}^{\text{prob}}$ is obtained by:
    \begin{equation}
      \mathcal{\hat{A}}^{\text{prob}} = (\text{Softmax}^{\text{Row}}(h(X)) \otimes \text{Softmax}^{\text{Column}}(h(X))
    \end{equation}    
    where each component $\mathcal{\hat{A}}^{\text{prob}}_{i,j}$ in the $\mathcal{\hat{A}}^{\text{prob}}$ ranges from 0 to 1 which denotes the presence probability of an edge $\mathbf{e}_{i,j}$.
    
It is clear that denoted as the $\mathcal{\hat{A}}^{\text{prob}}_{i,j}$ is obtained by taking the global context $X$ into consideration. In other words, to compute the $\mathcal{\hat{A}}^{\text{prob}}$, we first need to compute the global contextual representation $X$.

**In summary, the procedure followed includes**:

1. compute the global contextual representation $X$
2. define a trainable linear function $h$ (a FC layer) to project $X$ to a matrix $X^h = h(X)$ that has the same size as the target adjacency matrix $\mathcal{\hat{A}}$ ($N \times N$ dimensions)
3. conduct Softmax operations on $X^h$ along its row and column vectors, i.e., $(\text{Softmax}^{\text{Row}}(h(X))$ and $(\text{Softmax}^{\text{Column}}(h(X))$
4. combine two generated matrices via element-wise product $\otimes$
    
    
Note that during the TTP, the vertex features $\mathcal{\hat{V}} = \mathcal{V}$ is defined for graph data (i.e., the $\mathcal{\hat{V}}$ is defined as the same to the basic vertex features $\mathcal{V}$).

The whole TTP code is given below (with its line by line comments given in the next section):

In [3]:
# The TTP code
class TTP(nn.Module):
    def __init__(self, in_dim, hidden_dim, edge_thresh):
        super().__init__()
        self.proj_g1 = nn.Linear(in_dim,hidden_dim**2)
        self.bn_node_lr_g1 = nn.BatchNorm1d(hidden_dim**2)
        self.proj_g2 = nn.Linear(in_dim,hidden_dim)
        self.bn_node_lr_g2 = nn.BatchNorm1d(hidden_dim)
        self.hidden_dim = hidden_dim #lr_g
        self.proj_g = nn.Linear(hidden_dim, 1)
        self.edge_thresh = edge_thresh                       # theta, the threshold
    def forward(self, g, h, e):
        lr_gs = []
        gs = dgl.unbatch(g)
        for g in gs:
            N = g.number_of_nodes()
            h_single = g.ndata['feat'].to(h.device)
            # h_single is X_GCN 
            print('X_GCN', h_single.shape)            # [N,K]
            
            h_proj1 = F.dropout(F.relu(self.bn_node_lr_g1(self.proj_g1(h_single))), 0.1, training=self.training).view(-1,self.hidden_dim)
            
            
            # h_proj1 is M_1 
            print('M_1', h_proj1.shape)            # [N,D^2]   is viwed in [N * D, D], where D is hidden_dim
            
            
            h_proj2 = F.dropout(F.relu(self.bn_node_lr_g2(self.proj_g2(h_single))), 0.1, training=self.training).permute(1,0)
            
            
            # h_proj2 is M_2 
            print('M_2', h_proj2.shape)            # [N,D]
            
            mm = torch.mm(h_proj1,h_proj2)
            mm = mm.view(N,self.hidden_dim,-1).permute(0,2,1)     # mm is the global contextual representation
            
            print('X', mm.shape)                   #[N, N, D]
              
            mm = self.proj_g(mm).squeeze(-1)       #𝑋^ℎ=ℎ(𝑋) [N, N, D] -> [N, N]
            
            matrix = F.softmax(mm, dim=0) * F.softmax(mm, dim=1)
            
            diag_mm = torch.diag(mm)
            
            diag_mm = torch.diag_embed(diag_mm)
            
            mm -= diag_mm                          # substracting the diag elements
            
            
            matrix = F.softmax(mm, dim=0) * F.softmax(mm, dim=1)
            
            lr_connetion = torch.where(matrix>self.edge_thresh)
            print('new connections:  ', lr_connetion[0], lr_connetion[1])
            g.add_edges(lr_connetion[0], lr_connetion[1])
            lr_gs.append(g)    
        g = dgl.batch(lr_gs).to(h.device)

        return g

#### 3.2.1 Global Contextual Representation $X$

We propose a GCN-CNN network as the backbone to project the input graph data $\mathcal{D}^{\text{in}}$ to an $X \in \mathbb{R}^{N \times N \times D}$, where $N$ denotes the number of vertices, i.e. the backbone projects the input graph to a $N \times N \times D$ dimensional latent space, summarising its global contextual information.



1. The GCN part first projects the input graph $\mathcal{D}^{\text{in}}$ to a matrix $X^{\text{GCN}}$ with the size of $N \times K$, where $K$ is the original dimensionality of the vertex for the input graph $\mathcal{D}^{\text{in}}$, which is defined as \begin{equation}
    \begin{split}
        X^{\text{GCN}} &= \text{GCN}(\mathcal{D}^{\text{in}}) \\
        \mathcal{D}^{\text{in}} &= \mathcal{G}^{\text{in}}(\mathcal{V}^{\text{in}}, \mathcal{E}^{\text{in}}) \\
    \end{split} \end{equation} where $\mathcal{V}^{\text{in}} \subseteq \{\mathbf{v}_i^{\text{in}} \in \mathbb{R}^{1 \times K} \mid i = 1,2 \cdots N  \}$, and $\mathcal{E}^{\text{in}} \subseteq \{ \mathbf{e}_{i,j}^{\text{in}} = \mathbf{e}(\mathbf{v}_i, \mathbf{v}_j) \mid \mathbf{v}_i^{\text{in}}, \mathbf{v}_j^{\text{in}} \in \mathcal{V}^{\text{in}} \}$ are the vertices set and edges set of the original input graph.

2. the CNN part produces the global contextual representation $X \in \mathbb{R}^{N\times N\times D}$ from the $X^{\text{GCN}}$: \begin{equation} \begin{split}
    &X = \mathbf{\bar{M}}_1 \mathbf{M}_2^\mathbf{T} \\
    &\mathbf{M}_1 = X^{\text{GCN}}W_1 \\
    &\mathbf{M}_2 = X^{\text{GCN}}W_2 \\ \end{split} \label{eq:topology_global} \end{equation} where $W_1 \in \mathbb{R}^{K\times D^2}$ and $W_2 \in \mathbb{R}^{K\times D}$ are learnable weight matrices. During this process, the the $W_1$ projects $X^{\text{GCN}}$ to a high-dimensional matrix $\mathbf{M}_1  \in \mathbb{R}^{N\times D^2}$ (i.e., $\mathbf{M}_1$ has $N$ rows and $D^2$ columns), while $W_2$ projecting each row vector of $X^{\text{GCN}}$ from $K$ dimension to $D$ dimension, resulting in a matrix $\mathbf{M}_2 \in \mathbb{R}^{N\times D}$. Then, the $\mathbf{M}_1$ is reshaped as $\mathbf{\bar{M}}_1 \in \mathbb{R}^{ND\times D}$, and we conduct matrix multiplication between $\mathbf{\bar{M}}_1$ and $\mathbf{M}_2$ and reshape the obtained matrix to achieve a global contextual representation $X \in \mathbb{R}^{N\times N\times D}$ that summarizes entire contextual information of the $\mathcal{D}^{\text{in}}$.

**In our implementation**:

1. For simplicity, we directly get the nodes and their features into the matrix h_single ```h_single = g.ndata['feat']``` where h_single is the so-called $X^{\text{GCN}}$ which has [N,K]=[4,3] dimensions in our case.

2. The $W_1 \in \mathbb{R}^{K\times D^2}$  is defined as:
``` python
    self.proj_g1 = nn.Linear(in_dim,hidden_dim**2)
    self.bn_node_lr_g1 = nn.BatchNorm1d(hidden_dim**2) 
    # hidden_dim is D
```
The the $\mathbf{M}_1$ is reshaped as $\mathbf{\bar{M}}_1 \in \mathbb{R}^{ND\times D}$ with the code:
```python
    h_proj1 = F.dropout(F.relu(self.bn_node_lr_g1(self.proj_g1(h_single))), 0.1, training=self.training).view(-1,self.hidden_dim)

```
3. The $W_2 \in \mathbb{R}^{K\times D}$ is defined as:
```python
    self.proj_g2 = nn.Linear(in_dim,hidden_dim)
    self.bn_node_lr_g2 = nn.BatchNorm1d(hidden_dim)

```
while the $\mathbf{M}_2 \in \mathbb{R}^{N\times D}$ is computed by
```python
    h_proj2 = F.dropout(F.relu(self.bn_node_lr_g2(self.proj_g2(h_single))), 0.1, training=self.training).permute(1,0)
    # The final permute() is for matrix multiplication
```
4. At last, the $X = \mathbf{\bar{M}}_1 \mathbf{M}_2^\mathbf{T}$ is compuated by ``` mm = torch.mm(h_proj1,h_proj2)```, and is reshaped to $X \in \mathbb{R}^{N\times N\times D}$ by ```mm = mm.view(N,self.hidden_dim,-1).permute(0,2,1)```.

#### 3.2.2 Compute $\mathcal{\hat{A}}^{\text{prob}}$ and Produce the final $\mathcal{\hat{A}}$

**Recap of our procedure**:
1. compuate the global contextual representation $X$
2. define a trainable linear function $h$ (a FC layer) to project $X$ to a matrix $X^h = h(X)$ that has the same size as the target adjacency matrix $\mathcal{\hat{A}}$ ($N \times N$ dimensions)
3. conduct Softmax operations on $X^h$ along its row and column vectors, i.e., $(\text{Softmax}^{\text{Row}}(h(X))$ and $(\text{Softmax}^{\text{Column}}(h(X))$
4. combine two generated matrices via element-wise product $\otimes$

**So far, we have finished the 1st step, i.e., compuate $X$**

The code below defines the $h$ (a FC layer) and projects $X$ to a matrix $X^h = h(X)$

```python
    self.proj_g = nn.Linear(hidden_dim, 1) # h is a FC layer
    mm = self.proj_g(mm).squeeze(-1) ##h(X)
```

By substrcating the diagnol elements, the softmax based $\mathcal{\hat{A}}^{\text{prob}}$ is compuated by:

``` python
    diag_mm = torch.diag(mm)
    diag_mm = torch.diag_embed(diag_mm)
    mm -= diag_mm
    matrix = F.softmax(mm, dim=0) * F.softmax(mm, dim=1)
```

At last, with the ```edge_thresh``` we are able to get a new adjacent matrix $\mathcal{\hat{A}}_{i,j}$

``` python
    lr_connetion = torch.where(matrix>self.edge_thresh)
    g.add_edges(lr_connetion[0], lr_connetion[1])
```

**After TTP, the learned ```lr_g```  may be sligtly different from the original ```batch_graphs```.**

In [4]:
ttp = TTP(in_dim = 3, hidden_dim = 13, edge_thresh=0.2)

print(batch_graphs)
lr_g = ttp(batch_graphs, batch_x, batch_e)
print(lr_g)
print(lr_g.num_edges())

Graph(num_nodes=4, num_edges=5,
      ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})
X_GCN torch.Size([4, 3])
M_1 torch.Size([52, 13])
M_2 torch.Size([13, 4])
X torch.Size([4, 4, 13])
new connections:   tensor([1, 2, 3]) tensor([1, 0, 3])
Graph(num_nodes=4, num_edges=8,
      ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})
8


In [5]:
print(lr_g.all_edges())

(tensor([0, 0, 0, 1, 1, 1, 2, 3]), tensor([1, 2, 3, 2, 3, 1, 0, 3]))


In [6]:
print(g.all_edges())

(tensor([0, 0, 0, 1, 1]), tensor([1, 2, 3, 2, 3]))


### 3.3 Multi-Dimensional Edge Feature Generation (MEFG)

Once all vertex features $\mathcal{\hat{V}}$ and the task-specific topology $\mathcal{\hat{A}}$ are obtained, we propose a Multi-dimensional Edge Feature Generation (MEFG) module that further learns multiple task-specific relationship cues between vertices to describe each presented edge in a $1 \times K$ dimensional space, i.e., assigning each presented edge with a multi-dimensional feature $\mathbf{\hat{e}}_{i,j} \in \mathbb{R}^{1 \times K}$.


MEFG generates a task-specific multi-dimensional edge feature $\mathbf{\hat{e}}_{i,j}$ for each **directed** edge ($\mathcal{\hat{A}}_{i,j} = 1$) by considering: (i) the corresponding vertex features $\mathbf{v}_i$ and $\mathbf{v}_j$; (ii) the global contextual representation $X$; and (iii) the task-specific graph topology $\hat{\mathcal{A}}$ generated by TTP module. This process can be formulated as:
\begin{equation}
\begin{split}
    &\mathbf{\hat{e}}_{i,j} = \text{MEFG}(X, \hat{\mathcal{A}}, \mathbf{\hat{v}}_i, \mathbf{\hat{v}}_j) \\
    &\mathcal{\hat{E}} \subseteq \{ \mathbf{\hat{e}}_{i,j} = \mathbf{\hat{e}}(\mathbf{\hat{v}}_i, \mathbf{\hat{v}}_j)\mid \mathbf{\hat{v}}_i, \mathbf{\hat{v}}_j \in \mathcal{\hat{v}} \quad \text{and} \quad \mathcal{\hat{A}}_{i,j} = 1 \}    
\end{split}
\end{equation}




Specifically, the MEFG module is more complex than TTP due to th fact that it contains three major blocks: 

* A **vertex-context relationship modeling (VCR)** block first locates each vertex-related cue in the global contextual representation $X$, outputting contextaware vertex features. 
* A **vertex-vertex relationship modeling (VVR)** block further extracts context-aware vertex-vertex relationship features from the produced vertex-context representation (i.e., the outputs of the VCR) to generate a context-aware multi-dimensional edge feature for each edge. 
* Finally, a **topology masking (TM)** block is introduced to to incorporate taskspecific topology information (learned by TTP) to all obtained context-aware vertex-vertex relationship features, resulting in a set of task-specific, topology and contextaware multi-dimensional edge features.

The code of MEFG is given below:

In [7]:
class CrossTransformer(nn.Module):

    def __init__(self, d_model, nhead=1, layer_nums=1, attention_type='linear'):
        super().__init__()
        
        encoder_layer = CrossTransformerEncoder(d_model, nhead, attention_type)
        self.VCR_layers = nn.ModuleList([encoder_layer for _ in range(layer_nums)])
        self.VVR_layers = nn.ModuleList([encoder_layer for _ in range(layer_nums)])
        
        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, qfea, kfea, mask0=None, mask1=None):
        """
        Args:
            qfea (torch.Tensor): [B, N, D]
            kfea (torch.Tensor): [B, D]
            mask0 (torch.Tensor): [B, N] (optional)
            mask1 (torch.Tensor): [B, N] (optional)
        """
        #assert self.d_model == qfea.size(2), "the feature number of src and transformer must be equal"
        
        B,N,D = qfea.shape
        kfea = kfea.unsqueeze(1).repeat(1, N, 1) #[B,N,D]
        
        mask1 = torch.ones([B,N]).to(qfea.device)
        for layer in self.VCR_layers:
            qfea = layer(qfea, kfea, mask0, mask1) #[B,N,D]
            #kfea = layer(kfea, qfea, mask1, mask0)
        
        qfea_end = qfea.repeat(1,1,N).view(B,-1,D) #[B,N*N,D]
        qfea_start = qfea.repeat(1,N,1).view(B,-1,D) #[B,N*N,D]
        #mask2 = mask0.repeat([1,N])
        for layer in self.VVR_layers:
            #qfea_start = layer(qfea_start, qfea_end, mask2, mask2)
            qfea_start = layer(qfea_start, qfea_end)#[B,N*N,D]

        return qfea_start.view([B,N,N,D]) 

class MEFG(nn.Module):
    
    def __init__(self, in_dim,hidden_dim, max_node_num, global_layer_num = 2, dropout = 0.1):
        super().__init__()
        self.edge_proj = nn.Conv1d(in_channels=2,out_channels=1,kernel_size=3,padding=1)
        self.edge_proj2 = nn.Linear(in_dim,hidden_dim) #baseline4
        self.edge_proj3 = nn.Linear(in_dim,hidden_dim)
        self.edge_proj4 = nn.Linear(hidden_dim,hidden_dim)
        self.hidden_dim = hidden_dim #baseline4
        self.bn_node_lr_e = nn.BatchNorm1d(hidden_dim)
        
        self.max_node_num = max_node_num
        
        self.global_layers = nn.ModuleList([ GatedGCNLayer(hidden_dim, hidden_dim, dropout, True, True) for _ in range(global_layer_num -1) ]) 
        self.global_layers.append(GatedGCNLayer(hidden_dim, hidden_dim, dropout, True, True))

        #self.global_layers = nn.ModuleList([ ResidualAttentionBlock( d_model = hidden_dim, n_head = 1)
        #                                    for _ in range(global_layer_num) ]) 
        
        self.CrossT = CrossTransformer(hidden_dim, nhead=1, layer_nums=1, attention_type='linear')
        
    def forward(self, g, h, e):
        
        g.apply_edges(lambda edges: {'src' : edges.src['feat']})
        src = g.edata['src'].unsqueeze(1) #[M,1,D]
        g.apply_edges(lambda edges: {'dst' : edges.dst['feat']})
        dst = g.edata['dst'].unsqueeze(1) #[M,1,D]
        edge = torch.cat((src,dst),1).to(h.device) #[M,2,D]
        lr_e_local = self.edge_proj(edge).squeeze(1)#[M,D]
        lr_e_local = self.edge_proj2(lr_e_local)

        hs = []
        gs = dgl.unbatch(g)
        mask0 = torch.zeros([len(gs),self.max_node_num]).to(h.device)
        for i,g0 in enumerate(gs):
            Ng = g0.number_of_nodes()
            padding = nn.ConstantPad1d((0,self.max_node_num - Ng),0)
            pad_h = padding(g0.ndata['feat'].T).T #[Nmax, D]
            hs.append(pad_h.unsqueeze(0))
            mask0[i,:Ng] = 1
        hs = torch.cat(hs,0).to(h.device) #[B,Nmax,Din]
        hs = self.edge_proj3(hs) #[B,Nmax,hidden_num]
        
        if e is None:
            e = torch.ones([g.number_of_edges() ,h.shape[-1]]).to(h.device)
        # Gated-GCN for extract global feature
        hs2g = h
        for conv in self.global_layers:
            hs2g, _ = conv(g, hs2g, e)
        g.ndata['hs2g'] = hs2g
        global_g = dgl.mean_nodes(g, 'hs2g') #[B,hidden_num]
        
        '''
        # Transformer for extract global feature
        mask_t = mask0.unsqueeze(1)*mask0.unsqueeze(2)
        mask_t = (mask_t==0)
        #mask_t = None
        
        hs2g = hs.permute((1,0,2))
        for conv in self.global_layers:
            hs2g = conv(hs2g, mask_t)
        global_g = hs2g.permute((1,0,2)).mean(1) #[B,D]
        '''
        # hs ([B, MaxnumNode, Hidden_Num])
        # global_g ([B, Hidden_Num])
        edge = self.CrossT(hs, global_g, mask0) #[B,N,N,D]
        
        index_edge = []
        for i,g0 in enumerate(gs):
            index_edge.append(edge[i, g0.all_edges()[0],g0.all_edges()[1],:])
        index_edge = torch.cat(index_edge,0)
        
        lr_e_global = self.edge_proj4(index_edge)
        
        
        if e is not None:
            e = e + lr_e_local + lr_e_global 
        else:
            e = lr_e_local + lr_e_global 
#        lr_e = lr_e_local + lr_e_global 
    
        # bn=>relu=>dropout
        e = self.bn_node_lr_e(e)
        e = F.relu(e)
        e = F.dropout(e, 0.1, training=self.training)
        
        return e

Before going to the MEFG module,
* **We first need to project the current node feature into a higher dimension**. Our motivation is straightforward: the current dimension of the node feature is $K$=3, since MEFG assigns each presented edge with a $K$ dimensional feature. Such a small dimension might not be able to capture the abundant links between nodes, we therefore first enlarge the dimension from $K$ to **hidden_dim**, which is a hyper-parameter to be tuned.
* **We directly use the task-specific graph topology generated by the TTP module.** 
* **We create an empty tensor to store the learned edge features.**


In [8]:
in_dim =3
hidden_dim=13
embedding_h = nn.Linear(in_dim, hidden_dim)
h = g.ndata['feat']
print(h.shape) # 4 x 3
h = embedding_h(h)
print(h.shape) # 4 * hidden_dim = 13

torch.Size([4, 3])
torch.Size([4, 13])


In [9]:
g_afterTTP = copy.deepcopy(lr_g)
print(g_afterTTP.all_edges())

(tensor([0, 0, 0, 1, 1, 1, 2, 3]), tensor([1, 2, 3, 2, 3, 1, 0, 3]))


In [10]:
e = None

In [11]:
g_afterTTP.edata['feat']

tensor([[0.7877],
        [0.7362],
        [0.5946],
        [0.8739],
        [0.8137],
        [0.0000],
        [0.0000],
        [0.0000]])

#### 3.3.1 Local Features Extractor

We first compute a direct edge feature by taking its source and destination nodes into consideration. We name this feature ```lr_e_local``` which is solely based on the two nodes. The code is given below:


```python
    g.apply_edges(lambda edges: {'src' : edges.src['feat']})
    src = g.edata['src'].unsqueeze(1) #[M,1,D], here D is original K = 3
    g.apply_edges(lambda edges: {'dst' : edges.dst['feat']})
    dst = g.edata['dst'].unsqueeze(1) #[M,1,D]
    edge = torch.cat((src,dst),1).to(h.device) #[M,2,D]
    lr_e_local = self.edge_proj(edge).squeeze(1)#[M,D]
    lr_e_local = self.edge_proj2(lr_e_local)
```

On the other hand, the VCR and VVR blocks are used to predict the global edge feature which also considers the global contextual information.

#### 3.3.2  Vertex-Context Relationship Modelling (VCR) and Vertex-Vertex Relationship Modelling (VVR)

The VCR block takes vertex features  $\mathbf{\hat{v}}_i$ and $\mathbf{\hat{v}}_j$ and the global contextual representation $X$ as input. It first conducts cross attention between $\mathbf{\hat{v}}_i$ and $X$ as well as $\mathbf{\hat{v}}_j$ and $X$. Here, the vertex features $\mathbf{\hat{v}}_i$ and $\mathbf{\hat{v}}_j$ are independently used as queries to locate vertex-context relationship features $\mathcal{F}_{i,x}$ and $\mathcal{F}_{j,x}$ in $X$ (i.e., $X$ is treated as the key and value for attention operations). Mathematically speaking, this process can be represented as:
\begin{equation}
\begin{split}
    \mathcal{F}_{i,x} = \text{VCR}(\mathbf{\hat{v}}_i, X) \\
    \mathcal{F}_{j,x} = \text{VCR}(\mathbf{\hat{v}}_j, X)    
\end{split}
\end{equation}
with the cross attention operation in VCR defined as:
\begin{equation}
     \text{VCR}(A, B) = \text{softmax}(\frac{A W_q (B W_k)^T}{\sqrt{d_k} }) B W_v
\label{eq:VCR}
\end{equation}
where $W_q$, $W_k$ and $W_v$ are learnable weight vectors or matrices (depending on the shape of the input data) for the query, key and value encoding, respectively, and $d_k$ is a scaling factor set to the same as the number of the $B$'s channels. Subsequently, the produced $\mathcal{F}_{i,x}$ and $\mathcal{F}_{j,x}$ contain the vertex $\mathbf{\hat{v}}_i$-related and vertex $\mathbf{\hat{v}}_j$-related task-specific cues extracted from the global contextual representation $X$.


Based on the $\mathcal{F}_{i,x}$ and $\mathcal{F}_{j,x}$, the VVR block further extracts task-specific context cues that relate to both vertices. VVR is also a cross-attention block that has the same form as VCR. In particular, it individually takes $\mathcal{F}_{i,x}$ as the query and $\mathcal{F}_{j,x}$ as the key and value, as well as $\mathcal{F}_{j,x}$ as the query and $\mathcal{F}_{i,x}$ as the key and value, producing two context-aware vertex-vertex relationship features $\mathcal{F}_{i,x,j}$ and $\mathcal{F}_{j,x,i}$, respectively. Here, the $\mathcal{F}_{i,x,j}$ encodes $\mathcal{F}_{i,x}$-related cues in the $\mathcal{F}_{j,x}$, while the $\mathcal{F}_{j,x,i}$ encoding $\mathcal{F}_{j,x}$-related cues in the $\mathcal{F}_{i,x}$. In other words, the context-aware vertex-vertex relationship features $\mathcal{F}_{i,x,j}$ and $\mathcal{F}_{j,x,i}$ contain cues that not only come from the whole context, but also relate to both vertex $\mathbf{\hat{v}}_i$ and $\mathbf{\hat{v}}_j$. We formulated this process as:
\begin{equation}
\begin{split}
\mathcal{F}_{i,x,j} = \text{VVR}(\mathcal{F}_{i,x}, \mathcal{F}_{j,x}) \\
\mathcal{F}_{j,x,i} = \text{VVR}(\mathcal{F}_{j,x}, \mathcal{F}_{i,x})    
\end{split}
\end{equation}
Depending on the data shape, we finally employ either a pooling layer or a fully-connected layer, to flatten $\mathcal{F}_{i,x,j}$ and $\mathcal{F}_{j,x,i}$ to a pair of multi-dimensional edge feature vectors $\mathbf{\bar{e}}_{i,j}$ and $\mathbf{\bar{e}}_{j,i}$ (we denote this operation as $FL$):
\begin{equation}
\begin{split}
\mathbf{\bar{e}}_{i,j} = FL(\mathcal{F}_{i,x,j}) \\
\mathbf{\bar{e}}_{j,i} = FL(\mathcal{F}_{j,x,i})
\end{split}
\end{equation}
As a result, each of the produced multi-dimensional edge feature encodes task-specific cues from the whole contextual cues of the input data $\mathcal{D}^{\text{in}}$, which relate to both $\mathbf{v}_i$ and $\mathbf{v}_j$.

In our code implementation,
* **the vertex features is denoted as ```hs``` below:**

```python 
    gs = dgl.unbatch(g)
    mask0 = torch.zeros([len(gs),self.max_node_num]).to(h.device)
    for i,g0 in enumerate(gs):
        Ng = g0.number_of_nodes()
        padding = nn.ConstantPad1d((0,self.max_node_num - Ng),0)
        pad_h = padding(g0.ndata['feat'].T).T 
        hs.append(pad_h.unsqueeze(0))
        mask0[i,:Ng] = 1
    #Set ones to first[Ng,D] elements, and pad zeros to the rest [Nmax-Ng, D]
    hs = torch.cat(hs,0).to(h.device) #List to Tensor Tuple[B,Nmax,Din]
    hs = self.edge_proj3(hs) #[B,Nmax,hidden]
```

* **since the topology of the input graph has been changed in TTP, we need to recompute the global contextual representation $X$：**

```python
    if e is None:
        e = torch.ones([g.number_of_edges() ,h.shape[-1]]).to(h.device)
    # Gated-GCN for extract global feature
    hs2g = h
    for conv in self.global_layers:
        hs2g, _ = conv(g, hs2g, e)
    g.ndata['hs2g'] = hs2g
    global_g = dgl.mean_nodes(g, 'hs2g') #[B,hidden_num]
```

We note that here the computation for $X$ is different from the computation in the TTP module: 1) first, we learn the node features denoted as ```hs2g``` from the input graph, which has a resolution of $N \times K$, and 2) the global contextual representation $X$ is then directly computed by the average pooling to $1 \times K$.

* **VCR and VVR are based on the same CrossTransformerEncoder function**

```python
    encoder_layer = CrossTransformerEncoder(d_model, nhead, attention_type)
    self.VCR_layers = nn.ModuleList([encoder_layer for _ in range(layer_nums)])
    self.VVR_layers = nn.ModuleList([encoder_layer for _ in range(layer_nums)])

    B,N,D = qfea.shape
    kfea = kfea.unsqueeze(1).repeat(1, N, 1) #[B,N,D]
    
    mask1 = torch.ones([B,N]).to(qfea.device)
    for layer in self.VCR_layers:
        qfea = layer(qfea, kfea, mask0, mask1) #[B,N,D]
    
    qfea_end = qfea.repeat(1,1,N).view(B,-1,D) #[B,N*N,D]
    qfea_start = qfea.repeat(1,N,1).view(B,-1,D) #[B,N*N,D]
    for layer in self.VVR_layers:
        qfea_start = layer(qfea_start, qfea_end)#[B,N*N,D]
```
The ```qfea_start``` is the cross-attention maps where each element (i,j) is a $K$ dimensional vectore denotes the links between the the ith and jth node. By enumerating all existing edges in the graph, we can get the new edge feature ```index_edge```, which is followd by a FC or Pooling layer (```self.edge_proj4()``` in the code) is introduced to generate global task-specific multi-dimensional edge features ```lr_e_global```.

```python
    edge = self.CrossT(hs, global_g, mask0) #[B,N,N,D]
    
    index_edge = []
    for i,g0 in enumerate(gs):
        index_edge.append(edge[i, g0.all_edges()[0],g0.all_edges()[1],:])
    index_edge = torch.cat(index_edge,0)
    
    lr_e_global = self.edge_proj4(index_edge)
```

#### 3.3.3 TM

The TM block finally encodes each task-specific and context-aware edge feature $\mathbf{\bar{e}}_{j,i}$ into a task-specific, topology and context-aware multi-dimensional edge feature $\mathbf{\hat{e}}_{j,i}$ by incorporating the learned task-specific topology cues as:}
\begin{equation}
\mathbf{\hat{e}}_{i,j} = \textbf{TM}(\mathbf{\bar{e}}_{j,i}, \mathcal{\hat{A}}^{\text{Dis}})
\end{equation}
where $\mathcal{\hat{A}}^{\text{Dis}} \in \mathbb{R}^{N \times N}$ is a matrix denotes the distance between every pair of vertices or every edge's occurrence probability (i.e., the $\mathcal{\hat{A}}^{\text{prob}}$ for graph data). Specifically, for graph data, the TM module encodes task-specific topology cues into the global contextual representation $X$ as $\hat{X} = \text{Conv}(X, \mathcal{\hat{A}})$ (Conv denotes a convolution layer), and then replaces the $X$ with $\hat{X}$ during the reasoning of the VCR and VVR. For non-graph data, it encodes a mask $\mathcal{M} \in \mathbb{R}^{|\mathcal{V}| \times |\mathcal{V}|}$ from $\mathcal{\hat{A}}^{\text{Dis}}$ via a linear projector, where each component $\mathcal{M}_{i,j}$ represents a task-specific weight of the edge feature $\mathbf{\bar{e}}_{j,i}$. Then, the final $\mathbf{\bar{e}}_{j,i}$ is obtained by applying $\mathcal{M}_{i,j}$ to weight $\mathbf{\bar{e}}_{j,i}$.

```python 
    index_edge = []
    for i,g0 in enumerate(gs):
        index_edge.append(edge[i, g0.all_edges()[0],g0.all_edges()[1],:])
    index_edge = torch.cat(index_edge,0)
    
    lr_e_global = self.edge_proj4(index_edge)
```


#### 3.3.4  Multi-Dimensional Edge Feature

The final Multi-Dimensional Edge Feature is combined with the local and global features:
```python
    if e is not None:
        e = e + lr_e_local + lr_e_global 
    else:
        e = lr_e_local + lr_e_global
    
    # bn=>relu=>dropout
    e = self.bn_node_lr_e(e)
    e = F.relu(e)
    e = F.dropout(e, 0.1, training=self.training)
```

### 3.4 Summary

To this end, we have learned a task-specific graph ```g=lr_g```, high dimensional node features ```h```, and multi-dimensional edge features ```e=lr_e```. We can now apply these components to different graph networks (such as GAT or GatedGCN) to perform different tasks.
```python
    self.layers = nn.ModuleList([ GatedGCNLayer(hidden_dim, hidden_dim, dropout, self.batch_norm, self.residual) for _ in range(n_layers-1) ]) 

    self.layers = nn.ModuleList([CustomGATLayer(hidden_dim * num_heads, hidden_dim, num_heads, dropout, self.batch_norm, self.residual) for _ in range(n_layers-1)])
    
    for conv in self.layers:
        h, e = conv(g, h, e)
```

## 4 Step-by-Step Implementation on Non-Graph <a  name="Nongraph_implementation"></a>
In this section, we will walk you through the core idea of our Gratis on the non-graph data. Specifically, we will start with a random image, we will leave out the ground-truth label for now and only focus on the learning of **task-specific topology** and **multi-dimensional edge feature**. The usage of ground-truth labels will be introduced for each specific task in later sections.

Recap:

With an input graph and its task-specific ground truth label, the proposed Gratis helps us

* find an optimal task-specific topology
* generate the multi-dimensional edge feature

But what happens with non-graph input?

![submodules_2](submodules_2.png)


### 4.1 Graph Definition (GD)

Graph Definition (GD) produces a basic graph representation $\mathcal{G}^{\text{B}}(\mathcal{V}, \mathcal{E})$ from the input data $\mathcal{D}^{\text{in}}$. The $\mathcal{G}^{\text{B}}(\mathcal{V}, \mathcal{E})$ is defined by a set of vertex features $\mathcal{V}$, a basic topology (adjacency matrix) $\mathcal{A}$, and a set of basic edge features $\mathcal{E} \subseteq \{ \mathbf{e}_{i,j} = \mathbf{e}(\mathbf{v}_i, \mathbf{v}_j) \mid \mathbf{v}_i, \mathbf{v}_j \in \mathcal{V} \quad \text{and} \quad \mathcal{A}_{i,j} = 1 \}$;

(i) For a non-graph sample represented by a set of vectors/channels (e.g., multi-channel time-series), we directly treat each vector/channel as a vertex, while the the global contextual representation matrix X is obtained by concatenating all vectors/channels. (ii) For a non-graph sample labelled with multiple key objects/components (e.g., multiple audio events labelled for an audio signal / multiple facial action units (AUs) labelled for a face image), we propose a Vertex Feature Extraction (VFE) block which learns each vertex to represent a specific object/component, while the global contextual representation X is directly learned from the input non-graph sample. In both case (i) and case (ii), the number of vertices equals to the number of vectors/channels or the labelled objects/components.


A Vertex Feature Extraction (VFE) module that is made up of $N$ vertex feature extractors, each of which consists of a fully connected layer (FC) and a global average pooling (GAP) layer. The $i_{th}$ vertex feature extractor takes the global contextual representation $X$ as the input and produces a set of latent feature maps which are then further processed by the GAP layer as a vector of $K$ dimensions. This vector represents the feature of the $i_{th}$ vertex $\mathbf{v}_i$ in the basic graph $\mathcal{G}^{\text{B}}$.

The global contextual representation $X$ is defined as:
\begin{equation}
  X = \text{Backbone}(\mathcal{D}^{\text{in}})
\end{equation}
In this paper, a CNN or a transformer is employed as the backbone to directly extract $X$ from non-graph data, where the $X$ is a set of latent feature maps whose sizes depend on the input data $\mathcal{D}^{\text{in}}$.

As a result, the vertices set $\mathcal{V}$ contains $N$ vectors (each has $K$ dimensions), representing $N$ vertices and their features. In addition, if the non-graph data is represented by a set of vectors/multi-channel data, we directly treat each vector as a vertex, and the matrix that concatenates all vectors as the global contextual representation $X$. After that, we manually define edge presence (the basic adjacency matrix $\mathcal{A}$) in the basic graph $\mathcal{G}^{\text{B}}$ according to a human interpretable rule $\mathcal{R}$ (e.g., Euclidean distance or correlation between vertex features) depending on the data $\mathcal{D}^{\text{in}}$ and the task, where the basic edge representation between each pair of connected vertices is defined as $1$. These can be formulated as
    \begin{equation}
    \begin{split}
        &\mathcal{V} = \text{VFE}(X) \subseteq \{\mathbf{v}_i \in \mathbb{R}^{1 \times K}  \mid i = 1, 2, \cdots, N \} \\
        &\mathcal{E} \subseteq \{ \mathbf{e}_{i,j} =  1 \mid \mathbf{v}_i, \mathbf{v}_j \in \mathcal{V} \quad \text{and} \quad \mathcal{A}_{i,j} = 1 \}
    \end{split}
    \label{eq:gd-vertex}
    \end{equation}
    which is subject to:
    \begin{equation}
    \begin{split}
        \mathcal{A}_{i,j} = 
        \begin{cases}
        1 & \{\mathbf{v}_i, \mathbf{v}_j \} \in \mathcal{R} \\
        0 &  \text{Otherwise}
        \end{cases}
    \end{split}
    \label{eq:raw_a_nongraph}
    \end{equation}

In summary, the procedure includes:

* Extract the global contextual representation $X$
* Extract the node feature with the VFE module
* Define the edge presence (the basic adjacency matrix $\mathcal{A}$) according to a human interpretable rule $\mathcal{R}$

Let's start, we first need an ```input_image```: 

In [12]:
import torch
input_image = torch.randn((1,1,224,224))

#### 4.1.1 Global Contextual Representation $X$

In our implementation, we employ a CNN or a transformer as the backbone to directly extract $X$ from non-graph data.

```python
if 'transformer' in backbone:
    if backbone == 'swin_transformer_tiny':
        self.backbone = swin_transformer_tiny()
    elif backbone == 'swin_transformer_small':
        self.backbone = swin_transformer_small()
    else:
        self.backbone = swin_transformer_base()
    self.in_channels = self.backbone.num_features
    self.out_channels = self.in_channels // 2
    self.backbone.head = None

elif 'resnet' in backbone:
    if backbone == 'resnet18':
        self.backbone = resnet18()
    elif backbone == 'resnet101':
        self.backbone = resnet101()
    else:
        self.backbone = resnet50()
    self.in_channels = self.backbone.fc.weight.shape[1]
    self.out_channels = self.in_channels // 4
    self.backbone.fc = None
else:
    raise Exception("Error: wrong backbone name: ", backbone)

```

Using ResNet-50 as an example, the resolution of our [1, 1, 224, 224] input image changes to [1, 64, 112, 112] →  [1, 64, 56, 56] →  [1, 256, 56, 56] → [1, 512, 28, 28] → [1, 1024, 14, 14] → [1, 2048, 7, 7].


In our implementation, the [1, 2048, 7, 7] patch is first reshaped to the [1, 49, 2048]  dimension and then converted to a highly compact feature representation with an FC layer (LinearBlock below).

```python
class LinearBlock(nn.Module):
    def __init__(self, in_features,out_features=None,drop=0.0):
        super().__init__()
        out_features = out_features or in_features
        self.fc = nn.Linear(in_features, out_features)
        self.bn = nn.BatchNorm1d(out_features)
        self.relu = nn.ReLU(inplace=True)
        self.drop = nn.Dropout(drop)
        self.fc.weight.data.normal_(0, math.sqrt(2. / out_features))
        self.bn.weight.data.fill_(1)
        self.bn.bias.data.zero_()

    def forward(self, x):
        x = self.drop(x)
        x = self.fc(x).permute(0, 2, 1)
        x = self.relu(self.bn(x)).permute(0, 2, 1)
        return x
```

Overall, the final resolution of the Global Contextual Representation $X$ is **[1, 49, 512]**.

#### 4.1.2 Vertex Feature Extraction (VFE)
So far, we have obtained the Global Contextual Representation $X$, we then need a VFE to create its corresponding Graph.

The proposed VFE module consists of a fully connected layer (FC) and a global average pooling (GAP) layer.


In the code below, the VFE module takes the global $X$ as input and output a list of vertices stored in ```f_v```, while ```num_classes```is the parameter that controls the number of vertices to construct the graph for this image.

```python
    # Define the VFE module, where self.num_classes is the number of nodes 
    for i in range(self.num_classes):
        # FC layer
        layer = LinearBlock(self.in_channels, self.in_channels)
        class_linear_layers += [layer]
    self.class_linears = nn.ModuleList(class_linear_layers)
    
    #Input:x = X is the global contextual rep
    f_u = []
    for i, layer in enumerate(self.class_linears):
        f_u.append(layer(x).unsqueeze(1))
    f_u = torch.cat(f_u, dim=1)
    # mean denotes for the GAP layer
    f_v = f_u.mean(dim=-2)
```

For example, if the ```self.num_classes==4, where self.num_classes is the number of nodes```, the input [1, 49, 512] sized $X$ will become [1, 4, 512] sized ```f_v```, where each node contains a 512 dimensional features.

Note that in the rest of this tutorial, we fixate 4 nodes for this example.

#### 4.1.3 Define the Edge Presence ($\mathcal{A}$) according to  $\mathcal{R}$



As a result, the vertices set $\mathcal{V}$ contains $N$ vectors (each has $K$ dimensions), representing $N$ vertices and their features.We hereby manually define edge presence (the basic adjacency matrix $\mathcal{A}$) in the basic graph $\mathcal{G}^{\text{B}}$ according to a human interpretable rule $\mathcal{R}$ (e.g., Euclidean distance or correlation between vertex features) depending on the data $\mathcal{D}^{\text{in}}$ and the task, where the basic edge representation between each pair of connected vertices is defined as $1$. These can be formulated as
    \begin{equation}
    \begin{split}
        &\mathcal{V} = \text{VFE}(X) \subseteq \{\mathbf{v}_i \in \mathbb{R}^{1 \times K}  \mid i = 1, 2, \cdots, N \} \\
        &\mathcal{E} \subseteq \{ \mathbf{e}_{i,j} =  1 \mid \mathbf{v}_i, \mathbf{v}_j \in \mathcal{V} \quad \text{and} \quad \mathcal{A}_{i,j} = 1 \}
    \end{split}
    \end{equation}
    which is subject to:
    \begin{equation}
    \begin{split}
        \mathcal{A}_{i,j} = 
        \begin{cases}
        1 & \{\mathbf{v}_i, \mathbf{v}_j \} \in \mathcal{R} \\
        0 &  \text{Otherwise}
        \end{cases}
    \end{split}
    \end{equation}
    

So far, we have obtained the nodes (```f_v```) for constructing a graph from the input image.
It's time for us to decide the rule $\mathcal{R}$ determining the edge presence, as in the subfigure which:

![Rule_Graph](Rule_Graph.png)

which we can use the Euclidean distance or other similarity / distance meseaure:

```python
number_vertices = 4
f_e = torch.zeros(f_v.shape[0], 4 * 4, 512).cpu().detach().numpy()
for m in range(f_v.shape[0]):
    for i in range(4):
        for j in range(4):
            a = f_v[m,i,:]
            b = f_v[m,j,:]
            f_e[m,i*j,:].fill(distance.euclidean(a.cpu().detach().numpy(), b.cpu().detach().numpy()))
```
    
To this end, we have obtained the nodes (```f_v```, [1, 4, 512]) and edges (```f_e```, [1, 16, 512]) for the input image. In other words, **we have obtained the basic graph represention $\mathcal{G}^{\text{B}}(\mathcal{V}, \mathcal{E})$ for this image**.

Note that **in expriments, we do not need to pre-define the ```f_e``` within the code. We can directly use graph edge modeling** for learning multi-dimensional edge features.

```python
f_e = self.edge_extractor(f_u, x)
f_e = f_e.mean(dim=-2)
```

### 4.2 Task-Specific Topology Prediction (TTP)

Different from graph data whose vertex features are pre-defined and fixed, vertex features of non-graph data are dependent on the learning process of the VFE. We propose to train the VFE as the $\hat{\text{VFE}}(X)$ which additionally encodes task-specific associations among vertices into vertex features, i.e., the learned vertex features encode not only task-specific object representations contained in the input data but also the association among them. In other words, the presence of each edge is decided by the corresponding pair of vertices using a specific rule (e.g., distances or similarity between vertices), where the edge feature of each presented edge is $1$. This process can be denoted as:
    \begin{equation}
    \begin{split}
        &\mathcal{\hat{V}} = \hat{\text{VFE}}(X) \subseteq \{\mathbf{\hat{v}}_i \in \mathbb{R}^{1 \times K}  \mid i = 1, 2, \cdots, N \} \\
        &\mathcal{E}^{\mathcal{\hat{V}}} \subseteq \{ \mathbf{e}^{\text{V}}_{i,j} =  1 \mid \mathbf{\hat{v}}_i, \mathbf{\hat{v}}_j \in \mathcal{\hat{V}} \quad \text{and} \quad \mathcal{\hat{A}}^{\text{V}}_{i,j} = 1 \}
    \end{split}
    \end{equation}
    which is conditioned on:
    \begin{equation}
    \begin{split}
        \mathcal{\hat{A}}^{\text{V}}_{i,j} = 
        \begin{cases}
        1 & \mathbf{\hat{v}}_j \in \mathcal{C}(\mathbf{v}_i) \\
        0 &  \text{Otherwise}
        \end{cases}
    \end{split}
    \label{eq:C-nearest}
    \end{equation}     
    where $\mathcal{C}(\mathbf{v}_i)$ denotes the $C$ nearest neighbour vertices of the vertex $\mathbf{v}_i$. In this paper, the vertex-decided adjacency matrix is obtained by connecting each vertex to its $C$ nearest neighbour vertices. To achieve the $\hat{\text{VFE}}$, we further attach a GCN predictor to produce a prediction from the obtained graph $\mathcal{G}^{\text{V}}(\mathcal{\hat{V}},\mathcal{E}^{V})$, and use the loss between the prediction and labels to supervise the training of the $\hat{\text{VFE}}$. This way, the $\hat{\text{VFE}}$ learns to directly produce both task-specific vertex features $\mathcal{\hat{V}}$ and a task-specific vertex-decided adjacency matrix $\mathcal{\hat{A}}^{\text{V}}$ for the input non-graph data, where the $\mathcal{\hat{A}}^{\text{V}}$ allows the graph $\mathcal{G}^{\text{V}}(\mathcal{\hat{V}},\mathcal{E}^{V})$ to have optimal connectivity/message passing paths among vertices.

```python

self.gnn = GNN(self.in_channels, self.num_classes)  ## teh attached GCN predictor

# f_v, f_e = self.gnn(f_v, torch.Tensor(f_e).cuda())
f_v, f_e = self.gnn(f_v, f_e)
    
```
 Meanwhile, for graph representations of non-graph data, $\mathcal{\hat{A}}$ can be formulated as: 
\begin{equation}
\mathcal{\hat{A}}_{i,j} = 
\begin{cases}
1 &   \mathcal{\hat{A}}^{\text{V}} = 1 \quad  \text{or} \quad  \mathcal{A}_{i,j} = 1 \\
0 &  \text{Otherwise}
\end{cases}
\label{eq:ttp_nongraph}
\end{equation}

In summary, during the generation of tasks-specific graph topology $\mathcal{\hat{A}}$, each edge's presence $\mathcal{\hat{A}}_{i,j}$ is decided by not only the corresponding vertices $\mathbf{v}_i$ and $\mathbf{v}_j$, but also the global contextual information contained in $X$. This further leads the $\mathcal{\hat{A}}_{i,j}$ to be globally optimal, i.e., the graph has a globally optimal and task-specific message passing paths. 


### 4.3 Multi-Dimensional Edge Feature Generation (MEFG)

```python
    # MEFL

class CrossAttn(nn.Module):
    """ cross attention Module"""
    def __init__(self, in_channels):
        super(CrossAttn, self).__init__()
        self.in_channels = in_channels
        self.linear_q = nn.Linear(in_channels, in_channels // 2)
        self.linear_k = nn.Linear(in_channels, in_channels // 2)
        self.linear_v = nn.Linear(in_channels, in_channels)
        self.scale = (self.in_channels // 2) ** -0.5
        self.attend = nn.Softmax(dim=-1)

        self.linear_k.weight.data.normal_(0, math.sqrt(2. / (in_channels // 2)))
        self.linear_q.weight.data.normal_(0, math.sqrt(2. / (in_channels // 2)))
        self.linear_v.weight.data.normal_(0, math.sqrt(2. / in_channels))

    def forward(self, y, x):
        query = self.linear_q(y)
        key = self.linear_k(x)
        value = self.linear_v(x)
        dots = torch.matmul(query, key.transpose(-2, -1)) * self.scale
        attn = self.attend(dots)
        out = torch.matmul(attn, value)
        return out


class MEFG(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(MEFG, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.VCR = CrossAttn(self.in_channels)
        self.VVR = CrossAttn(self.in_channels)
        self.edge_proj = nn.Linear(in_channels, in_channels)
        self.bn = nn.BatchNorm2d(self.num_classes * self.num_classes)

        self.edge_proj.weight.data.normal_(0, math.sqrt(2. / in_channels))
        self.bn.weight.data.fill_(1)
        self.bn.bias.data.zero_()

    def forward(self, class_feature, global_feature):
        B, N, D, C = class_feature.shape
        global_feature = global_feature.repeat(1, N, 1).view(B, N, D, C)
        feat = self.VCR(class_feature, global_feature)
        feat_end = feat.repeat(1, 1, N, 1).view(B, -1, D, C)
        feat_start = feat.repeat(1, N, 1, 1).view(B, -1, D, C)
        feat = self.VVR(feat_start, feat_end)
        edge = self.bn(self.edge_proj(feat))
        return edge

```
The VCR block takes vertex features  $\mathbf{\hat{v}}_i$ and $\mathbf{\hat{v}}_j$ and the global contextual representation $X$ as input. It first conducts cross attention between $\mathbf{\hat{v}}_i$ and $X$ as well as $\mathbf{\hat{v}}_j$ and $X$. Here, the vertex features $\mathbf{\hat{v}}_i$ and $\mathbf{\hat{v}}_j$ are independently used as queries to locate vertex-context relationship features $\mathcal{F}_{i,x}$ and $\mathcal{F}_{j,x}$ in $X$ (i.e., $X$ is treated as the key and value for attention operations). Mathematically speaking, this process can be represented as:
\begin{equation}
\begin{split}
    \mathcal{F}_{i,x} = \text{VCR}(\mathbf{\hat{v}}_i, X) \\
    \mathcal{F}_{j,x} = \text{VCR}(\mathbf{\hat{v}}_j, X)    
\end{split}
\end{equation}
with the cross attention operation in VCR defined as:
\begin{equation}
     \text{VCR}(A, B) = \text{softmax}(\frac{A W_q (B W_k)^T}{\sqrt{d_k} }) B W_v
\end{equation}
where $W_q$, $W_k$ and $W_v$ are learnable weight vectors or matrices (depending on the shape of the input data) for the query, key and value encoding, respectively, and $d_k$ is a scaling factor set to the same as the number of the $B$'s channels. Subsequently, the produced $\mathcal{F}_{i,x}$ and $\mathcal{F}_{j,x}$ contain the vertex $\mathbf{\hat{v}}_i$-related and vertex $\mathbf{\hat{v}}_j$-related task-specific cues extracted from the global contextual representation $X$.


Based on the $\mathcal{F}_{i,x}$ and $\mathcal{F}_{j,x}$, the VVR block further extracts task-specific context cues that relate to both vertices. VVR is also a cross-attention block that has the same form as VCR. In particular, it individually takes $\mathcal{F}_{i,x}$ as the query and $\mathcal{F}_{j,x}$ as the key and value, as well as $\mathcal{F}_{j,x}$ as the query and $\mathcal{F}_{i,x}$ as the key and value, producing two context-aware vertex-vertex relationship features $\mathcal{F}_{i,x,j}$ and $\mathcal{F}_{j,x,i}$, respectively. Here, the $\mathcal{F}_{i,x,j}$ encodes $\mathcal{F}_{i,x}$-related cues in the $\mathcal{F}_{j,x}$, while the $\mathcal{F}_{j,x,i}$ encoding $\mathcal{F}_{j,x}$-related cues in the $\mathcal{F}_{i,x}$. In other words, the context-aware vertex-vertex relationship features $\mathcal{F}_{i,x,j}$ and $\mathcal{F}_{j,x,i}$ contain cues that not only come from the whole context, but also relate to both vertex $\mathbf{\hat{v}}_i$ and $\mathbf{\hat{v}}_j$. We formulated this process as:
\begin{equation}
\begin{split}
\mathcal{F}_{i,x,j} = \text{VVR}(\mathcal{F}_{i,x}, \mathcal{F}_{j,x}) \\
\mathcal{F}_{j,x,i} = \text{VVR}(\mathcal{F}_{j,x}, \mathcal{F}_{i,x})    
\end{split}
\end{equation}
Depending on the data shape, we finally employ either a pooling layer or a fully-connected layer, to flatten $\mathcal{F}_{i,x,j}$ and $\mathcal{F}_{j,x,i}$ to a pair of multi-dimensional edge feature vectors $\mathbf{\bar{e}}_{i,j}$ and $\mathbf{\bar{e}}_{j,i}$ (we denote this operation as $L$):
\begin{equation}
\begin{split}
\mathbf{\bar{e}}_{i,j} = L(\mathcal{F}_{i,x,j}) \\
\mathbf{\bar{e}}_{j,i} = L(\mathcal{F}_{j,x,i})
\end{split}
\end{equation}
As a result, each of the produced multi-dimensional edge feature encodes task-specific cues from the whole contextual cues of the input data $\mathcal{D}^{\text{in}}$, which relate to both $\mathbf{v}_i$ and $\mathbf{v}_j$.

#### 4.3.1 Vertex-Context Relationship Modelling (VCR) and Vertex-Vertex Relationship Modelling (VVR)

The detailed implementation of VCR and VVR are listed in the code block below:


```python
def forward(self, class_feature, global_feature):
    B, N, D, C = class_feature.shape
    global_feature = global_feature.repeat(1, N, 1).view(B, N, D, C)
    feat = self.VCR(class_feature, global_feature)
    feat_end = feat.repeat(1, 1, N, 1).view(B, -1, D, C)
    feat_start = feat.repeat(1, N, 1, 1).view(B, -1, D, C)
    feat = self.VVR(feat_start, feat_end)
    edge = self.bn(self.edge_proj(feat))
    return edge
```

#### 4.3.2 TM

```python
    mask = self.mask(f_v).view(b,n*n,1)
    f_e = f_e * mask
    f_v, f_e = self.gnn(f_v, f_e)
```

### 4.4 Summary

Combining all these code blocks, we are able to produce the final graph with nodes and edges respectively stored in ```f_v, f_e```, which then can be used for various graph-related tasks:

```python
    self.sc = nn.Parameter(torch.FloatTensor(torch.zeros(self.num_classes, self.in_channels)))

    b, n, c = f_v.shape
    sc = self.sc
    sc = self.relu(sc)
    sc = F.normalize(sc, p=2, dim=-1)
    cl = F.normalize(f_v, p=2, dim=-1)
    cl = (cl * sc.view(1, n, c)).sum(dim=-1, keepdim=False)
    cl_edge = self.edge_fc(f_e)
```

```python
from model.MEFL import GRATIS
net = GRATIS(num_classes=4, backbone='resnet50')
```

## 5 Experiments<a name="Experiments"></a>

### 5.1 Graph Classification (Graph)


#### Data Preparation

Download the MNIST and CIFAR10 Super-pixel datasets and preprocess the data with respectively the given notebooks in this webpage https://github.com/graphdeeplearning/benchmarking-gnns/tree/master/data/superpixels: i.e., ```prepare_superpixels_CIFAR.ipynb and prepare_superpixels_MNIST.ipynb``` or directly use https://github.com/graphdeeplearning/benchmarking-gnns/blob/master/data/script_download_superpixels.sh instead.

The given notebook includes the training, validation, and testing data splitting and will obtain a new pkl file. In other words, two files (data/superpixels/CIFAR10.pkl and data/superpixels/MNIST.pkl) will be generated after this step.

#### Training Command

We have included the training command in respectively the ```MNIST.sh and CIFAR10.sh``` two bash files, where we present the detailed training command for each method, for example:
`python main_superpixels_graph_classification_best_model.py --dataset MNIST --gpu_id 0 --config 'configs/superpixels_graph_classification_GatedGCN_MNIST_100k.json' --batch_size 32 --dropout 0.1 --max_time 120 \`


#### Data Loading

In the **main_superpixels_graph_classification_best_model.py**, we set up the training environment, such as the optimizer, learning rate scheduler, and data loaders.

```python

optimizer = optim.Adam(model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',factor=params['lr_reduce_factor'],patience=params['lr_schedule_patience'],verbose=True)

train_loader = DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, drop_last=drop_last, collate_fn=dataset.collate)
val_loader = DataLoader(valset, batch_size=params['batch_size'], shuffle=False, drop_last=drop_last, collate_fn=dataset.collate)
test_loader = DataLoader(testset, batch_size=params['batch_size'], shuffle=False, drop_last=drop_last, collate_fn=dataset.collate)

```
#### Testing
The detailed training/validation/testing code is included in the **train_superpixels_graph_classification.py**, such as:
```python
from train.train_superpixels_graph_classification import train_epoch_sparse as train_epoch, evaluate_network_sparse as evaluate_network
epoch_train_loss, epoch_train_acc, optimizer = train_epoch(model, optimizer, device, train_loader, epoch,args)
epoch_val_loss, epoch_val_acc = evaluate_network(model, device, val_loader, epoch, args)
_, epoch_test_acc = evaluate_network(model, device, test_loader, epoch, args)
```

### 5.2 Node Classification  (Graph)
#### Data Preparation

Download the ```SBM_PATTERN and SBM_CLUSTER``` datasets using https://github.com/graphdeeplearning/benchmarking-gnns/blob/master/data/script_download_SBMs.sh

The bash file will automatically download the ```SBM_CLUSTER.pkl and SBM_PATTERN.pkl```

```
downloading SBM_CLUSTER.pkl...
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                        Dload  Upload   Total   Spent    Left  Speed
100 1205M  100 1205M    0     0   124M      0  0:00:09  0:00:09 --:--:--  123M

downloading SBM_PATTERN.pkl...
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                        Dload  Upload   Total   Spent    Left  Speed
100 1886M  100 1886M    0     0   147M      0  0:00:12  0:00:12 --:--:--  145M
```
#### Training Command

We have included the training command in respectively the ```CLUSTER.sh and PATTERN.sh``` two bash files, where we present the detailed training command for each method, for example:
```
python main_SBMs_node_classification_best_model.py --dataset SBM_CLUSTER --gpu_id 1 --config 'configs/SBMs_node_clustering_GAT_CLUSTER_500k.json' --batch_size 16 --out_dir ./output/backbone/CLUSTER/gat_2x/ --dropout 0.1 --max_time 60
```

```
python main_SBMs_node_classification_best_model.py --dataset SBM_PATTERN --gpu_id 0 --config 'configs/SBMs_node_clustering_GatedGCN_PATTERN_PE_500k.json' --batch_size 16 --max_time 60
```

### 5.3 Link Prediction (Graph)
#### Data Preparation

Download the ```TSP``` datasets using https://github.com/graphdeeplearning/benchmarking-gnns/blob/master/data/script_download_TSP.sh

```
downloading TSP.pkl...
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                        Dload  Upload   Total   Spent    Left  Speed
100 1783M  100 1783M    0     0  18.2M      0  0:01:37  0:01:37 --:--:-- 16.3M
```
#### Training Command

```
#Gated-GCN base
python main_TSP_edge_classification_best_model.py --dataset TSP \
--gpu_id 0 \
--config 'configs/TSP_edge_classification_GatedGCN_100k.json' --edge_feat True \
--batch_size 16 \
--max_time 60 \

#GAT base
python main_TSP_edge_classification_best_model.py --dataset TSP \
--gpu_id 1 \
--config 'configs/TSP_edge_classification_GAT_edgereprfeat.json' --edge_feat True \
--batch_size 16 \
--max_time 60 \
```

### 5.4 Graph Classification (Non-Graph)

### 5.5 Node Classification  (Non-Graph)

### 5.6 Link Prediction (Non-Graph)