In [1]:
import torch

from torch.utils.data import DataLoader, random_split

In [2]:
from models_by_hand import GCNConvByHand
from dataset import MessagePassingDataset

In [3]:
device = "cuda"

In [4]:
dataset_s_n7 = MessagePassingDataset(
    device,
    "star_graph_n7",
    "star_graph_n7_config_rank_dataset.csv",
    D=7
)

dataset = dataset_s_n7

In [5]:
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

print("Train size", len(train_dataset))
print("Test size", len(test_dataset))

Train size 358
Test size 90


In [6]:
batch_size = 64
dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [7]:
epochs = 50

In [None]:
class SimpleMPModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        embedding_dim = 3
        self.embedding = torch.nn.Embedding(len(dataset), embedding_dim)
        self.gcn1 = GCNConvByHand(
            embedding_dim, dataset.D, bias=True, device=device
        ).to(device)
        self.gcn2 = GCNConvByHand(
            dataset.D, dataset.D, bias=True, device=device
        ).to(device)

    def forward(self, x, A):
        x = self.embedding(x)
        x = x.squeeze(-2)
        output = self.gcn1(x, A)
        output = self.gcn2(output, A)
        return output

    def fit(self, epochs):
        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=0.00001)
        for epoch in range(1, epochs + 1):
            self.train()
            total_loss = 0
            count = 0
            for batch in dataloader:
                x = batch[0]
                y = batch[1]
                A = dataset.A.repeat(x.shape[0], 1, 1).to(device)
                out = self(x, A)
                optimizer.zero_grad()
                loss = criterion(out, y)
                total_loss += loss
                count += 1
                loss.backward()
                optimizer.step()

            print(
                "Training set | Epoch",
                epoch,
                "| MSE Loss:",
                round((total_loss / count).item(), 4),
            )

In [9]:
model = SimpleMPModel().to(device)
print(model)

SimpleMPModel(
  (embedding): Embedding(448, 3)
  (gcn1): GCNConvByHand(
    (linear): Linear(in_features=3, out_features=7, bias=True)
  )
  (gcn2): GCNConvByHand(
    (linear): Linear(in_features=7, out_features=7, bias=True)
  )
)


In [10]:
model.fit(epochs)

Training set | Epoch 1 | MSE Loss: 1.6038
Training set | Epoch 2 | MSE Loss: 1.085
Training set | Epoch 3 | MSE Loss: 0.8564
Training set | Epoch 4 | MSE Loss: 0.7158
Training set | Epoch 5 | MSE Loss: 0.6174
Training set | Epoch 6 | MSE Loss: 0.5501
Training set | Epoch 7 | MSE Loss: 0.4943
Training set | Epoch 8 | MSE Loss: 0.4594
Training set | Epoch 9 | MSE Loss: 0.454
Training set | Epoch 10 | MSE Loss: 0.4459
Training set | Epoch 11 | MSE Loss: 0.4396
Training set | Epoch 12 | MSE Loss: 0.4379
Training set | Epoch 13 | MSE Loss: 0.4342
Training set | Epoch 14 | MSE Loss: 0.4297
Training set | Epoch 15 | MSE Loss: 0.4218
Training set | Epoch 16 | MSE Loss: 0.4239
Training set | Epoch 17 | MSE Loss: 0.42
Training set | Epoch 18 | MSE Loss: 0.4234
Training set | Epoch 19 | MSE Loss: 0.409
Training set | Epoch 20 | MSE Loss: 0.409
Training set | Epoch 21 | MSE Loss: 0.4044
Training set | Epoch 22 | MSE Loss: 0.4094
Training set | Epoch 23 | MSE Loss: 0.3989
Training set | Epoch 24 | 