In [16]:
from torch_geometric.loader import DataLoader

from data import load_mnist_graph

from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool

from layers import GraphSAGELayer

import torch
import torch.nn as nn
import torch.nn.functional as F

from train import train_mnist, evaluate_mnist

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


%reload_ext autoreload
%autoreload 2

In [17]:
train_dataset, test_dataset = load_mnist_graph(subset=10000)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)



In [18]:
class GraphSAGE(nn.Module):
    def __init__(self, input_features, hidden_features, output_features, num_layers):
        super(GraphSAGE, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(GraphSAGELayer(input_features, hidden_features, aggr='sum'))
        for _ in range(num_layers - 1):
            self.layers.append(GraphSAGELayer(hidden_features, hidden_features, aggr='sum'))
        self.linear = nn.Linear(hidden_features, output_features)
    def forward(self, x, edge_index, batch):
        for layer in self.layers:
            x = layer.propagate(x, edge_index)

        # readout
        x = global_mean_pool(x, batch)

        x = self.linear(x)
        return x
    

In [19]:

### Max number of epochs
max_epochs = 300

### Number of features
n_features, n_classes = 1, 10
hidden_size = 8

### DEFINE THE MODEL
basic_model = GraphSAGE(n_features, hidden_size, n_classes, num_layers=4).to(device)

### DEFINE LOSS FUNCTION
loss_fcn = nn.BCEWithLogitsLoss()

### DEFINE OPTIMIZER
optimizer = torch.optim.Adam(basic_model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=20, min_lr=0.0001)

### TRAIN THE MODEL
epoch_list, basic_model_scores = train_mnist(
    basic_model,
    loss_fcn,
    device,
    optimizer,
    max_epochs,
    train_loader,
    val_loader,
)

100%|██████████| 300/300 [13:37<00:00,  2.72s/epoch, f1=0.763, loss=0.145]
