In [1]:
import itertools

import torch

from torch_geometric.nn import GCNConv
from torch_geometric.nn.pool import global_mean_pool
from torch.utils.data import DataLoader, random_split

In [2]:
from helpers import CVFConfigForGCNDataset

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
dataset_graph_6 = CVFConfigForGCNDataset(
    device,
    "graph_6_config_rank_dataset.csv",
    "graph_6_edge_index.json",
)

dataset_graph_7 = CVFConfigForGCNDataset(
    device,
    "graph_7_config_rank_dataset.csv",
    "graph_7_edge_index.json",
)

dataset_graph_8 = CVFConfigForGCNDataset(
    device,
    "graph_8_config_rank_dataset.csv",
    "graph_8_edge_index.json",
)

dataset_graph_10 = CVFConfigForGCNDataset(
    device,
    "graph_10_config_rank_dataset.csv",
    "graph_10_edge_index.json",
)

dataset_rr_n7 = CVFConfigForGCNDataset(
    device,
    "graph_random_regular_graph_n7_d4_config_rank_dataset.csv",
    "graph_random_regular_graph_n7_d4_edge_index.json",
)

dataset_plc_n8 = CVFConfigForGCNDataset(
    device,
    "graph_powerlaw_cluster_graph_n8_config_rank_dataset.csv",
    "graph_powerlaw_cluster_graph_n8_edge_index.json",
)

batch_size = 64

dataset_coll = [
    dataset_graph_6,
    dataset_graph_7,
    dataset_graph_8,
    dataset_graph_10,
    dataset_rr_n7,
    dataset_plc_n8,
]

In [5]:
train_dataloader_coll = []
test_dataloader_coll = []

for dataset in dataset_coll:
    train_size = int(0.95 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    train_dataloader_coll.append(train_loader)
    test_dataloader_coll.append(test_loader)

train_dataloader_coll_iter = [iter(i) for i in train_dataloader_coll]

In [6]:
def generate_batch():
    end_loop = [False for _ in range(len(train_dataloader_coll))]
    while not any(end_loop):
        for di, data_loader in enumerate(train_dataloader_coll_iter):
            if end_loop[di]:
                continue
            try:
                batch = next(data_loader)
            except StopIteration:
                end_loop[di] = True
                continue
            yield batch, di

In [7]:
print("Number of batches:", [len(i) for i in train_dataloader_coll])

Number of batches: [72, 171, 266, 475, 1160, 2128]


In [None]:
class VanillaGNN(torch.nn.Module):
    def __init__(self, dim_in, dim_h, dim_out):
        super().__init__()
        self.gcn1 = GCNConv(dim_in, dim_h)
        self.gcn2 = GCNConv(dim_h, dim_h)
        self.out = torch.nn.Linear(dim_h, dim_out)

    def forward(self, x, edge_index):
        h = self.gcn1(x, edge_index)
        print("h", h.shape)
        h = torch.relu(h)
        h = self.gcn2(h, edge_index)
        h = torch.relu(h)
        h = self.out(h)
        h = torch.relu(h)
        h = global_mean_pool(h, torch.zeros(h.size(1)).to(device).long())
        return h

    def fit(self, epochs):
        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=0.0001)
        # edge_index = dataset.edge_index.t().to(device)
        dataloaders = itertools.tee(generate_batch(), epochs)
        for epoch in range(1, epochs + 1):
            self.train()
            total_loss = 0
            count = 0
            for batch, di in dataloaders[epoch - 1]:
                x = batch[0]
                y = batch[1]
                y = y.unsqueeze(-1)
                optimizer.zero_grad()
                out = self(x, dataset_coll[di].edge_index)
                # print(out.shape, y.shape)
                loss = criterion(out, y)
                total_loss += loss
                count += 1
                loss.backward()
                optimizer.step()

            if count > 0:
                print(
                    "Training set | Epoch",
                    epoch,
                    "| Loss:",
                    (total_loss / count).item(),
                )

In [9]:
gnn = VanillaGNN(1, 64, 1).to(device)
print(gnn)

gnn.fit(epochs=200)

VanillaGNN(
  (gcn1): GCNConv(1, 64)
  (gcn2): GCNConv(64, 64)
  (out): Linear(in_features=64, out_features=1, bias=True)
)
x torch.Size([64, 9, 1])
h torch.Size([64, 9, 64])
x torch.Size([64, 10, 1])
h torch.Size([64, 10, 64])
x torch.Size([64, 10, 1])
h torch.Size([64, 10, 64])
x torch.Size([64, 11, 1])
h torch.Size([64, 11, 64])
x torch.Size([64, 7, 1])
h torch.Size([64, 7, 64])
x torch.Size([64, 8, 1])
h torch.Size([64, 8, 64])
x torch.Size([64, 9, 1])
h torch.Size([64, 9, 64])
x torch.Size([64, 10, 1])
h torch.Size([64, 10, 64])
x torch.Size([64, 10, 1])
h torch.Size([64, 10, 64])
x torch.Size([64, 11, 1])
h torch.Size([64, 11, 64])
x torch.Size([64, 7, 1])
h torch.Size([64, 7, 64])
x torch.Size([64, 8, 1])
h torch.Size([64, 8, 64])
x torch.Size([64, 9, 1])
h torch.Size([64, 9, 64])
x torch.Size([64, 10, 1])
h torch.Size([64, 10, 64])
x torch.Size([64, 10, 1])
h torch.Size([64, 10, 64])
x torch.Size([64, 11, 1])
h torch.Size([64, 11, 64])
x torch.Size([64, 7, 1])
h torch.Size([64,

KeyboardInterrupt: 

In [None]:
# testing
import csv

torch.no_grad()
torch.set_printoptions(profile="full")

f = open("test_result.csv", "w", newline="")
csv_writer = csv.writer(f)
csv_writer.writerow(["Actual", "Predicted"])

criterion = torch.nn.MSELoss()

for indx in range(len(dataset_coll)):
    # indx = 0
    total_loss = 0
    total_matched = 0
    dataset = dataset_coll[indx]
    test_loader = test_dataloader_coll[indx]
    test_dataset = test_loader.dataset

    count = 0
    for batch in test_loader:
        x = batch[0]
        y = batch[1]
        y = y.unsqueeze(-1)
        out = gnn(x, dataset.edge_index)
        csv_writer.writerows(zip(y.detach().cpu().numpy(), out.detach().cpu().numpy()))
        loss = criterion(out, y)
        # print("Loss: ", loss)
        total_loss += loss
        out = torch.round(out)
        matched = (out == y).sum().item()
        total_matched += matched
        count += 1

    print(
        "Indx",
        indx,
        "Test loss:",
        total_loss.detach() / count,
        "Total matched",
        total_matched,
        "out of",
        len(test_dataset),
        f"({round(total_matched/len(test_dataset) * 100, 2)}%)",
    )

f.close()

Indx 0 Test loss: tensor(0.5431, device='cuda:0') Total matched 130 out of 240 (54.17%)
Indx 1 Test loss: tensor(0.3911, device='cuda:0') Total matched 350 out of 576 (60.76%)
Indx 2 Test loss: tensor(0.5350, device='cuda:0') Total matched 436 out of 896 (48.66%)
Indx 3 Test loss: tensor(0.4347, device='cuda:0') Total matched 923 out of 1600 (57.69%)
Indx 4 Test loss: tensor(0.9890, device='cuda:0') Total matched 1501 out of 3907 (38.42%)
Indx 5 Test loss: tensor(1.2696, device='cuda:0') Total matched 2458 out of 7168 (34.29%)
