# Purpose
This notebook trains the TrainerGCN using the provided dataset

### Imports and Setup

In [None]:
cd ..

In [None]:
# Downloaded libraries
import torch
from torch import nn
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T


# Local files
from datasets.graph_data import NNDataset
from models.graph import TrainerGCN

In [None]:
# Debugging Settings
torch.set_printoptions(threshold=12500)

In [None]:
# Constants
TRAINING_SPLIT = 0.8

In [None]:
# Hyperparameters
num_epoch = 8
batch_size = 8

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Data Loading

In [None]:
transform = None
# transform = T.Compose([T.ToUndirected()])

In [None]:
nndataset = NNDataset(root="../", transform=transform)

size = len(nndataset)
train_num = int(size * TRAINING_SPLIT)
test_num = size - train_num

print(
    f"Dataset loaded, {train_num} training samples and {test_num} testing samples")

In [None]:
# Preview of the Data

data = nndataset[0]
data

In [None]:
data.is_undirected()

In [None]:
train_loader = DataLoader(
    dataset=nndataset[:train_num], batch_size=batch_size, shuffle=True)
test_loader = DataLoader(
    dataset=nndataset[train_num:], batch_size=test_num, shuffle=True)

In [None]:
print("Number of batches:", int(train_num / batch_size))

### Loading the Model

In [None]:
model = TrainerGCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
loss_fn = nn.MSELoss()

In [None]:
model

### Training and Evaluation

In [None]:
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    for i, data in enumerate(dataloader):
        data.to(device)

        # forward propagation
        out_w, out_b = model(data)
        loss = 0
        loss += loss_fn(out_b, data.y_node)
        loss += loss_fn(out_w, data.y_edge)

        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print status every n batches
        if i % 35 == 0:
            loss, current = loss.item(), i * batch_size
            print(
                f"Training Loss: {loss:>7f}  [{current:>5d}/{train_num:>5d}]")

In [None]:
def test(dataloader, model, loss_fn):
    model.eval()
    with torch.no_grad():
        data = iter(dataloader).next().to(device)

        # forward propagation
        out_w, out_b = model(data)
        loss = loss_fn(out_b, data.y_node) + loss_fn(out_w, data.y_edge)

        loss = loss.item()
        print(f"Validation Loss: {loss:>7f}")


In [None]:
# Model Training
for epoch in range(num_epoch):
    print(f"Epoch {epoch + 1} / {num_epoch}:")
    train(train_loader, model, loss_fn, optimizer)
    test(test_loader, model, loss_fn)

### Comparison Using One Instance of Data

In [None]:
data = nndataset[0]
data = data.to(device)
data

In [None]:
data.design

In [None]:
out_w, out_b = model(data)

In [None]:
out_w

In [None]:
data.y_edge

In [None]:
torch.save(model.state_dict(), "../trained_model/gnnmodel")
print("Model saved")