In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch_geometric.datasets import KarateClub

In [2]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

In [25]:
device = "cpu"

In [26]:
class edge_model(nn.Module):
    def __init__(self, in_dim, e_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, 128),
            nn.ReLU(),
            nn.Linear(128, e_dim))
    def forward(self, e_k, v_rk, v_sk, u):
        return self.mlp(torch.cat([e_k, v_rk, v_sk, u]))

In [27]:
class node_model(nn.Module):
    def __init__(self, in_dim, v_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, 128),
            nn.ReLU(),
            nn.Linear(128, v_dim))
    def forward(self, e_i_agg, v_i, u):
        return self.mlp(torch.cat([e_i_agg, v_i.unsqueeze(0), u.unsqueeze(0)], dim = 1))

In [28]:
class global_model(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1))
    def forward(self, e_agg, v_agg, u):
        return self.mlp(torch.cat([e_agg, v_agg, u.unsqueeze(0)], dim = 1))

In [29]:
class edge_to_node_agg(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, E_i):
        # E_i [E_i, F_e]
        e_agg_i = torch.sum(E_i, dim=0)
        return e_agg_i.unsqueeze(0)

In [30]:
class edge_to_global_agg(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, E):
        # E [E, F_e]
        e_agg = torch.sum(E, dim=0)
        return e_agg.unsqueeze(0)  

In [31]:
class node_to_global_agg(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, V):
        # V [V, F_v]
        v_agg = torch.sum(V, dim=0)
        return v_agg.unsqueeze(0)

In [47]:
class GraphNet(nn.Module):
    def __init__(self, Ne, Nn, n_dim, e_dim):
        super().__init__()
        self.n_dim = n_dim
        self.e_dim = e_dim
        self.Nn = Nn
        self.Ne = Ne
        
        self.edge_model = edge_model(e_dim + 2*n_dim + 1, self.e_dim)              # phi_e
        self.node_model = node_model(n_dim + 2, self.n_dim)                        # phi_v
        self.global_model = global_model(n_dim+e_dim+1)                            # phi_u
        self.edge_to_node_agg = edge_to_node_agg()                                 # rho_e_v
        self.edge_to_global_agg = edge_to_global_agg()                             # rho_e_u
        self.node_to_global_agg = node_to_global_agg()                             # rho_v_u

        
    def forward(self, E, V, u, r, s):
        E_prime = torch.empty((self.Ne, self.e_dim))
        for k in range(self.Ne):
            e_k, v_rk, v_sk = E[k], V[r[k]], V[s[k]]
            e_prime_k = self.edge_model(e_k, v_rk, v_sk, u)                        # 1. Compute updated edge attributes
            E_prime[k] = e_prime_k
            
        V_prime = torch.empty((self.Nn, self.n_dim))
        for i in range(self.Nn):
            if any(r == i):
                E_prime_i = torch.stack([E_prime[k] for k in range(Ne) if r[k]==i], dim = 0)
                e_prime_bar_i = self.edge_to_node_agg(E_prime_i)                   # 2. Aggregate edge attributes per node
                v_prime_i = self.node_model(e_prime_bar_i, V[i], u)                # 3. Compute updated node attributes
                V_prime[i] = v_prime_i
            
        e_prime_bar = self.edge_to_global_agg(E_prime)                             # 4. Aggregate edge attributes globally
        
        v_prime_bar = self.node_to_global_agg(V_prime)                             # 5. Aggregate node attributes globally
        u_prime = self.global_model(e_prime_bar, v_prime_bar, u)                   # 6. Compute updated global attribute
        
        
        return E_prime, V_prime, u_prime

In [48]:
class Decoder(nn.Module):
    def __init__(self, in_dim, out_dim, Nn):
        super().__init__()
        self.out_dim = out_dim
        self.in_dim = in_dim
        self.Nn = Nn
        self.mlp = nn.Sequential(
            nn.Linear(self.in_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 4))
    
    def forward(self, V):
        out = torch.empty((self.Nn, self.out_dim))
        for i in range(self.Nn):
            v = self.mlp(V[i])
            out[i] = v
        return out

In [49]:
class Classifier(nn.Module):
    def __init__(self, Ne, Nn, n_dim, e_dim):
        super().__init__()
        self.gn = GraphNet(Ne, Nn, n_dim, e_dim)
        self.dec = Decoder(n_dim, 4, Nn)
        
    def forward(self, E, V, u, r, s):
        E_prime, V_prime, u_prime = self.gn(E, V, u, r, s)
        z = self.dec(V_prime)
        return z

In [50]:
dataset = KarateClub()
s, r = dataset[0].edge_index
V = dataset[0].x
E = torch.ones((s.shape[0], 1))
Nn, n_dim = V.shape
Ne, e_dim = E.shape
u = torch.tensor([1.])

In [52]:
model = Classifier(Ne, Nn, n_dim, e_dim)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)

def accuracy(pred_y, y):
    return (pred_y == y).sum() / len(y)

for epoch in range(51):
    optimizer.zero_grad()
    z = model(E, V, u, r, s)
    loss = criterion(z, dataset[0].y)
    acc = accuracy(z.argmax(dim=1), dataset[0].y)
    loss.backward()
    optimizer.step()
    if epoch % 10 == 0:
        print(f'Epoch {epoch:>3} | Loss: {loss:.2f} | Acc: {acc*100:.2f}%')

Epoch   0 | Loss: 1.41 | Acc: 11.76%
Epoch  10 | Loss: 0.34 | Acc: 88.24%
Epoch  20 | Loss: 0.00 | Acc: 100.00%
Epoch  30 | Loss: 0.00 | Acc: 100.00%
Epoch  40 | Loss: 0.00 | Acc: 100.00%
Epoch  50 | Loss: 0.00 | Acc: 100.00%
