In [1]:
import itertools

import torch

from torch_geometric.nn import GCNConv
from torch_geometric.nn.pool import global_mean_pool
from torch.utils.data import DataLoader, random_split

In [2]:
from helpers import CVFConfigForGCNDataset

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
color_mapping_categories = 15

# dataset = CVFConfigDataset(
#     "small_graph_test_config_rank_dataset.csv", "small_graph_edge_index.json", 4
# )

# dataset_n1 = CVFConfigForGCNDataset(
#     device,
#     "graph_1_config_rank_dataset.csv",
#     "graph_1_edge_index.json",
#     20
# )

# dataset_n4 = CVFConfigForGCNDataset(
#     device,
#     "graph_4_config_rank_dataset.csv",
#     "graph_4_edge_index.json",
#     color_mapping_categories
# )

# dataset_n5 = CVFConfigForGCNDataset(
#     device,
#     "graph_5_config_rank_dataset.csv",
#     "graph_5_edge_index.json",
#     color_mapping_categories
# )

# dataset_n6 = CVFConfigForGCNDataset(
#     device,
#     "graph_6_config_rank_dataset.csv",
#     "graph_6_edge_index.json",
#     color_mapping_categories,
# )

# dataset_n6b = CVFConfigForGCNDataset(
#     device,
#     "graph_6b_config_rank_dataset.csv",
#     "graph_6b_edge_index.json",
#     color_mapping_categories,
# )

# dataset_n7 = CVFConfigForGCNDataset(
#     device,
#     "graph_7_config_rank_dataset.csv",
#     "graph_7_edge_index.json",
#     color_mapping_categories,
# )

# # dataset_n8 = CVFConfigForGCNDataset(
# #     device,
# #     "graph_8_config_rank_dataset.csv",
# #     "graph_8_edge_index.json",
# #     color_mapping_categories,
# # )

# dataset_pl_n5 = CVFConfigForGCNDataset(
#     device,
#     "graph_powerlaw_cluster_graph_n5_config_rank_dataset.csv",
#     "graph_powerlaw_cluster_graph_n5_edge_index.json",
#     color_mapping_categories
# )

# dataset_pl_n6 = CVFConfigForGCNDataset(
#     device,
#     "graph_powerlaw_cluster_graph_n6_config_rank_dataset.csv",
#     "graph_powerlaw_cluster_graph_n6_edge_index.json",
#     color_mapping_categories
# )

# dataset_pl_n7 = CVFConfigForGCNDataset(
#     device,
#     "graph_powerlaw_cluster_graph_n7_config_rank_dataset.csv",
#     "graph_powerlaw_cluster_graph_n7_edge_index.json",
#     color_mapping_categories
# )

# dataset_pl_n8 = CVFConfigForGCNDataset(
#     device,
#     "graph_powerlaw_cluster_graph_n8_config_rank_dataset.csv",
#     "graph_powerlaw_cluster_graph_n8_edge_index.json",
#     color_mapping_categories
# )

# dataset_pl_n9 = CVFConfigForGCNDataset(
#     device,
#     "graph_powerlaw_cluster_graph_n9_config_rank_dataset.csv",
#     "graph_powerlaw_cluster_graph_n9_edge_index.json",
#     color_mapping_categories,
# )

# dataset_pl_n12 = CVFConfigForGCNDataset(
#     device,
#     "graph_powerlaw_cluster_graph_n12_config_rank_dataset.csv",
#     "graph_powerlaw_cluster_graph_n12_edge_index.json",
#     color_mapping_categories
# )

# dataset_implicit_n15 = CVFConfigForGCNDataset(
#     device,
#     "implicit_graph_n15_config_rank_dataset.csv",
#     "implicit_graph_n15_edge_index.json",
#     color_mapping_categories,
#     one_hot_encode=False
# )

# dataset_implicit_n10 = CVFConfigForGCNDataset(
#     device,
#     "implicit_graph_n10_config_rank_dataset.csv",
#     "implicit_graph_n10_edge_index.json",
#     color_mapping_categories,
#     one_hot_encode=True
# )

# dataset_implicit_n4 = CVFConfigForGCNDataset(
#     device,
#     "implicit_graph_n4_config_rank_dataset.csv",
#     "implicit_graph_n4_edge_index.json",
#     color_mapping_categories,
#     one_hot_encode=True
# )

# dataset_implicit_n5 = CVFConfigForGCNDataset(
#     device,
#     "implicit_graph_n5_config_rank_dataset.csv",
#     "implicit_graph_n5_edge_index.json",
# )

dataset_implicit_n5 = CVFConfigForGCNDataset(
    device,
    "graph_random_regular_graph_n7_d4_config_rank_dataset.csv",
    "graph_random_regular_graph_n7_d4_edge_index.json",
)

batch_size = 64

dataset_coll = [dataset_implicit_n5]
train_dataloader_coll = []
test_dataloader_coll = []

