In [245]:
import torch
import torch.nn.functional as F
from torch import Tensor, matmul
from torch.nn import (
    BatchNorm1d, Dropout, Embedding, Linear, Module, ModuleList, ReLU, LeakyReLU,Sequential
)
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch_geometric.transforms as T
from torch_geometric.datasets import GNNBenchmarkDataset, Planetoid, ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn import (
    GINEConv, global_add_pool, MessagePassing, BatchNorm
)
from torch_geometric.nn.inits import reset
from torch_geometric.nn.resolver import activation_resolver, normalization_resolver
from torch_geometric.typing import Adj
from torch_geometric.utils import degree, sort_edge_index, to_dense_batch

from mamba_ssm import Mamba
import inspect

In [197]:
path, subset = '/temp', True

transform = T.AddRandomWalkPE(walk_length=20, attr_name='pe')
train_dataset = ZINC(path, subset=subset, split='train', pre_transform=transform)
val_dataset = ZINC(path, subset=subset, split='val', pre_transform=transform)
test_dataset = ZINC(path, subset=subset, split='test', pre_transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
test_loader = DataLoader(test_dataset, batch_size=64)

In [263]:
class MambaConv(torch.nn.Module):
    def __init__(
        self,
        channels: int,
        conv: MessagePassing,
        dropout: float = 0.0,
        d_state: int = 16,
        d_conv: int = 4,
        norm: str = 'batch_norm'
    ):
        super().__init__()

        self.channels = channels
        self.conv = conv
        self.dropout = dropout

        self.self_attn = Mamba(
                d_model=channels,
                d_state=d_state,
                d_conv=d_conv,
                expand=2## Expansion factor
        )

        self.mlp = Sequential(
            Linear(channels, channels * 2),
            torch.nn.ReLU(),
            Dropout(dropout),
            Linear(channels * 2, channels),
            Dropout(dropout),
        )

        self.norm1, self.norm2, self.norm3 = [BatchNorm(channels) for _ in range(3)]
        self.norm_with_batch = 'batch' in inspect.signature(self.norm1.forward).parameters

    def reset_parameters(self):
        if self.conv:
            self.conv.reset_parameters()
        self.self_attn._reset_parameters()
        self.mlp.apply(lambda m: m.reset_parameters() if hasattr(m, 'reset_parameters') else None)
        for norm in [self.norm1, self.norm2, self.norm3]:
            norm.reset_parameters()

    def forward(self, x: Tensor, edge_index: Adj, batch: None, **kwargs):
        original_x = x
        if self.conv:
            conv_out = self.conv(x, edge_index, **kwargs)
            x = F.dropout(conv_out, p=self.dropout, training=self.training)
            x = x + original_x
            if self.norm1:
                x = self.norm1(x, batch) if self.norm_with_batch else self.norm1(x)
        ## Apply self-attention
        dense_x, mask = to_dense_batch(x, batch)
        attn_out = self.self_attn(dense_x)[mask]
        x = F.dropout(attn_out, p=self.dropout, training=self.training) + original_x
        if self.norm2:
            x = self.norm2(x, batch) if self.norm_with_batch else self.norm2(x)
        mlp_out = self.mlp(x)
        out = mlp_out + original_x
        if self.norm3:
            out = self.norm3(out, batch) if self.norm_with_batch else self.norm3(out)
        return out

    def message(self, x_j, edge_weight):
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

    def message_and_aggregate(self, adj_t, x) :
        return matmul(adj_t, x, reduce=self.aggr)

    def permute_within_batch(self, x, batch):
        permuted_indices = torch.empty_like(batch, dtype=torch.long)
        # Iterate over each unique batch index directly
        for batch_index in torch.unique(batch):
            # Find indices within the current batch
            indices = torch.where(batch == batch_index)[0]
            # Randomly shuffle
            shuffled_indices = indices[torch.randperm(indices.size(0))]
            permuted_indices[indices] = shuffled_indices
        return x[permuted_indices], batch[permuted_indices]


    def __repr__(self):
      conv_type = self.conv.__class__.__name__ if self.conv else 'None'
      return (f"{self.__class__.__name__}("
              f"channels={self.channels}, "
              f"d_state={self.d_state}, "
              f"d_conv={self.d_conv}, ")

In [264]:
class GraphModel(torch.nn.Module):
    def __init__(self, channels: int, pe_dim: int, num_layers: int, d_state: int, d_conv: int):
        super().__init__()

        self.node_emb = Embedding(28, channels - pe_dim)
        self.pe_lin = Linear(20, pe_dim)
        self.pe_norm = BatchNorm1d(20)
        self.edge_emb = Embedding(4, channels)

        self.convs = ModuleList()
        for _ in range(num_layers):
            nn = Sequential(
                Linear(channels, channels),
                ReLU(),
                Linear(channels, channels),
            )
            conv = MambaConv(channels, GINEConv(nn),d_state=d_state, d_conv=d_conv)
            self.convs.append(conv)

        # Multi-layer perceptron (MLP) for prediction
        self.mlp = Sequential(
          Linear(channels, channels // 2),
          LeakyReLU(0.2),
          Linear(channels // 2, channels // 4),
          LeakyReLU(0.2),
          Linear(channels // 4, 1),
      )

    def forward(self, x, pe, edge_index, edge_attr, batch):
        # Combine node embeddings and processed pe
        x = torch.cat((self.node_emb(x.squeeze(-1)), self.pe_lin(self.pe_norm(pe))), dim=1)
        # Process edge attributes throug embedding layer
        edge_attr = self.edge_emb(edge_attr)

        for conv in self.convs:
            x = conv(x, edge_index, batch=batch, edge_attr=edge_attr)
            # Aggregate node features globally
        return self.mlp(global_add_pool(x, batch))

In [265]:
def train():
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.pe, data.edge_index, data.edge_attr, data.batch)
        loss = torch.nn.functional.l1_loss(out.squeeze(), data.y, reduction='mean')
        loss.backward()
        total_loss += loss.item() * data.num_graphs
        optimizer.step()
    return total_loss / len(train_loader.dataset)


In [266]:
def test(loader):
    model.eval()
    total_error = 0
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.pe, data.edge_index, data.edge_attr, data.batch)
            total_error += torch.nn.functional.l1_loss(out.squeeze(), data.y, reduction='sum').item()
    return total_error / len(loader.dataset)


In [267]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GraphModel(channels=64, pe_dim=8, num_layers=10,
                   d_conv=4, d_state=16,
                  ).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min')
arr = []
for epoch in range(1, 6):
    loss = train()
    val_mae = test(val_loader)
    test_mae = test(test_loader)
    scheduler.step(val_mae)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '
          f'Test: {test_mae:.4f}')
    arr.append(test_mae)
ordering = arr
print(ordering)

Epoch: 01, Loss: 0.6481, Val: 0.6769, Test: 0.6624
Epoch: 02, Loss: 0.5200, Val: 0.4470, Test: 0.4622
Epoch: 03, Loss: 0.4752, Val: 0.4034, Test: 0.3503
Epoch: 04, Loss: 0.4453, Val: 0.4360, Test: 0.4107
Epoch: 05, Loss: 0.4215, Val: 0.6822, Test: 0.6717
[0.6623605422973633, 0.46223797035217284, 0.35028810691833495, 0.4106653003692627, 0.6717489128112792]
