In [None]:
import torch


from torch_geometric.nn import GCNConv
from torch_geometric.nn.pool import global_mean_pool
from torch.utils.data import ConcatDataset, Sampler, DataLoader, random_split

In [2]:
from helpers import CVFConfigForGCNGridSearchDataset

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
dataset_rr_n4 = CVFConfigForGCNGridSearchDataset(
    device,
    "graph_random_regular_graph_n4_d3_config_rank_dataset.csv",
    "graph_random_regular_graph_n4_d3_edge_index.json",
)

dataset_rr_n5 = CVFConfigForGCNGridSearchDataset(
    device,
    "graph_random_regular_graph_n5_d4_config_rank_dataset.csv",
    "graph_random_regular_graph_n5_d4_edge_index.json",
)

dataset_rr_n6 = CVFConfigForGCNGridSearchDataset(
    device,
    "graph_random_regular_graph_n6_d3_config_rank_dataset.csv",
    "graph_random_regular_graph_n6_d3_edge_index.json",
)

dataset_rr_n7 = CVFConfigForGCNGridSearchDataset(
    device,
    "graph_random_regular_graph_n7_d4_config_rank_dataset.csv",
    "graph_random_regular_graph_n7_d4_edge_index.json",
)

dataset_rr_n8 = CVFConfigForGCNGridSearchDataset(
    device,
    "graph_random_regular_graph_n8_d4_config_rank_dataset.csv",
    "graph_random_regular_graph_n8_d4_edge_index.json",
)

batch_size = 64

dataset_coll = [
    dataset_rr_n4,
    dataset_rr_n5,
    dataset_rr_n6,
    dataset_rr_n7,
    dataset_rr_n8,
]

In [5]:
# train_size = int(0.9 * len(datasets))
# test_size = len(datasets) - train_size
# train_dataset, test_dataset = random_split(datasets, [train_size, test_size])

train_sizes = [int(0.9 * len(ds)) for ds in dataset_coll]
test_sizes = [len(ds) - trs for ds, trs in zip(dataset_coll, train_sizes)]

train_test_datasets = [
    random_split(ds, [tr_s, ts])
    for ds, tr_s, ts in zip(dataset_coll, train_sizes, test_sizes)
]

train_datasets = [ds[0] for ds in train_test_datasets]
test_datasets = [ds[1] for ds in train_test_datasets]

In [6]:
datasets = ConcatDataset(train_datasets)
print(len(datasets))

428602


In [7]:
class CustomBatchSampler(Sampler):
    def __init__(self, datasets: ConcatDataset, batch_size: int):
        self.datasets = datasets
        self.batch_size = batch_size

    @property
    def module(self):
        return self._module

    @module.setter
    def module(self, val):
        self._module = val

    # def __len__(self):
    #     return len(self.datasets)

    def __iter__(self):
        last_accessed = [0] + self.datasets.cumulative_sizes[:]
        end_loop = [False for _ in range(len(self.datasets.datasets))]

        while not all(end_loop):
            for turn in range(len(self.datasets.datasets)):
                if end_loop[turn]:
                    continue

                batch_size = self.batch_size
                if (
                    last_accessed[turn] + batch_size
                    >= self.datasets.cumulative_sizes[turn]
                ):
                    batch_size = (
                        self.datasets.cumulative_sizes[turn] - last_accessed[turn]
                    )
                    end_loop[turn] = True

                # currently explicitly setting edge index before yielding
                # TODO: find a better way to do it
                self.module.edge_index = self.datasets.datasets[turn].dataset.edge_index

                yield list(range(last_accessed[turn], last_accessed[turn] + batch_size))

                last_accessed[turn] += batch_size

In [8]:
batch_sampler = CustomBatchSampler(datasets, batch_size=batch_size)
dataloader = DataLoader(datasets, batch_sampler=batch_sampler)

In [9]:
batch_sampler.datasets.cumulative_sizes

[230, 3042, 6728, 77040, 428602]

In [10]:
# for i, batch in enumerate(dataloaders):
#     print(i, len(batch[0]))
#     break

