In [1]:
import random

import torch


from torch_geometric.nn import GCNConv
from torch_geometric.nn.pool import global_mean_pool
from torch.utils.data import ConcatDataset, Sampler, DataLoader, random_split

In [2]:
from helpers import CVFConfigForGCNGridSearchDataset

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
dataset_rr_n4 = CVFConfigForGCNGridSearchDataset(
    device,
    "graph_random_regular_graph_n4_d3_config_rank_dataset.csv",
    "graph_random_regular_graph_n4_d3_edge_index.json",
)

dataset_rr_n5 = CVFConfigForGCNGridSearchDataset(
    device,
    "graph_random_regular_graph_n5_d4_config_rank_dataset.csv",
    "graph_random_regular_graph_n5_d4_edge_index.json",
)

dataset_rr_n6 = CVFConfigForGCNGridSearchDataset(
    device,
    "graph_random_regular_graph_n6_d3_config_rank_dataset.csv",
    "graph_random_regular_graph_n6_d3_edge_index.json",
)

dataset_rr_n7 = CVFConfigForGCNGridSearchDataset(
    device,
    "graph_random_regular_graph_n7_d4_config_rank_dataset.csv",
    "graph_random_regular_graph_n7_d4_edge_index.json",
)

dataset_rr_n8 = CVFConfigForGCNGridSearchDataset(
    device,
    "graph_random_regular_graph_n8_d4_config_rank_dataset.csv",
    "graph_random_regular_graph_n8_d4_edge_index.json",
)

batch_size = 64

dataset_coll = [
    dataset_rr_n4,
    dataset_rr_n5,
    dataset_rr_n6,
    dataset_rr_n7,
    dataset_rr_n8,
]

In [5]:
class CustomBatchSampler(Sampler):
    def __init__(self, datasets: ConcatDataset, batch_size: int):
        self.datasets = datasets
        self.batch_size = batch_size

    @property
    def module(self):
        return self._module

    @module.setter
    def module(self, val):
        self._module = val

    def __iter__(self):
        last_accessed = [0] + self.datasets.cumulative_sizes[:]
        end_loop = [False for _ in range(len(self.datasets.datasets))]

        while not all(end_loop):
            for turn in range(len(self.datasets.datasets)):
                if end_loop[turn]:
                    continue

                batch_size = self.batch_size
                if (
                    last_accessed[turn] + batch_size
                    >= self.datasets.cumulative_sizes[turn]
                ):
                    batch_size = (
                        self.datasets.cumulative_sizes[turn] - last_accessed[turn]
                    )
                    end_loop[turn] = True

                # currently explicitly setting edge index before yielding
                # TODO: find a better way to do it
                self.module.edge_index = self.datasets.datasets[turn].edge_index

                yield list(range(last_accessed[turn], last_accessed[turn] + batch_size))

                last_accessed[turn] += batch_size

In [6]:

datasets = ConcatDataset(dataset_coll)
batch_sampler = CustomBatchSampler(datasets, batch_size=batch_size)
dataloader = DataLoader(datasets, batch_sampler=batch_sampler)

In [7]:
len(datasets)

476227

In [8]:
batch_sampler.datasets.cumulative_sizes

[256, 3381, 7477, 85602, 476227]

In [9]:
# for i, batch in enumerate(dataloaders):
#     print(i, len(batch[0]))
#     break

In [None]:
class VanillaGNN(torch.nn.Module):
    def __init__(self, dim_in, dim_h, dim_out):
        super().__init__()
        self.gcn1 = GCNConv(dim_in, dim_h)
        self.gcn2 = GCNConv(dim_h, dim_h)
        self.out = torch.nn.Linear(dim_h, dim_out)

    @property
    def edge_index(self):
        return self._edge_index

    @edge_index.setter
    def edge_index(self, val):
        self._edge_index = val

    def forward(self, x):
        h = self.gcn1(x, self.edge_index)
        h = torch.relu(h)
        h = self.gcn2(h, self.edge_index)
        h = torch.relu(h)
        h = self.out(h)
        h = torch.relu(h)
        h = global_mean_pool(h, torch.zeros(h.size(1)).to(device).long())
        return h

    def fit(self, epochs):
        dataloader.batch_sampler.module = self
        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=0.001)
        for epoch in range(1, epochs + 1):
            self.train()
            total_loss = 0
            count = 0
            for batch in dataloader:
                x = batch[0]
                y = batch[1]
                optimizer.zero_grad()
                out = self(x)
                loss = criterion(out, y)
                total_loss += loss
                count += 1
                loss.backward()
                optimizer.step()

            print(
                "Training set | Epoch",
                epoch,
                "| Loss:",
                round((total_loss / count).item(), 4),
            )

