In [1]:
import json
import sys
sys.path.append("../src")

from model_function import build_graph_from_sample

# Construct the GNN Data
with open("../data/casuals_feas.json") as f:
    casuals = json.load(f)

# Remove the climbing point <= 6
casuals = [c for c in casuals if c["features"]["n_points"]>6]

graph_list = [build_graph_from_sample(sample) for sample in casuals]

In [6]:
import torch.optim as optim
from model_function import ClimbGNN
import torch.nn as nn

model = ClimbGNN()
# Use Adam optimizer to update the model parameters
optimizer = optim.Adam(model.parameters(),lr=0.01)

# loss function to calculate the loss
loss_fn = nn.CrossEntropyLoss()

for epoch in range(10):
    total_loss = 0
    correct = 0
    for graph in graph_list:
        # Clear last round gradient
        optimizer.zero_grad()
        # forward: get the predicted result
        out = model(graph)
        # calculate the loss between predicted value and real label
        loss = loss_fn(out,graph.y)
        # calculate the gradient
        loss.backward()
        # Use gradient to update the model parameters
        optimizer.step()
        total_loss += loss.item()

        pred = model(graph).argmax(dim=1)
        correct += int(pred == graph.y)

    print(f"Epoch {epoch+1}: Loss = {total_loss:.4f}, Accuracy: {correct / len(graph_list):.4f}")

Epoch 1: Loss = 1794.3580, Accuracy: 0.9649
Epoch 2: Loss = 694.3920, Accuracy: 0.9653
Epoch 3: Loss = 689.1733, Accuracy: 0.9653
Epoch 4: Loss = 687.7087, Accuracy: 0.9653
Epoch 5: Loss = 686.7751, Accuracy: 0.9653
Epoch 6: Loss = 685.5256, Accuracy: 0.9653
Epoch 7: Loss = 683.4602, Accuracy: 0.9653
Epoch 8: Loss = 682.2746, Accuracy: 0.9653
Epoch 9: Loss = 682.2736, Accuracy: 0.9653
Epoch 10: Loss = 682.2736, Accuracy: 0.9653


In [9]:
import torch

# Save the model
torch.save(model.state_dict(),"climb_gnn_model.pth")

# # Load
# model = ClimbGNN()
# model.load_state_dict(torch.load("climb_gnn_model.pth"))
# model.eval()