In [11]:
class VanillaGNN(torch.nn.Module):
    def __init__(self, dim_in, dim_h, dim_out):
        super().__init__()
        self.gcn1 = GCNConv(dim_in, dim_h)
        self.gcn2 = GCNConv(dim_h, dim_h)
        self.out = torch.nn.Linear(dim_h, dim_out)

    @property
    def edge_index(self):
        return self._edge_index

    @edge_index.setter
    def edge_index(self, val):
        self._edge_index = val

    def forward(self, x):
        h = self.gcn1(x, self.edge_index)
        h = torch.relu(h)
        h = self.gcn2(h, self.edge_index)
        h = torch.relu(h)
        h = self.out(h)
        h = torch.relu(h)
        h = global_mean_pool(h, torch.zeros(h.size(1)).to(device).long())
        return h

    def fit(self, epochs):
        dataloader.batch_sampler.module = self
        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=0.001)
        for epoch in range(1, epochs + 1):
            self.train()
            total_loss = 0
            count = 0
            for batch in dataloader:
                x = batch[0]
                y = batch[1]
                optimizer.zero_grad()
                out = self(x)
                loss = criterion(out, y)
                total_loss += loss
                count += 1
                loss.backward()
                optimizer.step()

            print(
                "Training set | Epoch",
                epoch,
                "| Loss:",
                round((total_loss / count).item(), 4),
            )

In [12]:
# gnn = VanillaGNN(1, 64, 1).to(device)
# print(gnn)

# gnn.fit(epochs=10)

In [13]:
# testing
# import csv

# torch.no_grad()
# torch.set_printoptions(profile="full")

# f = open("test_result.csv", "w", newline="")
# csv_writer = csv.writer(f)
# csv_writer.writerow(["Actual", "Predicted"])

# criterion = torch.nn.MSELoss()
# total_loss = 0
# total_matched = 0

test_concat_datasets = ConcatDataset(test_datasets)
batch_sampler = CustomBatchSampler(test_concat_datasets, batch_size=batch_size)
test_dataloader = DataLoader(test_concat_datasets, batch_sampler=batch_sampler)

# test_dataloader.batch_sampler.module = gnn

# count = 0
# for batch in test_dataloader:
# 	x = batch[0]
# 	y = batch[1]
# 	out = gnn(x)
# 	# print(y.shape, out.shape)
# 	csv_writer.writerows(zip(y.detach().cpu().numpy(), out.detach().cpu().numpy()))
# 	loss = criterion(out, y)
# 	total_loss += loss
# 	out = torch.round(out)
# 	matched = (out == y).sum().item()
# 	total_matched += matched
# 	count += 1

# print(
# 	"Test loss:",
# 	total_loss.detach() / count,
# 	"Total matched",
# 	total_matched,
# 	"out of",
# 	len(test_concat_datasets),
# 	f"({round(total_matched/len(test_concat_datasets) * 100, 2)}%)",
# )

In [14]:
def get_test_loss(model):
    torch.no_grad()

    criterion = torch.nn.MSELoss()
    # test_dataloader.batch_sampler.module = gnn

    count = 0
    total_loss = 0
    for batch in test_dataloader:
        x = batch[0]
        y = batch[1]
        out = torch.FloatTensor(model.predict(x)).to(device)
        loss = criterion(out, y)
        total_loss += loss
        # out = torch.round(out)
        # matched = (out == y).sum().item()
        # total_matched += matched
        count += 1

    loss = total_loss / count
    return loss

In [15]:
from skorch import NeuralNet

In [16]:
class CustomNeuralNet(NeuralNet):
    def fit_loop(self, X, y=None, epochs=None, **fit_params):
        # super().fit_loop()
        """The proper fit loop.

        Contains the logic of what actually happens during the fit
        loop.

        Parameters
        ----------
        X : input data, compatible with skorch.dataset.Dataset
          By default, you should be able to pass:

            * numpy arrays
            * torch tensors
            * pandas DataFrame or Series
            * scipy sparse CSR matrices
            * a dictionary of the former three
            * a list/tuple of the former three
            * a Dataset

          If this doesn't work with your data, you have to pass a
          ``Dataset`` that can deal with the data.

        y : target data, compatible with skorch.dataset.Dataset
          The same data types as for ``X`` are supported. If your X is
          a Dataset that contains the target, ``y`` may be set to
          None.

        epochs : int or None (default=None)
          If int, train for this number of epochs; if None, use
          ``self.max_epochs``.

        **fit_params : dict
          Additional parameters passed to the ``forward`` method of
          the module and to the ``self.train_split`` call.

        """
        self.check_data(X, y)
        self.check_training_readiness()
        epochs = epochs if epochs is not None else self.max_epochs

        on_epoch_kwargs = {
            "dataset_train": datasets,
            "dataset_valid": test_concat_datasets,
        }
        dataloader.batch_sampler.module = self.module_
        iterator_train = dataloader
        test_dataloader.batch_sampler.module = self.module_
        iterator_valid = test_dataloader

        for _ in range(epochs):
            self.notify("on_epoch_begin", **on_epoch_kwargs)

            self.run_single_epoch(
                iterator_train,
                training=True,
                prefix="train",
                step_fn=self.train_step,
                **fit_params
            )

            self.run_single_epoch(
                iterator_valid,
                training=False,
                prefix="valid",
                step_fn=self.validation_step,
                **fit_params
            )

            self.notify("on_epoch_end", **on_epoch_kwargs)

        return self

