### Imports and Setup

In [None]:
# Downloaded libraries
import torch
from torch import nn
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T


# Local files
from dataset_graphs import NNDataset
from models import Trainer_GCN

In [None]:
# Constants
TRAINING_SPLIT = 0.8

In [None]:
# Hyperparameters
num_epoch = 12
batch_size = 8

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Data Loading

In [None]:
transform = None
# transform = T.Compose([T.ToUndirected()])

In [None]:
nndataset = NNDataset(root="../", transform=transform)

size = len(nndataset)
train_num = int(size * TRAINING_SPLIT)
test_num = size - train_num

print(
    f"Dataset loaded, {train_num} training samples and {test_num} testing samples")

In [None]:
# Preview of the Data

data = nndataset[0]
data

In [None]:
data.is_undirected()

In [None]:
train_loader = DataLoader(
    dataset=nndataset[:train_num], batch_size=batch_size, shuffle=True)
test_loader = DataLoader(
    dataset=nndataset[train_num:], batch_size=test_num, shuffle=True)

### Loading the Model

In [None]:
model = Trainer_GCN().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, weight_decay=5e-4)
loss_fn = nn.MSELoss()

In [None]:
model

### Training and Evaluation

In [None]:
# Model Training
model.train()
for epoch in range(num_epoch):
    print(f"Epoch {epoch + 1} / {num_epoch}:")
    for i, data in enumerate(train_loader):
        data.to(device)

        # forward propagation
        out_w, out_b = model(data)
        loss = 0
        loss += loss_fn(out_b, data.y_node)
        loss += loss_fn(out_w, data.y_edge)

        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print status every n batches
        if i % 10 == 0:
            loss, current = loss.item(), i * batch_size
            print(
                f"Training Loss: {loss:>7f}  [{current:>5d}/{train_num:>5d}]")

In [None]:
# Model Evaluation
model.eval()
with torch.no_grad():
    data = iter(test_loader).next().to(device)

    # forward propagation
    out_w, out_b = model(data)
    loss = loss_fn(out_b, data.y_node) + loss_fn(out_w, data.y_edge)

    loss = loss.item()
    print(f"Validation Loss: {loss:>7f}")


### Comparison Using One Instance of Data

In [None]:
data = nndataset[0]
data = data.to(device)
data

In [None]:
out_w, out_b = model(data)

In [None]:
out_w

In [None]:
data.y_edge

In [None]:
torch.save(model.state_dict(), "../model/model")
print("Model saved")