In [1]:
import torch
import torch.nn as nn

from torch_geometric.nn.pool import global_mean_pool
from torch.utils.data import ConcatDataset, DataLoader, random_split, Sampler

In [2]:
from helpers import  CVFConfigForGCNWSuccLSTMDataset, CVFConfigForGCNWSuccLSTMWNormalizationDataset

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
dataset_s_n7 = CVFConfigForGCNWSuccLSTMDataset(
    device,
    "star_graph_n7_config_rank_dataset.csv",
    "star_graph_n7_edge_index.json",
)

# dataset_s_n13 = CVFConfigForGCNWSuccLSTMDataset(
#     device,
#     "star_graph_n13_config_rank_dataset.csv",
#     "star_graph_n13_edge_index.json",
# )

# dataset_s_n15 = CVFConfigForGCNWSuccLSTMDataset(
#     device,
#     "star_graph_n15_config_rank_dataset.csv",
#     "star_graph_n15_edge_index.json",
# )

dataset_rr_n7 = CVFConfigForGCNWSuccLSTMDataset(
    device,
    "graph_random_regular_graph_n7_d4_config_rank_dataset.csv",
    "graph_random_regular_graph_n7_d4_edge_index.json",
)

# dataset_rr_n8 = CVFConfigForGCNWSuccLSTMDataset(
#     device,
#     "graph_random_regular_graph_n8_d4_config_rank_dataset.csv",
#     "graph_random_regular_graph_n8_d4_edge_index.json",
# )

dataset_plc_n7 = CVFConfigForGCNWSuccLSTMDataset(
    device,
    "graph_powerlaw_cluster_graph_n7_config_rank_dataset.csv",
    "graph_powerlaw_cluster_graph_n7_edge_index.json",
)

# dataset_plc_n9 = CVFConfigForGCNWSuccLSTMDataset(
#     device,
#     "graph_powerlaw_cluster_graph_n9_config_rank_dataset.csv",
#     "graph_powerlaw_cluster_graph_n9_edge_index.json",
# )

dataset_implict_n5 = CVFConfigForGCNWSuccLSTMDataset(
    device,
    "implicit_graph_n5_config_rank_dataset.csv",
    "implicit_graph_n5_edge_index.json",
    "dijkstra",
)

# dataset_implict_n10 = CVFConfigForGCNWSuccWEIDataset(
#     device,
#     "implicit_graph_n10_config_rank_dataset.csv",
#     "implicit_graph_n10_edge_index.json",
#     "dijkstra",
# )

dataset_implict_n7 = CVFConfigForGCNWSuccLSTMDataset(
    device,
    "implicit_graph_n7_config_rank_dataset.csv",
    "implicit_graph_n7_edge_index.json",
    "dijkstra",
)

batch_size = 256

dataset_coll = [
    dataset_implict_n5,
    # dataset_implict_n10,
    # dataset_implict_n7,
    # dataset_s_n7,
    # # dataset_s_n13,
    # # dataset_s_n15,
    # dataset_rr_n7,
    # # dataset_rr_n8,
    # dataset_plc_n7,
    # # dataset_plc_n9,
]

In [5]:
train_sizes = [int(0.95 * len(ds)) for ds in dataset_coll]
test_sizes = [len(ds) - trs for ds, trs in zip(dataset_coll, train_sizes)]

train_test_datasets = [
    random_split(ds, [tr_s, ts])
    for ds, tr_s, ts in zip(dataset_coll, train_sizes, test_sizes)
]

train_datasets = [ds[0] for ds in train_test_datasets]
test_datasets = [ds[1] for ds in train_test_datasets]

In [6]:
datasets = ConcatDataset(train_datasets)
print(f"Train Dataset size: {len(datasets):,}")

Train Dataset size: 230


In [7]:
class CustomBatchSampler(Sampler):
    def __init__(self, datasets: ConcatDataset, batch_size: int):
        self.datasets = datasets
        self.batch_size = batch_size

    def __iter__(self):
        last_accessed = [0] + self.datasets.cumulative_sizes[:]
        end_loop = [False for _ in range(len(self.datasets.datasets))]

        while not all(end_loop):
            for turn in range(len(self.datasets.datasets)):
                if end_loop[turn]:
                    continue

                batch_size = self.batch_size
                if (
                    last_accessed[turn] + batch_size
                    >= self.datasets.cumulative_sizes[turn]
                ):
                    batch_size = (
                        self.datasets.cumulative_sizes[turn] - last_accessed[turn]
                    )
                    end_loop[turn] = True

                yield list(range(last_accessed[turn], last_accessed[turn] + batch_size))

                last_accessed[turn] += batch_size

In [8]:
# from normalization import compute_mean_std, NormalizeTransform

# loader = DataLoader(datasets, batch_sampler=CustomBatchSampler(datasets, batch_size=1024))
# mean, std = compute_mean_std(loader)
# print(mean, std)


In [9]:

# transform = NormalizeTransform(mean, std)
# for dataset in train_datasets:
#     dataset.dataset.set_transform(transform)

In [10]:
batch_sampler = CustomBatchSampler(datasets, batch_size=batch_size)
dataloader = DataLoader(datasets, batch_sampler=batch_sampler)

