In [1]:
import torch


from torch_geometric.nn import GCNConv
from torch_geometric.nn.pool import global_mean_pool
from torch.utils.data import ConcatDataset, Sampler, DataLoader, random_split

In [2]:
from helpers import CVFConfigForGCNWSuccDataset

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
dataset_tiny_test = CVFConfigForGCNWSuccDataset(
    device,
    "tiny_graph_test_config_rank_dataset.csv",
    "tiny_graph_edge_index.json",
)

dataset_graph_1 = CVFConfigForGCNWSuccDataset(
    device,
    "graph_1_config_rank_dataset.csv",
    "graph_1_edge_index.json",
)

dataset_rr_n8 = CVFConfigForGCNWSuccDataset(
    device,
    "graph_random_regular_graph_n8_d4_config_rank_dataset.csv",
    "graph_random_regular_graph_n8_d4_edge_index.json",
)


batch_size = 10240

dataset_coll = [dataset_rr_n8]

In [5]:
train_sizes = [int(0.9 * len(ds)) for ds in dataset_coll]
test_sizes = [len(ds) - trs for ds, trs in zip(dataset_coll, train_sizes)]

train_test_datasets = [
    random_split(ds, [tr_s, ts])
    for ds, tr_s, ts in zip(dataset_coll, train_sizes, test_sizes)
]

train_datasets = [ds[0] for ds in train_test_datasets]
test_datasets = [ds[1] for ds in train_test_datasets]

In [6]:
datasets = ConcatDataset(train_datasets)
print(len(datasets))

351562


In [7]:
class CustomBatchSampler(Sampler):
    def __init__(self, datasets: ConcatDataset, batch_size: int):
        self.datasets = datasets
        self.batch_size = batch_size

    @property
    def module(self):
        return self._module

    @module.setter
    def module(self, val):
        self._module = val

    def __iter__(self):
        last_accessed = [0] + self.datasets.cumulative_sizes[:]
        end_loop = [False for _ in range(len(self.datasets.datasets))]

        while not all(end_loop):
            for turn in range(len(self.datasets.datasets)):
                if end_loop[turn]:
                    continue

                batch_size = self.batch_size
                if (
                    last_accessed[turn] + batch_size
                    >= self.datasets.cumulative_sizes[turn]
                ):
                    batch_size = (
                        self.datasets.cumulative_sizes[turn] - last_accessed[turn]
                    )
                    end_loop[turn] = True

                # currently explicitly setting edge index before yielding
                # TODO: find a better way to do it
                self.module.edge_index = self.datasets.datasets[turn].dataset.edge_index

                yield list(range(last_accessed[turn], last_accessed[turn] + batch_size))

                last_accessed[turn] += batch_size

In [8]:
batch_sampler = CustomBatchSampler(datasets, batch_size=batch_size)
dataloader = DataLoader(datasets, batch_sampler=batch_sampler)

In [9]:
class VanillaGNN(torch.nn.Module):
    def __init__(self, dim_in, dim_h, dim_out):
        super().__init__()
        self.gcn1 = GCNConv(dim_in, dim_h)
        self.gcn2 = GCNConv(dim_h, dim_h)
        self.out = torch.nn.Linear(dim_h, dim_out)

    @property
    def edge_index(self):
        return self._edge_index

    @edge_index.setter
    def edge_index(self, val):
        self._edge_index = val

    def forward(self, x):
        h = self.gcn1(x, self.edge_index)
        h = torch.relu(h)
        h = self.gcn2(h, self.edge_index)
        h = torch.relu(h)
        h = self.out(h)
        h = torch.relu(h)
        h = global_mean_pool(h, torch.zeros(h.size(1)).to(device).long())
        return h

    def fit(self, epochs):
        dataloader.batch_sampler.module = self
        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=0.01)
        for epoch in range(1, epochs + 1):
            self.train()
            total_loss = 0
            count = 0
            for batch in dataloader:
                x = batch[0]
                y = batch[1]
                y = y.unsqueeze(-1)
                optimizer.zero_grad()
                out = self(x)
                # print(out, y)
                loss = criterion(out, y)
                total_loss += loss
                count += 1
                loss.backward()
                optimizer.step()

            print(
                "Training set | Epoch",
                epoch,
                "| Loss:",
                round((total_loss / count).item(), 4),
            )

In [None]:
gnn = VanillaGNN(2, 64, 1).to(device)
print(gnn)

gnn.fit(epochs=50)

VanillaGNN(
  (gcn1): GCNConv(2, 64)
  (gcn2): GCNConv(64, 64)
  (out): Linear(in_features=64, out_features=1, bias=True)
)
Training set | Epoch 1 | Loss: 2.1679
Training set | Epoch 2 | Loss: 1.3062


In [None]:
# testing
import csv

torch.no_grad()
torch.set_printoptions(profile="full")

f = open("test_result.csv", "w", newline="")
csv_writer = csv.writer(f)
csv_writer.writerow(["Actual", "Predicted"])

criterion = torch.nn.MSELoss()
total_loss = 0
total_matched = 0

test_concat_datasets = ConcatDataset(test_datasets)
batch_sampler = CustomBatchSampler(test_concat_datasets, batch_size=batch_size)
test_dataloader = DataLoader(test_concat_datasets, batch_sampler=batch_sampler)

test_dataloader.batch_sampler.module = gnn

count = 0
for batch in test_dataloader:
    x = batch[0]
    y = batch[1]
    y = y.unsqueeze(-1)
    out = gnn(x)
    csv_writer.writerows(zip(y.detach().cpu().numpy(), out.detach().cpu().numpy()))
    loss = criterion(out, y)
    total_loss += loss
    out = torch.round(out)
    matched = (out == y).sum().item()
    total_matched += matched
    count += 1

f.close()

print(
    "Test loss:",
    total_loss.detach() / count,
    "Total matched",
    total_matched,
    "out of",
    len(test_concat_datasets),
    f"({round(total_matched/len(test_concat_datasets) * 100, 2)}%)",
)

Test loss: tensor(0.0767, device='cuda:0') Total matched 467 out of 512 (91.21%)
