In [1]:
import torch


from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_dense_adj
from torch.utils.data import DataLoader, random_split

In [2]:
from helpers import CVFConfigDataset

In [3]:
class VanillaGNNLayer(torch.nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.linear = torch.nn.Linear(dim_in, dim_out, bias=False)

    def forward(self, x, adjacency):
        x = self.linear(x)
        # x = torch.sparse.mm(adjacency, x)
        x = torch.matmul(adjacency, x)
        return x

In [4]:
# dataset = CVFConfigDataset(
#     "small_graph_test_config_rank_dataset.csv", "small_graph_edge_index.json", 4
# )
# dataset = CVFConfigDataset(
#     "graph_1_config_rank_dataset.csv", "graph_1_edge_index.json", 10
# )
dataset = CVFConfigDataset(
    "graph_4_config_rank_dataset.csv", "graph_4_edge_index.json", 10
)
# dataset = CVFConfigDataset(
#     "graph_5_config_rank_dataset.csv", "graph_5_edge_index.json", 9
# )
# train_size = int(0.75 * len(dataset))
# test_size = len(dataset) - train_size
# train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
data_loader = DataLoader(dataset, batch_size=531441, shuffle=False)

In [5]:
# adjacency = to_dense_adj(dataset.edge_index.t().contiguous())[0]
# adjacency += torch.eye(len(adjacency))
# adjacency = adjacency.unsqueeze(0)
# adjacency

# dataset.edge_index

In [6]:
class VanillaGNN(torch.nn.Module):
    def __init__(self, dim_in, dim_h, dim_out):
        super().__init__()
        self.gcn1 = GCNConv(dim_in, dim_h)
        self.gcn2 = GCNConv(dim_h, dim_h)
        # self.gnn1 = VanillaGNN(dim_in, dim_h)
        # self.gnn2 = VanillaGNN(dim_h, dim_h)
        self.out = torch.nn.Linear(dim_h, dim_out)

    # def forward_gnn(self, x, adjacency):
    #     h = self.gnn1(x, adjacency)
    #     h = torch.relu(h)
    #     h = self.gnn2(h, adjacency)
    #     h = torch.relu(h)
    #     h = self.out(h)
    #     # h = torch.sigmoid(h)
    #     h = h.squeeze(-1)
    #     return h

    def forward(self, x, edge_index):
        h = self.gcn1(x, edge_index)
        h = torch.relu(h)
        h = self.gcn2(h, edge_index)
        h = torch.relu(h)
        h = self.out(h)
        h = h.squeeze(-1)
        return h

    def fit(self, data_loader, epochs):
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=5e-4)
        self.train()
        for epoch in range(1, epochs + 1):
            avg_loss = 0
            count = 0
            for batch in data_loader:
                x = batch[0]
                y = batch[1]
                # print(x)
                # print(y)
                optimizer.zero_grad()
                out = self(x, dataset.edge_index.t())
                # out = self(x, adjacency)
                # print("output", out, "y", y)
                loss = criterion(out, y)
                avg_loss += loss
                count += 1
                loss.backward()
                optimizer.step()

            print("Loss:", avg_loss / count)


In [None]:
gnn = VanillaGNN(1, 64, 1)
print(gnn)

gnn.fit(data_loader, epochs=25)

In [None]:
torch.set_printoptions(profile="full")

total_matched = 0

for batch in data_loader:
    x = batch[0]
    y = batch[1]
    predicted = gnn(x, dataset.edge_index.t())
    predicted = predicted.argmax(dim=1)
    matched = (predicted == y).sum().item()
    total_matched += matched
    # print("y", y, "predicted", predicted, "Matched", matched)

print("Total matched", total_matched, "| Accuracy", round(total_matched/len(dataset), 4) * 100, "%")

Total matched 297134 | Accuracy 45.29 %


In [None]:
gnn.fit(data_loader, epochs=25)

Loss: tensor(1.8076, grad_fn=<DivBackward0>)
Loss: tensor(1.7204, grad_fn=<DivBackward0>)
Loss: tensor(1.7578, grad_fn=<DivBackward0>)
Loss: tensor(1.7384, grad_fn=<DivBackward0>)
Loss: tensor(1.7305, grad_fn=<DivBackward0>)
Loss: tensor(1.7355, grad_fn=<DivBackward0>)
Loss: tensor(1.7189, grad_fn=<DivBackward0>)
Loss: tensor(1.7166, grad_fn=<DivBackward0>)
Loss: tensor(1.7220, grad_fn=<DivBackward0>)
Loss: tensor(1.7182, grad_fn=<DivBackward0>)
Loss: tensor(1.7160, grad_fn=<DivBackward0>)
Loss: tensor(1.7167, grad_fn=<DivBackward0>)
Loss: tensor(1.7119, grad_fn=<DivBackward0>)
Loss: tensor(1.7100, grad_fn=<DivBackward0>)
Loss: tensor(1.7118, grad_fn=<DivBackward0>)
Loss: tensor(1.7113, grad_fn=<DivBackward0>)
Loss: tensor(1.7106, grad_fn=<DivBackward0>)
Loss: tensor(1.7098, grad_fn=<DivBackward0>)
Loss: tensor(1.7074, grad_fn=<DivBackward0>)
Loss: tensor(1.7065, grad_fn=<DivBackward0>)
Loss: tensor(1.7068, grad_fn=<DivBackward0>)
Loss: tensor(1.7063, grad_fn=<DivBackward0>)
Loss: tens