In [1]:
import torch

from torch_geometric.nn import GCNConv
from torch_geometric.nn.pool import global_mean_pool
from torch.utils.data import ConcatDataset, Sampler, DataLoader, random_split

In [2]:
from models_by_hand import GCNConvByHand
from helpers import CVFConfigForGCNWSuccWEIDataset

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
dataset_s_n7 = CVFConfigForGCNWSuccWEIDataset(
    device,
    "star_graph_n7_config_rank_dataset.csv",
    "star_graph_n7_edge_index.json",
)

dataset_rr_n7 = CVFConfigForGCNWSuccWEIDataset(
    device,
    "graph_random_regular_graph_n7_d4_config_rank_dataset.csv",
    "graph_random_regular_graph_n7_d4_edge_index.json",
)

dataset_plc_n7 = CVFConfigForGCNWSuccWEIDataset(
    device,
    "graph_powerlaw_cluster_graph_n7_config_rank_dataset.csv",
    "graph_powerlaw_cluster_graph_n7_edge_index.json",
)


batch_size = 64

dataset_coll = [
    dataset_s_n7,
    dataset_rr_n7,
    dataset_plc_n7,
]

In [5]:
train_sizes = [int(0.95 * len(ds)) for ds in dataset_coll]
test_sizes = [len(ds) - trs for ds, trs in zip(dataset_coll, train_sizes)]

train_test_datasets = [
    random_split(ds, [tr_s, ts])
    for ds, tr_s, ts in zip(dataset_coll, train_sizes, test_sizes)
]

train_datasets = [ds[0] for ds in train_test_datasets]
test_datasets = [ds[1] for ds in train_test_datasets]

In [6]:
datasets = ConcatDataset(train_datasets)
print(len(datasets))

100293


In [7]:
dataloader = DataLoader(datasets, batch_size=batch_size, shuffle=True)

In [None]:
class VanillaGNN(torch.nn.Module):
    def __init__(self, dim_in, dim_h, dim_out):
        super().__init__()
        self.gcn1 = GCNConvByHand(dim_in, dim_h, device=device)
        self.gcn2 = GCNConvByHand(dim_h, dim_h, device=device)
        self.linear1 = torch.nn.Linear(dim_h, dim_h)
        self.out = torch.nn.Linear(dim_h, dim_out)

    @property
    def edge_index(self):
        return self._edge_index

    @edge_index.setter
    def edge_index(self, val):
        self._edge_index = val

    def forward(self, x, A):
        # print(x, A)
        h = self.gcn1(x, A)
        h = torch.relu(h)
        h = self.gcn2(h, A)
        h = torch.relu(h)
        h = self.out(h)
        h = torch.relu(h)
        h = global_mean_pool(h, torch.zeros(h.size(1)).to(device).long())
        return h

    def fit(self, epochs):
        dataloader.batch_sampler.module = self
        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=0.0001)
        for epoch in range(1, epochs + 1):
            self.train()
            total_loss = 0
            count = 0
            relative_loss = 0
            for batch in dataloader:
                x = batch[0]
                y = batch[1]
                y = y.unsqueeze(-1)
                optimizer.zero_grad()
                out = self(x[0], x[1])
                loss = criterion(out, y)
                denom = torch.sum(y)
                if denom == 0:
                    denom = 0.00001
                relative_loss += torch.sum(abs(out - y)) / denom
                total_loss += loss
                count += 1
                loss.backward()
                optimizer.step()

            print(
                "Training set | Epoch",
                epoch,
                "| MSE Loss:",
                round((total_loss / count).item(), 4),
                "| Relative Loss:",
                round((relative_loss / count).item(), 4),
            )

In [None]:
gnn = VanillaGNN(3, 64, 1).to(device)
print(gnn)
print()
print("Total parameters:", sum(p.numel() for p in gnn.parameters()))
print()
gnn.fit(epochs=1)

