<a href="https://colab.research.google.com/github/2019mohamed/GraphMixer/blob/main/GraphMixer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [69]:
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange
import networkx as nx
import dgl
import numpy as np

class GlobalAveragePooling(nn.Module):

    def __init__(self, dim: int = 1):
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.mean(dim=self.dim)


class Classifier(nn.Module):

    def __init__(self, input_dim: int, num_classes: int):
        super().__init__()
        self.model = nn.Linear(input_dim, num_classes)
        nn.init.zeros_(self.model.weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)


class MLPBlock(nn.Module):

    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, input_dim)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)


class MixerBlock(nn.Module):

    def __init__(
        self,
        num_patches: int,
        num_channels: int,
        tokens_hidden_dim: int,
        channels_hidden_dim: int
    ):
        super().__init__()
        self.token_mixing = nn.Sequential(
            nn.LayerNorm(num_channels),
            Rearrange("b p c -> b c p"),
            MLPBlock(num_patches, tokens_hidden_dim),
            Rearrange("b c p -> b p c")
        )
        self.channel_mixing = nn.Sequential(
            nn.LayerNorm(num_channels),
            MLPBlock(num_channels, channels_hidden_dim)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.token_mixing(x)
        x = x + self.channel_mixing(x)
        return x

class Net(nn.Module):
    def __init__(self,
        num_layers:int,
        num_patches: int,
        num_channels: int,
        tokens_hidden_dim: int,
        channels_hidden_dim: int,
        num_classes: int):
        super().__init__()
        self.layers = [MixerBlock(num_patches , num_channels , tokens_hidden_dim , channels_hidden_dim) for _ in range(num_layers)]
        self.mixers = nn.Sequential(*self.layers)
        self.pool = GlobalAveragePooling()
        self.classifier = Classifier(num_channels ,num_classes)

    def forward (self,x):
      out = self.mixers(x)
      pooled = self.pool(out)
      return self.classifier(pooled)


from dgl.data import GINDataset
from dgl.nn.pytorch.conv import GraphConv, SAGEConv
from dgl.nn import SumPooling, AvgPooling, MaxPooling
dataset = GINDataset('PROTEINS' , self_loop= True)
#print(len(dataset[0][0].ndata['attr']),' ',dataset[0])
perm = np.random.permutation(len(dataset))
train_idx = perm[:1001]
test_idx = perm[1001:]

train_emd , test_emd = [], []
train_label , test_label = [],[]

sum_pool = SumPooling()
avg_pool = AvgPooling()
max_pool = MaxPooling()
for i in train_idx:
  emd = []
  g = dataset[i][0]
  l = dataset[i][1].detach().numpy()
  e1 = GraphConv(3 ,  16, norm='both', weight=True, bias=True)
  e1 = e1(g , g.ndata['attr'])
  sum = sum_pool(g , e1)[0].detach().numpy()
  avg = avg_pool(g , e1)[0].detach().numpy()
  max = max_pool(g , e1)[0].detach().numpy()
  emd.append(sum)
  emd.append(avg)
  emd.append(max)
  e1 = SAGEConv(3,  16, 'pool')
  e1 = e1(g , g.ndata['attr'])
  sum = sum_pool(g , e1)[0].detach().numpy()
  avg = avg_pool(g , e1)[0].detach().numpy()
  max = max_pool(g , e1)[0].detach().numpy()
  emd.append(sum)
  emd.append(avg)
  emd.append(max)
  train_emd.append(emd)
  train_label.append(l)

for i in test_idx:
  emd = []
  g = dataset[i][0]
  l = dataset[i][1].detach().numpy()
  e1 = GraphConv(3 ,  16, norm='both', weight=True, bias=True)
  e1 = e1(g , g.ndata['attr'])
  sum = sum_pool(g , e1)[0].detach().numpy()
  avg = avg_pool(g , e1)[0].detach().numpy()
  max = max_pool(g , e1)[0].detach().numpy()
  emd.append(sum)
  emd.append(avg)
  emd.append(max)
  e1 = SAGEConv(3 ,  16,  'pool')
  e1 = e1(g , g.ndata['attr'])
  sum = sum_pool(g , e1)[0].detach().numpy()
  avg = avg_pool(g , e1)[0].detach().numpy()
  max = max_pool(g , e1)[0].detach().numpy()
  emd.append(sum)
  emd.append(avg)
  emd.append(max)
  test_emd.append(emd)
  test_label.append(l)

train_emd = np.array(train_emd)
train_label = np.array(train_label)
test_emd = np.array(test_emd)
test_label = np.array(test_label)

from torch.utils.data import TensorDataset, DataLoader

train_emd = torch.tensor(train_emd)
test_emd = torch.tensor(test_emd)
train_label = torch.tensor(train_label)
test_label = torch.tensor(test_label)

print(train_emd.shape , ' ',train_label.shape)
print(test_emd.shape , ' ',test_label.shape)

train_data = TensorDataset(train_emd , train_label)
test_data = TensorDataset(test_emd , test_label)

train_loader = DataLoader(train_data , batch_size=16)

test_loader = DataLoader(test_data , batch_size=16)


net = Net(7 , 6 , 16 , 256 , 256 ,2 )

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

def train():
  net.train()
  for i , data in enumerate(train_loader):
    x , l = data
    optimizer.zero_grad()
    out = net(x)
    loss = criterion(out, l)
    loss.backward()
    optimizer.step()
  
  return loss.item()

def test():
  net.eval()

  loss = 0
  correct = 0
  with torch.no_grad():
    for i, data in enumerate(test_loader):
      x , l = data
      out = net(x)
      loss += criterion(out, l).item()
      pred = out.argmax(dim = 1, keepdim = True)
      correct += pred.eq (l.view_as(pred)).sum().item()

  return 100. * correct / len(test_emd) , loss / len(test_emd)


for _ in range(1000):
  print(train())
  print(test())


print('ACC AND LOSS' , test())


  




torch.Size([1001, 6, 16])   torch.Size([1001])
torch.Size([112, 6, 16])   torch.Size([112])
0.8021624088287354
(58.92857142857143, 0.04163598269224167)
0.5562059879302979
(76.78571428571429, 0.033701009516205103)
0.5033747553825378
(69.64285714285714, 0.03429496767265456)
0.41816824674606323
(70.53571428571429, 0.03404498392982142)
0.444724977016449
(69.64285714285714, 0.03672621798302446)
0.3365400433540344
(66.96428571428571, 0.03978177798645837)
0.29932549595832825
(69.64285714285714, 0.044087213064943044)
0.10133260488510132
(65.17857142857143, 0.051436668794069974)
0.193223774433136
(66.07142857142857, 0.051226728462747166)
0.14550703763961792
(66.96428571428571, 0.061222371246133535)
0.23124490678310394
(64.28571428571429, 0.07423170070563044)
0.054603829979896545
(69.64285714285714, 0.082057138638837)
0.015208842232823372
(68.75, 0.09235496446490288)
0.16196811199188232
(64.28571428571429, 0.06855205499700137)
0.04414895921945572
(62.5, 0.07212369410055024)
0.009151019155979156


KeyboardInterrupt: ignored