In [1]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
dataset = dgl.data.CoraGraphDataset(verbose=False)
g = dataset[0]
num_class = dataset.num_classes
# get node feature
feat = g.ndata["feat"]
# get data split

train_mask = g.ndata["train_mask"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
# get labels
label = g.ndata["label"]


In [None]:
g.num

In [None]:
indices = torch.where(feat != 0)
selected_feat = feat[indices[0], indices[1]]
selected_feat

In [None]:
from dgl import LapPE

transform3 = LapPE(k=3, feat_name="eigvec", eigval_name="eigval", padding=False)
g = train_dataset[0][0]

g3 = transform3(g)
print(g3.ndata["eigval"])
print(g3.ndata["eigvec"])


In [None]:
path = torch.randint(1, 8, (4, 4, 2))
max_path_length = 5
path_length = path.shape[2]
p1d = ( max_path_length - path_length,0)
F.pad(path, p1d, "constant", 0)


In [6]:
train_dataset = dgl.data.ZINCDataset(mode="train")
valid_dataset = dgl.data.ZINCDataset(mode="valid")
test_dataset = dgl.data.ZINCDataset(mode="test")

In [None]:
train_dataset[0]

In [None]:
train_dataset[1]

In [None]:
train_dataset[10]

In [None]:
len(train_dataset)

In [27]:
indices = torch.randperm(len(train_dataset))


In [None]:

indices

In [None]:
shuffled_train_dataset = [train_dataset[index] for index in indices]
shuffled_train_dataset

In [None]:
import torch as th
import torch.nn as nn
from centralityencoding import CentralityEncoder
from spaceencoding import SpatialEncoder
from edgeencoding import EdgeEncoder
from encoder import Encoder

class Graphormer(nn.Module):
    """
    Graphormer model for graph representation learning.

    Args:
        regrees_output_dim (int): Regression output dimension.
        edge_dim (int): Edge dimension.
        num_atoms (int): Maximum number of atoms in batch graphs.
        max_in_degree (int): Maximum in-degree in batch graphs.
        max_out_degree (int): Maximum out-degree in batch graphs.
        num_spatial (int): Maximum distance in batch graphs between two nodes.
        multi_hop_max_dist (int): Maximum multi-hop distance in batch graphs.
        num_encoder_layers (int): Number of encoder layers.
        embedding_dim (int): Embedding dimension.
        ffn_embedding_dim (int): Feed-forward network embedding dimension.
        num_attention_heads (int): Number of attention heads.
        dropout (float): Dropout rate.
        pre_layernorm (bool): Whether to use pre-layer normalization.
        activation_fn (nn.Module): Activation function.
    """
    def __init__(
        self,
        regrees_output_dim=1,
        edge_dim=1,
        num_atoms=0,
        max_in_degree=0,
        max_out_degree=0,
        num_spatial=0,
        multi_hop_max_dist=0,
        num_encoder_layers=12,
        embedding_dim=80,
        ffn_embedding_dim=80,
        num_attention_heads=8,
        dropout=0.1,
        pre_layernorm=True,
        activation_fn=nn.GELU(),
    ):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.embedding_dim = embedding_dim
        self.num_heads = num_attention_heads
        self.atom_encoder = nn.Embedding(num_atoms + 1, embedding_dim, padding_idx=0)
        self.graph_token = nn.Embedding(1, embedding_dim)
        self.degree_encoder = CentralityEncoder(
            max_in_degree=max_in_degree,
            max_out_degree=max_out_degree,
            embedding_dim=embedding_dim,
        )
        self.path_encoder = EdgeEncoder(
            max_len=multi_hop_max_dist,
            feat_dim=edge_dim,
            num_heads=num_attention_heads,
        )
        self.spatial_encoder = SpatialEncoder(
            max_dist=num_spatial, num_heads=num_attention_heads
        )
        self.graph_token_virtual_distance = nn.Embedding(1, num_attention_heads)
        self.emb_layer_norm = nn.LayerNorm(self.embedding_dim)
        self.layers = nn.ModuleList([])
        self.layers.extend(
            [
                Encoder(
                    hidden_size=embedding_dim,
                    ffn_out_size=ffn_embedding_dim,
                    attention_dropout=dropout,
                    num_heads=num_attention_heads,
                )
                for _ in range(num_encoder_layers)
            ]
        )
        self.lm_head_transform_weight = nn.Linear(
            self.embedding_dim, self.embedding_dim
        )
        self.layer_norm = nn.LayerNorm(self.embedding_dim)
        self.activation_fn = activation_fn
        self.embed_out = nn.Linear(self.embedding_dim, regrees_output_dim, bias=False)
        self.lm_output_learned_bias = nn.Parameter(th.zeros(regrees_output_dim))

    def reset_output_layer_parameters(self):
        """
        Reset the parameters of the output layer.
        """
        self.lm_output_learned_bias = nn.Parameter(th.zeros(1))
        self.embed_out.reset_parameters()

    def forward(
        self,
        node_feat,
        in_degree,
        out_degree,
        path_data,
        dist,
        attn_mask=None,
    ):
        """
        Forward pass for the Graphormer model.

        Args:
            node_feat (Tensor): Node feature tensor.
            in_degree (Tensor): In-degree tensor.
            out_degree (Tensor): Out-degree tensor.
            path_data (Tensor): Path data tensor.
            dist (Tensor): Distance tensor.
            attn_mask (Tensor, optional): Attention mask tensor.

        Returns:
            Tensor: Graph representation tensor.
        """
        num_graphs, max_num_nodes, _ = node_feat.shape
        deg_emb = self.degree_encoder(in_degree, out_degree)
        node_feat = self.atom_encoder(node_feat.int()).sum(dim=-2)
        node_feat = node_feat + deg_emb
        graph_token_feat = self.graph_token.weight.unsqueeze(0).repeat(num_graphs, 1, 1)
        x = th.cat([graph_token_feat, node_feat], dim=1)
        attn_bias = th.zeros(
            num_graphs,
            max_num_nodes + 1,
            max_num_nodes + 1,
            self.num_heads,
        )
        path_encoding = self.path_encoder(dist, path_data)
        spatial_encoding = self.spatial_encoder(dist)
        attn_bias[:, 1:, 1:, :] = path_encoding + spatial_encoding
        t = self.graph_token_virtual_distance.weight.reshape(1, 1, self.num_heads)
        attn_bias[:, 1:, 0, :] = attn_bias[:, 1:, 0, :] + t
        attn_bias[:, 0, :, :] = attn_bias[:, 0, :, :] + t
        x = self.emb_layer_norm(x)
        for layer in self.layers:
            x = layer(
                x,
                att_mask=attn_mask,
                att_bias=attn_bias,
            )
        graph_rep = x[:, 0, :]
        graph_rep = self.layer_norm(
            self.activation_fn(self.lm_head_transform_weight(graph_rep))
        )
        graph_rep = self.embed_out(graph_rep) + self.lm_output_learned_bias

        return graph_rep


In [None]:
# filepath: /home/a373k/Desktop/feb/Graphormer/zincdata.py
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import dgl
import dgl.data

class ZincDataset(torch.utils.data.Dataset):
    """
    ZincDataset class for loading and processing the ZINC dataset.

    This class handles the loading of the ZINC dataset, creating train, validation,
    and test splits, and processing the graphs and labels.

    Attributes:
        train (list): List of training samples.
        val (list): List of validation samples.
        test (list): List of test samples.
        max_dist (int): Maximum shortest path distance in the dataset.
        max_in_degree (int): Maximum in-degree in the dataset.
        max_out_degree (int): Maximum out-degree in the dataset.
        max_num_nodes (int): Maximum number of nodes in the dataset.
    """
    def __init__(self):
        train_dataset = dgl.data.ZINCDataset(mode="train")[:256*14]
        valid_dataset = dgl.data.ZINCDataset(mode="valid")[:256*2]
        test_dataset = dgl.data.ZINCDataset(mode="test")[:256*2]

        train_samples = [
            (graph, label) for graph, label in zip(train_dataset[0], train_dataset[1])
        ]
        valid_samples = [
            (graph, label) for graph, label in zip(valid_dataset[0], valid_dataset[1])
        ]
        test_samples = [
            (graph, label) for graph, label in zip(test_dataset[0], test_dataset[1])
        ]

        self.train = train_samples
        self.val = valid_samples
        self.test = test_samples
        self.max_dist = 0
        self.max_in_degree = 0
        self.max_out_degree = 0
        self.max_num_nodes = 0

        for dataset in [train_samples, valid_samples, test_samples]:
            for g, labels in dataset:
                spd, path = dgl.shortest_dist(g, return_paths=True)
                g.ndata["spd"] = spd
                g.ndata["path"] = path
                dist_maxi = torch.max(spd).item()
                if dist_maxi > self.max_dist:
                    self.max_dist = dist_maxi
                in_degree_maxi = torch.max(g.in_degrees()).item()
                if in_degree_maxi > self.max_in_degree:
                    self.max_in_degree = in_degree_maxi
                out_degree_maxi = torch.max(g.out_degrees()).item()
                if out_degree_maxi > self.max_out_degree:
                    self.max_out_degree = out_degree_maxi
                max_nodes = g.num_nodes()
                if max_nodes > self.max_num_nodes:
                    self.max_num_nodes = max_nodes

    def collate(self, samples):
        """
        Custom collate function to batch graphs, labels, and additional data.

        Args:
            samples (list): List of samples, where each sample is a tuple (graph, label).

        Returns:
            tuple: Batched data including labels, attention mask, node features,
                   in-degrees, out-degrees, path data, and distances.
        """
        graphs, labels = zip(*samples)
        num_graphs = len(graphs)
        num_nodes = [g.num_nodes() for g in graphs]
        max_num_nodes = max(num_nodes)

        attn_mask = torch.zeros(num_graphs, max_num_nodes + 1, max_num_nodes + 1)

        node_feat = []
        in_degree, out_degree = [], []
        path_data = []

        dist = -torch.ones((num_graphs, max_num_nodes, max_num_nodes), dtype=torch.long)

        for i in range(num_graphs):
            attn_mask[i, :, num_nodes[i] + 1 :] = 1
            attn_mask[i, num_nodes[i] + 1 :, :] = 1

            nd_feat = graphs[i].ndata["feat"] + 1
            if len(nd_feat.shape) == 1:
                nd_feat = nd_feat.unsqueeze(1)
            node_feat.append(nd_feat)

            in_degree.append(
                torch.clamp(graphs[i].in_degrees() + 1, min=0, max=self.max_in_degree)
            )
            out_degree.append(
                torch.clamp(graphs[i].out_degrees() + 1, min=0, max=self.max_out_degree)
            )

            path = graphs[i].ndata["path"]
            path_len = path.size(dim=2)
            max_len = self.max_dist
            if (path_len >= max_len):
                shortest_path = path[:, :, :max_len]
            else:
                p1d = (0, max_len - path_len)
                shortest_path = F.pad(path, p1d, "constant", -1)
            pad_num_nodes = max_num_nodes - num_nodes[i]
            p3d = (0, 0, 0, pad_num_nodes, 0, pad_num_nodes)
            shortest_path = F.pad(shortest_path, p3d, "constant", -1)

            edata = graphs[i].edata["feat"] + 1
            if len(edata.shape) == 1:
                edata = edata.unsqueeze(-1)
            edata = torch.cat((edata, torch.zeros(1, edata.shape[1])), dim=0)
            path_data.append(edata[shortest_path])

            dist[i, : num_nodes[i], : num_nodes[i]] = graphs[i].ndata["spd"]

        node_feat = pad_sequence(node_feat, batch_first=True)
        in_degree = pad_sequence(in_degree, batch_first=True)
        out_degree = pad_sequence(out_degree, batch_first=True)

        return (
            torch.stack(labels).reshape(num_graphs, -1),
            attn_mask,
            node_feat,
            in_degree,
            out_degree,
            torch.stack(path_data),
            dist,
        )


In [None]:
# filepath: /home/a373k/Desktop/feb/Graphormer/centralityencoding.py
import torch.nn as nn

class CentralityEncoder(nn.Module):
    """
    Centrality Encoder for encoding node centrality features.

    Args:
        max_in_degree (int): Maximum in-degree of nodes.
        max_out_degree (int): Maximum out-degree of nodes.
        embedding_dim (int): Dimension of the embedding.
    """
    def __init__(self, max_in_degree, max_out_degree, embedding_dim):
        super().__init__()
        self.in_degree_embedding_table = nn.Embedding(max_in_degree+1, embedding_dim, padding_idx=0)
        self.out_degree_embedding_table = nn.Embedding(max_out_degree+1, embedding_dim, padding_idx=0)

    def forward(self, in_degrees, out_degrees):
        """
        Forward pass for the centrality encoder.

        Args:
            in_degrees (Tensor): In-degree tensor.
            out_degrees (Tensor): Out-degree tensor.

        Returns:
            Tensor: Centrality encoding tensor.
        """
        z_in_degree = self.in_degree_embedding_table(in_degrees)
        z_out_degree = self.out_degree_embedding_table(out_degrees)
        z = z_in_degree + z_out_degree
        return z


In [None]:


class SpatialEncoder(nn.Module):
    """
    Spatial Encoder for encoding shortest path distances.

    Args:
        max_dist (int): Maximum distance for the shortest path.
        num_heads (int): Number of attention heads.
    """
    def __init__(self, max_dist, num_heads=1):
        super().__init__()
        self.max_dist = max_dist
        self.num_heads = num_heads
        self.embedding_table = nn.Embedding(max_dist + 2, num_heads, padding_idx=0)

    def forward(self, dist):
        """
        Forward pass for the spatial encoder.

        Args:
            dist (Tensor): Shortest path distance tensor.

        Returns:
            Tensor: Spatial encoding tensor.
        """
        spatial_encoding = self.embedding_table(
            th.clamp(
                dist,
                min=-1,
                max=self.max_dist,
            )
            + 1
        )
        return spatial_encoding


In [None]:


class EdgeEncoder(nn.Module):
    """
    Edge Encoder for encoding edge features along the shortest path.

    Args:
        max_len (int): Maximum length of the shortest path.
        feat_dim (int): Dimension of the edge features.
        num_heads (int): Number of attention heads.
    """
    def __init__(self, max_len, feat_dim, num_heads=1):
        super().__init__()
        self.max_len = max_len
        self.feat_dim = feat_dim
        self.num_heads = num_heads
        self.embedding_table = nn.Embedding(max_len * num_heads, feat_dim)

    def forward(self, dist, path_data):
        """
        Forward pass for the edge encoder.

        Args:
            dist (Tensor): Shortest path distance tensor.
            path_data (Tensor): Edge feature tensor along the shortest path.

        Returns:
            Tensor: Path encoding tensor.
        """
        shortest_distance = th.clamp(dist, min=1, max=self.max_len)
        edge_embedding = self.embedding_table.weight.reshape(
            self.max_len, self.num_heads, -1
        )
        path_encoding = th.div(
            th.einsum("bxyld,lhd->bxyh", path_data, edge_embedding).permute(
                3, 0, 1, 2
            ),
            shortest_distance,
        ).permute(1, 2, 3, 0)
        return path_encoding


In [None]:
# filepath: /home/a373k/Desktop/feb/Graphormer/encoder.py
import torch
import torch.nn as nn

class FeedForwardNetwork(nn.Module):
    """
    Feed Forward Network used in the encoder.

    Args:
        hidden_size (int): Size of the hidden layer.
        ffn_size (int): Size of the feed-forward layer.
        encoder_dropout (float): Dropout rate.
    """
    def __init__(self, hidden_size, ffn_size, encoder_dropout):  # corrected typo here
        super(FeedForwardNetwork, self).__init__()
        self.layer1 = nn.Linear(hidden_size, ffn_size)
        self.gelu = nn.GELU()
        self.layer2 = nn.Linear(ffn_size, hidden_size)
        self.fnn_dropout = nn.Dropout(encoder_dropout)  # corrected typo here

    def forward(self, x):
        """
        Forward pass for the feed-forward network.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Output tensor.
        """
        x = self.layer1(x)
        x = self.gelu(x)
        x = self.layer2(x)
        x = self.fnn_dropout(x)
        return x

class Attention(nn.Module):
    """
    Multi-head Attention mechanism.

    Args:
        hidden_size (int): Size of the hidden layer.
        attention_drop (float): Dropout rate for attention.
        num_heads (int): Number of attention heads.
    """
    def __init__(self, hidden_size, attention_drop, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.att_size = hidden_size // num_heads
        self.scale = self.att_size**-0.5
        self.linear_q = nn.Linear(hidden_size, num_heads * self.att_size)
        self.linear_k = nn.Linear(hidden_size, num_heads * self.att_size)
        self.linear_v = nn.Linear(hidden_size, num_heads * self.att_size)
        self.att_dropout = nn.Dropout(attention_drop)
        self.output_layer = nn.Linear(num_heads * self.att_size, hidden_size)

    def forward(self, h, att_bias=None, mask=None):
        """
        Forward pass for the attention mechanism.

        Args:
            h (Tensor): Input tensor.
            att_bias (Tensor, optional): Attention bias tensor.
            mask (Tensor, optional): Attention mask tensor.

        Returns:
            Tensor: Output tensor.
        """
        q = self.linear_q(h)
        k = self.linear_k(h)
        v = self.linear_v(h)

        q = q.view(q.size(0), q.size(1), self.num_heads, self.att_size).transpose(1, 2)
        k = k.view(k.size(0), k.size(1), self.num_heads, self.att_size).transpose(1, 2)
        v = v.view(v.size(0), v.size(1), self.num_heads, self.att_size).transpose(1, 2)

        attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        if att_bias is not None:
            attn_weights = attn_weights + att_bias

        if mask is not None:
            attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))

        attn_weights = torch.softmax(attn_weights, dim=-1)
        attn_weights = self.att_dropout(attn_weights)

        attn_output = torch.matmul(attn_weights, v)
        attn_output = attn_output.transpose(1, 2).contiguous().view(h.size(0), h.size(1), -1)
        attn_output = self.output_layer(attn_output)

        return attn_output

