<a href="https://colab.research.google.com/github/Nostal1ga/8959prgram_prepare/blob/main/draft1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
import torch
import torch.nn.functional as F
from torch import Tensor, matmul
from torch.nn import (
    BatchNorm1d, Dropout, Embedding, Linear, Module, ModuleList, ReLU, LeakyReLU,Sequential
)
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch_geometric.transforms as T
from torch_geometric.datasets import GNNBenchmarkDataset, Planetoid, ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn import (
    GINEConv, global_add_pool, MessagePassing, BatchNorm
)
from torch_geometric.nn.inits import reset
from torch_geometric.nn.resolver import activation_resolver, normalization_resolver
from torch_geometric.typing import Adj
from torch_geometric.utils import degree, sort_edge_index, to_dense_batch

from mamba_ssm import Mamba
import inspect


In [129]:
path, subset = '/temp', True

transform = T.AddRandomWalkPE(walk_length=20, attr_name='pe')
train_dataset = ZINC(path, subset=subset, split='train', pre_transform=transform)
val_dataset = ZINC(path, subset=subset, split='val', pre_transform=transform)
test_dataset = ZINC(path, subset=subset, split='test', pre_transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
test_loader = DataLoader(test_dataset, batch_size=64)

In [135]:
class MambaConv(MessagePassing):
    def __init__(
        self,
        channels: int,
        conv: MessagePassing,
        dropout: float = 0.0,
        d_state: int = 16,
        d_conv: int = 4,
        shuffle_ind: int = 0,
        norm: str = 'batch_norm'
    ):
        super().__init__()

        self.channels = channels
        self.conv = conv
        self.dropout = dropout
        self.shuffle_ind = shuffle_ind

        self.self_attn = Mamba(
                d_model=channels,
                d_state=d_state,
                d_conv=d_conv,
                expand=2## Expansion factor
        )

        self.mlp = Sequential(
            Linear(channels, channels * 2),
            torch.nn.ReLU(),
            Dropout(dropout),
            Linear(channels * 2, channels),
            Dropout(dropout),
        )

        self.norm1, self.norm2, self.norm3 = [BatchNorm(channels) for _ in range(3)]
        self.norm_with_batch = 'batch' in inspect.signature(self.norm1.forward).parameters

    def reset_parameters(self):
        if self.conv:
            self.conv.reset_parameters()
        self.self_attn._reset_parameters()
        self.mlp.apply(lambda m: m.reset_parameters() if hasattr(m, 'reset_parameters') else None)
        for norm in [self.norm1, self.norm2, self.norm3]:
            norm.reset_parameters()

    def forward(self, x, edge_index, batch=None, **kwargs):
        hs = []

        if self.conv is not None:
            conv_output = self.conv(x, edge_index, **kwargs)
            conv_output = F.dropout(conv_output, p=self.dropout, training=self.training)
            conv_output = conv_output + x  # Residual connection
            conv_output = self._apply_norm(conv_output, self.norm1, batch)
            hs.append(conv_output)

        # Process self-attention
        if self.shuffle_ind == 0:
            h, mask = to_dense_batch(x, batch)
            attn_output = self.self_attn(h)[mask]
        else:
            attn_outputs = [self._process_shuffled_self_attn(x, batch) for _ in range(self.shuffle_ind)]
            attn_output = sum(attn_outputs) / self.shuffle_ind

        attn_output = F.dropout(attn_output, p=self.dropout, training=self.training)
        attn_output = attn_output + x  # Residual connection
        attn_output = self._apply_norm(attn_output, self.norm2, batch)
        hs.append(attn_output)

        # Combine all outputs
        combined_output = sum(hs)
        final_output = self.mlp(combined_output) + combined_output  # Pass through MLP and add residual
        final_output = self._apply_norm(final_output, self.norm3, batch)
        return final_output

    def _apply_norm(self, input_tensor, norm_layer, batch):
        #"Applies the given normalization layer"
        if norm_layer is not None:
            if self.norm_with_batch:
                return norm_layer(input_tensor, batch=batch)
            else:
                return norm_layer(input_tensor)
        return input_tensor

    def _process_shuffled_self_attn(self, x, batch):
        #"Processes self-attention for shuffled inputs."
        permuted_indices = self.permute_within_batch(x, batch)
        dense_x, mask = to_dense_batch(x[permuted_indices], batch)
        return self.self_attn(dense_x)[mask][permuted_indices]

    def permute_within_batch(self,x, batch):
        permuted_indices = torch.cat([
            indices[torch.randperm(len(indices))]
            for batch_index in torch.unique(batch)
            for indices in [(batch == batch_index).nonzero().squeeze()]
        ])
        return permuted_indices

    def message(self, x_j, edge_attr, PE_i, PE_j):
        r_ij = ((PE_i - PE_j) ** 2).sum(dim=-1, keepdim=True)
        r_ij = self.mlp(r_ij)  # the MLP is 1 dim --> hidden_dim --> 1 dim
        return ((x_j + edge_attr).relu()) * r_ij

    def __repr__(self):
      conv_type = self.conv.__class__.__name__ if self.conv else 'None'
      return (f"{self.__class__.__name__}("
              f"channels={self.channels}, "
              f"d_state={self.d_state}, "
              f"d_conv={self.d_conv}, ")

In [136]:
class GraphModel(torch.nn.Module):
    def __init__(self, channels: int, pe_dim: int, num_layers: int,shuffle_ind: int, d_state: int, d_conv: int):
        super().__init__()

        self.node_emb = Embedding(28, channels - pe_dim)
        self.pe_lin = Linear(20, pe_dim)
        self.pe_norm = BatchNorm1d(20)
        self.edge_emb = Embedding(4, channels)
        self.shuffle_ind = shuffle_ind

        self.convs = ModuleList()
        for _ in range(num_layers):
            nn = Sequential(
                Linear(channels, channels),
                ReLU(),
                Linear(channels, channels),
            )
            conv = MambaConv(channels, GINEConv(nn),shuffle_ind=self.shuffle_ind,d_state=d_state, d_conv=d_conv)
            self.convs.append(conv)

        # Multi-layer perceptron (MLP) for prediction
        self.mlp = Sequential(
          Linear(channels, channels // 2),
          LeakyReLU(0.2),
          Linear(channels // 2, channels // 4),
          LeakyReLU(0.2),
          Linear(channels // 4, 1),
      )

    def forward(self, x, pe, edge_index, edge_attr, batch):
        # Combine node embeddings and processed pe
        x = torch.cat((self.node_emb(x.squeeze(-1)), self.pe_lin(self.pe_norm(pe))), dim=1)
        # Process edge attributes throug embedding layer
        edge_attr = self.edge_emb(edge_attr)

        for conv in self.convs:
            x = conv(x, edge_index, batch=batch, edge_attr=edge_attr)
            # Aggregate node features globally
        return self.mlp(global_add_pool(x, batch))

In [137]:
def train():
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.pe, data.edge_index, data.edge_attr, data.batch)
        loss = torch.nn.functional.l1_loss(out.squeeze(), data.y, reduction='mean')
        loss.backward()
        total_loss += loss.item() * data.num_graphs
        optimizer.step()
    return total_loss / len(train_loader.dataset)

In [138]:
def test(loader):
    model.eval()
    total_error = 0
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.pe, data.edge_index, data.edge_attr, data.batch)
            total_error += torch.nn.functional.l1_loss(out.squeeze(), data.y, reduction='sum').item()
    return total_error / len(loader.dataset)


In [139]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GraphModel(channels=64, pe_dim=8, num_layers=10,
                   d_conv=4, d_state=16,shuffle_ind=1
                  ).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min')
arr = []
for epoch in range(1, 6):
    loss = train()
    val_mae = test(val_loader)
    test_mae = test(test_loader)
    scheduler.step(val_mae)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '
          f'Test: {test_mae:.4f}')
    arr.append(test_mae)
ordering = arr
print(ordering)

Epoch: 01, Loss: 0.6346, Val: 0.7450, Test: 0.7867
Epoch: 02, Loss: 0.5235, Val: 0.5546, Test: 0.5288
Epoch: 03, Loss: 0.4841, Val: 0.5907, Test: 0.5963
Epoch: 04, Loss: 0.4592, Val: 0.5049, Test: 0.4638
Epoch: 05, Loss: 0.4353, Val: 0.4535, Test: 0.4539
[0.786665714263916, 0.5287840900421142, 0.5962694263458252, 0.4637624397277832, 0.45388284301757814]
