In [1]:
import torch
import torch.nn.functional as F

from torch_geometric.data import Data
from torch.utils.data import DataLoader
from torch_geometric.utils import to_dense_adj

In [2]:
from helpers import CVFConfigDataset

In [None]:
class VanillaGNNLayer(torch.nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.linear = torch.nn.Linear(dim_in, dim_out, bias=False)

    def forward(self, x, adjacency):
        x = self.linear(x)
        x = torch.sparse.mm(adjacency, x)
        return x

In [22]:
dataset = CVFConfigDataset()
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)

In [23]:
edge_index = torch.tensor([
    [0, 1],
    [0, 2],
    [0, 3],
    [1, 0],
    [2, 0],
    [3, 0]
], dtype=torch.long)

# # x = torch.tensor([[0], [0], [0], [0]], dtype=torch.float) # configs: [0, 0, 0, 0], [1, 1, 1, 1]
# x = torch.tensor([[0, 1], [0, 1], [0, 1], [0, 1]], dtype=torch.float) # configs: [0, 0, 0, 0], [1, 1, 1, 1]

# y = torch.tensor([[3.0, 1.0]]) # ranks

# data = Data(x=x, edge_index=edge_index.t().contiguous(), y=y)
# data.train_mask = torch.tensor([1 for _ in range(len(x))])

In [24]:
adjacency = to_dense_adj(edge_index.t().contiguous())[0]
adjacency += torch.eye(len(adjacency))
adjacency

tensor([[1., 1., 1., 1.],
        [1., 1., 0., 0.],
        [1., 0., 1., 0.],
        [1., 0., 0., 1.]])

In [25]:
class VanillaGNN(torch.nn.Module):
    def __init__(self, dim_in, dim_h, dim_out):
        super().__init__()
        self.gnn1 = VanillaGNNLayer(dim_in, dim_h)
        self.gnn2 = VanillaGNNLayer(dim_h, dim_h)
        self.out = torch.nn.Linear(dim_h, dim_out)
    
    def forward(self, x, adjacency):
        h = self.gnn1(x, adjacency)
        h = torch.relu(h)
        h = self.gnn2(h, adjacency)
        h = torch.relu(h)
        # print("h2 r", h)
        h = self.out(h)
        # print(torch.relu(h))
        result = F.softmax(h, dim=0)
        # print("result", result)
        return result
    
    def fit(self, data_loader, epochs):
        # criterion = torch.nn.MSELoss()
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001, weight_decay=5e-4)
        self.train()
        for epoch in range(epochs+1):
            for batch in data_loader:
                x = batch[0][0]
                # y = torch.zeros(4, 1)
                # y[batch[1]] = 1.0
                y = batch[1]
                optimizer.zero_grad()
                out = self(x, adjacency).T
                # print("output", out, "y", y)
                loss = criterion(out, y)
                loss.backward()
                optimizer.step()
            
            # if loss < 0.009:
            #     print("Loss threshold met.")
            #     break
            print("Loss:", loss)


In [26]:
gnn = VanillaGNN(1, 64, 1)
print(gnn)
gnn.fit(data_loader, epochs=200)

VanillaGNN(
  (gnn1): VanillaGNNLayer(
    (linear): Linear(in_features=1, out_features=64, bias=True)
  )
  (gnn2): VanillaGNNLayer(
    (linear): Linear(in_features=64, out_features=64, bias=True)
  )
  (out): Linear(in_features=64, out_features=1, bias=True)
)
Loss: tensor(0.9940, grad_fn=<NllLossBackward0>)
Loss: tensor(0.7508, grad_fn=<NllLossBackward0>)
Loss: tensor(0.7498, grad_fn=<NllLossBackward0>)
Loss: tensor(0.7493, grad_fn=<NllLossBackward0>)
Loss: tensor(0.7485, grad_fn=<NllLossBackward0>)
Loss: tensor(0.7476, grad_fn=<NllLossBackward0>)
Loss: tensor(0.7468, grad_fn=<NllLossBackward0>)
Loss: tensor(0.7463, grad_fn=<NllLossBackward0>)
Loss: tensor(0.7457, grad_fn=<NllLossBackward0>)
Loss: tensor(0.7453, grad_fn=<NllLossBackward0>)
Loss: tensor(0.7450, grad_fn=<NllLossBackward0>)
Loss: tensor(0.7447, grad_fn=<NllLossBackward0>)
Loss: tensor(0.7446, grad_fn=<NllLossBackward0>)
Loss: tensor(0.7445, grad_fn=<NllLossBackward0>)
Loss: tensor(0.7444, grad_fn=<NllLossBackward0>)
L

In [14]:
# out = gnn(torch.tensor([[3.], [2.], [2.], [2.]]), adjacency)
# print(out)

In [20]:
for batch in data_loader:
    x = batch[0][0]
    y = batch[1]
    predicted = gnn(x, adjacency).argmax(dim=0)
    print("y", y, "predicted", predicted)

y tensor([3]) predicted tensor([0])
y tensor([2]) predicted tensor([0])
y tensor([2]) predicted tensor([0])
y tensor([1]) predicted tensor([0])
y tensor([2]) predicted tensor([0])
y tensor([1]) predicted tensor([0])
y tensor([1]) predicted tensor([0])
y tensor([0]) predicted tensor([0])
y tensor([0]) predicted tensor([0])
y tensor([1]) predicted tensor([0])
y tensor([1]) predicted tensor([0])
y tensor([2]) predicted tensor([0])
y tensor([1]) predicted tensor([0])
y tensor([2]) predicted tensor([0])
y tensor([2]) predicted tensor([0])
y tensor([3]) predicted tensor([0])
y tensor([0]) predicted tensor([0])
y tensor([0]) predicted tensor([0])
y tensor([0]) predicted tensor([0])
y tensor([0]) predicted tensor([0])
y tensor([0]) predicted tensor([0])
y tensor([0]) predicted tensor([0])
y tensor([0]) predicted tensor([0])
y tensor([0]) predicted tensor([0])
y tensor([0]) predicted tensor([0])
y tensor([0]) predicted tensor([0])
y tensor([0]) predicted tensor([0])
y tensor([0]) predicted tens