In [5]:
%load_ext autoreload
%autoreload 2

import torch

from torch.utils.data import DataLoader
from dataset import RecoveryRateDataset
import os
from training_utils import compute_mean_and_std, apply_std_scaling, set_seed
from model_lightning import LSTMEstimator
from training_loop import model_train

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint

SEED = 1234
seed_everything(SEED, workers=True)

import os
os.environ["WANDB_START_METHOD"] = "thread"
import wandb

Global seed set to 1234


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
root_dir = os.getcwd().split('Start_PyCharm')[0]
data_file_path = root_dir + 'Start_PyCharm/state_vector_dataframes_as_training_data/data/airport_state_df_w_features.pickle'


SEQUENCE_LENGTH = 5

features_to_drop = ['reg_type', 'reg_bool_type', 'reg_cause']
rr_dataset = RecoveryRateDataset(data_file_path, features_to_drop=features_to_drop, sequence_length=SEQUENCE_LENGTH,
                                 fill_with='backfill')


train_val_test_ratios = (0.6, 0.2, 0.2)
train, val, test = torch.utils.data.random_split(rr_dataset, lengths=[int(len(rr_dataset)*ratio)
                                                                      for ratio in train_val_test_ratios])

#Compute mean and std for standard scaling.
training_mean, training_std = compute_mean_and_std(train)

scaled_training_data = apply_std_scaling(train, training_mean, training_std)
scaled_val_data = apply_std_scaling(val, training_mean, training_std)
scaled_test_data = apply_std_scaling(test, training_mean, training_std)

In [None]:
from pytorch_lightning.loggers import WandbLogger

#Training parameters.
BATCH_SIZE = 32
LR = 0.001
N_EPOCHS = 500

training_data_loader = DataLoader(scaled_training_data, batch_size=BATCH_SIZE, shuffle=True,
                                  num_workers=4)
val_data_loader = DataLoader(scaled_val_data, batch_size=64, num_workers=4)
test_data_loader = DataLoader(scaled_test_data)


model = LSTMEstimator(len(rr_dataset.feature_names), initial_dense_layer_size=50, dense_parameter_multiplier=2,
                      dense_layer_count=3, lstm_layer_count=3, lstm_hidden_units=50, sequence_length=SEQUENCE_LENGTH)

wandb_instance = wandb.init(project='recovery_rate')

wandb_logger = WandbLogger(project='recovery_rate', log_model="all")
checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode="min")

wandb_logger.experiment.config.update({'batch_size': BATCH_SIZE, 'init_lr': LR, 'scaling_mean': training_mean, 
                     'scaling_std': training_std})

lightning_trainer = Trainer(logger=wandb_logger, callbacks=[checkpoint_callback])
lightning_trainer.fit(model, training_data_loader, val_data_loader)

  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name                      | Type       | Params
---------------------------------------------------------
0 | feature_extracting_layers | Sequential | 25.9 K
1 | lstm_layer                | LSTM       | 91.2 K
2 | estimator_layers          | Sequential | 30.7 K
---------------------------------------------------------
147 K     Trainable params
0         Non-trainable params
147 K     Total params
0.591     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 1234


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fcba6901040>
Traceback (most recent call last):
  File "/home/aksoy/PyEnvs/start/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fcba6901040>self._shutdown_workers()

Traceback (most recent call last):
  File "/home/aksoy/PyEnvs/start/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
      File "/home/aksoy/PyEnvs/start/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
assert self._parent_pid == os.getpid(), 'can only test a child process'
    AssertionError: self._shutdown_workers()can only test a child process

  File "/home/aksoy/PyEnvs/start/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shut

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fcba6901040>
Traceback (most recent call last):
  File "/home/aksoy/PyEnvs/start/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/aksoy/PyEnvs/start/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fcba6901040>  File "/usr/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fcba6901040>

    Traceback (most recent call last):
  File "/home/aksoy/PyEnvs/start/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fcba6901040>
Traceback (most recent call last):
  File "/home/aksoy/PyEnvs/start/lib/python3.

    self._shutdown_workers()
  File "/home/aksoy/PyEnvs/start/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fcba6901040>Exception ignored in:   File "/usr/lib/python3.9/multiprocessing/process.py", line 160, in is_alive

Traceback (most recent call last):
  File "/home/aksoy/PyEnvs/start/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    Exception ignored in: self._shutdown_workers()<function _MultiProcessingDataLoaderIter.__del__ at 0x7fcba6901040>    
Traceback (most recent call last):

  File "/home/aksoy/PyEnvs/start/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
  File "/home/aksoy/PyEnvs/start/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shut

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [7]:
wandb.finish()

VBox(children=(Label(value=' 3.42MB of 3.42MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,▁▁██
trainer/global_step,▁▁██
training_loss,█▁
val_loss,█▁

0,1
epoch,1.0
trainer/global_step,4333.0
training_loss,0.69861
val_loss,0.68027
