In [1]:
import torch
import torch.nn.functional as F

from torch.utils.data import DataLoader, random_split

In [2]:
from torch_geometric.nn import GCNConv
from dataset import MessagePassingDataset

In [3]:
device = "cuda"

In [4]:
dataset_tiny = MessagePassingDataset(
    device,
    "tiny_graph_test",
    "tiny_graph_test_config_rank_dataset.csv",
    D=3
)

dataset_s_n7 = MessagePassingDataset(
    device,
    "star_graph_n7",
    "star_graph_n7_config_rank_dataset.csv",
    D=7
)

dataset_rr_n7 = MessagePassingDataset(
    device,
    "graph_random_regular_graph_n7_d4",
    "graph_random_regular_graph_n7_d4_config_rank_dataset.csv",
    D=7
)

dataset = dataset_tiny

In [5]:
train_size = int(0.95 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

print("Train size", len(train_dataset))
print("Test size", len(test_dataset))

Train size 11
Test size 1


In [6]:
batch_size = 1
dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [7]:
epochs = 5

In [8]:
class SimpleMPModel(torch.nn.Module):
    def __init__(self, h=16):
        super().__init__()
        self.gcn1 = GCNConv(1, h, bias=True)
        self.gcn2 = GCNConv(h, dataset.D, bias=True)

    def forward(self, x, A):
        output = self.gcn1(x, A)
        output = F.relu(output)
        output = self.gcn2(output, A)
        return output

    def fit(self, epochs):
        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=0.0001)
        for epoch in range(1, epochs + 1):
            self.train()
            total_loss = 0
            count = 0
            for batch in dataloader:
                x = batch[0]
                y = batch[1]
                print(x, y)
                # A = dataset.A.repeat(x.shape[0], 1, 1).to(device)
                out = self(x, dataset.edge_index)
                optimizer.zero_grad()
                loss = criterion(out, y)
                total_loss += loss
                count += 1
                loss.backward()
                optimizer.step()

            print(
                "Training set | Epoch",
                epoch,
                "| MSE Loss:",
                round((total_loss / count).item(), 4),
            )

In [9]:
model = SimpleMPModel(32).to(device)
print(model)

SimpleMPModel(
  (gcn1): GCNConv(1, 32)
  (gcn2): GCNConv(32, 3)
)


In [10]:
model.fit(epochs)

tensor([[[0.],
         [1.],
         [1.]]], device='cuda:0') tensor([[[-1., -1., -1.],
         [-1., -1., -1.],
         [-1., -1., -1.]]], device='cuda:0')
tensor([[[1.],
         [0.],
         [0.]]], device='cuda:0') tensor([[[-1., -1., -1.],
         [-1., -1., -1.],
         [-1., -1., -1.]]], device='cuda:0')
tensor([[[0.],
         [1.],
         [0.]]], device='cuda:0') tensor([[[ 2.,  1.,  0.],
         [-1., -1., -1.],
         [ 0.,  1.,  1.]]], device='cuda:0')
tensor([[[2.],
         [1.],
         [1.]]], device='cuda:0') tensor([[[-1., -1., -1.],
         [-1., -1., -1.],
         [-1., -1., -1.]]], device='cuda:0')
tensor([[[0.],
         [0.],
         [0.]]], device='cuda:0') tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]], device='cuda:0')
tensor([[[2.],
         [1.],
         [0.]]], device='cuda:0') tensor([[[-1., -1., -1.],
         [-1., -1., -1.],
         [-1., -1., -1.]]], device='cuda:0')
tensor([[[1.],
         [1.],
         [0.]