In [17]:
# dataset = dataset_rr_n4

In [18]:
# net = CustomNeuralNet(
#     VanillaGNN,
#     train_split=None,
#     device=device,
#     lr=0.01,
#     batch_size=32,
#     max_epochs=10,
#     criterion=torch.nn.MSELoss,
#     optimizer=torch.optim.Adam,
#     optimizer__weight_decay=0.01,
#     module__dim_in=1,
#     module__dim_h=32,
#     module__dim_out=1,
# )

In [19]:
# net.fit(datasets, y=None)

In [20]:
# net.history

In [21]:
# net.history.to_file('history')

In [22]:
# from sklearn.model_selection import GridSearchCV

# params = {
#     "lr": [0.01],
#     "max_epochs": [5, 10],
#     "module__dim_in": [1],
#     "module__dim_h": [32],
#     "module__dim_out": [1],
# }

# gs = GridSearchCV(net, params, cv=3, scoring='neg_mean_squared_error')

# gs.fit(datasets, y=None)

# gs.best_params_

In [23]:
# get_test_loss(net)

In [24]:
params = {
    "lr": [0.01],
    "batch_size": [32],
    "max_epochs": [10],
    "optimizer": [torch.optim.SGD, torch.optim.Adam],
    "module__dim_h": [16, 32, 64],
    "optimizer__weight_decay": [0.01],
}

In [None]:
import itertools

param_combinations = list(itertools.product(*params.values()))

param_combinations_dict = [dict(zip(params.keys(), combination)) for combination in param_combinations]

# param_combinations_dict

[{'lr': 0.01,
  'batch_size': 32,
  'max_epochs': 10,
  'optimizer': torch.optim.sgd.SGD,
  'module__dim_h': 16,
  'optimizer__weight_decay': 0.01},
 {'lr': 0.01,
  'batch_size': 32,
  'max_epochs': 10,
  'optimizer': torch.optim.sgd.SGD,
  'module__dim_h': 32,
  'optimizer__weight_decay': 0.01},
 {'lr': 0.01,
  'batch_size': 32,
  'max_epochs': 10,
  'optimizer': torch.optim.sgd.SGD,
  'module__dim_h': 64,
  'optimizer__weight_decay': 0.01},
 {'lr': 0.01,
  'batch_size': 32,
  'max_epochs': 10,
  'optimizer': torch.optim.adam.Adam,
  'module__dim_h': 16,
  'optimizer__weight_decay': 0.01},
 {'lr': 0.01,
  'batch_size': 32,
  'max_epochs': 10,
  'optimizer': torch.optim.adam.Adam,
  'module__dim_h': 32,
  'optimizer__weight_decay': 0.01},
 {'lr': 0.01,
  'batch_size': 32,
  'max_epochs': 10,
  'optimizer': torch.optim.adam.Adam,
  'module__dim_h': 64,
  'optimizer__weight_decay': 0.01}]

In [None]:
avg_for = 1
for params in param_combinations_dict:
    avg_loss = 0.0
    for i in range(avg_for):
        net = CustomNeuralNet(
            VanillaGNN,
            train_split=None,
            device=device,
            criterion=torch.nn.MSELoss,
            # optimizer=torch.optim.Adam,
            # optimizer__weight_decay=0.01,
            module__dim_in=1,
            # module__dim_h=32,
            module__dim_out=1,
            **params
        )
        net.fit(datasets, y=None)
        avg_loss += get_test_loss(net)
    print(params, avg_loss / avg_for)

  epoch    train_loss    valid_loss      dur
-------  ------------  ------------  -------
      1        [36m1.2959[0m        [32m1.2630[0m  29.8525
      2        [36m1.2121[0m        [32m1.2458[0m  29.5186
      3        [36m1.2078[0m        [32m1.2412[0m  29.4248
      4        [36m1.2067[0m        [32m1.2397[0m  29.4076
      5        [36m1.2061[0m        [32m1.2387[0m  29.3968
      6        [36m1.2057[0m        [32m1.2379[0m  29.4178
      7        [36m1.2054[0m        [32m1.2364[0m  29.4383
      8        [36m1.2051[0m        [32m1.2360[0m  29.4640
      9        [36m1.2049[0m        [32m1.2355[0m  29.4576
     10        [36m1.2048[0m        [32m1.2354[0m  29.4676
{'lr': 0.01, 'batch_size': 32, 'max_epochs': 10, 'optimizer': <class 'torch.optim.sgd.SGD'>, 'module__dim_h': 16, 'optimizer__weight_decay': 0.01} tensor(1.2394, device='cuda:0')
  epoch    train_loss    valid_loss      dur
-------  ------------  ------------  -------
      1    