for dataset in dataset_coll:
    train_size = int(0.75 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    train_dataloader_coll.append(train_loader)
    test_dataloader_coll.append(test_loader)

train_dataloader_coll_iter = [iter(i) for i in train_dataloader_coll]

In [5]:
def generate_batch():
    end_loop = [False for _ in range(len(train_dataloader_coll))]
    while not any(end_loop):
        for di, data_loader in enumerate(train_dataloader_coll_iter):
            if end_loop[di]:
                continue
            try:
                batch = next(data_loader)
            except StopIteration:
                end_loop[di] = True
                continue
            yield batch, di

In [6]:
print("Number of batches:", [len(i) for i in train_dataloader_coll])

Number of batches: [916]


In [7]:
class VanillaGNN(torch.nn.Module):
    def __init__(self, dim_in, dim_h, dim_out):
        super().__init__()
        self.gcn1 = GCNConv(dim_in, dim_h)
        self.gcn2 = GCNConv(dim_h, dim_h)
        self.out = torch.nn.Linear(dim_h, dim_out)

    def forward(self, x, edge_index):
        h = self.gcn1(x, edge_index)
        h = torch.relu(h)
        h = self.gcn2(h, edge_index)
        h = torch.relu(h)
        h = self.out(h)
        h = torch.relu(h)
        h = global_mean_pool(h, torch.zeros(h.size(1)).to(device).long())
        return h

    def fit(self, epochs):
        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=0.001)
        # edge_index = dataset.edge_index.t().to(device)
        dataloaders = itertools.tee(generate_batch(), epochs)
        for epoch in range(1, epochs + 1):
            self.train()
            total_loss = 0
            count = 0
            for batch, di in dataloaders[epoch - 1]:
                x = batch[0]
                y = batch[1]
                y = y.unsqueeze(-1)
                optimizer.zero_grad()
                out = self(x, dataset_coll[di].edge_index)
                # print(out.shape, y.shape)
                loss = criterion(out, y)
                total_loss += loss
                count += 1
                loss.backward()
                optimizer.step()

            if count > 0:
                print("Epoch", epoch, "| Loss:", (total_loss / count).item())

In [8]:
gnn = VanillaGNN(1, 64, 1).to(device)
print(gnn)

gnn.fit(epochs=25)

VanillaGNN(
  (gcn1): GCNConv(1, 64)
  (gcn2): GCNConv(64, 64)
  (out): Linear(in_features=64, out_features=1, bias=True)
)
Epoch 1 | Loss: 1.157930850982666
Epoch 2 | Loss: 0.990382730960846
Epoch 3 | Loss: 0.9840494394302368
Epoch 4 | Loss: 0.9810714721679688
Epoch 5 | Loss: 0.9804043173789978
Epoch 6 | Loss: 0.9799413084983826
Epoch 7 | Loss: 0.9779075384140015
Epoch 8 | Loss: 0.9770877957344055
Epoch 9 | Loss: 0.9765182137489319
Epoch 10 | Loss: 0.9755871891975403
Epoch 11 | Loss: 0.97500079870224
Epoch 12 | Loss: 0.9743167161941528
Epoch 13 | Loss: 0.9737181067466736
Epoch 14 | Loss: 0.9735273718833923
Epoch 15 | Loss: 0.9733970165252686
Epoch 16 | Loss: 0.9733732342720032
Epoch 17 | Loss: 0.9733194708824158
Epoch 18 | Loss: 0.9729306101799011
Epoch 19 | Loss: 0.9728265404701233
Epoch 20 | Loss: 0.9725192785263062
Epoch 21 | Loss: 0.9724025726318359
Epoch 22 | Loss: 0.972720742225647
Epoch 23 | Loss: 0.9726395606994629
Epoch 24 | Loss: 0.9728273153305054
Epoch 25 | Loss: 0.9721915

In [11]:
# testing
import csv

torch.no_grad()
torch.set_printoptions(profile="full")

f = open("test_result.csv", "w", newline="")
csv_writer = csv.writer(f)
csv_writer.writerow(["Actual", "Predicted"])

criterion = torch.nn.MSELoss()
total_loss = 0
total_matched = 0

indx = 0
dataset = dataset_coll[indx]
test_loader = test_dataloader_coll[indx]
test_dataset = test_loader.dataset

count = 0
for batch in test_loader:
    x = batch[0]
    y = batch[1]
    y = y.unsqueeze(-1)
    out = gnn(x, dataset.edge_index)
    csv_writer.writerows(zip(y.detach().cpu().numpy(), out.detach().cpu().numpy()))
    loss = criterion(out, y)
    total_loss += loss
    out = torch.round(out)
    matched = (out == y).sum().item()
    total_matched += matched
    count += 1

print(
    "Test loss:",
    total_loss.detach() / count,
    "Total matched",
    total_matched,
    "out of",
    len(test_dataset),
    f"({round(total_matched/len(test_dataset) * 100, 2)}%)",
)

Test loss: tensor(0.9425, device='cuda:0') Total matched 7550 out of 19532 (38.65%)