VanillaGNN(
  (gcn1): GCNConvByHand(
    (linear): Linear(in_features=3, out_features=16, bias=True)
  )
  (gcn2): GCNConvByHand(
    (linear): Linear(in_features=16, out_features=16, bias=True)
  )
  (linear1): Linear(in_features=16, out_features=16, bias=True)
  (out): Linear(in_features=16, out_features=1, bias=True)
)

Total parameters: 625

tensor([[[2.0000, 1.3333, 5.7143],
         [4.0000, 4.0000, 5.7143],
         [2.0000, 1.6667, 5.7143],
         [0.0000, 0.0000, 5.7143],
         [1.0000, 1.0000, 5.7143],
         [2.0000, 2.3333, 5.7143],
         [3.0000, 3.0000, 5.7143]],

        [[1.0000, 1.2000, 7.4286],
         [0.0000, 0.4000, 7.4286],
         [4.0000, 4.0000, 7.4286],
         [3.0000, 3.0000, 7.4286],
         [0.0000, 0.4000, 7.4286],
         [0.0000, 0.2000, 7.4286],
         [1.0000, 1.2000, 7.4286]],

        [[1.0000, 0.8000, 8.8571],
         [4.0000, 3.2000, 8.8571],
         [2.0000, 2.0000, 8.8571],
         [0.0000, 0.0000, 8.8571],
         [1.0000, 

KeyboardInterrupt: 

In [None]:
# testing
import csv

# torch.set_printoptions(profile="full")

f = open("test_result_w_succ_same_nodes.csv", "w", newline="")
csv_writer = csv.writer(f)
csv_writer.writerow(["Dataset", "Actual", "Predicted"])

criterion = torch.nn.MSELoss()

gnn.eval()

with torch.no_grad():
    total_loss = 0
    relative_loss = 0
    total_matched = 0
    test_concat_datasets = ConcatDataset(test_datasets)
    test_dataloader = DataLoader(
        test_concat_datasets, batch_size=len(test_concat_datasets)
    )
    total_loss = 0
    relative_loss = 0
    total_matched = 0
    count = 0
    for batch in test_dataloader:
        x = batch[0]
        y = batch[1]
        y = y.unsqueeze(-1)
        out = gnn(x[0], x[1])
        csv_writer.writerows(
            [
                ("", j, k)
                for (j, k) in zip(y.detach().cpu().numpy(), out.detach().cpu().numpy())
            ]
        )
        loss = criterion(out, y)
        # print("Loss: ", loss)
        total_loss += loss
        denom = torch.sum(y)
        if denom == 0:
            denom = 0.00001
        relative_loss += torch.sum(abs(out - y)) / denom
        out = torch.round(out)
        matched = (out == y).sum().item()
        total_matched += matched
        count += 1

    print(
        "Test set |",
        # dataset.dataset.dataset_name,
        "| MSE loss:",
        round((total_loss / count).item(), 4),
        "| Relative Loss:",
        round((relative_loss / count).item(), 4),
        "| Total matched",
        total_matched,
        "out of",
        len(test_concat_datasets),
        f"(Accuracy: {round(total_matched/len(test_concat_datasets) * 100, 2)}%)",
    )

f.close()

tensor([[[6.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         ...,
         [1.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.5000, 2.4286],
         [1.0000, 1.0000, 2.4286],
         [1.0000, 1.0000, 2.4286],
         ...,
         [0.0000, 0.2500, 2.4286],
         [1.0000, 1.0000, 2.4286],
         [0.0000, 0.2500, 2.4286]],

        [[4.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000],
         ...,
         [1.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000]],

        ...,

        [[0.0000, 0.2000, 7.0000],
         [2.0000, 2.0000, 7.0000],
         [2.0000, 2.0000, 7.0000],
         ...,
         [3.0000, 2.6000, 7.0000],
         [3.0000, 2.6000, 7.0000],
         [0.0000, 0.2000, 7.0000]],

        [[1.0000, 1.2000, 6.8571],
         [2.0000, 2.0000, 6.8571],
         [0.