### Imports and Setup

In [1]:
# Downloaded libraries
import torch
from torch import nn
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T


# Local files
from dataset_graphs import NNDataset
from models import Trainer_GCN

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Debugging Settings
torch.set_printoptions(threshold=12500)

In [3]:
# Constants
TRAINING_SPLIT = 0.8

In [4]:
# Hyperparameters
num_epoch = 5
batch_size = 8

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Data Loading

In [6]:
transform = None
# transform = T.Compose([T.ToUndirected()])

In [7]:
nndataset = NNDataset(root="../", transform=transform)

size = len(nndataset)
train_num = int(size * TRAINING_SPLIT)
test_num = size - train_num

print(
    f"Dataset loaded, {train_num} training samples and {test_num} testing samples")

Dataset loaded, 672 training samples and 168 testing samples


In [8]:
# Preview of the Data

data = nndataset[0]
data

Data(design=[3], edge_index=[2, 480], x=[101, 503], edge_weight=[480, 1], y_node=[101, 1], y_edge=[480, 1], input_mask=[101, 1], num_nodes=101)

In [9]:
data.is_undirected()

False

In [10]:
train_loader = DataLoader(
    dataset=nndataset[:train_num], batch_size=batch_size, shuffle=True)
test_loader = DataLoader(
    dataset=nndataset[train_num:], batch_size=test_num, shuffle=True)

In [11]:
print("Number of batches:", int(train_num / batch_size))

Number of batches: 84


### Loading the Model

In [12]:
model = Trainer_GCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
loss_fn = nn.MSELoss()

In [13]:
model

Trainer_GCN(
  (conv1): GATConv(503, 128, heads=1)
  (conv2): GATConv(128, 128, heads=1)
  (dense_1B): Linear(in_features=128, out_features=1, bias=True)
  (dense_1W): Linear(in_features=128, out_features=1, bias=True)
)

### Training and Evaluation

In [14]:
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    for i, data in enumerate(dataloader):
        data.to(device)

        # forward propagation
        out_w, out_b = model(data)
        loss = 0
        loss += loss_fn(out_b, data.y_node)
        loss += loss_fn(out_w, data.y_edge)

        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print status every n batches
        if i % 35 == 0:
            loss, current = loss.item(), i * batch_size
            print(
                f"Training Loss: {loss:>7f}  [{current:>5d}/{train_num:>5d}]")

In [15]:
def test(dataloader, model, loss_fn):
    model.eval()
    with torch.no_grad():
        data = iter(dataloader).next().to(device)

        # forward propagation
        out_w, out_b = model(data)
        loss = loss_fn(out_b, data.y_node) + loss_fn(out_w, data.y_edge)

        loss = loss.item()
        print(f"Validation Loss: {loss:>7f}")


In [16]:
# Model Training
for epoch in range(num_epoch):
    print(f"Epoch {epoch + 1} / {num_epoch}:")
    train(train_loader, model, loss_fn, optimizer)
    test(test_loader, model, loss_fn)

Epoch 1 / 5:
Training Loss: 0.335180  [    0/  672]
Training Loss: 0.176350  [  280/  672]
Training Loss: 0.151065  [  560/  672]
Validation Loss: 0.135761
Epoch 2 / 5:
Training Loss: 0.163674  [    0/  672]
Training Loss: 0.154293  [  280/  672]
Training Loss: 0.141995  [  560/  672]
Validation Loss: 0.141383
Epoch 3 / 5:
Training Loss: 0.150922  [    0/  672]
Training Loss: 0.204343  [  280/  672]
Training Loss: 0.186679  [  560/  672]
Validation Loss: 0.122192
Epoch 4 / 5:
Training Loss: 0.177199  [    0/  672]
Training Loss: 0.149488  [  280/  672]
Training Loss: 0.187349  [  560/  672]
Validation Loss: 0.129359
Epoch 5 / 5:
Training Loss: 0.133983  [    0/  672]
Training Loss: 0.155119  [  280/  672]
Training Loss: 0.123437  [  560/  672]
Validation Loss: 0.118770


### Comparison Using One Instance of Data

In [17]:
data = nndataset[0]
data = data.to(device)
data

Data(design=[3], edge_index=[2, 480], x=[101, 503], edge_weight=[480, 1], y_node=[101, 1], y_edge=[480, 1], input_mask=[101, 1], num_nodes=101)

In [18]:
out_w, out_b = model(data)

In [19]:
out_w

tensor([[0.0687],
        [0.0846],
        [0.0982],
        [0.0769],
        [0.0726],
        [0.0891],
        [0.0833],
        [0.0696],
        [0.0821],
        [0.0983],
        [0.0883],
        [0.0723],
        [0.0846],
        [0.0756],
        [0.0711],
        [0.0864],
        [0.0695],
        [0.0969],
        [0.0780],
        [0.0768],
        [0.0984],
        [0.0980],
        [0.0852],
        [0.0719],
        [0.0918],
        [0.0758],
        [0.0988],
        [0.0725],
        [0.0764],
        [0.0728],
        [0.0980],
        [0.0827],
        [0.0818],
        [0.0842],
        [0.0753],
        [0.0756],
        [0.0742],
        [0.0689],
        [0.0879],
        [0.0809],
        [0.0718],
        [0.0901],
        [0.0784],
        [0.0963],
        [0.0843],
        [0.1008],
        [0.0976],
        [0.0991],
        [0.1008],
        [0.0953],
        [0.0745],
        [0.0810],
        [0.0860],
        [0.0986],
        [0.0920],
        [0

In [20]:
data.y_edge

tensor([[ 2.7527e-01],
        [ 3.7460e-01],
        [-1.6935e-01],
        [ 2.0094e-01],
        [ 2.4731e-01],
        [ 7.2946e-01],
        [ 7.7207e-01],
        [ 1.5824e-02],
        [-2.5034e-01],
        [-4.8785e-02],
        [-6.3134e-01],
        [ 7.9112e-01],
        [-1.7706e-01],
        [ 3.0932e-01],
        [ 9.2282e-02],
        [ 1.6429e-01],
        [ 3.1813e-01],
        [ 5.8020e-01],
        [ 3.4108e-02],
        [ 8.7037e-01],
        [-4.0597e-01],
        [-1.3940e-01],
        [-4.1489e-02],
        [ 7.2107e-01],
        [-6.0859e-01],
        [-3.5519e-01],
        [ 6.1911e-01],
        [ 8.3324e-02],
        [-4.2840e-01],
        [-2.2025e-01],
        [ 2.9421e-01],
        [ 3.9168e-01],
        [ 6.7508e-01],
        [-4.4235e-01],
        [-6.3572e-01],
        [-6.6365e-01],
        [ 5.1898e-01],
        [-9.0155e-02],
        [-3.5181e-01],
        [-7.2123e-02],
        [ 3.7966e-01],
        [ 7.0443e-01],
        [ 2.2730e-01],
        [-3

In [21]:
torch.save(model.state_dict(), "../model/model")
print("Model saved")

Model saved
