In [1]:
import torch
from climb_mlp_utils import load_and_preprocess_data, train_climb_generator, train_climb_generator_sequential, ClimbLSTM, ClimbGeneratorSequential

DATA_JSON_PATH = 'data/all-data.json'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
lstm_model = ClimbLSTM()
lstm_model.load_state_dict(torch.load('best_climb_lstm.pth', weights_only=True))
train_ds, val_ds, hold_map = load_and_preprocess_data(DATA_JSON_PATH, val_split=0.0, sequential=True)
generator_lstm = ClimbGeneratorSequential(lstm_model, hold_map, 'cpu')

Loading data from data/all-data.json...
Extracted 263 holds...
Extracted 816 sequences...
3264 training moves estimated with dataset augmentation...
Split: 816 Train / 0 Val


In [4]:
train_ds, val_ds, hold_map = load_and_preprocess_data(DATA_JSON_PATH, val_split=0.01, sequential=True)
generator_lstm = train_climb_generator_sequential(train_ds, val_ds, hold_map, model_type="lstm", num_epochs=5000)

Loading data from data/all-data.json...
Extracted 263 holds...
Extracted 816 sequences...
3264 training moves estimated with dataset augmentation...
Split: 807 Train / 9 Val


Training LSTM: 100%|█| 5000/5000 [38:42<00:00,  2.15it/s, T_MSE=0.0021, V_MSE=0


In [2]:
# 1. Load Data with Augmentation
train_ds, val_ds, hold_map = load_and_preprocess_data(DATA_JSON_PATH, val_split=0.2)
# 2. Train Model
generator = train_climb_generator(train_ds, val_ds, hold_map, num_epochs=5000, device=DEVICE)

Loading data from data/all-data.json...
Extracted 263 holds...
Extracted 962 sequences...
3848 training moves estimated with dataset augmentation...
Split: 769 Train / 193 Val


Training: 100%|█| 5000/5000 [28:48<00:00,  2.89it/s, T_MSE=0.0131, V_MSE=0.0473


In [4]:
train_ds, val_ds, hold_map = load_and_preprocess_data(DATA_JSON_PATH, val_split=0.2, sequential=True)
generator_rnn = train_climb_generator_sequential(train_ds, val_ds, hold_map, model_type="rnn", num_epochs=2000)

Loading data from data/all-data.json...
Extracted 263 holds...
Extracted 438 sequences...
2862 training moves estimated with dataset augmentation...
Split: 350 Train / 88 Val


Training RNN: 100%|█| 400/400 [00:50<00:00,  7.94it/s, T_MSE=0.0049, V_MSE=0.00


In [3]:
climbs_to_watch = [[178,178],[206,139],[193,193],[176,187],[148,140],[176,176],[169,145],[177,173],[179,179],[181,172]]

for climb in climbs_to_watch:
    print(climb)
    # print("MLP Prediction:",generator.generate(climb[0],climb[1]))
    print("RNN Prediction:",generator_rnn.generate(climb[0],climb[1]))
    print("LSTM Prediction:",generator_lstm.generate(climb[0],climb[1]))

[178, 178]


NameError: name 'generator_rnn' is not defined

In [8]:
climbs_to_watch = [[178,178],[206,139],[193,193],[176,187],[148,140],[176,176],[169,145],[177,173],[179,179],[181,172]]

for climb in climbs_to_watch:
    print(climb[0],climb[1])
    print(climb, generator_rnn.generate(climb[0],climb[1]))

178 178
[178, 178] [(178, 178), (178, 139), (178, 120), (153, 120), (153, 144), (153, 110), (82, 110), (69, 110), (50, 110), (162, 110)]
206 139
[206, 139] [(82, 139), (82, 86), (82, 60), (32, 60), (32, 43), (53, 43), (52, 43), (52, 88), (88, 88), (88, 128)]
193 193
[193, 193] [(193, 60), (16, 60), (16, 8), (1, 8)]
176 187
[176, 187] [(16, 187), (110, 187), (110, 51), (110, 19), (50, 19), (50, 86), (82, 86), (69, 86), (136, 86), (146, 86)]
148 140
[148, 140] [(90, 140), (111, 140), (118, 140), (82, 140), (122, 140), (122, 69), (69, 69), (69, 15), (69, 48), (122, 48)]
176 176
[176, 176] [(176, 138), (176, 43), (88, 43), (88, 20), (22, 20), (11, 20), (11, 22), (36, 22), (22, 22)]
169 145
[169, 145] [(107, 145), (68, 145), (68, 19), (68, 29), (68, 1), (68, 16), (53, 16), (68, 16), (144, 16), (144, 120)]
177 173
[177, 173] [(173, 173), (114, 173), (114, 75), (176, 75), (176, 68), (176, 140), (176, 120), (120, 120), (120, 148), (148, 148)]
179 179
[179, 179] [(163, 179), (163, 89), (163, 11

In [3]:
generator_lstm.generate(211,198,temperature=.01)

[(211, 198), (211, 165), (90, 165), (90, 88), (16, 88), (16, 6)]

In [4]:
generator_lstm.generate(199,159)

[(199, 159),
 (156, 159),
 (85, 159),
 (164, 159),
 (84, 159),
 (84, 66),
 (66, 66),
 (66, 66)]

In [8]:
for i in range(5):
    print(generator_lstm.generate(202,139,temperature=0.02))

[(202, 139), (202, 178), (202, 110), (168, 110), (168, 78), (168, 153), (168, 86), (168, 28), (28, 28), (28, 28)]
[(202, 139), (221, 139), (221, 118), (67, 118), (69, 118), (69, 17), (69, 28), (4, 28), (4, 4)]
[(202, 139), (202, 178), (202, 110), (69, 110), (69, 143), (69, 26), (15, 26), (3, 26), (3, 4)]
[(202, 139), (202, 178), (202, 110), (82, 110), (82, 50), (82, 5), (3, 5)]
[(202, 139), (202, 178), (202, 110), (168, 110), (168, 125), (168, 69), (3, 69), (3, 28), (3, 3)]


In [8]:
generator_lstm.generate(14,14)

[(14, 14), (14, 17), (34, 17), (50, 17), (50, 17)]