class Encoder(nn.Module):
    """
    Encoder layer consisting of multi-head attention and feed-forward network.

    Args:
        hidden_size (int): Size of the hidden layer.
        ffn_out_size (int): Size of the feed-forward layer.
        attention_dropout (float): Dropout rate for attention.
        num_heads (int): Number of attention heads.
    """
    def __init__(self, hidden_size, ffn_out_size, attention_dropout, num_heads):
        super().__init__()
        self.attention = Attention(hidden_size, attention_dropout, num_heads)
        self.ffn = FeedForwardNetwork(hidden_size, ffn_out_size, attention_dropout)
        self.layer_norm1 = nn.LayerNorm(hidden_size)
        self.layer_norm2 = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, x, att_mask=None, att_bias=None):
        """
        Forward pass for the encoder layer.

        Args:
            x (Tensor): Input tensor.
            att_mask (Tensor, optional): Attention mask tensor.
            att_bias (Tensor, optional): Attention bias tensor.

        Returns:
            Tensor: Output tensor.
        """
        h = self.layer_norm1(x)
        h = self.attention(h, att_bias, att_mask)
        h = self.dropout(h)
        h = x + h

        x = self.layer_norm2(h)
        x = self.ffn(x)
        x = self.dropout(x)
        x = h + x

        return x
