In [571]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import KarateClub, Planetoid
from processing import edge_index_to_adj_matrix

dataset = KarateClub()
data = dataset[0]
adj_matrix = edge_index_to_adj_matrix(data.edge_index)

In [579]:
g = torch.Generator().manual_seed(42)
mask = torch.zeros(34)
mask[torch.randperm(34, generator=g)[:24]] = True
mask = mask.bool()

In [580]:
class LGCL(nn.Module):
    def __init__(self, input_dim, output_dim, k):
        super().__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.k = k
        
        """
        3.2 - 1-D Convolutional Neural Networks
        The c(·) function in formula 3

        I think they are using there own implementation??
        But i will use pytorch implementation for now
        """
        self.conv1 = nn.Conv1d(
            in_channels=input_dim, 
            out_channels=output_dim, 
            kernel_size=k+1
        )

    def forward(self, x, adj_matrix):
        """
        3.2 - k-largest Node Selection
        The f(·) function in formula 3

        This is kind of life preliminary step
        The "message passing" occurs during the convolution
        """
        # if len(neighbors_features) < 4:
        #     neighbors_features = F.pad(neighbors_features, pad=(0, 0, 0, k - len(neighbors_features)))
        
        neighbors_idx = adj_matrix == 1
        neighbors_features = x.unsqueeze(0) * neighbors_idx.unsqueeze(-1)
        top_k_features = torch.topk(neighbors_features, k=self.k, dim=1).values
        
        self_and_top_k_features = torch.cat([x.unsqueeze(1), top_k_features], dim=1) # N x N x k vector
        self_and_top_k_features = self_and_top_k_features.transpose(1, 2)
        
        out = self.conv1(self_and_top_k_features)
        out = out.transpose(1, 2)
        
        return out.squeeze(1)

In [586]:
class LGCN(nn.Module):
    def __init__(self, input_dim, output_dim, k):
        super().__init__()

        self.LGCL1 = LGCL(input_dim, 4, k)
        self.LGCL2 = LGCL(25, output_dim, k)

    def forward(self, x, adj_matrix):
        h = self.LGCL1(x, adj_matrix)
        # h = F.relu(h)
        # logits = self.LGCL2(h, adj_matrix)
        
        return h

In [587]:
model = LGCN(34, 4, 1)
targets = F.one_hot(data.y, num_classes=4).float()

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=.01
)

In [588]:
import numpy as np
model.train()
for idx in range(100):
    optimizer.zero_grad()
    preds = model(data.x, adj_matrix)
    # print(preds)
    # break
    loss = loss_fn(preds[mask], targets[mask])
    loss.backward()
    optimizer.step()
    print(loss.item())

1.4479280710220337
1.3827403783798218
1.319880485534668
1.259462594985962
1.2015976905822754
1.1463440656661987
1.0936994552612305
1.0436075925827026
0.9959812760353088
0.9507231712341309
0.9077350497245789
0.8669193387031555
0.8281809687614441
0.7914261817932129
0.756563663482666
0.723504364490509
0.6921618580818176
0.6624529957771301
0.6342976689338684
0.6076192259788513
0.5823437571525574
0.558400571346283
0.5357216000556946
0.5142415165901184
0.4938971996307373
0.4746280908584595
0.4563758373260498
0.4390844404697418
0.4227001965045929
0.4071715772151947
0.39244937896728516
0.3784867823123932
0.3652392327785492
0.3526644706726074
0.3407226502895355
0.3293760120868683
0.3185892403125763
0.3083289861679077
0.29856398701667786
0.28926515579223633
0.2804049551486969
0.27195796370506287
0.26390019059181213
0.2562093734741211
0.2488645762205124
0.24184636771678925
0.2351364940404892
0.2287178635597229
0.22257445752620697
0.21669141948223114
0.21105466783046722
0.20565104484558105
0.20046

In [589]:
model.eval()
out = model(data.x, adj_matrix)
out = out.squeeze(1)
pred = out.argmax(dim=1)
print(pred)
accuracy = (pred == data.y).sum() / len(data.y)
accuracy

tensor([1, 1, 1, 1, 3, 3, 3, 1, 0, 1, 3, 1, 1, 1, 0, 0, 3, 1, 0, 1, 0, 1, 0, 0,
        2, 2, 0, 0, 0, 0, 0, 2, 0, 0])


tensor(0.9706)

In [585]:
data.y

tensor([1, 1, 1, 1, 3, 3, 3, 1, 0, 1, 3, 1, 1, 1, 0, 0, 3, 1, 0, 1, 0, 1, 0, 0,
        2, 2, 0, 0, 2, 0, 0, 2, 0, 0])