In [1]:
import torch
from climb_mlp_utils import load_and_preprocess_data, train_climb_generator, train_climb_generator_sequential

DATA_JSON_PATH = 'data/all-data.json'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
train_ds, val_ds, hold_map = load_and_preprocess_data(DATA_JSON_PATH, val_split=0.2, sequential=True)
generator_lstm = train_climb_generator_sequential(train_ds, val_ds, hold_map, num_epochs=5000)

Loading data from data/all-data.json...
Extracted 263 holds...
Extracted 564 sequences...
2862 training moves estimated with dataset augmentation...
Split: 451 Train / 113 Val


Training LSTM: 100%|█| 5000/5000 [21:54<00:00,  3.80it/s, T_MSE=0.0013, V_MSE=0


In [3]:
# 1. Load Data with Augmentation
train_ds, val_ds, hold_map = load_and_preprocess_data(DATA_JSON_PATH, val_split=0.2)
# 2. Train Model
generator = train_climb_generator(train_ds, val_ds, hold_map, num_epochs=2000, device=DEVICE)

Loading data from data/all-data.json...
Extracted 263 holds...
Extracted 554 sequences...
2862 training moves estimated with dataset augmentation...
Split: 443 Train / 111 Val


Training: 100%|██| 500/500 [01:37<00:00,  5.12it/s, T_MSE=0.0110, V_MSE=0.0368]


In [4]:
train_ds, val_ds, hold_map = load_and_preprocess_data(DATA_JSON_PATH, val_split=0.2, sequential=True)
generator_rnn = train_climb_generator_sequential(train_ds, val_ds, hold_map, model_type="rnn", num_epochs=2000)

Loading data from data/all-data.json...
Extracted 263 holds...
Extracted 438 sequences...
2862 training moves estimated with dataset augmentation...
Split: 350 Train / 88 Val


Training RNN: 100%|█| 400/400 [00:50<00:00,  7.94it/s, T_MSE=0.0049, V_MSE=0.00


In [7]:
climbs_to_watch = [[178,178],[206,139],[193,193],[176,187],[148,140],[176,176],[169,145],[177,173],[179,179],[181,172]]

for climb in climbs_to_watch:
    print(climb)
    # print("MLP Prediction:",generator.generate(climb[0],climb[1]))
    print("RNN Prediction:",generator_rnn.generate(climb[0],climb[1]))
    print("LSTM Prediction:",generator_lstm.generate(climb[0],climb[1]))

[178, 178]
RNN Prediction: [(178, 178)]
LSTM Prediction: [(178, 178), (178, 139), (50, 139), (50, 92), (74, 92), (74, 54), (51, 54), (51, 13), (54, 13), (54, 11), (18, 11)]
[206, 139]
RNN Prediction: [(206, 139), (50, 139), (50, 10), (2, 10), (46, 10), (46, 59), (46, 105), (21, 105), (21, 109), (21, 201), (221, 201)]
LSTM Prediction: [(206, 139), (82, 139), (82, 51), (2, 51), (2, 31), (2, 13), (30, 13)]
[193, 193]
RNN Prediction: [(193, 193), (193, 60), (26, 60), (26, 17), (16, 17)]
LSTM Prediction: [(193, 193), (193, 60), (69, 60), (69, 17), (16, 17)]
[176, 187]
RNN Prediction: [(176, 187), (94, 187), (94, 90), (52, 90), (52, 66), (52, 45), (45, 45), (45, 66), (45, 88), (90, 88), (86, 88)]
LSTM Prediction: [(176, 187), (68, 187), (68, 75), (19, 75), (19, 20)]
[148, 140]
RNN Prediction: [(148, 140), (90, 140), (90, 92), (29, 92), (29, 38)]
LSTM Prediction: [(148, 140), (148, 90), (110, 90), (110, 16), (26, 16)]
[176, 176]
RNN Prediction: [(176, 176), (120, 176), (120, 90), (94, 90), (1

In [8]:
climbs_to_watch = [[178,178],[206,139],[193,193],[176,187],[148,140],[176,176],[169,145],[177,173],[179,179],[181,172]]

for climb in climbs_to_watch:
    print(climb[0],climb[1])
    print(climb, generator_rnn.generate(climb[0],climb[1]))

178 178
[178, 178] [(178, 178), (178, 139), (178, 120), (153, 120), (153, 144), (153, 110), (82, 110), (69, 110), (50, 110), (162, 110)]
206 139
[206, 139] [(82, 139), (82, 86), (82, 60), (32, 60), (32, 43), (53, 43), (52, 43), (52, 88), (88, 88), (88, 128)]
193 193
[193, 193] [(193, 60), (16, 60), (16, 8), (1, 8)]
176 187
[176, 187] [(16, 187), (110, 187), (110, 51), (110, 19), (50, 19), (50, 86), (82, 86), (69, 86), (136, 86), (146, 86)]
148 140
[148, 140] [(90, 140), (111, 140), (118, 140), (82, 140), (122, 140), (122, 69), (69, 69), (69, 15), (69, 48), (122, 48)]
176 176
[176, 176] [(176, 138), (176, 43), (88, 43), (88, 20), (22, 20), (11, 20), (11, 22), (36, 22), (22, 22)]
169 145
[169, 145] [(107, 145), (68, 145), (68, 19), (68, 29), (68, 1), (68, 16), (53, 16), (68, 16), (144, 16), (144, 120)]
177 173
[177, 173] [(173, 173), (114, 173), (114, 75), (176, 75), (176, 68), (176, 140), (176, 120), (120, 120), (120, 148), (148, 148)]
179 179
[179, 179] [(163, 179), (163, 89), (163, 11

In [8]:
generator_lstm.generate(195,179)

[(195, 179),
 (140, 179),
 (140, 87),
 (89, 87),
 (89, 38),
 (38, 38),
 (38, 8),
 (38, 45),
 (38, 38),
 (45, 38)]

In [8]:
generator_lstm.generate(201,200)

[(201, 200),
 (160, 200),
 (138, 200),
 (138, 190),
 (138, 53),
 (19, 53),
 (19, 9),
 (18, 9)]

In [5]:
generator_lstm.generate(108,136)

[(108, 136), (108, 71), (108, 28), (3, 28), (3, 5), (4, 5), (4, 26), (4, 5)]