In [1]:
import itertools

import torch
import torch.nn as nn

from torch_geometric.nn.pool import global_mean_pool

In [2]:
from helpers_wo_embedding import CVFConfigDataset

from torch.utils.data import DataLoader, random_split

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
# torch.manual_seed(20)
# for same weight re-initialization

In [5]:
dataset_random_regular_n4 = CVFConfigDataset(
    "coloring",
    "graph_random_regular_graph_n4_d3_config_rank_dataset.csv",
    "graph_random_regular_graph_n4_d3_A.json",
    4,
)

dataset_random_regular_n5 = CVFConfigDataset(
    "coloring",
    "graph_random_regular_graph_n5_d4_config_rank_dataset.csv",
    "graph_random_regular_graph_n5_d4_A.json",
    5,
)

dataset_random_regular_n6 = CVFConfigDataset(
    "coloring",
    "graph_random_regular_graph_n6_d3_config_rank_dataset.csv",
    "graph_random_regular_graph_n6_d3_A.json",
    6,
)

dataset_random_regular_n7 = CVFConfigDataset(
    "coloring",
    "graph_random_regular_graph_n7_d4_config_rank_dataset.csv",
    "graph_random_regular_graph_n7_d4_A.json",
    7,
)

dataset_random_regular_n8 = CVFConfigDataset(
    "coloring",
    "graph_random_regular_graph_n8_d4_config_rank_dataset.csv",
    "graph_random_regular_graph_n8_d4_A.json",
    8,
)

In [6]:
dataset_coll = [dataset_random_regular_n4, dataset_random_regular_n5, dataset_random_regular_n6, dataset_random_regular_n7, dataset_random_regular_n8]

batch_size = 50

In [7]:
train_loaders = []
test_loaders = []

for dataset in dataset_coll:
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    train_loaders.append(train_loader)
    test_loaders.append(test_loader)

train_loaders_iter = [iter(i) for i in train_loaders]

In [8]:
def generate_batch():
    end_loop = [False for _ in range(len(train_loaders))]
    while not any(end_loop):
        for di, data_loader in enumerate(train_loaders_iter):
            if end_loop[di]:
                continue
            try:
                batch = next(data_loader)
            except StopIteration:
                end_loop[di] = True
                continue
            yield batch, di

In [9]:
print("Number of batches:", [len(i) for i in train_loaders])

Number of batches: [5, 50, 66, 1250, 6250]


In [10]:
train_loaders[3].dataset.dataset.A

tensor([[0, 1, 1, 0, 1, 0, 1],
        [1, 0, 0, 0, 1, 1, 1],
        [1, 0, 0, 1, 0, 1, 1],
        [0, 0, 1, 0, 1, 1, 1],
        [1, 1, 0, 1, 0, 1, 0],
        [0, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 0, 0, 0]])

In [11]:
class GCNConvByHand(nn.Module):
    """maps D x N to D x N"""

    def __init__(self, dim_in):
        super().__init__()
        self.linear = torch.nn.Linear(dim_in, dim_in, bias=True)
        # torch.nn.init.ones_(self.linear.weight)
        # torch.nn.init.ones_(self.linear.bias)
        # torch.nn.init.xavier_uniform_(self.linear.weight) 

    def forward(self, x, A):
        num_nodes = A.shape[0]
        omega_k = self.linear.weight
        beta_k = self.linear.bias.reshape(-1, 1)
        # print("omega_k", omega_k.shape, "beta_k", beta_k)
        # H_k = x
        x = torch.matmul(
            beta_k, torch.reshape(torch.ones(num_nodes).to(device), (1, -1))
        ) + torch.matmul(omega_k, torch.matmul(x, A + torch.eye(num_nodes).to(device)))
        return x