In [11]:
class SimpleLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.lstm = nn.GRU(input_size, hidden_size, batch_first=True)
        self.norm = nn.LayerNorm(hidden_size)
        self.h2o = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        output, _ = self.lstm(x)
        output = self.norm(output)
        output = self.h2o(output)
        output = torch.relu(output)
        output = global_mean_pool(output, torch.zeros(output.size(1)).to(device).long())
        return output

    def fit(self, epochs):
        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=0.0001)
        for epoch in range(1, epochs + 1):
            self.train()
            total_loss = 0
            count = 0
            for batch in dataloader:
                x = batch[0]
                x = x[0]
                y = batch[1]
                y = y.unsqueeze(-1)
                out = self(x)
                optimizer.zero_grad()
                loss = criterion(out, y)
                total_loss += loss
                count += 1
                loss.backward()
                optimizer.step()

            print(
                "Training set | Epoch",
                epoch,
                "| MSE Loss:",
                round((total_loss / count).item(), 4),
            )

In [12]:
D = 3
H = 64

model = SimpleLSTM(D, H, 1).to(device)
print(model)
print()
print("Total parameters:", f"{sum(p.numel() for p in model.parameters()):,}")
print()
model.fit(epochs=1000)

SimpleLSTM(
  (lstm): GRU(3, 64, batch_first=True)
  (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  (h2o): Linear(in_features=64, out_features=1, bias=True)
)

Total parameters: 13,441

Training set | Epoch 1 | MSE Loss: 68.7478
Training set | Epoch 2 | MSE Loss: 68.7478
Training set | Epoch 3 | MSE Loss: 68.7478
Training set | Epoch 4 | MSE Loss: 68.7467
Training set | Epoch 5 | MSE Loss: 60.1434
Training set | Epoch 6 | MSE Loss: 48.1817
Training set | Epoch 7 | MSE Loss: 37.9102
Training set | Epoch 8 | MSE Loss: 31.2793
Training set | Epoch 9 | MSE Loss: 27.1112
Training set | Epoch 10 | MSE Loss: 24.1874
Training set | Epoch 11 | MSE Loss: 21.8355
Training set | Epoch 12 | MSE Loss: 19.7158
Training set | Epoch 13 | MSE Loss: 17.6558
Training set | Epoch 14 | MSE Loss: 15.6163
Training set | Epoch 15 | MSE Loss: 13.7448
Training set | Epoch 16 | MSE Loss: 12.4157
Training set | Epoch 17 | MSE Loss: 11.6562
Training set | Epoch 18 | MSE Loss: 11.2343
Training set | 

# Testing

In [13]:
import csv
import datetime

f = open(f"test_results/test_result_w_succ_diff_nodes_lstm_{datetime.datetime.now().strftime("%Y_%m_%d_%H_%M")}.csv", "w", newline="")
csv_writer = csv.writer(f)
csv_writer.writerow(["Dataset", "Actual", "Predicted"])

criterion = torch.nn.MSELoss()

model.eval()

with torch.no_grad():
    test_concat_datasets = ConcatDataset(test_datasets)
    test_batch_sampler = CustomBatchSampler(test_concat_datasets, batch_size=10240)
    test_dataloader = DataLoader(test_concat_datasets, batch_sampler=test_batch_sampler)

    total_loss = 0
    total_matched = 0
    count = 0
    for batch in test_dataloader:
        x = batch[0]
        x = x[0]
        y = batch[1]
        y = y.unsqueeze(-1)
        out = model(x)
        csv_writer.writerows(
            (i, j.item(), k.item())
            for (i, j, k) in zip(
                 batch[0][1], y.detach().cpu().numpy(), out.detach().cpu().numpy()
            )
        )
        loss = criterion(out, y)
        total_loss += loss
        out = torch.round(out)
        matched = (out == y).sum().item()
        total_matched += matched
        count += 1

    print(
        "Test set",
        "| MSE loss:",
        round((total_loss / count).item(), 4),
        "| Total matched",
        total_matched,
        "out of",
        len(test_concat_datasets),
        f"(Accuracy: {round(total_matched/len(test_concat_datasets) * 100, 2)}%)",
    )

f.close()

Test set | MSE loss: 6.4735 | Total matched 9 out of 13 (Accuracy: 69.23%)


# Testing with Untrained Datasets

In [14]:
dataset_s_n13 = CVFConfigForGCNWSuccLSTMDataset(
    device,
    "star_graph_n13_config_rank_dataset.csv",
    "star_graph_n13_edge_index.json",
)


dataset = dataset_s_n13

# dataset.set_transform(transform)
criterion = torch.nn.MSELoss()

with torch.no_grad():
    test_dataloader = DataLoader(dataset, batch_size=10240)
    total_loss = 0
    total_matched = 0
    count = 0
    for batch in test_dataloader:
        x = batch[0]
        y = batch[1]
        y = y.unsqueeze(-1)
        out = model(x[0])
        loss = criterion(out, y)
        total_loss += loss
        out = torch.round(out)
        matched = (out == y).sum().item()
        total_matched += matched
        count += 1

    print(
        "Test set",
        f"| {dataset.dataset_name}",
        "| MSE loss:",
        round((total_loss / count).item(), 4),
        "| Total matched",
        f"{total_matched:,}",
        "out of",
        f"{len(dataset):,}",
        f"(Accuracy: {round(total_matched/len(dataset) * 100, 2)}%)",
    )

Test set | star_graph_n13 | MSE loss: 6.7323 | Total matched 18,052 out of 53,248 (Accuracy: 33.9%)