In [None]:
gnn = VanillaGNN(1, 64, 1).to(device)
print(gnn)

gnn.fit(epochs=10)

VanillaGNN(
  (gcn1): GCNConv(1, 64)
  (gcn2): GCNConv(64, 64)
  (out): Linear(in_features=64, out_features=1, bias=True)
)
Training set | Epoch 1 | Loss: 1.3497
Training set | Epoch 2 | Loss: 1.2761
Training set | Epoch 3 | Loss: 1.2684
Training set | Epoch 4 | Loss: 1.2581
Training set | Epoch 5 | Loss: 1.2578


In [None]:
# # testing
# import csv

# torch.no_grad()
# torch.set_printoptions(profile="full")

# f = open("test_result.csv", "w", newline="")
# csv_writer = csv.writer(f)
# csv_writer.writerow(["Actual", "Predicted"])

# criterion = torch.nn.MSELoss()
# total_loss = 0
# total_matched = 0

# indx = 0
# dataset = dataset_coll[indx]
# test_loader = test_dataloader_coll[indx]
# test_dataset = test_loader.dataset

# count = 0
# for batch in test_loader:
# 	x = batch[0]
# 	y = batch[1]
# 	# out = gnn(x, dataset.edge_index)
# 	out = gnn(x)
# 	# print(y.shape, out.shape)
# 	csv_writer.writerows(zip(y.detach().cpu().numpy(), out.detach().cpu().numpy()))
# 	loss = criterion(out, y)
# 	total_loss += loss
# 	out = torch.round(out)
# 	matched = (out == y).sum().item()
# 	total_matched += matched
# 	count += 1

# print(
# 	"Test loss:",
# 	total_loss.detach() / count,
# 	"Total matched",
# 	total_matched,
# 	"out of",
# 	len(test_dataset),
# 	f"({round(total_matched/len(test_dataset) * 100, 2)}%)",
# )

In [None]:
# from skorch import NeuralNet

In [None]:
# class CustomNeuralNet(NeuralNet):
#     def fit_loop(self, X, y=None, epochs=None, **fit_params):
#         """The proper fit loop.

#         Contains the logic of what actually happens during the fit
#         loop.

#         Parameters
#         ----------
#         X : input data, compatible with skorch.dataset.Dataset
#           By default, you should be able to pass:

#             * numpy arrays
#             * torch tensors
#             * pandas DataFrame or Series
#             * scipy sparse CSR matrices
#             * a dictionary of the former three
#             * a list/tuple of the former three
#             * a Dataset

#           If this doesn't work with your data, you have to pass a
#           ``Dataset`` that can deal with the data.

#         y : target data, compatible with skorch.dataset.Dataset
#           The same data types as for ``X`` are supported. If your X is
#           a Dataset that contains the target, ``y`` may be set to
#           None.

#         epochs : int or None (default=None)
#           If int, train for this number of epochs; if None, use
#           ``self.max_epochs``.

#         **fit_params : dict
#           Additional parameters passed to the ``forward`` method of
#           the module and to the ``self.train_split`` call.

#         """
#         self.check_data(X, y)
#         self.check_training_readiness()
#         epochs = epochs if epochs is not None else self.max_epochs

#         dataset_train, dataset_valid = self.get_split_datasets(
#             X, y, **fit_params)
#         on_epoch_kwargs = {
#             'dataset_train': dataset_train,
#             'dataset_valid': dataset_valid,
#         }
#         iterator_train = self.get_iterator(dataset_train, training=True)
#         # iterator_train = generate_batch()
#         print(iterator_train.__dict__)
#         iterator_valid = None
#         if dataset_valid is not None:
#             iterator_valid = self.get_iterator(dataset_valid, training=False)

#         for _ in range(epochs):
#             self.notify('on_epoch_begin', **on_epoch_kwargs)

#             self.run_single_epoch(iterator_train, training=True, prefix="train",
#                                   step_fn=self.train_step, **fit_params)

#             self.run_single_epoch(iterator_valid, training=False, prefix="valid",
#                                   step_fn=self.validation_step, **fit_params)

#             self.notify("on_epoch_end", **on_epoch_kwargs)
#         return self

In [None]:
# dataset = dataset_rr_n4

In [None]:
# net = CustomNeuralNet(
# 	VanillaGNN,
# 	# train_split=None,
# 	device=device,
# 	lr=0.01,
# 	batch_size=64,
# 	max_epochs=5,
# 	criterion=torch.nn.MSELoss,
# 	optimizer=torch.optim.Adam,
# 	optimizer__weight_decay=0.01,
# 	module__dim_in=1,
# 	module__dim_h=64,
# 	module__dim_out=1,
# 	module__edge_index=dataset.edge_index,
# )

In [None]:
# net.fit(dataset, y=None)

In [None]:
# net.history