In [30]:
%load_ext autoreload
%autoreload 2

import torch
from torch.utils.data import DataLoader
from dataset import RecoveryRateDataset
import os
from scaling_utils import compute_mean_and_std, apply_std_scaling
from model import LSTMEstimator
from training_loop import model_train

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
root_dir = os.getcwd().split('Start_PyCharm')[0]
data_file_path = root_dir + 'Start_PyCharm/state_vector_dataframes_as_training_data/data/airport_state_df_w_features.pickle'


SEQUENCE_LENGTH = 5

features_to_drop = ['reg_type', 'reg_bool_type', 'reg_cause']
rr_dataset = RecoveryRateDataset(data_file_path, features_to_drop=features_to_drop, sequence_length=SEQUENCE_LENGTH,
                                 fill_with='backfill')


train_val_test_ratios = (0.6, 0.2, 0.2)
train, val, test = torch.utils.data.random_split(rr_dataset, lengths=[int(len(rr_dataset)*ratio)
                                                                      for ratio in train_val_test_ratios])

#Compute mean and std for standard scaling.
training_mean, training_std = compute_mean_and_std(train)

scaled_training_data = apply_std_scaling(train, training_mean, training_std)
scaled_val_data = apply_std_scaling(val, training_mean, training_std)
scaled_test_data = apply_std_scaling(test, training_mean, training_std)


#Training parameters.
BATCH_SIZE = 64
LR = 0.01
N_EPOCHS = 50

training_data_loader = DataLoader(scaled_training_data, batch_size=BATCH_SIZE, shuffle=True)
val_data_loader = DataLoader(scaled_val_data)
test_data_loader = DataLoader(scaled_test_data)

In [31]:
model = LSTMEstimator(len(rr_dataset.feature_names), initial_dense_layer_size=50, dense_parameter_multiplier=2,
                      dense_layer_count=3, lstm_layer_count=3, lstm_hidden_units=50, sequence_length=SEQUENCE_LENGTH)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
lr_scheduler = None # Optional
loss_func = torch.nn.MSELoss()

model_save_path='best_model'

train_history = []
val_history = []

model_train(model, optimizer, loss_func, N_EPOCHS, training_data_loader, val_data_loader, train_history, val_history, model_save_path)

100%|██████████| 1084/1084 [00:11<00:00, 90.98it/s]



 Starting epoch 0, loss: 0.86113457270379, val_loss: 0.01903850957751274 lr= 0.01

 Validation loss inf --> 0.7828313380863088. Saving best model.


100%|██████████| 1084/1084 [00:16<00:00, 65.24it/s]



 Starting epoch 1, loss: 0.8560404142165536, val_loss: 0.016205113381147385 lr= 0.01


100%|██████████| 1084/1084 [00:18<00:00, 57.77it/s]



 Starting epoch 2, loss: 0.8611819699983975, val_loss: 0.0005399854271672666 lr= 0.01


100%|██████████| 1084/1084 [00:18<00:00, 59.55it/s]



 Starting epoch 3, loss: 0.8745986378555808, val_loss: 0.0007412534905597568 lr= 0.01


100%|██████████| 1084/1084 [00:16<00:00, 65.38it/s]



 Starting epoch 4, loss: 0.974227848074423, val_loss: 0.0033417213708162308 lr= 0.01


100%|██████████| 1084/1084 [00:17<00:00, 63.59it/s]



 Starting epoch 5, loss: 0.9646240338215749, val_loss: 0.04323336109519005 lr= 0.01


100%|██████████| 1084/1084 [00:14<00:00, 72.67it/s]



 Starting epoch 6, loss: 0.9709728184633809, val_loss: 0.043556809425354004 lr= 0.01


100%|██████████| 1084/1084 [00:15<00:00, 71.72it/s]



 Starting epoch 7, loss: 0.979104260456958, val_loss: 0.22224202752113342 lr= 0.01


100%|██████████| 1084/1084 [00:15<00:00, 71.20it/s]



 Starting epoch 8, loss: 0.9531946766063076, val_loss: 0.23949721455574036 lr= 0.01


100%|██████████| 1084/1084 [00:14<00:00, 73.65it/s]



 Starting epoch 9, loss: 0.9474601257606187, val_loss: 0.1415201723575592 lr= 0.01


100%|██████████| 1084/1084 [00:13<00:00, 77.51it/s]



 Starting epoch 10, loss: 0.9775061268923028, val_loss: 0.4015291631221771 lr= 0.01


100%|██████████| 1084/1084 [00:14<00:00, 75.45it/s]



 Starting epoch 11, loss: 1.000826184668774, val_loss: 0.5021031498908997 lr= 0.01


100%|██████████| 1084/1084 [00:13<00:00, 77.50it/s]



 Starting epoch 12, loss: 1.0004991268622259, val_loss: 0.032570235431194305 lr= 0.01


100%|██████████| 1084/1084 [00:14<00:00, 73.09it/s]



 Starting epoch 13, loss: 1.0210225213992639, val_loss: 0.10190396010875702 lr= 0.01


100%|██████████| 1084/1084 [00:14<00:00, 77.36it/s]



 Starting epoch 14, loss: 1.0196811017267597, val_loss: 0.08180105686187744 lr= 0.01


100%|██████████| 1084/1084 [00:14<00:00, 75.81it/s]



 Starting epoch 15, loss: 1.0286114615735313, val_loss: 0.20305150747299194 lr= 0.01


100%|██████████| 1084/1084 [00:14<00:00, 74.60it/s]



 Starting epoch 16, loss: 1.0380197776777058, val_loss: 0.3330211937427521 lr= 0.01


100%|██████████| 1084/1084 [00:14<00:00, 76.97it/s]



 Starting epoch 17, loss: 1.0479175308203785, val_loss: 0.40643101930618286 lr= 0.01


100%|██████████| 1084/1084 [00:14<00:00, 74.66it/s]



 Starting epoch 18, loss: 1.0383080220832817, val_loss: 0.12143111228942871 lr= 0.01


100%|██████████| 1084/1084 [00:15<00:00, 72.07it/s]



 Starting epoch 19, loss: 1.0343477584256677, val_loss: 0.0019087089458480477 lr= 0.01


100%|██████████| 1084/1084 [00:13<00:00, 78.37it/s]



 Starting epoch 20, loss: 1.03435538245908, val_loss: 0.09725890308618546 lr= 0.01


100%|██████████| 1084/1084 [00:14<00:00, 75.24it/s]



 Starting epoch 21, loss: 1.0334483247209079, val_loss: 0.08803870528936386 lr= 0.01


100%|██████████| 1084/1084 [00:14<00:00, 77.14it/s]



 Starting epoch 22, loss: 1.0364935061863427, val_loss: 0.02560782991349697 lr= 0.01


100%|██████████| 1084/1084 [00:13<00:00, 77.65it/s]



 Starting epoch 23, loss: 1.1376558558472407, val_loss: 0.03475113585591316 lr= 0.01


100%|██████████| 1084/1084 [00:14<00:00, 74.29it/s]



 Starting epoch 24, loss: 1.1369168285239226, val_loss: 0.04594578966498375 lr= 0.01


100%|██████████| 1084/1084 [00:15<00:00, 71.16it/s]



 Starting epoch 25, loss: 1.1373353656587566, val_loss: 0.03674255311489105 lr= 0.01


100%|██████████| 1084/1084 [00:14<00:00, 74.41it/s]



 Starting epoch 26, loss: 1.1373854179168978, val_loss: 0.04330998659133911 lr= 0.01


100%|██████████| 1084/1084 [00:15<00:00, 70.92it/s]


KeyboardInterrupt: 