In [12]:
class GCNByHand(nn.Module):
    def __init__(self, N, in_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConvByHand(in_channels)
        self.conv2 = GCNConvByHand(in_channels)
        self.linear1 = torch.nn.Linear(N, 16, bias=True)
        self.out = torch.nn.Linear(16, out_channels, bias=True)

    def forward(self, x, A):
        x = self.conv1(x, A)
        x = torch.relu(x)
        x = self.conv2(x, A)
        x = torch.relu(x)
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.out(x)
        # x = torch.relu(x)
        # x = torch.sigmoid(x) * 7
        # print("x after output layer\n", x)
        x = global_mean_pool(
            x, torch.zeros(x.size(1)).long().to(device)
        )  # all are from graph 0, single graph
        return x

    def fit(self, epochs):
        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=0.01)
        dataloaders = itertools.tee(generate_batch(), epochs)
        for epoch in range(1, epochs + 1):
            self.train()
            total_loss = 0
            count = 0
            for batch, _ in dataloaders[epoch - 1]:
                x = batch[0].to(device)
                y = batch[1].to(device)
                y = y.unsqueeze(0).reshape(-1, 1, 1).float()
                optimizer.zero_grad()
                out = self(x, train_loader.dataset.dataset.A.to(device))
                loss = criterion(out, y)
                total_loss += loss
                loss.backward()
                optimizer.step()
                count += len(batch[1])

            print("Training set | Epoch:", epoch, "Loss:", total_loss / count)

    def fit_old(self, train_loader, epochs):
        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.Adam(
            self.parameters(), lr=0.01, weight_decay=0.01
        )  # weight_decay is a L2 regularization parameter
        for epoch in range(1, epochs + 1):
            self.train()
            total_loss = 0
            count = 0
            for batch in train_loader:
                x = batch[0].to(device)
                y = batch[1].to(device)
                optimizer.zero_grad()
                out = self(x, train_loader.dataset.dataset.A.to(device))
                loss = criterion(out, y)
                total_loss += loss
                count += 1
                loss.backward()
                optimizer.step()

            print("Training set | Epoch:", epoch, "Loss:", total_loss / count)

            # self.eval()
            # with torch.no_grad():
            #     total_loss = 0
            #     count = 0
            #     for batch in validation_loader:
            #         x = batch[0].to(device)
            #         x = add_graph_properties(x)
            #         y = batch[1].to(device)
            #         y = y.unsqueeze(0).reshape(-1, 1, 1).float()
            #         optimizer.zero_grad()
            #         out = self(x)
            #         loss = criterion(out, y)
            #         total_loss += loss
            #         count += 1

            # print("Validatn set | Epoch:", epoch, "Loss:", total_loss / count)
            # print()

In [13]:
num_nodes = 3       # N
num_features = 1     # D
num_labels = 1       # O

model = GCNByHand(N=num_nodes, in_channels=num_features, out_channels=num_labels)
model.to(device)

GCNByHand(
  (conv1): GCNConvByHand(
    (linear): Linear(in_features=1, out_features=1, bias=True)
  )
  (conv2): GCNConvByHand(
    (linear): Linear(in_features=1, out_features=1, bias=True)
  )
  (linear1): Linear(in_features=3, out_features=16, bias=True)
  (out): Linear(in_features=16, out_features=1, bias=True)
)

A = np.array(A)

H_k = np.array(x)
# H_k

# H_k__A = H_k @ A
# H_k__A, H_k__A.shape

omega_0 = np.array(model.conv1.linear.weight.detach().numpy())
beta_0 = np.array(model.conv1.linear.bias.detach().numpy()).reshape((-1, 1))
# print(omega_0, beta_0)


preactivation = beta_0 @ np.ones(num_nodes).reshape((1, -1)) + omega_0 @ H_k @ (
    A + np.identity(num_nodes)
)
# preactivation

out_wt = np.array(model.out.weight.detach().numpy())
out_bias = np.array(model.out.bias.detach().numpy())
preactivation @ out_wt.transpose() + out_bias

In [14]:
model.fit(train_loader, 20)

TypeError: GCNByHand.fit() takes 2 positional arguments but 3 were given

In [None]:
import csv

# testing
torch.no_grad()
# torch.set_printoptions(profile="full")

f = open("test_result.csv", "w", newline='')
csv_writer = csv.writer(f)
csv_writer.writerow(["Actual", "Predicted"])

total_matched = 0
criterion = torch.nn.MSELoss()

total_loss = 0
for batch in test_loader:
    x = batch[0].to(device)
    # x = x.repeat(1, 8, 1)
    y = batch[1].to(device)
    out = model(x, test_loader.dataset.dataset.A.to(device))
    csv_writer.writerows(zip(y.detach().cpu().numpy(), out.detach().cpu().numpy()))
    loss = criterion(out, y)
    out = torch.round(out)
    matched = (out == y).sum().item()
    total_matched += matched
    total_loss += loss
    loss.backward()

f.close()
print(
    "Total matched",
    total_matched,
    "out of",
    len(test_set),
    "| ",
    "Loss:",
    total_loss / len(test_loader),
    "| Accuracy",
    round(total_matched / len(test_set) * 100, 4),
    "%",
)