# Construct the training data into Data format

In [1]:
import sys
sys.path.append("../src")

import torch
import pandas as pd
from ast import literal_eval
from torch_geometric.data import Data
from reachability_model_function import build_graph_reachability
from dataset_function import generate_labeled_route_no_sides

In [24]:
dataset = pd.read_csv("../data/reachability_dataset.csv")

In [25]:
for name in dataset.columns:
    dataset[name] = dataset[name].apply(literal_eval)

In [None]:
# graph_list: each graph is a training sample
graph_list = []

for _, row in dataset.iterrows():
    labels = generate_labeled_route_no_sides(row["route"], row["hands"], row["feet"], row["climber"])
    data = build_graph_reachability(row["route"], row["hands"], row["feet"], row["climber"], labels)
    graph_list.append(data)

In [37]:
# Save the graph_list to local
torch.save(graph_list, "graph_list.pt")

# Train the Model

In [2]:
# Load pt file
graph_list = torch.load("graph_list.pt")

  graph_list = torch.load("graph_list.pt")


In [3]:
from torch_geometric.loader import DataLoader
from reachability_model_function import ReachabilityGNN
import torch.nn as nn

# shuffle=True: Shuffle the order of the graph at each epoch to enhance the model's generalisation ability.
loader = DataLoader(graph_list, batch_size=8, shuffle=True)
model = ReachabilityGNN()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(30):
    total_loss = 0
    model.train()
    for batch in loader:
        out = model(batch)
        loss = loss_fn(out, batch.y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch}, Loss: {total_loss:.4f}")

Epoch 0, Loss: 10200.2411
Epoch 1, Loss: 8529.8699
Epoch 2, Loss: 8369.2437
Epoch 3, Loss: 8352.5424
Epoch 4, Loss: 8339.2210
Epoch 5, Loss: 8335.7312
Epoch 6, Loss: 8337.0926
Epoch 7, Loss: 8331.2692
Epoch 8, Loss: 8331.8796
Epoch 9, Loss: 8328.4643
Epoch 10, Loss: 8329.8843
Epoch 11, Loss: 8325.8382
Epoch 12, Loss: 8326.3527
Epoch 13, Loss: 8325.7759
Epoch 14, Loss: 8327.4013
Epoch 15, Loss: 8326.0678
Epoch 16, Loss: 8325.5547
Epoch 17, Loss: 8323.3064
Epoch 18, Loss: 8323.7609
Epoch 19, Loss: 8321.7250
Epoch 20, Loss: 8325.9250
Epoch 21, Loss: 8322.0709
Epoch 22, Loss: 8321.8647
Epoch 23, Loss: 8326.3309
Epoch 24, Loss: 8322.4069
Epoch 25, Loss: 8323.5809
Epoch 26, Loss: 8319.9038
Epoch 27, Loss: 8319.8520
Epoch 28, Loss: 8320.8607
Epoch 29, Loss: 8320.0152


In [4]:
torch.save(model.state_dict(),"reachability_model.pt")