In [None]:
import os
import pickle
import torch
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter

In [2]:
from utils import *
from models import *

## Experiment 1

* file: 1_single
* window size = 10, horizon = 10, max bird speed = 10 m/s, batch size = 64
* Single LSTM cell with hidden dim = 32 + Linear Layer
* Time not included in prediction

In [None]:
filepath = 'data_raw/1_single.csv'

# Load and preprocess dataset
df = load_and_preprocess_df(filepath, max_speed = 10, sampling_interval = '1s')

# Create windows
X, y = windowify(df, window_size = 10, horizon = 10, cols = ['time', 'x', 'y', 'z'])

# Save windows
X_filepath = '\window10_horizon_10_file_1_X.pkl'
y_filepath = '\window10_horizon_10_file_1_y.pkl'
save_windows(X, y, X_filepath, y_filepath)

# Create train, val and test sets
X_train, y_train, X_val, y_val, X_test, y_test = prepare_data(X, y)

In [15]:
# Instantiate model, writer, optimizer and loss function
# Model parameters
INPUT_DIM = 3
HIDDEN_DIM = 32
OUTPUT_DIM = 3
NUM_LAYERS = 1

# Instantiate model, optimizer and loss function
model_1 = LSTMModel(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM, NUM_LAYERS, horizon = 10)
optimizer = torch.optim.Adam(params = model_1.parameters())
loss_fn = torch.nn.functional.mse_loss
writer = SummaryWriter()

# Display model's summary
summary(model_1)

Layer (type:depth-idx)                   Param #
LSTMModel                                --
├─LSTM: 1-1                              4,736
├─Linear: 1-2                            99
Total params: 4,835
Trainable params: 4,835
Non-trainable params: 0

In [16]:
# Train model
BATCH_SIZE = 64
NUM_EPOCHS = 50000
train(model_1, X_train, y_train, BATCH_SIZE, NUM_EPOCHS, optimizer, loss_fn, X_val, y_val, writer)
# Close writer
writer.close()
# Save model
torch.save(model_1, 'model_1.pt')

  0%|                                                                               | 12/50000 [00:00<14:44, 56.55it/s]

Training loss at epoch 0: 125.63140106201172 and val loss: 129.08534240722656


  2%|█▌                                                                           | 1008/50000 [00:17<13:06, 62.28it/s]

Training loss at epoch 1000: 107.7870864868164 and val loss: 111.30915069580078


  4%|███                                                                          | 2007/50000 [00:33<12:43, 62.85it/s]

Training loss at epoch 2000: 93.44876098632812 and val loss: 98.09645080566406


  6%|████▋                                                                        | 3013/50000 [00:50<12:50, 60.98it/s]

Training loss at epoch 3000: 81.551513671875 and val loss: 87.49923706054688


  8%|██████▏                                                                      | 4008/50000 [01:06<13:17, 57.67it/s]

Training loss at epoch 4000: 76.5997314453125 and val loss: 78.41646575927734


 10%|███████▋                                                                     | 5011/50000 [01:22<12:22, 60.56it/s]

Training loss at epoch 5000: 71.7087173461914 and val loss: 70.53697204589844


 12%|█████████▎                                                                   | 6014/50000 [01:39<11:38, 62.93it/s]

Training loss at epoch 6000: 69.49211120605469 and val loss: 63.4941520690918


 14%|██████████▊                                                                  | 7011/50000 [01:55<11:07, 64.39it/s]

Training loss at epoch 7000: 54.9080696105957 and val loss: 57.7381477355957


 16%|████████████▎                                                                | 8007/50000 [02:12<13:39, 51.25it/s]

Training loss at epoch 8000: 49.612205505371094 and val loss: 52.92089080810547


 18%|█████████████▉                                                               | 9010/50000 [02:28<11:19, 60.32it/s]

Training loss at epoch 9000: 42.65231704711914 and val loss: 48.689666748046875


 20%|███████████████▏                                                            | 10007/50000 [02:44<10:34, 63.00it/s]

Training loss at epoch 10000: 32.0245475769043 and val loss: 44.50851821899414


 22%|████████████████▋                                                           | 11012/50000 [03:01<10:10, 63.85it/s]

Training loss at epoch 11000: 44.18962860107422 and val loss: 40.218223571777344


 24%|██████████████████▎                                                         | 12012/50000 [03:17<10:24, 60.81it/s]

Training loss at epoch 12000: 38.53528594970703 and val loss: 36.757259368896484


 26%|███████████████████▊                                                        | 13010/50000 [03:34<09:50, 62.69it/s]

Training loss at epoch 13000: 27.5701847076416 and val loss: 33.401390075683594


 28%|█████████████████████▎                                                      | 14013/50000 [03:50<09:55, 60.45it/s]

Training loss at epoch 14000: 33.37895584106445 and val loss: 30.41472625732422


 30%|██████████████████████▊                                                     | 15011/50000 [04:06<10:09, 57.39it/s]

Training loss at epoch 15000: 14.925990104675293 and val loss: 27.803878784179688


 32%|████████████████████████▎                                                   | 16006/50000 [04:23<09:03, 62.57it/s]

Training loss at epoch 16000: 17.380069732666016 and val loss: 25.085399627685547


 34%|█████████████████████████▊                                                  | 17009/50000 [04:39<09:25, 58.34it/s]

Training loss at epoch 17000: 19.043352127075195 and val loss: 23.048891067504883


 36%|███████████████████████████▍                                                | 18010/50000 [04:56<08:56, 59.57it/s]

Training loss at epoch 18000: 18.2707462310791 and val loss: 21.29773712158203


 38%|████████████████████████████▉                                               | 19006/50000 [05:12<08:53, 58.08it/s]

Training loss at epoch 19000: 18.818836212158203 and val loss: 20.004179000854492


 40%|██████████████████████████████▍                                             | 20013/50000 [05:28<07:48, 63.96it/s]

Training loss at epoch 20000: 15.24288558959961 and val loss: 18.760276794433594


 42%|███████████████████████████████▉                                            | 21013/50000 [05:45<07:51, 61.50it/s]

Training loss at epoch 21000: 15.939201354980469 and val loss: 17.919349670410156


 44%|█████████████████████████████████▍                                          | 22010/50000 [06:01<07:59, 58.43it/s]

Training loss at epoch 22000: 12.790146827697754 and val loss: 17.331464767456055


 46%|██████████████████████████████████▉                                         | 23009/50000 [06:18<07:07, 63.17it/s]

Training loss at epoch 23000: 11.737263679504395 and val loss: 16.273038864135742


 48%|████████████████████████████████████▍                                       | 24007/50000 [06:34<06:55, 62.50it/s]

Training loss at epoch 24000: 13.700789451599121 and val loss: 15.865751266479492


 50%|██████████████████████████████████████                                      | 25011/50000 [06:50<06:46, 61.40it/s]

Training loss at epoch 25000: 17.39759063720703 and val loss: 15.615954399108887


 52%|███████████████████████████████████████▌                                    | 26010/50000 [07:07<06:27, 61.93it/s]

Training loss at epoch 26000: 11.124580383300781 and val loss: 14.91318416595459


 54%|█████████████████████████████████████████                                   | 27012/50000 [07:23<06:04, 63.09it/s]

Training loss at epoch 27000: 12.322161674499512 and val loss: 14.596348762512207


 56%|██████████████████████████████████████████▌                                 | 28010/50000 [07:40<06:03, 60.42it/s]

Training loss at epoch 28000: 11.877010345458984 and val loss: 14.143553733825684


 58%|████████████████████████████████████████████                                | 29007/50000 [07:56<05:29, 63.78it/s]

Training loss at epoch 29000: 12.425487518310547 and val loss: 14.10316276550293


 60%|█████████████████████████████████████████████▌                              | 30011/50000 [08:12<05:32, 60.05it/s]

Training loss at epoch 30000: 12.47988510131836 and val loss: 14.081989288330078


 62%|███████████████████████████████████████████████▏                            | 31009/50000 [08:29<05:25, 58.38it/s]

Training loss at epoch 31000: 11.981176376342773 and val loss: 13.38635540008545


 64%|████████████████████████████████████████████████▋                           | 32008/50000 [08:45<04:46, 62.73it/s]

Training loss at epoch 32000: 12.641252517700195 and val loss: 13.62569522857666


 66%|██████████████████████████████████████████████████▏                         | 33010/50000 [09:02<04:31, 62.52it/s]

Training loss at epoch 33000: 11.698464393615723 and val loss: 13.2427978515625


 68%|███████████████████████████████████████████████████▋                        | 34009/50000 [09:18<04:27, 59.87it/s]

Training loss at epoch 34000: 10.812955856323242 and val loss: 12.795170783996582


 70%|█████████████████████████████████████████████████████▏                      | 35010/50000 [09:34<04:18, 57.93it/s]

Training loss at epoch 35000: 10.254952430725098 and val loss: 12.537083625793457


 72%|██████████████████████████████████████████████████████▋                     | 36013/50000 [09:51<03:47, 61.37it/s]

Training loss at epoch 36000: 10.223186492919922 and val loss: 12.377864837646484


 74%|████████████████████████████████████████████████████████▎                   | 37007/50000 [10:07<03:28, 62.35it/s]

Training loss at epoch 37000: 11.729140281677246 and val loss: 12.403192520141602


 76%|█████████████████████████████████████████████████████████▊                  | 38010/50000 [10:24<03:08, 63.70it/s]

Training loss at epoch 38000: 11.525973320007324 and val loss: 12.055245399475098


 78%|███████████████████████████████████████████████████████████▎                | 39008/50000 [10:40<03:03, 59.92it/s]

Training loss at epoch 39000: 10.284234046936035 and val loss: 12.386114120483398


 80%|████████████████████████████████████████████████████████████▊               | 40013/50000 [10:57<02:42, 61.62it/s]

Training loss at epoch 40000: 10.6660737991333 and val loss: 11.966856002807617


 82%|██████████████████████████████████████████████████████████████▎             | 41006/50000 [11:13<02:35, 57.70it/s]

Training loss at epoch 41000: 10.859447479248047 and val loss: 11.925691604614258


 84%|███████████████████████████████████████████████████████████████▊            | 42007/50000 [11:29<02:08, 62.11it/s]

Training loss at epoch 42000: 11.322870254516602 and val loss: 12.34808349609375


 86%|█████████████████████████████████████████████████████████████████▎          | 43005/50000 [11:46<01:51, 62.58it/s]

Training loss at epoch 43000: 12.11783504486084 and val loss: 11.92151927947998


 88%|██████████████████████████████████████████████████████████████████▉         | 44013/50000 [12:03<01:35, 62.87it/s]

Training loss at epoch 44000: 11.675432205200195 and val loss: 11.694114685058594


 90%|████████████████████████████████████████████████████████████████████▍       | 45006/50000 [12:19<01:20, 62.21it/s]

Training loss at epoch 45000: 10.624152183532715 and val loss: 11.555985450744629


 92%|█████████████████████████████████████████████████████████████████████▉      | 46007/50000 [12:36<01:05, 60.54it/s]

Training loss at epoch 46000: 11.38594913482666 and val loss: 11.48325252532959


 94%|███████████████████████████████████████████████████████████████████████▍    | 47009/50000 [12:53<00:51, 58.31it/s]

Training loss at epoch 47000: 11.057711601257324 and val loss: 11.506850242614746


 96%|████████████████████████████████████████████████████████████████████████▉   | 48011/50000 [13:09<00:34, 57.83it/s]

Training loss at epoch 48000: 10.112452507019043 and val loss: 11.138969421386719


 98%|██████████████████████████████████████████████████████████████████████████▍ | 49011/50000 [13:26<00:16, 60.77it/s]

Training loss at epoch 49000: 10.896770477294922 and val loss: 11.551279067993164


100%|████████████████████████████████████████████████████████████████████████████| 50000/50000 [13:42<00:00, 60.77it/s]


In [None]:
# See results of training
%load_ext tensorboard
%tensorboard --logdir=runs

In [20]:
### TRAIN SET ###
# Get model's predictions of NN
model_1.eval()
with torch.inference_mode():
    y_pred = model_1(X_train)

# Calculate total (mean) train error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_train))
print(f'Total (mean) train loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on train set
y_pred_naive = naive_forecast(X_train, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_train))
print(f'Total (mean) train loss of naive forecast: {naive_loss.item()}')

### VAL SET ###
# Get model's predictions of NN
model_1.eval()
with torch.inference_mode():
    y_pred = model_1(X_val)

# Calculate total (mean) val error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_val))
print(f'Total (mean) val loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on val set
y_pred_naive = naive_forecast(X_val, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_val))
print(f'Total (mean) val loss of naive forecast: {naive_loss.item()}')

### TEST SET ###
# Get model's predictions of NN
model_1.eval()
with torch.inference_mode():
    y_pred = model_1(X_test)

# Calculate total (mean) test error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_test))
print(f'Total (mean) test loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on test set
y_pred_naive = naive_forecast(X_test, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_test))
print(f'Total (mean) test loss of naive forecast: {naive_loss.item()}')

Total (mean) train loss of NN: 10.110610961914062
Total (mean) train loss of naive forecast: 12.17895221710205
Total (mean) val loss of NN: 11.304615020751953
Total (mean) val loss of naive forecast: 11.918389320373535
Total (mean) test loss of NN: 11.819862365722656
Total (mean) test loss of naive forecast: 12.19200611114502


## Experiment 2 - same as exp 1 but bigger LSTM cell

* file: 1_single
* window size = 10, horizon = 10, max bird speed = 10 m/s, batch size = 64
* Single LSTM cell with hidden dim = 128 + Linear Layer
* Time not included in prediction

In [26]:
# Instantiate model, writer, optimizer and loss function
# Model parameters
INPUT_DIM = 3
HIDDEN_DIM = 128
OUTPUT_DIM = 3
NUM_LAYERS = 1

# Instantiate model, optimizer and loss function
model_2 = LSTMModel(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM, NUM_LAYERS, horizon = 10)
optimizer = torch.optim.Adam(params = model_2.parameters())
loss_fn = torch.nn.functional.mse_loss
writer = SummaryWriter()

# Display model's summary
summary(model_2)

Layer (type:depth-idx)                   Param #
LSTMModel                                --
├─LSTM: 1-1                              68,096
├─Linear: 1-2                            387
Total params: 68,483
Trainable params: 68,483
Non-trainable params: 0

In [22]:
# Train model
BATCH_SIZE = 64
NUM_EPOCHS = 100000
train(model_2, X_train, y_train, BATCH_SIZE, NUM_EPOCHS, optimizer, loss_fn, X_val, y_val, writer)
# Close writer
writer.close()
# Save model
torch.save(model_2, 'model_2.pt')

  0%|                                                                             | 24/100000 [00:00<14:18, 116.39it/s]

Training loss at epoch 0: 124.26663208007812 and val loss: 129.07986450195312


  1%|▊                                                                          | 1023/100000 [00:08<13:27, 122.50it/s]

Training loss at epoch 1000: 69.57818603515625 and val loss: 75.8642807006836


  2%|█▌                                                                         | 2022/100000 [00:17<14:25, 113.23it/s]

Training loss at epoch 2000: 48.17258834838867 and val loss: 50.663673400878906


  3%|██▎                                                                        | 3017/100000 [00:25<11:50, 136.43it/s]

Training loss at epoch 3000: 40.97676467895508 and val loss: 36.515682220458984


  4%|███                                                                        | 4019/100000 [00:33<12:13, 130.81it/s]

Training loss at epoch 4000: 21.15785026550293 and val loss: 26.97170066833496


  5%|███▊                                                                       | 5017/100000 [00:42<14:46, 107.15it/s]

Training loss at epoch 5000: 19.73539161682129 and val loss: 21.467247009277344


  6%|████▌                                                                      | 6022/100000 [00:51<14:43, 106.40it/s]

Training loss at epoch 6000: 12.783031463623047 and val loss: 17.371410369873047


  7%|█████▎                                                                     | 7024/100000 [01:01<14:42, 105.33it/s]

Training loss at epoch 7000: 14.159794807434082 and val loss: 15.313224792480469


  8%|██████                                                                      | 8018/100000 [01:11<15:32, 98.69it/s]

Training loss at epoch 8000: 11.615710258483887 and val loss: 13.326959609985352


  9%|██████▊                                                                     | 9016/100000 [01:22<18:07, 83.66it/s]

Training loss at epoch 9000: 10.73275089263916 and val loss: 12.09399127960205


 10%|███████▍                                                                  | 10023/100000 [01:33<14:42, 101.95it/s]

Training loss at epoch 10000: 9.918718338012695 and val loss: 11.389496803283691


 11%|████████▎                                                                  | 11013/100000 [01:45<18:59, 78.08it/s]

Training loss at epoch 11000: 9.349024772644043 and val loss: 10.578091621398926


 12%|█████████                                                                  | 12013/100000 [01:57<16:25, 89.27it/s]

Training loss at epoch 12000: 9.146089553833008 and val loss: 10.378327369689941


 13%|█████████▊                                                                 | 13012/100000 [02:12<23:22, 62.01it/s]

Training loss at epoch 13000: 8.368450164794922 and val loss: 11.405783653259277


 14%|██████████▌                                                                | 14010/100000 [02:28<23:30, 60.95it/s]

Training loss at epoch 14000: 6.9395904541015625 and val loss: 9.219128608703613


 15%|███████████▎                                                               | 15010/100000 [02:44<23:18, 60.76it/s]

Training loss at epoch 15000: 7.880409240722656 and val loss: 8.87732219696045


 16%|████████████                                                               | 16010/100000 [03:02<23:40, 59.15it/s]

Training loss at epoch 16000: 6.72813081741333 and val loss: 8.692522048950195


 17%|████████████▊                                                              | 17007/100000 [03:17<23:11, 59.64it/s]

Training loss at epoch 17000: 6.649219036102295 and val loss: 9.369446754455566


 18%|█████████████▌                                                             | 18013/100000 [03:34<22:36, 60.46it/s]

Training loss at epoch 18000: 6.034568786621094 and val loss: 8.258950233459473


 19%|██████████████▎                                                            | 19005/100000 [03:50<26:40, 50.59it/s]

Training loss at epoch 19000: 6.8388519287109375 and val loss: 8.272220611572266


 20%|███████████████                                                            | 20009/100000 [04:07<21:24, 62.27it/s]

Training loss at epoch 20000: 6.0430216789245605 and val loss: 8.074667930603027


 21%|███████████████▊                                                           | 21013/100000 [04:23<22:04, 59.64it/s]

Training loss at epoch 21000: 6.4010114669799805 and val loss: 7.8642354011535645


 22%|████████████████▌                                                          | 22007/100000 [04:39<23:00, 56.51it/s]

Training loss at epoch 22000: 5.622830390930176 and val loss: 7.730092525482178


 23%|█████████████████▎                                                         | 23011/100000 [04:55<22:01, 58.28it/s]

Training loss at epoch 23000: 5.758708477020264 and val loss: 7.574041366577148


 24%|██████████████████                                                         | 24008/100000 [05:13<22:41, 55.82it/s]

Training loss at epoch 24000: 5.5684404373168945 and val loss: 7.339241981506348


 25%|██████████████████▊                                                        | 25006/100000 [05:29<20:48, 60.08it/s]

Training loss at epoch 25000: 6.188899517059326 and val loss: 7.345407009124756


 26%|███████████████████▌                                                       | 26013/100000 [05:48<20:42, 59.53it/s]

Training loss at epoch 26000: 5.046827793121338 and val loss: 7.163525104522705


 27%|████████████████████▎                                                      | 27010/100000 [06:05<21:04, 57.73it/s]

Training loss at epoch 27000: 4.874624252319336 and val loss: 6.97618293762207


 28%|█████████████████████                                                      | 28011/100000 [06:22<20:18, 59.06it/s]

Training loss at epoch 28000: 5.291135311126709 and val loss: 6.940235614776611


 29%|█████████████████████▊                                                     | 29011/100000 [06:38<19:35, 60.41it/s]

Training loss at epoch 29000: 5.207753658294678 and val loss: 6.78305721282959


 30%|██████████████████████▌                                                    | 30009/100000 [06:54<19:43, 59.14it/s]

Training loss at epoch 30000: 5.251869201660156 and val loss: 6.738524436950684


 31%|███████████████████████▎                                                   | 31006/100000 [07:10<17:43, 64.89it/s]

Training loss at epoch 31000: 6.323537349700928 and val loss: 6.8008575439453125


 32%|████████████████████████                                                   | 32009/100000 [07:26<15:19, 73.94it/s]

Training loss at epoch 32000: 4.9983038902282715 and val loss: 6.783321380615234


 33%|████████████████████████▊                                                  | 33010/100000 [07:42<18:17, 61.05it/s]

Training loss at epoch 33000: 5.364764213562012 and val loss: 6.570031642913818


 34%|█████████████████████████▌                                                 | 34011/100000 [07:59<15:19, 71.76it/s]

Training loss at epoch 34000: 4.475471496582031 and val loss: 6.613394737243652


 35%|██████████████████████████▎                                                | 35006/100000 [08:15<18:19, 59.09it/s]

Training loss at epoch 35000: 4.9574079513549805 and val loss: 6.443047046661377


 36%|███████████████████████████                                                | 36008/100000 [08:32<18:05, 58.93it/s]

Training loss at epoch 36000: 4.703993320465088 and val loss: 6.518394470214844


 37%|███████████████████████████▊                                               | 37011/100000 [08:47<17:09, 61.17it/s]

Training loss at epoch 37000: 4.537790298461914 and val loss: 6.450307369232178


 38%|████████████████████████████▌                                              | 38007/100000 [09:04<17:26, 59.22it/s]

Training loss at epoch 38000: 5.203092098236084 and val loss: 6.282898902893066


 39%|█████████████████████████████▎                                             | 39007/100000 [09:20<16:58, 59.87it/s]

Training loss at epoch 39000: 4.810244560241699 and val loss: 6.277425289154053


 40%|██████████████████████████████                                             | 40010/100000 [09:37<17:07, 58.37it/s]

Training loss at epoch 40000: 4.121532440185547 and val loss: 6.462374210357666


 41%|██████████████████████████████▊                                            | 41012/100000 [09:52<17:14, 57.04it/s]

Training loss at epoch 41000: 4.63636589050293 and val loss: 6.16277551651001


 42%|███████████████████████████████▌                                           | 42008/100000 [10:09<17:22, 55.64it/s]

Training loss at epoch 42000: 4.032390594482422 and val loss: 6.23922872543335


 43%|████████████████████████████████▎                                          | 43011/100000 [10:25<15:42, 60.48it/s]

Training loss at epoch 43000: 4.4239325523376465 and val loss: 6.335793495178223


 44%|█████████████████████████████████                                          | 44006/100000 [10:42<15:51, 58.88it/s]

Training loss at epoch 44000: 4.936915874481201 and val loss: 6.20685338973999


 45%|█████████████████████████████████▊                                         | 45009/100000 [10:58<12:19, 74.34it/s]

Training loss at epoch 45000: 4.746363162994385 and val loss: 6.293546199798584


 46%|██████████████████████████████████▌                                        | 46013/100000 [11:15<15:24, 58.40it/s]

Training loss at epoch 46000: 4.207235813140869 and val loss: 6.3394083976745605


 47%|███████████████████████████████████▎                                       | 47014/100000 [11:31<12:09, 72.63it/s]

Training loss at epoch 47000: 4.68344783782959 and val loss: 6.278355598449707


 48%|████████████████████████████████████                                       | 48012/100000 [11:48<14:22, 60.30it/s]

Training loss at epoch 48000: 4.352856636047363 and val loss: 6.172255992889404


 49%|████████████████████████████████████▊                                      | 49006/100000 [12:05<12:20, 68.88it/s]

Training loss at epoch 49000: 4.1908440589904785 and val loss: 6.2874579429626465


 50%|█████████████████████████████████████▌                                     | 50013/100000 [12:20<13:26, 61.98it/s]

Training loss at epoch 50000: 4.420925140380859 and val loss: 6.274000644683838


 51%|██████████████████████████████████████▎                                    | 51012/100000 [12:37<13:43, 59.50it/s]

Training loss at epoch 51000: 4.201415061950684 and val loss: 6.204073429107666


 52%|███████████████████████████████████████                                    | 52007/100000 [12:53<13:15, 60.34it/s]

Training loss at epoch 52000: 3.9310197830200195 and val loss: 6.388870716094971


 53%|███████████████████████████████████████▊                                   | 53011/100000 [13:10<13:00, 60.20it/s]

Training loss at epoch 53000: 4.959397792816162 and val loss: 6.361212253570557


 54%|████████████████████████████████████████▌                                  | 54009/100000 [13:25<12:30, 61.30it/s]

Training loss at epoch 54000: 3.8838539123535156 and val loss: 6.33784294128418


 55%|█████████████████████████████████████████▎                                 | 55010/100000 [13:42<12:19, 60.86it/s]

Training loss at epoch 55000: 4.984575271606445 and val loss: 6.638656139373779


 56%|██████████████████████████████████████████                                 | 56009/100000 [13:57<12:11, 60.16it/s]

Training loss at epoch 56000: 4.398649215698242 and val loss: 6.2377400398254395


 57%|██████████████████████████████████████████▊                                | 57012/100000 [14:13<11:43, 61.14it/s]

Training loss at epoch 57000: 4.355501174926758 and val loss: 6.218862056732178


 58%|███████████████████████████████████████████▌                               | 58010/100000 [14:29<11:08, 62.79it/s]

Training loss at epoch 58000: 4.091048717498779 and val loss: 6.282528400421143


 59%|████████████████████████████████████████████▎                              | 59010/100000 [14:46<11:36, 58.82it/s]

Training loss at epoch 59000: 4.030118942260742 and val loss: 6.183559894561768


 60%|█████████████████████████████████████████████                              | 60012/100000 [15:02<08:57, 74.42it/s]

Training loss at epoch 60000: 3.5636019706726074 and val loss: 6.074398994445801


 61%|█████████████████████████████████████████████▊                             | 61010/100000 [15:18<10:50, 59.98it/s]

Training loss at epoch 61000: 3.810152053833008 and val loss: 6.204989433288574


 62%|██████████████████████████████████████████████▌                            | 62008/100000 [15:34<08:41, 72.83it/s]

Training loss at epoch 62000: 4.5426859855651855 and val loss: 6.521798133850098


 63%|███████████████████████████████████████████████▎                           | 63008/100000 [15:51<10:27, 58.96it/s]

Training loss at epoch 63000: 4.464957237243652 and val loss: 6.266261577606201


 64%|████████████████████████████████████████████████                           | 64010/100000 [16:07<09:22, 63.96it/s]

Training loss at epoch 64000: 3.642605781555176 and val loss: 6.21234130859375


 65%|████████████████████████████████████████████████▊                          | 65008/100000 [16:23<09:36, 60.75it/s]

Training loss at epoch 65000: 4.134321689605713 and val loss: 6.17870569229126


 66%|█████████████████████████████████████████████████▌                         | 66012/100000 [16:40<09:45, 58.02it/s]

Training loss at epoch 66000: 3.5820465087890625 and val loss: 6.137841701507568


 67%|██████████████████████████████████████████████████▎                        | 67014/100000 [16:56<10:05, 54.49it/s]

Training loss at epoch 67000: 4.163861274719238 and val loss: 6.08320426940918


 68%|███████████████████████████████████████████████████                        | 68013/100000 [17:12<08:49, 60.38it/s]

Training loss at epoch 68000: 3.434666872024536 and val loss: 6.148446083068848


 69%|███████████████████████████████████████████████████▊                       | 69008/100000 [17:27<08:25, 61.32it/s]

Training loss at epoch 69000: 3.5990073680877686 and val loss: 6.100387096405029


 70%|████████████████████████████████████████████████████▌                      | 70009/100000 [17:44<08:29, 58.88it/s]

Training loss at epoch 70000: 4.021248817443848 and val loss: 6.136735916137695


 71%|█████████████████████████████████████████████████████▎                     | 71008/100000 [18:00<08:00, 60.29it/s]

Training loss at epoch 71000: 4.420582294464111 and val loss: 6.1067399978637695


 72%|██████████████████████████████████████████████████████                     | 72011/100000 [18:17<07:50, 59.51it/s]

Training loss at epoch 72000: 4.360509872436523 and val loss: 6.1561150550842285


 73%|██████████████████████████████████████████████████████▊                    | 73009/100000 [18:33<07:27, 60.35it/s]

Training loss at epoch 73000: 3.927830457687378 and val loss: 6.270224571228027


 74%|███████████████████████████████████████████████████████▌                   | 74006/100000 [18:50<06:54, 62.64it/s]

Training loss at epoch 74000: 3.948674440383911 and val loss: 6.235411643981934


 75%|████████████████████████████████████████████████████████▎                  | 75013/100000 [19:06<06:59, 59.59it/s]

Training loss at epoch 75000: 3.833667516708374 and val loss: 6.242660999298096


 76%|█████████████████████████████████████████████████████████                  | 76007/100000 [19:23<06:57, 57.46it/s]

Training loss at epoch 76000: 3.4909679889678955 and val loss: 6.180314064025879


 77%|█████████████████████████████████████████████████████████▊                 | 77010/100000 [19:39<06:36, 57.99it/s]

Training loss at epoch 77000: 3.9919273853302 and val loss: 6.143866062164307


 78%|██████████████████████████████████████████████████████████▌                | 78011/100000 [19:56<06:23, 57.40it/s]

Training loss at epoch 78000: 3.518165111541748 and val loss: 6.266315460205078


 79%|███████████████████████████████████████████████████████████▎               | 79011/100000 [20:13<06:15, 55.92it/s]

Training loss at epoch 79000: 3.9688920974731445 and val loss: 6.386706352233887


 80%|████████████████████████████████████████████████████████████               | 80008/100000 [20:31<06:27, 51.53it/s]

Training loss at epoch 80000: 3.359438896179199 and val loss: 6.193893909454346


 81%|████████████████████████████████████████████████████████████▊              | 81010/100000 [20:47<05:28, 57.75it/s]

Training loss at epoch 81000: 3.368309497833252 and val loss: 6.1989545822143555


 82%|█████████████████████████████████████████████████████████████▌             | 82010/100000 [21:04<05:03, 59.23it/s]

Training loss at epoch 82000: 3.7264609336853027 and val loss: 6.268959045410156


 83%|██████████████████████████████████████████████████████████████▎            | 83006/100000 [21:21<04:41, 60.43it/s]

Training loss at epoch 83000: 2.898634672164917 and val loss: 6.239076137542725


 84%|███████████████████████████████████████████████████████████████            | 84008/100000 [21:38<04:33, 58.43it/s]

Training loss at epoch 84000: 3.7081971168518066 and val loss: 6.172822952270508


 85%|███████████████████████████████████████████████████████████████▊           | 85010/100000 [21:55<04:23, 56.78it/s]

Training loss at epoch 85000: 3.1663310527801514 and val loss: 6.109901428222656


 86%|████████████████████████████████████████████████████████████████▌          | 86009/100000 [22:12<04:01, 57.84it/s]

Training loss at epoch 86000: 3.2536816596984863 and val loss: 6.110992431640625


 87%|█████████████████████████████████████████████████████████████████▎         | 87013/100000 [22:29<03:38, 59.39it/s]

Training loss at epoch 87000: 3.7937498092651367 and val loss: 6.140391826629639


 88%|██████████████████████████████████████████████████████████████████         | 88008/100000 [22:47<03:30, 57.10it/s]

Training loss at epoch 88000: 3.482018232345581 and val loss: 6.155892372131348


 89%|██████████████████████████████████████████████████████████████████▊        | 89011/100000 [23:03<03:09, 57.85it/s]

Training loss at epoch 89000: 3.5793042182922363 and val loss: 6.136404514312744


 90%|███████████████████████████████████████████████████████████████████▌       | 90013/100000 [23:20<02:24, 69.22it/s]

Training loss at epoch 90000: 3.122316837310791 and val loss: 6.178669452667236


 91%|████████████████████████████████████████████████████████████████████▎      | 91007/100000 [23:37<02:28, 60.50it/s]

Training loss at epoch 91000: 3.52815318107605 and val loss: 6.166372776031494


 92%|█████████████████████████████████████████████████████████████████████      | 92008/100000 [23:54<01:53, 70.40it/s]

Training loss at epoch 92000: 3.613193988800049 and val loss: 6.178439617156982


 93%|█████████████████████████████████████████████████████████████████████▊     | 93011/100000 [24:10<01:58, 59.22it/s]

Training loss at epoch 93000: 3.6994659900665283 and val loss: 6.127582550048828


 94%|██████████████████████████████████████████████████████████████████████▌    | 94012/100000 [24:27<01:26, 69.48it/s]

Training loss at epoch 94000: 3.317531108856201 and val loss: 6.320593357086182


 95%|███████████████████████████████████████████████████████████████████████▎   | 95008/100000 [24:44<01:26, 57.43it/s]

Training loss at epoch 95000: 3.4435346126556396 and val loss: 6.119819641113281


 96%|████████████████████████████████████████████████████████████████████████   | 96013/100000 [25:00<00:55, 71.53it/s]

Training loss at epoch 96000: 3.3632118701934814 and val loss: 6.189737319946289


 97%|████████████████████████████████████████████████████████████████████████▊  | 97013/100000 [25:16<00:46, 64.12it/s]

Training loss at epoch 97000: 3.3178772926330566 and val loss: 6.212215900421143


 98%|█████████████████████████████████████████████████████████████████████████▌ | 98012/100000 [25:33<00:27, 71.67it/s]

Training loss at epoch 98000: 3.559656858444214 and val loss: 6.16392707824707


 99%|██████████████████████████████████████████████████████████████████████████▎| 99007/100000 [25:50<00:19, 52.02it/s]

Training loss at epoch 99000: 3.466656446456909 and val loss: 6.040592670440674


100%|██████████████████████████████████████████████████████████████████████████| 100000/100000 [26:07<00:00, 63.80it/s]


In [None]:
# See results of training
%load_ext tensorboard
%tensorboard --logdir=runs

In [24]:
### TRAIN SET ###
# Get model's predictions of NN
model_2.eval()
with torch.inference_mode():
    y_pred = model_2(X_train)

# Calculate total (mean) train error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_train))
print(f'Total (mean) train loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on train set
y_pred_naive = naive_forecast(X_train, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_train))
print(f'Total (mean) train loss of naive forecast: {naive_loss.item()}')

### VAL SET ###
# Get model's predictions of NN
model_2.eval()
with torch.inference_mode():
    y_pred = model_2(X_val)

# Calculate total (mean) val error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_val))
print(f'Total (mean) val loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on val set
y_pred_naive = naive_forecast(X_val, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_val))
print(f'Total (mean) val loss of naive forecast: {naive_loss.item()}')

### TEST SET ###
# Get model's predictions of NN
model_2.eval()
with torch.inference_mode():
    y_pred = model_2(X_test)

# Calculate total (mean) test error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_test))
print(f'Total (mean) test loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on test set
y_pred_naive = naive_forecast(X_test, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_test))
print(f'Total (mean) test loss of naive forecast: {naive_loss.item()}')

Total (mean) train loss of NN: 3.4459738731384277
Total (mean) train loss of naive forecast: 12.17895221710205
Total (mean) val loss of NN: 6.035409450531006
Total (mean) val loss of naive forecast: 11.918389320373535
Total (mean) test loss of NN: 6.848249435424805
Total (mean) test loss of naive forecast: 12.19200611114502


## Experiment 3 - same as exp 2 but with three LSTM cells

* file: 1_single
* window size = 10, horizon = 10, max bird speed = 10 m/s, batch size = 64
* Triple LSTM cell with hidden dim = 128 + Linear Layer
* Time not included in prediction

In [27]:
# Instantiate model, writer, optimizer and loss function
# Model parameters
INPUT_DIM = 3
HIDDEN_DIM = 128
OUTPUT_DIM = 3
NUM_LAYERS = 3

# Instantiate model, optimizer and loss function
model_3 = LSTMModel(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM, NUM_LAYERS, horizon = 10)
optimizer = torch.optim.Adam(params = model_3.parameters())
loss_fn = torch.nn.functional.mse_loss
writer = SummaryWriter()

# Display model's summary
summary(model_3)

Layer (type:depth-idx)                   Param #
LSTMModel                                --
├─LSTM: 1-1                              332,288
├─Linear: 1-2                            387
Total params: 332,675
Trainable params: 332,675
Non-trainable params: 0

In [28]:
# Train model
BATCH_SIZE = 64
NUM_EPOCHS = 100000
train(model_3, X_train, y_train, BATCH_SIZE, NUM_EPOCHS, optimizer, loss_fn, X_val, y_val, writer)
# Close writer
writer.close()
# Save model
torch.save(model_3, 'model_3.pt')

  0%|                                                                              | 10/100000 [00:00<33:45, 49.36it/s]

Training loss at epoch 0: 122.87010192871094 and val loss: 129.20913696289062


  1%|▊                                                                           | 1005/100000 [00:20<39:52, 41.38it/s]

Training loss at epoch 1000: 72.11249542236328 and val loss: 79.94609069824219


  2%|█▌                                                                          | 2006/100000 [00:47<41:47, 39.08it/s]

Training loss at epoch 2000: 49.36914825439453 and val loss: 55.39445114135742


  3%|██▎                                                                         | 3006/100000 [01:16<46:40, 34.63it/s]

Training loss at epoch 3000: 35.57210159301758 and val loss: 39.48653793334961


  4%|███                                                                         | 4004/100000 [01:48<55:32, 28.80it/s]

Training loss at epoch 4000: 22.037353515625 and val loss: 27.140235900878906


  5%|███▊                                                                        | 5004/100000 [02:22<58:11, 27.21it/s]

Training loss at epoch 5000: 17.54884147644043 and val loss: 17.574317932128906


  6%|████▍                                                                     | 6003/100000 [03:10<1:21:02, 19.33it/s]

Training loss at epoch 6000: 10.554391860961914 and val loss: 11.896607398986816


  7%|█████▏                                                                    | 7005/100000 [04:02<1:18:00, 19.87it/s]

Training loss at epoch 7000: 5.131878852844238 and val loss: 9.054017066955566


  8%|█████▉                                                                    | 8004/100000 [04:55<1:20:26, 19.06it/s]

Training loss at epoch 8000: 5.565798282623291 and val loss: 7.554507255554199


  9%|██████▋                                                                   | 9004/100000 [05:47<1:21:57, 18.51it/s]

Training loss at epoch 9000: 4.22291898727417 and val loss: 6.8184027671813965


 10%|███████▎                                                                 | 10003/100000 [06:42<1:20:00, 18.75it/s]

Training loss at epoch 10000: 3.1615889072418213 and val loss: 6.3672685623168945


 11%|████████                                                                 | 11004/100000 [07:36<1:17:27, 19.15it/s]

Training loss at epoch 11000: 3.3412187099456787 and val loss: 6.090548515319824


 12%|████████▊                                                                | 12003/100000 [08:30<1:23:35, 17.54it/s]

Training loss at epoch 12000: 2.818420886993408 and val loss: 5.576244354248047


 13%|█████████▍                                                               | 13004/100000 [09:32<1:23:25, 17.38it/s]

Training loss at epoch 13000: 3.5797410011291504 and val loss: 5.684665679931641


 14%|██████████▏                                                              | 14003/100000 [10:27<1:14:25, 19.26it/s]

Training loss at epoch 14000: 2.705777883529663 and val loss: 5.261282444000244


 15%|██████████▉                                                              | 15003/100000 [11:27<1:14:48, 18.94it/s]

Training loss at epoch 15000: 2.6546761989593506 and val loss: 5.139700889587402


 16%|███████████▋                                                             | 16003/100000 [12:21<1:14:29, 18.79it/s]

Training loss at epoch 16000: 2.3231022357940674 and val loss: 5.120970249176025


 17%|████████████▍                                                            | 17004/100000 [13:15<1:16:33, 18.07it/s]

Training loss at epoch 17000: 2.1677234172821045 and val loss: 4.779642581939697


 18%|█████████████▏                                                           | 18003/100000 [14:09<1:16:40, 17.82it/s]

Training loss at epoch 18000: 2.3870460987091064 and val loss: 4.952700614929199


 19%|█████████████▊                                                           | 19004/100000 [15:05<1:30:33, 14.91it/s]

Training loss at epoch 19000: 2.2482876777648926 and val loss: 4.978577136993408


 20%|██████████████▌                                                          | 20004/100000 [16:00<1:12:52, 18.30it/s]

Training loss at epoch 20000: 1.956526756286621 and val loss: 4.778539657592773


 21%|███████████████▎                                                         | 21004/100000 [16:56<1:13:34, 17.90it/s]

Training loss at epoch 21000: 2.5560765266418457 and val loss: 4.804751396179199


 22%|████████████████                                                         | 22004/100000 [17:52<1:11:24, 18.21it/s]

Training loss at epoch 22000: 1.9486747980117798 and val loss: 4.805435657501221


 23%|████████████████▊                                                        | 23004/100000 [18:48<1:09:57, 18.34it/s]

Training loss at epoch 23000: 1.7942689657211304 and val loss: 4.588682651519775


 24%|█████████████████▌                                                       | 24003/100000 [19:45<1:13:08, 17.32it/s]

Training loss at epoch 24000: 1.5102074146270752 and val loss: 4.672370433807373


 25%|██████████████████▎                                                      | 25003/100000 [20:41<1:11:35, 17.46it/s]

Training loss at epoch 25000: 1.9284988641738892 and val loss: 4.703434944152832


 26%|██████████████████▉                                                      | 26004/100000 [21:38<1:12:51, 16.93it/s]

Training loss at epoch 26000: 1.8445590734481812 and val loss: 4.561686992645264


 27%|███████████████████▋                                                     | 27004/100000 [22:34<1:03:25, 19.18it/s]

Training loss at epoch 27000: 1.8021986484527588 and val loss: 4.895177364349365


 28%|████████████████████▍                                                    | 28004/100000 [23:32<1:11:21, 16.81it/s]

Training loss at epoch 28000: 1.9799977540969849 and val loss: 4.842864990234375


 29%|█████████████████████▏                                                   | 29003/100000 [24:29<1:07:17, 17.59it/s]

Training loss at epoch 29000: 2.303818702697754 and val loss: 4.768463134765625


 30%|█████████████████████▉                                                   | 30003/100000 [25:26<1:13:37, 15.84it/s]

Training loss at epoch 30000: 2.485551118850708 and val loss: 4.923240661621094


 31%|██████████████████████▋                                                  | 31004/100000 [26:31<1:08:11, 16.86it/s]

Training loss at epoch 31000: 1.8411225080490112 and val loss: 4.579442977905273


 32%|███████████████████████▎                                                 | 32004/100000 [27:28<1:03:41, 17.79it/s]

Training loss at epoch 32000: 1.802718162536621 and val loss: 4.709217548370361


 33%|████████████████████████                                                 | 33004/100000 [28:25<1:01:59, 18.01it/s]

Training loss at epoch 33000: 1.1629706621170044 and val loss: 4.559811592102051


 34%|████████████████████████▊                                                | 34004/100000 [29:29<1:08:44, 16.00it/s]

Training loss at epoch 34000: 1.3807746171951294 and val loss: 4.623510360717773


 35%|█████████████████████████▌                                               | 35003/100000 [30:27<1:00:54, 17.78it/s]

Training loss at epoch 35000: 1.849115014076233 and val loss: 4.771284580230713


 36%|██████████████████████████▎                                              | 36003/100000 [31:25<1:00:47, 17.55it/s]

Training loss at epoch 36000: 1.3057539463043213 and val loss: 4.475335121154785


 37%|███████████████████████████                                              | 37003/100000 [32:22<1:00:46, 17.28it/s]

Training loss at epoch 37000: 1.4687808752059937 and val loss: 4.683213233947754


 38%|████████████████████████████▌                                              | 38003/100000 [33:19<58:17, 17.72it/s]

Training loss at epoch 38000: 1.3363631963729858 and val loss: 4.72328519821167


 39%|█████████████████████████████▎                                             | 39004/100000 [34:18<59:39, 17.04it/s]

Training loss at epoch 39000: 1.461684226989746 and val loss: 4.799561023712158


 40%|██████████████████████████████                                             | 40003/100000 [35:16<59:37, 16.77it/s]

Training loss at epoch 40000: 1.577430248260498 and val loss: 4.544408321380615


 41%|██████████████████████████████▊                                            | 41003/100000 [36:15<59:32, 16.51it/s]

Training loss at epoch 41000: 1.9686402082443237 and val loss: 4.991209506988525


 42%|███████████████████████████████▌                                           | 42004/100000 [37:13<55:42, 17.35it/s]

Training loss at epoch 42000: 1.3272322416305542 and val loss: 4.686424732208252


 43%|███████████████████████████████▍                                         | 43003/100000 [38:11<1:01:31, 15.44it/s]

Training loss at epoch 43000: 1.8231984376907349 and val loss: 4.7533440589904785


 44%|█████████████████████████████████                                          | 44003/100000 [39:09<54:18, 17.18it/s]

Training loss at epoch 44000: 2.3594634532928467 and val loss: 4.935060977935791


 45%|█████████████████████████████████▊                                         | 45003/100000 [40:07<52:36, 17.42it/s]

Training loss at epoch 45000: 1.5082932710647583 and val loss: 4.860526084899902


 46%|██████████████████████████████████▌                                        | 46004/100000 [41:05<52:56, 17.00it/s]

Training loss at epoch 46000: 1.7284197807312012 and val loss: 4.780921459197998


 47%|███████████████████████████████████▎                                       | 47004/100000 [42:02<51:50, 17.04it/s]

Training loss at epoch 47000: 1.1240442991256714 and val loss: 4.640690803527832


 48%|████████████████████████████████████                                       | 48003/100000 [43:00<49:42, 17.44it/s]

Training loss at epoch 48000: 2.0370357036590576 and val loss: 4.623793601989746


 49%|████████████████████████████████████▊                                      | 49004/100000 [43:59<48:57, 17.36it/s]

Training loss at epoch 49000: 2.7412874698638916 and val loss: 4.773900508880615


 50%|█████████████████████████████████████▌                                     | 50003/100000 [44:58<49:58, 16.67it/s]

Training loss at epoch 50000: 1.490937352180481 and val loss: 4.595275402069092


 51%|██████████████████████████████████████▎                                    | 51004/100000 [45:57<47:41, 17.12it/s]

Training loss at epoch 51000: 1.953653335571289 and val loss: 4.5665693283081055


 52%|███████████████████████████████████████                                    | 52004/100000 [46:55<46:45, 17.11it/s]

Training loss at epoch 52000: 2.6906049251556396 and val loss: 4.830073356628418


 53%|███████████████████████████████████████▊                                   | 53004/100000 [47:53<43:55, 17.83it/s]

Training loss at epoch 53000: 1.4513605833053589 and val loss: 4.4917144775390625


 54%|████████████████████████████████████████▌                                  | 54003/100000 [48:52<45:46, 16.75it/s]

Training loss at epoch 54000: 1.202399730682373 and val loss: 4.560072422027588


 55%|█████████████████████████████████████████▎                                 | 55003/100000 [49:50<43:28, 17.25it/s]

Training loss at epoch 55000: 1.1516597270965576 and val loss: 4.619626522064209


 56%|██████████████████████████████████████████                                 | 56003/100000 [50:48<41:32, 17.65it/s]

Training loss at epoch 56000: 2.3994288444519043 and val loss: 4.921635150909424


 57%|██████████████████████████████████████████▊                                | 57003/100000 [51:46<40:52, 17.53it/s]

Training loss at epoch 57000: 1.4790089130401611 and val loss: 4.529455184936523


 58%|███████████████████████████████████████████▌                               | 58003/100000 [52:45<40:38, 17.22it/s]

Training loss at epoch 58000: 1.1086331605911255 and val loss: 4.750185012817383


 59%|████████████████████████████████████████████▎                              | 59003/100000 [53:44<40:53, 16.71it/s]

Training loss at epoch 59000: 1.036939263343811 and val loss: 4.540361404418945


 60%|█████████████████████████████████████████████                              | 60003/100000 [54:44<39:08, 17.03it/s]

Training loss at epoch 60000: 1.3680262565612793 and val loss: 4.563923358917236


 61%|█████████████████████████████████████████████▊                             | 61003/100000 [55:43<37:57, 17.12it/s]

Training loss at epoch 61000: 1.3333467245101929 and val loss: 4.55359411239624


 62%|██████████████████████████████████████████████▌                            | 62003/100000 [56:43<36:12, 17.49it/s]

Training loss at epoch 62000: 1.1844004392623901 and val loss: 4.602910041809082


 63%|███████████████████████████████████████████████▎                           | 63003/100000 [57:42<38:27, 16.03it/s]

Training loss at epoch 63000: 1.734114408493042 and val loss: 4.590167045593262


 64%|████████████████████████████████████████████████                           | 64003/100000 [58:41<34:34, 17.36it/s]

Training loss at epoch 64000: 1.9146438837051392 and val loss: 4.248723030090332


 65%|████████████████████████████████████████████████▊                          | 65003/100000 [59:39<32:45, 17.81it/s]

Training loss at epoch 65000: 1.3001633882522583 and val loss: 4.566241264343262


 66%|████████████████████████████████████████████████▏                        | 66003/100000 [1:00:38<33:21, 16.99it/s]

Training loss at epoch 66000: 1.027006983757019 and val loss: 4.582676410675049


 67%|████████████████████████████████████████████████▉                        | 67003/100000 [1:01:37<32:47, 16.77it/s]

Training loss at epoch 67000: 0.9698165655136108 and val loss: 4.543423175811768


 68%|█████████████████████████████████████████████████▋                       | 68004/100000 [1:02:36<30:03, 17.75it/s]

Training loss at epoch 68000: 0.9069410562515259 and val loss: 4.558467388153076


 69%|██████████████████████████████████████████████████▎                      | 69003/100000 [1:03:35<28:06, 18.38it/s]

Training loss at epoch 69000: 1.0358660221099854 and val loss: 4.469261646270752


 70%|███████████████████████████████████████████████████                      | 70003/100000 [1:04:33<30:25, 16.43it/s]

Training loss at epoch 70000: 11.714941024780273 and val loss: 9.102010726928711


 71%|███████████████████████████████████████████████████▊                     | 71004/100000 [1:05:32<30:08, 16.03it/s]

Training loss at epoch 71000: 1.0393985509872437 and val loss: 4.534965515136719


 72%|████████████████████████████████████████████████████▌                    | 72005/100000 [1:06:31<26:24, 17.67it/s]

Training loss at epoch 72000: 1.2969074249267578 and val loss: 4.845706462860107


 73%|█████████████████████████████████████████████████████▎                   | 73003/100000 [1:07:31<28:45, 15.65it/s]

Training loss at epoch 73000: 0.9955936670303345 and val loss: 4.498397350311279


 74%|██████████████████████████████████████████████████████                   | 74004/100000 [1:08:31<24:35, 17.62it/s]

Training loss at epoch 74000: 0.9107600450515747 and val loss: 4.482819080352783


 75%|██████████████████████████████████████████████████████▊                  | 75004/100000 [1:09:29<24:27, 17.03it/s]

Training loss at epoch 75000: 1.1039457321166992 and val loss: 4.443784713745117


 76%|███████████████████████████████████████████████████████▍                 | 76004/100000 [1:10:28<23:59, 16.67it/s]

Training loss at epoch 76000: 1.0018895864486694 and val loss: 4.569454669952393


 77%|████████████████████████████████████████████████████████▏                | 77004/100000 [1:11:26<22:12, 17.26it/s]

Training loss at epoch 77000: 0.9800419807434082 and val loss: 4.522939682006836


 78%|████████████████████████████████████████████████████████▉                | 78004/100000 [1:12:24<21:27, 17.08it/s]

Training loss at epoch 78000: 0.814008355140686 and val loss: 4.5576348304748535


 79%|█████████████████████████████████████████████████████████▋               | 79004/100000 [1:13:23<20:13, 17.30it/s]

Training loss at epoch 79000: 1.2241861820220947 and val loss: 4.68166446685791


 80%|██████████████████████████████████████████████████████████▍              | 80004/100000 [1:14:22<19:44, 16.88it/s]

Training loss at epoch 80000: 1.0740965604782104 and val loss: 4.499469757080078


 81%|███████████████████████████████████████████████████████████▏             | 81004/100000 [1:15:20<18:34, 17.04it/s]

Training loss at epoch 81000: 1.4781028032302856 and val loss: 4.871643543243408


 82%|███████████████████████████████████████████████████████████▊             | 82004/100000 [1:16:20<17:06, 17.53it/s]

Training loss at epoch 82000: 0.978922426700592 and val loss: 4.549154758453369


 83%|████████████████████████████████████████████████████████████▌            | 83004/100000 [1:17:18<17:24, 16.27it/s]

Training loss at epoch 83000: 0.8244014382362366 and val loss: 4.45994758605957


 84%|█████████████████████████████████████████████████████████████▎           | 84004/100000 [1:18:18<14:57, 17.83it/s]

Training loss at epoch 84000: 2.9063103199005127 and val loss: 4.797712326049805


 85%|██████████████████████████████████████████████████████████████           | 85004/100000 [1:19:17<14:52, 16.80it/s]

Training loss at epoch 85000: 6.191977500915527 and val loss: 4.940947532653809


 86%|██████████████████████████████████████████████████████████████▊          | 86005/100000 [1:20:16<12:42, 18.36it/s]

Training loss at epoch 86000: 1.9599717855453491 and val loss: 5.024346351623535


 87%|███████████████████████████████████████████████████████████████▌         | 87004/100000 [1:21:14<12:19, 17.57it/s]

Training loss at epoch 87000: 1.0706751346588135 and val loss: 4.4179158210754395


 88%|████████████████████████████████████████████████████████████████▏        | 88004/100000 [1:22:12<12:29, 16.00it/s]

Training loss at epoch 88000: 0.8056506514549255 and val loss: 4.6988844871521


 89%|████████████████████████████████████████████████████████████████▉        | 89003/100000 [1:23:11<10:12, 17.96it/s]

Training loss at epoch 89000: 1.7372424602508545 and val loss: 4.714521884918213


 90%|█████████████████████████████████████████████████████████████████▋       | 90003/100000 [1:24:11<09:38, 17.29it/s]

Training loss at epoch 90000: 0.9398044347763062 and val loss: 4.813872814178467


 91%|██████████████████████████████████████████████████████████████████▍      | 91003/100000 [1:25:15<09:01, 16.62it/s]

Training loss at epoch 91000: 1.335444688796997 and val loss: 4.678410053253174


 92%|███████████████████████████████████████████████████████████████████▏     | 92003/100000 [1:26:17<08:19, 16.02it/s]

Training loss at epoch 92000: 1.1044491529464722 and val loss: 4.770938873291016


 93%|███████████████████████████████████████████████████████████████████▉     | 93003/100000 [1:27:19<06:55, 16.83it/s]

Training loss at epoch 93000: 2.502980947494507 and val loss: 4.855745315551758


 94%|████████████████████████████████████████████████████████████████████▌    | 94003/100000 [1:28:21<06:06, 16.35it/s]

Training loss at epoch 94000: 1.3935315608978271 and val loss: 4.564259052276611


 95%|█████████████████████████████████████████████████████████████████████▎   | 95003/100000 [1:29:25<04:56, 16.87it/s]

Training loss at epoch 95000: 0.9227527976036072 and val loss: 4.65894889831543


 96%|██████████████████████████████████████████████████████████████████████   | 96003/100000 [1:30:28<03:58, 16.77it/s]

Training loss at epoch 96000: 1.1871106624603271 and val loss: 4.43036413192749


 97%|██████████████████████████████████████████████████████████████████████▊  | 97003/100000 [1:31:30<03:12, 15.56it/s]

Training loss at epoch 97000: 1.251763939857483 and val loss: 4.5624470710754395


 98%|███████████████████████████████████████████████████████████████████████▌ | 98003/100000 [1:32:32<02:04, 16.04it/s]

Training loss at epoch 98000: 1.2496914863586426 and val loss: 4.864218711853027


 99%|████████████████████████████████████████████████████████████████████████▎| 99003/100000 [1:33:35<01:00, 16.56it/s]

Training loss at epoch 99000: 1.0494455099105835 and val loss: 4.554513454437256


100%|████████████████████████████████████████████████████████████████████████| 100000/100000 [1:34:37<00:00, 17.61it/s]


In [None]:
# See results of training
%load_ext tensorboard
%tensorboard --logdir=runs

In [30]:
### TRAIN SET ###
# Get model's predictions of NN
model_3.eval()
with torch.inference_mode():
    y_pred = model_3(X_train)

# Calculate total (mean) train error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_train))
print(f'Total (mean) train loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on train set
y_pred_naive = naive_forecast(X_train, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_train))
print(f'Total (mean) train loss of naive forecast: {naive_loss.item()}')

### VAL SET ###
# Get model's predictions of NN
model_3.eval()
with torch.inference_mode():
    y_pred = model_3(X_val)

# Calculate total (mean) val error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_val))
print(f'Total (mean) val loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on val set
y_pred_naive = naive_forecast(X_val, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_val))
print(f'Total (mean) val loss of naive forecast: {naive_loss.item()}')

### TEST SET ###
# Get model's predictions of NN
model_3.eval()
with torch.inference_mode():
    y_pred = model_3(X_test)

# Calculate total (mean) test error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_test))
print(f'Total (mean) test loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on test set
y_pred_naive = naive_forecast(X_test, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_test))
print(f'Total (mean) test loss of naive forecast: {naive_loss.item()}')

Total (mean) train loss of NN: 1.1314057111740112
Total (mean) train loss of naive forecast: 12.17895221710205
Total (mean) val loss of NN: 4.73565149307251
Total (mean) val loss of naive forecast: 11.918389320373535
Total (mean) test loss of NN: 6.44343900680542
Total (mean) test loss of naive forecast: 12.19200611114502


## Experiment 4 - same as experiment 1 but with GRU

* file: 1_single
* window size = 10, horizon = 10, max bird speed = 10 m/s, batch size = 64
* Single GRU cell with hidden dim = 32 + Linear Layer
* Time not included in prediction

In [32]:
# Instantiate model, writer, optimizer and loss function
# Model parameters
INPUT_DIM = 3
HIDDEN_DIM = 32
OUTPUT_DIM = 3
NUM_LAYERS = 1

# Instantiate model, optimizer and loss function
model_4 = GRUModel(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM, NUM_LAYERS, horizon = 10)
optimizer = torch.optim.Adam(params = model_4.parameters())
loss_fn = torch.nn.functional.mse_loss
writer = SummaryWriter()

# Display model's summary
summary(model_4)

Layer (type:depth-idx)                   Param #
GRUModel                                 --
├─GRU: 1-1                               3,552
├─Linear: 1-2                            99
Total params: 3,651
Trainable params: 3,651
Non-trainable params: 0

In [33]:
# Train model
BATCH_SIZE = 64
NUM_EPOCHS = 100000
train(model_4, X_train, y_train, BATCH_SIZE, NUM_EPOCHS, optimizer, loss_fn, X_val, y_val, writer)
# Close writer
writer.close()
# Save model
torch.save(model_4, 'model_4.pt')

  0%|                                                                              | 17/100000 [00:00<17:51, 93.29it/s]

Training loss at epoch 0: 127.5784912109375 and val loss: 128.94676208496094


  1%|▊                                                                           | 1007/100000 [00:18<36:32, 45.14it/s]

Training loss at epoch 1000: 112.67626190185547 and val loss: 111.46434020996094


  2%|█▌                                                                          | 2008/100000 [00:37<28:47, 56.71it/s]

Training loss at epoch 2000: 95.45538330078125 and val loss: 98.18897247314453


  3%|██▎                                                                         | 3010/100000 [00:54<27:41, 58.36it/s]

Training loss at epoch 3000: 88.27596282958984 and val loss: 87.7164306640625


  4%|███                                                                         | 4002/100000 [01:12<50:37, 31.61it/s]

Training loss at epoch 4000: 65.14952087402344 and val loss: 78.46461486816406


  5%|███▊                                                                        | 5009/100000 [01:32<34:40, 45.66it/s]

Training loss at epoch 5000: 66.86516571044922 and val loss: 69.8404541015625


  6%|████▌                                                                       | 6009/100000 [01:50<33:16, 47.08it/s]

Training loss at epoch 6000: 63.356689453125 and val loss: 62.264068603515625


  7%|█████▎                                                                      | 7009/100000 [02:09<31:24, 49.34it/s]

Training loss at epoch 7000: 48.2159309387207 and val loss: 56.193424224853516


  8%|██████                                                                      | 8006/100000 [02:30<29:17, 52.34it/s]

Training loss at epoch 8000: 51.31870651245117 and val loss: 51.053375244140625


  9%|██████▊                                                                     | 9013/100000 [02:46<22:49, 66.43it/s]

Training loss at epoch 9000: 38.178489685058594 and val loss: 46.35758972167969


 10%|███████▌                                                                   | 10009/100000 [03:03<24:15, 61.82it/s]

Training loss at epoch 10000: 32.9074592590332 and val loss: 41.92642593383789


 11%|████████▎                                                                  | 11004/100000 [03:30<35:18, 42.02it/s]

Training loss at epoch 11000: 34.872802734375 and val loss: 38.11930847167969


 12%|█████████                                                                  | 12009/100000 [03:54<26:58, 54.38it/s]

Training loss at epoch 12000: 30.490846633911133 and val loss: 34.16361999511719


 13%|█████████▊                                                                 | 13006/100000 [04:24<37:04, 39.10it/s]

Training loss at epoch 13000: 22.12461280822754 and val loss: 31.13693618774414


 14%|██████████▌                                                                | 14009/100000 [04:50<23:53, 60.00it/s]

Training loss at epoch 14000: 18.004993438720703 and val loss: 28.32319450378418


 15%|███████████▎                                                               | 15007/100000 [05:10<25:59, 54.51it/s]

Training loss at epoch 15000: 20.659414291381836 and val loss: 25.684219360351562


 16%|████████████                                                               | 16004/100000 [05:33<54:39, 25.62it/s]

Training loss at epoch 16000: 19.358272552490234 and val loss: 23.838905334472656


 17%|████████████▊                                                              | 17008/100000 [05:59<27:28, 50.36it/s]

Training loss at epoch 17000: 20.967266082763672 and val loss: 22.894681930541992


 18%|█████████████▌                                                             | 18007/100000 [06:22<39:37, 34.49it/s]

Training loss at epoch 18000: 19.613391876220703 and val loss: 20.255918502807617


 19%|██████████████▎                                                            | 19007/100000 [06:42<22:12, 60.79it/s]

Training loss at epoch 19000: 14.979554176330566 and val loss: 18.83750343322754


 20%|███████████████                                                            | 20013/100000 [07:00<22:05, 60.35it/s]

Training loss at epoch 20000: 14.615433692932129 and val loss: 17.8021297454834


 21%|███████████████▊                                                           | 21008/100000 [07:23<29:05, 45.25it/s]

Training loss at epoch 21000: 13.62596321105957 and val loss: 16.964303970336914


 22%|████████████████▌                                                          | 22010/100000 [07:45<24:06, 53.90it/s]

Training loss at epoch 22000: 12.3560791015625 and val loss: 16.561704635620117


 23%|█████████████████▎                                                         | 23006/100000 [08:05<27:31, 46.63it/s]

Training loss at epoch 23000: 12.541886329650879 and val loss: 15.609622955322266


 24%|██████████████████                                                         | 24008/100000 [08:22<21:09, 59.86it/s]

Training loss at epoch 24000: 12.443907737731934 and val loss: 15.66677188873291


 25%|██████████████████▊                                                        | 25007/100000 [08:39<20:27, 61.10it/s]

Training loss at epoch 25000: 11.749882698059082 and val loss: 14.925495147705078


 26%|███████████████████▌                                                       | 26011/100000 [09:00<25:25, 48.51it/s]

Training loss at epoch 26000: 12.890435218811035 and val loss: 14.86113166809082


 27%|████████████████████▎                                                      | 27008/100000 [09:18<27:39, 43.99it/s]

Training loss at epoch 27000: 10.492758750915527 and val loss: 14.236693382263184


 28%|█████████████████████                                                      | 28008/100000 [09:38<19:23, 61.88it/s]

Training loss at epoch 28000: 11.233366012573242 and val loss: 14.407750129699707


 29%|█████████████████████▊                                                     | 29013/100000 [09:56<16:27, 71.92it/s]

Training loss at epoch 29000: 11.53126335144043 and val loss: 14.038284301757812


 30%|██████████████████████▌                                                    | 30008/100000 [10:17<20:31, 56.84it/s]

Training loss at epoch 30000: 11.389464378356934 and val loss: 13.597426414489746


 31%|███████████████████████▎                                                   | 31009/100000 [10:35<21:04, 54.58it/s]

Training loss at epoch 31000: 14.808635711669922 and val loss: 13.924664497375488


 32%|████████████████████████                                                   | 32007/100000 [10:53<23:42, 47.81it/s]

Training loss at epoch 32000: 11.899867057800293 and val loss: 13.303228378295898


 33%|████████████████████████▊                                                  | 33011/100000 [11:12<15:49, 70.54it/s]

Training loss at epoch 33000: 10.274711608886719 and val loss: 12.864304542541504


 34%|█████████████████████████▌                                                 | 34003/100000 [11:31<21:56, 50.13it/s]

Training loss at epoch 34000: 10.172959327697754 and val loss: 12.649398803710938


 35%|██████████████████████████▎                                                | 35012/100000 [11:54<21:38, 50.04it/s]

Training loss at epoch 35000: 9.901692390441895 and val loss: 12.68153190612793


 36%|███████████████████████████                                                | 36007/100000 [12:18<32:49, 32.49it/s]

Training loss at epoch 36000: 10.119041442871094 and val loss: 12.85544204711914


 37%|███████████████████████████▊                                               | 37007/100000 [12:39<19:00, 55.22it/s]

Training loss at epoch 37000: 9.339166641235352 and val loss: 12.626443862915039


 38%|████████████████████████████▌                                              | 38008/100000 [12:57<19:08, 53.99it/s]

Training loss at epoch 38000: 10.410635948181152 and val loss: 12.386277198791504


 39%|█████████████████████████████▎                                             | 39003/100000 [13:15<23:48, 42.71it/s]

Training loss at epoch 39000: 9.63007640838623 and val loss: 12.497380256652832


 40%|██████████████████████████████                                             | 40006/100000 [13:37<16:47, 59.58it/s]

Training loss at epoch 40000: 10.635695457458496 and val loss: 12.707651138305664


 41%|██████████████████████████████▊                                            | 41006/100000 [13:55<18:51, 52.15it/s]

Training loss at epoch 41000: 10.531801223754883 and val loss: 12.02580738067627


 42%|███████████████████████████████▌                                           | 42009/100000 [14:13<14:48, 65.26it/s]

Training loss at epoch 42000: 10.261896133422852 and val loss: 11.977592468261719


 43%|████████████████████████████████▎                                          | 43012/100000 [14:31<18:25, 51.56it/s]

Training loss at epoch 43000: 9.376546859741211 and val loss: 11.859136581420898


 44%|█████████████████████████████████                                          | 44011/100000 [14:50<18:21, 50.83it/s]

Training loss at epoch 44000: 9.860749244689941 and val loss: 11.84019947052002


 45%|█████████████████████████████████▊                                         | 45008/100000 [15:08<17:37, 52.02it/s]

Training loss at epoch 45000: 10.67711353302002 and val loss: 12.2182035446167


 46%|██████████████████████████████████▌                                        | 46010/100000 [15:26<19:19, 46.57it/s]

Training loss at epoch 46000: 10.616949081420898 and val loss: 11.79109001159668


 47%|███████████████████████████████████▎                                       | 47005/100000 [15:45<15:24, 57.34it/s]

Training loss at epoch 47000: 8.726656913757324 and val loss: 11.946609497070312


 48%|████████████████████████████████████                                       | 48008/100000 [16:08<26:28, 32.73it/s]

Training loss at epoch 48000: 10.194705963134766 and val loss: 12.081993103027344


 49%|████████████████████████████████████▊                                      | 49009/100000 [16:30<19:43, 43.07it/s]

Training loss at epoch 49000: 9.141478538513184 and val loss: 12.256695747375488


 50%|█████████████████████████████████████▌                                     | 50003/100000 [16:54<34:16, 24.31it/s]

Training loss at epoch 50000: 8.693004608154297 and val loss: 11.66315746307373


 51%|██████████████████████████████████████▎                                    | 51005/100000 [17:15<13:34, 60.17it/s]

Training loss at epoch 51000: 9.981104850769043 and val loss: 11.649749755859375


 52%|███████████████████████████████████████                                    | 52007/100000 [17:36<14:28, 55.28it/s]

Training loss at epoch 52000: 9.467100143432617 and val loss: 11.374812126159668


 53%|███████████████████████████████████████▊                                   | 53007/100000 [17:59<12:18, 63.62it/s]

Training loss at epoch 53000: 9.240628242492676 and val loss: 11.84437370300293


 54%|████████████████████████████████████████▌                                  | 54011/100000 [18:17<13:30, 56.75it/s]

Training loss at epoch 54000: 9.110916137695312 and val loss: 11.514530181884766


 55%|█████████████████████████████████████████▎                                 | 55005/100000 [18:41<14:15, 52.60it/s]

Training loss at epoch 55000: 9.05030632019043 and val loss: 11.40449047088623


 56%|██████████████████████████████████████████                                 | 56009/100000 [19:03<15:11, 48.26it/s]

Training loss at epoch 56000: 9.16911792755127 and val loss: 12.078532218933105


 57%|██████████████████████████████████████████▊                                | 57007/100000 [19:23<13:39, 52.44it/s]

Training loss at epoch 57000: 9.184112548828125 and val loss: 11.419984817504883


 58%|███████████████████████████████████████████▌                               | 58006/100000 [19:45<14:54, 46.93it/s]

Training loss at epoch 58000: 9.106924057006836 and val loss: 11.28786563873291


 59%|████████████████████████████████████████████▎                              | 59011/100000 [20:04<12:34, 54.34it/s]

Training loss at epoch 59000: 9.164451599121094 and val loss: 11.522204399108887


 60%|█████████████████████████████████████████████                              | 60009/100000 [20:21<10:21, 64.36it/s]

Training loss at epoch 60000: 8.76401424407959 and val loss: 11.546465873718262


 61%|█████████████████████████████████████████████▊                             | 61011/100000 [20:36<09:45, 66.59it/s]

Training loss at epoch 61000: 8.840961456298828 and val loss: 11.426332473754883


 62%|██████████████████████████████████████████████▌                            | 62010/100000 [20:52<09:18, 67.98it/s]

Training loss at epoch 62000: 9.13962173461914 and val loss: 10.949646949768066


 63%|███████████████████████████████████████████████▎                           | 63010/100000 [21:07<09:37, 64.08it/s]

Training loss at epoch 63000: 9.233820915222168 and val loss: 11.248994827270508


 64%|████████████████████████████████████████████████                           | 64011/100000 [21:22<09:20, 64.16it/s]

Training loss at epoch 64000: 8.747279167175293 and val loss: 11.342301368713379


 65%|████████████████████████████████████████████████▊                          | 65011/100000 [21:38<08:38, 67.47it/s]

Training loss at epoch 65000: 8.591496467590332 and val loss: 11.45141315460205


 66%|█████████████████████████████████████████████████▌                         | 66006/100000 [21:53<08:54, 63.66it/s]

Training loss at epoch 66000: 10.000377655029297 and val loss: 11.401185035705566


 67%|██████████████████████████████████████████████████▎                        | 67013/100000 [22:11<08:46, 62.60it/s]

Training loss at epoch 67000: 9.469978332519531 and val loss: 10.999639511108398


 68%|███████████████████████████████████████████████████                        | 68009/100000 [22:31<09:04, 58.72it/s]

Training loss at epoch 68000: 9.768195152282715 and val loss: 11.372659683227539


 69%|███████████████████████████████████████████████████▊                       | 69011/100000 [22:51<09:47, 52.75it/s]

Training loss at epoch 69000: 9.634114265441895 and val loss: 11.210797309875488


 70%|████████████████████████████████████████████████████▌                      | 70006/100000 [23:15<16:19, 30.63it/s]

Training loss at epoch 70000: 9.77009391784668 and val loss: 11.1659517288208


 71%|█████████████████████████████████████████████████████▎                     | 71005/100000 [23:40<15:04, 32.06it/s]

Training loss at epoch 71000: 8.498138427734375 and val loss: 10.777847290039062


 72%|██████████████████████████████████████████████████████                     | 72006/100000 [24:03<11:07, 41.95it/s]

Training loss at epoch 72000: 9.527027130126953 and val loss: 10.903468132019043


 73%|██████████████████████████████████████████████████████▊                    | 73009/100000 [24:23<08:26, 53.33it/s]

Training loss at epoch 73000: 9.352386474609375 and val loss: 11.157946586608887


 74%|███████████████████████████████████████████████████████▌                   | 74009/100000 [24:45<08:30, 50.91it/s]

Training loss at epoch 74000: 10.263832092285156 and val loss: 11.01993465423584


 75%|████████████████████████████████████████████████████████▎                  | 75009/100000 [25:04<06:16, 66.32it/s]

Training loss at epoch 75000: 8.832751274108887 and val loss: 11.160910606384277


 76%|█████████████████████████████████████████████████████████                  | 76005/100000 [25:21<07:52, 50.77it/s]

Training loss at epoch 76000: 8.957518577575684 and val loss: 10.577003479003906


 77%|█████████████████████████████████████████████████████████▊                 | 77005/100000 [25:42<07:08, 53.61it/s]

Training loss at epoch 77000: 10.080341339111328 and val loss: 10.78417682647705


 78%|██████████████████████████████████████████████████████████▌                | 78006/100000 [26:01<10:30, 34.91it/s]

Training loss at epoch 78000: 9.555720329284668 and val loss: 10.505131721496582


 79%|███████████████████████████████████████████████████████████▎               | 79011/100000 [26:19<04:52, 71.83it/s]

Training loss at epoch 79000: 7.9793500900268555 and val loss: 10.96278190612793


 80%|████████████████████████████████████████████████████████████               | 80009/100000 [26:36<05:44, 58.07it/s]

Training loss at epoch 80000: 9.336925506591797 and val loss: 10.600865364074707


 81%|████████████████████████████████████████████████████████████▊              | 81012/100000 [26:55<05:56, 53.25it/s]

Training loss at epoch 81000: 9.525948524475098 and val loss: 11.05627727508545


 82%|█████████████████████████████████████████████████████████████▌             | 82009/100000 [27:14<05:29, 54.53it/s]

Training loss at epoch 82000: 9.296571731567383 and val loss: 10.762364387512207


 83%|██████████████████████████████████████████████████████████████▎            | 83010/100000 [27:33<04:53, 57.87it/s]

Training loss at epoch 83000: 8.405153274536133 and val loss: 10.74075698852539


 84%|███████████████████████████████████████████████████████████████            | 84007/100000 [27:50<04:21, 61.12it/s]

Training loss at epoch 84000: 8.484602928161621 and val loss: 10.61487102508545


 85%|███████████████████████████████████████████████████████████████▊           | 85005/100000 [28:13<07:32, 33.14it/s]

Training loss at epoch 85000: 8.91569995880127 and val loss: 10.800056457519531


 86%|████████████████████████████████████████████████████████████████▌          | 86008/100000 [28:35<04:29, 51.87it/s]

Training loss at epoch 86000: 8.971465110778809 and val loss: 10.554643630981445


 87%|█████████████████████████████████████████████████████████████████▎         | 87008/100000 [28:54<03:49, 56.57it/s]

Training loss at epoch 87000: 8.25892448425293 and val loss: 10.503390312194824


 88%|██████████████████████████████████████████████████████████████████         | 88009/100000 [29:11<04:06, 48.69it/s]

Training loss at epoch 88000: 7.988763809204102 and val loss: 10.36410140991211


 89%|██████████████████████████████████████████████████████████████████▊        | 89003/100000 [29:32<05:08, 35.68it/s]

Training loss at epoch 89000: 9.745327949523926 and val loss: 11.086992263793945


 90%|███████████████████████████████████████████████████████████████████▌       | 90006/100000 [29:51<02:53, 57.75it/s]

Training loss at epoch 90000: 8.830209732055664 and val loss: 10.545106887817383


 91%|████████████████████████████████████████████████████████████████████▎      | 91009/100000 [30:09<03:14, 46.17it/s]

Training loss at epoch 91000: 8.607414245605469 and val loss: 10.236983299255371


 92%|█████████████████████████████████████████████████████████████████████      | 92009/100000 [30:35<02:20, 56.74it/s]

Training loss at epoch 92000: 8.406912803649902 and val loss: 10.524063110351562


 93%|█████████████████████████████████████████████████████████████████████▊     | 93009/100000 [30:56<02:12, 52.61it/s]

Training loss at epoch 93000: 8.966490745544434 and val loss: 10.3887300491333


 94%|██████████████████████████████████████████████████████████████████████▌    | 94008/100000 [31:15<02:07, 46.98it/s]

Training loss at epoch 94000: 8.618167877197266 and val loss: 10.137338638305664


 95%|███████████████████████████████████████████████████████████████████████▎   | 95012/100000 [31:33<01:20, 62.09it/s]

Training loss at epoch 95000: 9.895516395568848 and val loss: 10.423519134521484


 96%|████████████████████████████████████████████████████████████████████████   | 96010/100000 [31:56<01:17, 51.81it/s]

Training loss at epoch 96000: 8.625792503356934 and val loss: 10.109292030334473


 97%|████████████████████████████████████████████████████████████████████████▊  | 97011/100000 [32:17<00:50, 59.42it/s]

Training loss at epoch 97000: 8.856011390686035 and val loss: 10.208110809326172


 98%|█████████████████████████████████████████████████████████████████████████▌ | 98009/100000 [32:41<01:08, 28.89it/s]

Training loss at epoch 98000: 9.000109672546387 and val loss: 10.532247543334961


 99%|██████████████████████████████████████████████████████████████████████████▎| 99004/100000 [32:57<00:19, 49.99it/s]

Training loss at epoch 99000: 8.235865592956543 and val loss: 10.183182716369629


100%|██████████████████████████████████████████████████████████████████████████| 100000/100000 [33:17<00:00, 50.07it/s]


In [34]:
### TRAIN SET ###
# Get model's predictions of NN
model_4.eval()
with torch.inference_mode():
    y_pred = model_4(X_train)

# Calculate total (mean) train error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_train))
print(f'Total (mean) train loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on train set
y_pred_naive = naive_forecast(X_train, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_train))
print(f'Total (mean) train loss of naive forecast: {naive_loss.item()}')

### VAL SET ###
# Get model's predictions of NN
model_4.eval()
with torch.inference_mode():
    y_pred = model_4(X_val)

# Calculate total (mean) val error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_val))
print(f'Total (mean) val loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on val set
y_pred_naive = naive_forecast(X_val, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_val))
print(f'Total (mean) val loss of naive forecast: {naive_loss.item()}')

### TEST SET ###
# Get model's predictions of NN
model_4.eval()
with torch.inference_mode():
    y_pred = model_4(X_test)

# Calculate total (mean) test error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_test))
print(f'Total (mean) test loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on test set
y_pred_naive = naive_forecast(X_test, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_test))
print(f'Total (mean) test loss of naive forecast: {naive_loss.item()}')

Total (mean) train loss of NN: 8.59819507598877
Total (mean) train loss of naive forecast: 12.17895221710205
Total (mean) val loss of NN: 10.029316902160645
Total (mean) val loss of naive forecast: 11.918389320373535
Total (mean) test loss of NN: 11.706341743469238
Total (mean) test loss of naive forecast: 12.19200611114502


## Experiment 5 - same as exp 4 but bigger GRU cell

* file: 1_single
* window size = 10, horizon = 10, max bird speed = 10 m/s, batch size = 64
* Single GRU cell with hidden dim = 128 + Linear Layer
* Time not included in prediction

In [35]:
# Instantiate model, writer, optimizer and loss function
# Model parameters
INPUT_DIM = 3
HIDDEN_DIM = 128
OUTPUT_DIM = 3
NUM_LAYERS = 1

# Instantiate model, optimizer and loss function
model_5 = GRUModel(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM, NUM_LAYERS, horizon = 10)
optimizer = torch.optim.Adam(params = model_5.parameters())
loss_fn = torch.nn.functional.mse_loss
writer = SummaryWriter()

# Display model's summary
summary(model_5)

Layer (type:depth-idx)                   Param #
GRUModel                                 --
├─GRU: 1-1                               51,072
├─Linear: 1-2                            387
Total params: 51,459
Trainable params: 51,459
Non-trainable params: 0

In [36]:
# Train model
BATCH_SIZE = 64
NUM_EPOCHS = 100000
train(model_5, X_train, y_train, BATCH_SIZE, NUM_EPOCHS, optimizer, loss_fn, X_val, y_val, writer)
# Close writer
writer.close()
# Save model
torch.save(model_5, 'model_5.pt')

  0%|                                                                              | 14/100000 [00:00<25:37, 65.04it/s]

Training loss at epoch 0: 137.29400634765625 and val loss: 129.130126953125


  1%|▋                                                                         | 1005/100000 [00:47<1:09:02, 23.90it/s]

Training loss at epoch 1000: 70.58674621582031 and val loss: 76.07453918457031


  2%|█▌                                                                          | 2006/100000 [01:39<49:59, 32.67it/s]

Training loss at epoch 2000: 35.15278244018555 and val loss: 50.3569450378418


  3%|██▎                                                                         | 3005/100000 [02:21<55:37, 29.06it/s]

Training loss at epoch 3000: 29.78680419921875 and val loss: 36.23373794555664


  4%|███                                                                         | 4006/100000 [02:52<50:05, 31.94it/s]

Training loss at epoch 4000: 21.65787124633789 and val loss: 26.804433822631836


  5%|███▊                                                                        | 5005/100000 [03:40<49:22, 32.06it/s]

Training loss at epoch 5000: 13.381390571594238 and val loss: 19.844377517700195


  6%|████▌                                                                       | 6004/100000 [04:18<49:34, 31.60it/s]

Training loss at epoch 6000: 11.052547454833984 and val loss: 15.538755416870117


  7%|█████▏                                                                    | 7004/100000 [05:07<1:34:27, 16.41it/s]

Training loss at epoch 7000: 8.33516788482666 and val loss: 12.649450302124023


  8%|█████▉                                                                    | 8003/100000 [06:02<1:30:35, 16.93it/s]

Training loss at epoch 8000: 10.056134223937988 and val loss: 10.398635864257812


  9%|██████▊                                                                     | 9006/100000 [06:42<47:53, 31.67it/s]

Training loss at epoch 9000: 7.555899143218994 and val loss: 9.382755279541016


 10%|███████▎                                                                 | 10003/100000 [07:43<1:22:57, 18.08it/s]

Training loss at epoch 10000: 6.1517815589904785 and val loss: 8.486224174499512


 11%|████████                                                                 | 11003/100000 [08:39<1:51:59, 13.24it/s]

Training loss at epoch 11000: 6.910723686218262 and val loss: 7.977690696716309


 12%|████████▊                                                                | 12004/100000 [09:25<1:33:13, 15.73it/s]

Training loss at epoch 12000: 6.059820175170898 and val loss: 7.793821334838867


 13%|█████████▊                                                                 | 13007/100000 [10:03<45:14, 32.05it/s]

Training loss at epoch 13000: 6.259528160095215 and val loss: 7.624335289001465


 14%|██████████▌                                                                | 14005/100000 [10:55<54:40, 26.21it/s]

Training loss at epoch 14000: 6.35766077041626 and val loss: 7.56045389175415


 15%|███████████▎                                                               | 15004/100000 [11:49<45:22, 31.23it/s]

Training loss at epoch 15000: 5.1195573806762695 and val loss: 6.898661136627197


 16%|████████████                                                               | 16008/100000 [12:38<41:59, 33.33it/s]

Training loss at epoch 16000: 5.81182861328125 and val loss: 6.874924182891846


 17%|████████████▊                                                              | 17007/100000 [13:10<36:18, 38.10it/s]

Training loss at epoch 17000: 5.321504592895508 and val loss: 6.688002109527588


 18%|█████████████▌                                                             | 18006/100000 [13:37<37:58, 35.99it/s]

Training loss at epoch 18000: 5.934452056884766 and val loss: 6.534370422363281


 19%|██████████████▎                                                            | 19004/100000 [14:04<36:34, 36.92it/s]

Training loss at epoch 19000: 4.619921684265137 and val loss: 6.516827583312988


 20%|███████████████                                                            | 20004/100000 [14:31<39:31, 33.73it/s]

Training loss at epoch 20000: 4.844102382659912 and val loss: 6.140953063964844


 21%|███████████████▊                                                           | 21008/100000 [14:58<34:48, 37.82it/s]

Training loss at epoch 21000: 5.654003620147705 and val loss: 6.564052104949951


 22%|████████████████▌                                                          | 22006/100000 [15:25<33:27, 38.86it/s]

Training loss at epoch 22000: 4.4343581199646 and val loss: 6.070849418640137


 23%|█████████████████▎                                                         | 23007/100000 [15:52<35:41, 35.96it/s]

Training loss at epoch 23000: 4.851198673248291 and val loss: 5.996455192565918


 24%|██████████████████                                                         | 24005/100000 [16:19<36:25, 34.77it/s]

Training loss at epoch 24000: 4.653863906860352 and val loss: 5.991632461547852


 25%|██████████████████▊                                                        | 25006/100000 [16:46<35:55, 34.79it/s]

Training loss at epoch 25000: 5.07378625869751 and val loss: 6.317183017730713


 26%|███████████████████▌                                                       | 26008/100000 [17:14<33:24, 36.91it/s]

Training loss at epoch 26000: 5.3818488121032715 and val loss: 5.899106025695801


 27%|████████████████████▎                                                      | 27005/100000 [17:41<34:29, 35.27it/s]

Training loss at epoch 27000: 4.275207996368408 and val loss: 5.96077823638916


 28%|█████████████████████                                                      | 28005/100000 [18:09<36:54, 32.51it/s]

Training loss at epoch 28000: 4.9949846267700195 and val loss: 5.710901260375977


 29%|█████████████████████▊                                                     | 29006/100000 [18:36<33:50, 34.97it/s]

Training loss at epoch 29000: 4.760914325714111 and val loss: 5.53432559967041


 30%|██████████████████████▌                                                    | 30006/100000 [19:03<32:19, 36.09it/s]

Training loss at epoch 30000: 4.795402526855469 and val loss: 6.5117716789245605


 31%|███████████████████████▎                                                   | 31005/100000 [19:31<31:00, 37.08it/s]

Training loss at epoch 31000: 4.920835494995117 and val loss: 5.75080680847168


 32%|████████████████████████                                                   | 32007/100000 [19:58<30:23, 37.28it/s]

Training loss at epoch 32000: 4.609650611877441 and val loss: 5.54490852355957


 33%|████████████████████████▊                                                  | 33006/100000 [20:25<30:17, 36.85it/s]

Training loss at epoch 33000: 4.622857093811035 and val loss: 5.633392333984375


 34%|█████████████████████████▌                                                 | 34005/100000 [20:53<27:39, 39.76it/s]

Training loss at epoch 34000: 4.8542680740356445 and val loss: 5.7076096534729


 35%|██████████████████████████▎                                                | 35006/100000 [21:20<28:29, 38.02it/s]

Training loss at epoch 35000: 3.5985114574432373 and val loss: 5.485127925872803


 36%|███████████████████████████                                                | 36006/100000 [21:47<27:40, 38.54it/s]

Training loss at epoch 36000: 4.456745624542236 and val loss: 5.753562927246094


 37%|███████████████████████████▊                                               | 37006/100000 [22:14<27:57, 37.56it/s]

Training loss at epoch 37000: 4.0489277839660645 and val loss: 5.499195098876953


 38%|████████████████████████████▌                                              | 38005/100000 [22:42<25:50, 39.98it/s]

Training loss at epoch 38000: 4.6321539878845215 and val loss: 5.438424110412598


 39%|█████████████████████████████▎                                             | 39005/100000 [23:10<27:33, 36.88it/s]

Training loss at epoch 39000: 3.8436520099639893 and val loss: 5.40513277053833


 40%|██████████████████████████████                                             | 40004/100000 [23:37<27:51, 35.88it/s]

Training loss at epoch 40000: 4.461236000061035 and val loss: 5.357844829559326


 41%|██████████████████████████████▊                                            | 41007/100000 [24:06<27:53, 35.26it/s]

Training loss at epoch 41000: 4.2664079666137695 and val loss: 5.435275077819824


 42%|███████████████████████████████▌                                           | 42009/100000 [24:34<25:56, 37.27it/s]

Training loss at epoch 42000: 3.8447277545928955 and val loss: 5.432535648345947


 43%|████████████████████████████████▎                                          | 43006/100000 [25:02<27:04, 35.09it/s]

Training loss at epoch 43000: 4.276347637176514 and val loss: 5.265191078186035


 44%|█████████████████████████████████                                          | 44007/100000 [25:31<25:35, 36.46it/s]

Training loss at epoch 44000: 4.087008476257324 and val loss: 5.456154823303223


 45%|█████████████████████████████████▊                                         | 45007/100000 [25:59<25:18, 36.22it/s]

Training loss at epoch 45000: 3.7871735095977783 and val loss: 5.5637407302856445


 46%|██████████████████████████████████▌                                        | 46005/100000 [26:27<24:31, 36.69it/s]

Training loss at epoch 46000: 3.7518064975738525 and val loss: 5.311429023742676


 47%|███████████████████████████████████▎                                       | 47005/100000 [26:55<26:44, 33.04it/s]

Training loss at epoch 47000: 4.237796306610107 and val loss: 5.271620273590088


 48%|████████████████████████████████████                                       | 48006/100000 [27:23<23:52, 36.29it/s]

Training loss at epoch 48000: 4.520655155181885 and val loss: 5.385852336883545


 49%|████████████████████████████████████▊                                      | 49008/100000 [27:51<24:31, 34.65it/s]

Training loss at epoch 49000: 4.149858474731445 and val loss: 5.353636264801025


 50%|█████████████████████████████████████▌                                     | 50004/100000 [28:19<26:59, 30.87it/s]

Training loss at epoch 50000: 3.8031740188598633 and val loss: 5.202701091766357


 51%|██████████████████████████████████████▎                                    | 51004/100000 [28:47<22:23, 36.48it/s]

Training loss at epoch 51000: 3.9201226234436035 and val loss: 5.2246785163879395


 52%|███████████████████████████████████████                                    | 52006/100000 [29:16<22:06, 36.17it/s]

Training loss at epoch 52000: 3.3305344581604004 and val loss: 5.180247783660889


 53%|███████████████████████████████████████▊                                   | 53005/100000 [29:44<22:36, 34.65it/s]

Training loss at epoch 53000: 3.77300763130188 and val loss: 5.1631364822387695


 54%|████████████████████████████████████████▌                                  | 54007/100000 [30:13<21:12, 36.14it/s]

Training loss at epoch 54000: 4.734382152557373 and val loss: 5.398687839508057


 55%|█████████████████████████████████████████▎                                 | 55007/100000 [30:41<20:55, 35.83it/s]

Training loss at epoch 55000: 4.0861005783081055 and val loss: 5.315264701843262


 56%|██████████████████████████████████████████                                 | 56006/100000 [31:10<21:22, 34.29it/s]

Training loss at epoch 56000: 3.4702680110931396 and val loss: 5.208903789520264


 57%|██████████████████████████████████████████▊                                | 57005/100000 [31:39<19:55, 35.98it/s]

Training loss at epoch 57000: 3.8807997703552246 and val loss: 5.206263065338135


 58%|███████████████████████████████████████████▌                               | 58005/100000 [32:07<20:53, 33.50it/s]

Training loss at epoch 58000: 3.735567569732666 and val loss: 5.2640604972839355


 59%|████████████████████████████████████████████▎                              | 59007/100000 [32:36<21:33, 31.69it/s]

Training loss at epoch 59000: 4.670029640197754 and val loss: 6.229090213775635


 60%|█████████████████████████████████████████████                              | 60008/100000 [33:06<18:39, 35.73it/s]

Training loss at epoch 60000: 3.9200961589813232 and val loss: 5.3692731857299805


 61%|█████████████████████████████████████████████▊                             | 61006/100000 [33:35<17:41, 36.74it/s]

Training loss at epoch 61000: 3.6724047660827637 and val loss: 5.002171993255615


 62%|██████████████████████████████████████████████▌                            | 62008/100000 [34:05<18:41, 33.88it/s]

Training loss at epoch 62000: 3.656468391418457 and val loss: 5.17384672164917


 63%|███████████████████████████████████████████████▎                           | 63006/100000 [34:34<17:31, 35.18it/s]

Training loss at epoch 63000: 3.1322202682495117 and val loss: 5.162734508514404


 64%|████████████████████████████████████████████████                           | 64005/100000 [35:03<16:29, 36.38it/s]

Training loss at epoch 64000: 3.6055264472961426 and val loss: 5.0714569091796875


 65%|████████████████████████████████████████████████▊                          | 65004/100000 [35:32<17:46, 32.81it/s]

Training loss at epoch 65000: 3.875314474105835 and val loss: 5.4478840827941895


 66%|█████████████████████████████████████████████████▌                         | 66008/100000 [36:01<15:51, 35.72it/s]

Training loss at epoch 66000: 3.645054340362549 and val loss: 5.271327972412109


 67%|██████████████████████████████████████████████████▎                        | 67005/100000 [36:30<15:27, 35.58it/s]

Training loss at epoch 67000: 3.486623764038086 and val loss: 5.107875823974609


 68%|███████████████████████████████████████████████████                        | 68004/100000 [36:59<17:27, 30.53it/s]

Training loss at epoch 68000: 3.409496545791626 and val loss: 5.135953903198242


 69%|███████████████████████████████████████████████████▊                       | 69004/100000 [37:28<15:09, 34.07it/s]

Training loss at epoch 69000: 3.679453134536743 and val loss: 5.063287734985352


 70%|████████████████████████████████████████████████████▌                      | 70007/100000 [37:58<13:53, 36.00it/s]

Training loss at epoch 70000: 3.5668981075286865 and val loss: 5.0621232986450195


 71%|█████████████████████████████████████████████████████▎                     | 71007/100000 [38:27<13:28, 35.84it/s]

Training loss at epoch 71000: 6.029941082000732 and val loss: 7.029469013214111


 72%|██████████████████████████████████████████████████████                     | 72008/100000 [38:56<13:19, 34.99it/s]

Training loss at epoch 72000: 4.131248950958252 and val loss: 5.207336902618408


 73%|██████████████████████████████████████████████████████▊                    | 73004/100000 [39:24<12:54, 34.87it/s]

Training loss at epoch 73000: 4.005063533782959 and val loss: 5.325002670288086


 74%|███████████████████████████████████████████████████████▌                   | 74008/100000 [39:53<12:08, 35.68it/s]

Training loss at epoch 74000: 3.7243611812591553 and val loss: 5.072598934173584


 75%|████████████████████████████████████████████████████████▎                  | 75007/100000 [40:23<12:43, 32.74it/s]

Training loss at epoch 75000: 3.099313735961914 and val loss: 4.9433064460754395


 76%|█████████████████████████████████████████████████████████                  | 76004/100000 [40:52<11:59, 33.34it/s]

Training loss at epoch 76000: 3.172071695327759 and val loss: 5.101936340332031


 77%|█████████████████████████████████████████████████████████▊                 | 77004/100000 [41:21<11:16, 33.98it/s]

Training loss at epoch 77000: 3.643979787826538 and val loss: 5.131873607635498


 78%|██████████████████████████████████████████████████████████▌                | 78007/100000 [41:51<10:17, 35.63it/s]

Training loss at epoch 78000: 3.502868890762329 and val loss: 5.336219787597656


 79%|███████████████████████████████████████████████████████████▎               | 79004/100000 [42:20<09:46, 35.78it/s]

Training loss at epoch 79000: 3.5021767616271973 and val loss: 5.124990940093994


 80%|████████████████████████████████████████████████████████████               | 80004/100000 [42:50<11:38, 28.65it/s]

Training loss at epoch 80000: 3.6654062271118164 and val loss: 5.275032997131348


 81%|████████████████████████████████████████████████████████████▊              | 81006/100000 [43:19<09:23, 33.68it/s]

Training loss at epoch 81000: 3.5227866172790527 and val loss: 5.139111042022705


 82%|█████████████████████████████████████████████████████████████▌             | 82006/100000 [43:48<09:02, 33.19it/s]

Training loss at epoch 82000: 2.935983657836914 and val loss: 5.156865119934082


 83%|██████████████████████████████████████████████████████████████▎            | 83004/100000 [44:17<08:11, 34.60it/s]

Training loss at epoch 83000: 3.2030115127563477 and val loss: 5.131353855133057


 84%|███████████████████████████████████████████████████████████████            | 84008/100000 [44:47<07:25, 35.89it/s]

Training loss at epoch 84000: 3.4740347862243652 and val loss: 5.242495059967041


 85%|███████████████████████████████████████████████████████████████▊           | 85006/100000 [45:16<07:14, 34.55it/s]

Training loss at epoch 85000: 3.4145970344543457 and val loss: 5.316702842712402


 86%|████████████████████████████████████████████████████████████████▌          | 86005/100000 [45:46<06:43, 34.67it/s]

Training loss at epoch 86000: 3.4403276443481445 and val loss: 5.189527988433838


 87%|█████████████████████████████████████████████████████████████████▎         | 87007/100000 [46:15<06:18, 34.34it/s]

Training loss at epoch 87000: 3.084632396697998 and val loss: 5.158740520477295


 88%|██████████████████████████████████████████████████████████████████         | 88005/100000 [46:45<05:49, 34.32it/s]

Training loss at epoch 88000: 3.6290900707244873 and val loss: 5.0219597816467285


 89%|██████████████████████████████████████████████████████████████████▊        | 89004/100000 [47:14<05:14, 35.00it/s]

Training loss at epoch 89000: 3.456634044647217 and val loss: 5.020110130310059


 90%|███████████████████████████████████████████████████████████████████▌       | 90005/100000 [47:44<04:50, 34.36it/s]

Training loss at epoch 90000: 3.22288179397583 and val loss: 4.955758094787598


 91%|████████████████████████████████████████████████████████████████████▎      | 91006/100000 [48:14<04:22, 34.23it/s]

Training loss at epoch 91000: 4.3998494148254395 and val loss: 5.6250834465026855


 92%|█████████████████████████████████████████████████████████████████████      | 92005/100000 [48:43<03:46, 35.29it/s]

Training loss at epoch 92000: 4.3413472175598145 and val loss: 5.824632167816162


 93%|█████████████████████████████████████████████████████████████████████▊     | 93007/100000 [49:13<03:32, 32.96it/s]

Training loss at epoch 93000: 3.723344326019287 and val loss: 5.146908760070801


 94%|██████████████████████████████████████████████████████████████████████▌    | 94007/100000 [49:43<03:00, 33.27it/s]

Training loss at epoch 94000: 3.422746181488037 and val loss: 5.172412872314453


 95%|███████████████████████████████████████████████████████████████████████▎   | 95005/100000 [50:12<02:23, 34.79it/s]

Training loss at epoch 95000: 3.6725995540618896 and val loss: 5.3782219886779785


 96%|████████████████████████████████████████████████████████████████████████   | 96006/100000 [50:42<01:55, 34.63it/s]

Training loss at epoch 96000: 3.618006467819214 and val loss: 5.163796424865723


 97%|████████████████████████████████████████████████████████████████████████▊  | 97008/100000 [51:12<01:22, 36.21it/s]

Training loss at epoch 97000: 3.071510076522827 and val loss: 5.01775598526001


 98%|█████████████████████████████████████████████████████████████████████████▌ | 98007/100000 [51:42<01:03, 31.14it/s]

Training loss at epoch 98000: 3.3404488563537598 and val loss: 5.058478355407715


 99%|██████████████████████████████████████████████████████████████████████████▎| 99005/100000 [52:12<00:27, 36.52it/s]

Training loss at epoch 99000: 3.8491227626800537 and val loss: 5.0261712074279785


100%|██████████████████████████████████████████████████████████████████████████| 100000/100000 [52:41<00:00, 31.63it/s]


In [37]:
### TRAIN SET ###
# Get model's predictions of NN
model_5.eval()
with torch.inference_mode():
    y_pred = model_5(X_train)

# Calculate total (mean) train error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_train))
print(f'Total (mean) train loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on train set
y_pred_naive = naive_forecast(X_train, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_train))
print(f'Total (mean) train loss of naive forecast: {naive_loss.item()}')

### VAL SET ###
# Get model's predictions of NN
model_5.eval()
with torch.inference_mode():
    y_pred = model_5(X_val)

# Calculate total (mean) val error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_val))
print(f'Total (mean) val loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on val set
y_pred_naive = naive_forecast(X_val, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_val))
print(f'Total (mean) val loss of naive forecast: {naive_loss.item()}')

### TEST SET ###
# Get model's predictions of NN
model_5.eval()
with torch.inference_mode():
    y_pred = model_5(X_test)

# Calculate total (mean) test error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_test))
print(f'Total (mean) test loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on test set
y_pred_naive = naive_forecast(X_test, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_test))
print(f'Total (mean) test loss of naive forecast: {naive_loss.item()}')

Total (mean) train loss of NN: 3.298945665359497
Total (mean) train loss of naive forecast: 12.17895221710205
Total (mean) val loss of NN: 4.999587059020996
Total (mean) val loss of naive forecast: 11.918389320373535
Total (mean) test loss of NN: 5.485226631164551
Total (mean) test loss of naive forecast: 12.19200611114502


## Experiment 6 - same as exp 5 but with three LSTM cells

* file: 1_single
* window size = 10, horizon = 10, max bird speed = 10 m/s, batch size = 64
* Triple GRU cell with hidden dim = 128 + Linear Layer
* Time not included in prediction

In [48]:
# Instantiate model, writer, optimizer and loss function
# Model parameters
INPUT_DIM = 3
HIDDEN_DIM = 128
OUTPUT_DIM = 3
NUM_LAYERS = 3

# Instantiate model, optimizer and loss function
model_6 = GRUModel(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM, NUM_LAYERS, horizon = 10)
optimizer = torch.optim.Adam(params = model_6.parameters())
loss_fn = torch.nn.functional.mse_loss
writer = SummaryWriter()

# Display model's summary
summary(model_6)

Layer (type:depth-idx)                   Param #
GRUModel                                 --
├─GRU: 1-1                               249,216
├─Linear: 1-2                            387
Total params: 249,603
Trainable params: 249,603
Non-trainable params: 0

In [49]:
# Train model
BATCH_SIZE = 64
NUM_EPOCHS = 100000
train(model_6, X_train, y_train, BATCH_SIZE, NUM_EPOCHS, optimizer, loss_fn, X_val, y_val, writer)
# Close writer
writer.close()
# Save model
torch.save(model_6, 'model_6.pt')

  0%|                                                                             | 5/100000 [00:00<1:14:58, 22.23it/s]

Training loss at epoch 0: 124.44750213623047 and val loss: 129.09237670898438


  1%|▋                                                                         | 1004/100000 [00:41<1:09:49, 23.63it/s]

Training loss at epoch 1000: 67.72344207763672 and val loss: 75.68167877197266


  2%|█▍                                                                        | 2002/100000 [01:28<2:14:43, 12.12it/s]

Training loss at epoch 2000: 59.71498107910156 and val loss: 49.551368713378906


  3%|██▏                                                                       | 3003/100000 [02:46<2:36:10, 10.35it/s]

Training loss at epoch 3000: 38.20257568359375 and val loss: 32.83209228515625


  4%|██▉                                                                       | 4002/100000 [04:19<2:28:56, 10.74it/s]

Training loss at epoch 4000: 18.028390884399414 and val loss: 20.530176162719727


  5%|███▋                                                                      | 5002/100000 [05:58<2:30:51, 10.49it/s]

Training loss at epoch 5000: 11.524444580078125 and val loss: 14.651620864868164


  6%|████▍                                                                     | 6003/100000 [07:32<2:29:49, 10.46it/s]

Training loss at epoch 6000: 4.504770755767822 and val loss: 8.12858772277832


  7%|█████▏                                                                    | 7002/100000 [09:08<2:25:23, 10.66it/s]

Training loss at epoch 7000: 4.364590167999268 and val loss: 6.430820465087891


  8%|█████▉                                                                    | 8002/100000 [10:43<2:28:11, 10.35it/s]

Training loss at epoch 8000: 4.879978656768799 and val loss: 5.692862510681152


  9%|██████▋                                                                   | 9003/100000 [12:21<2:34:53,  9.79it/s]

Training loss at epoch 9000: 4.293766498565674 and val loss: 5.367074012756348


 10%|███████▎                                                                 | 10002/100000 [14:00<2:30:20,  9.98it/s]

Training loss at epoch 10000: 2.4719603061676025 and val loss: 4.92047643661499


 11%|████████                                                                 | 11001/100000 [15:37<2:23:33, 10.33it/s]

Training loss at epoch 11000: 2.6897518634796143 and val loss: 4.726069450378418


 12%|████████▊                                                                | 12003/100000 [17:16<2:17:49, 10.64it/s]

Training loss at epoch 12000: 2.7381067276000977 and val loss: 4.8025712966918945


 13%|█████████▍                                                               | 13002/100000 [18:52<2:19:40, 10.38it/s]

Training loss at epoch 13000: 2.7205886840820312 and val loss: 4.784797191619873


 14%|██████████▏                                                              | 14002/100000 [20:28<2:16:46, 10.48it/s]

Training loss at epoch 14000: 3.2942428588867188 and val loss: 7.392183303833008


 15%|██████████▉                                                              | 15002/100000 [22:05<2:13:28, 10.61it/s]

Training loss at epoch 15000: 2.5478689670562744 and val loss: 4.687853813171387


 16%|███████████▋                                                             | 16002/100000 [23:42<2:11:52, 10.62it/s]

Training loss at epoch 16000: 2.483227014541626 and val loss: 4.452582836151123


 17%|████████████▍                                                            | 17002/100000 [25:22<2:13:08, 10.39it/s]

Training loss at epoch 17000: 2.098027229309082 and val loss: 4.383571624755859


 18%|█████████████▏                                                           | 18003/100000 [27:04<2:11:18, 10.41it/s]

Training loss at epoch 18000: 1.8511873483657837 and val loss: 4.416783809661865


 19%|█████████████▊                                                           | 19002/100000 [28:42<2:17:09,  9.84it/s]

Training loss at epoch 19000: 2.0163674354553223 and val loss: 4.119589328765869


 20%|██████████████▌                                                          | 20002/100000 [30:20<2:06:51, 10.51it/s]

Training loss at epoch 20000: 1.8745094537734985 and val loss: 4.267326354980469


 21%|███████████████▎                                                         | 21003/100000 [32:01<2:11:16, 10.03it/s]

Training loss at epoch 21000: 3.3084557056427 and val loss: 5.720621585845947


 22%|████████████████                                                         | 22002/100000 [33:39<2:22:21,  9.13it/s]

Training loss at epoch 22000: 2.5221264362335205 and val loss: 4.401837348937988


 23%|████████████████▊                                                        | 23002/100000 [35:20<2:06:32, 10.14it/s]

Training loss at epoch 23000: 1.8391239643096924 and val loss: 4.4600701332092285


 24%|█████████████████▌                                                       | 24001/100000 [37:01<2:06:14, 10.03it/s]

Training loss at epoch 24000: 2.1314215660095215 and val loss: 4.504339694976807


 25%|██████████████████▎                                                      | 25002/100000 [38:33<2:13:35,  9.36it/s]

Training loss at epoch 25000: 2.1221630573272705 and val loss: 4.226039886474609


 26%|██████████████████▉                                                      | 26002/100000 [40:11<2:07:45,  9.65it/s]

Training loss at epoch 26000: 1.692928671836853 and val loss: 4.253103256225586


 27%|███████████████████▋                                                     | 27002/100000 [41:53<1:58:36, 10.26it/s]

Training loss at epoch 27000: 1.6688616275787354 and val loss: 4.232803821563721


 28%|████████████████████▍                                                    | 28002/100000 [43:34<2:06:01,  9.52it/s]

Training loss at epoch 28000: 2.005094289779663 and val loss: 4.242537975311279


 29%|█████████████████████▏                                                   | 29002/100000 [45:15<2:10:23,  9.07it/s]

Training loss at epoch 29000: 1.7888950109481812 and val loss: 4.23056697845459


 30%|█████████████████████▉                                                   | 30003/100000 [46:56<1:55:29, 10.10it/s]

Training loss at epoch 30000: 1.3867510557174683 and val loss: 4.060686111450195


 31%|██████████████████████▋                                                  | 31002/100000 [48:35<2:03:23,  9.32it/s]

Training loss at epoch 31000: 1.9605321884155273 and val loss: 4.4907050132751465


 32%|███████████████████████▎                                                 | 32002/100000 [50:02<1:21:32, 13.90it/s]

Training loss at epoch 32000: 1.9000755548477173 and val loss: 4.339746952056885


 33%|████████████████████████                                                 | 33002/100000 [51:41<1:50:21, 10.12it/s]

Training loss at epoch 33000: 1.5292408466339111 and val loss: 4.372782230377197


 34%|████████████████████████▊                                                | 34002/100000 [53:17<1:24:20, 13.04it/s]

Training loss at epoch 34000: 2.478476047515869 and val loss: 9.456121444702148


 35%|█████████████████████████▌                                               | 35002/100000 [54:55<1:48:06, 10.02it/s]

Training loss at epoch 35000: 1.4459004402160645 and val loss: 4.351502895355225


 36%|██████████████████████████▎                                              | 36002/100000 [56:35<1:55:22,  9.25it/s]

Training loss at epoch 36000: 2.240300416946411 and val loss: 4.517820835113525


 37%|███████████████████████████                                              | 37002/100000 [58:15<2:02:14,  8.59it/s]

Training loss at epoch 37000: 1.4599552154541016 and val loss: 4.352935314178467


 38%|███████████████████████████▋                                             | 38002/100000 [59:55<1:41:03, 10.23it/s]

Training loss at epoch 38000: 2.790397882461548 and val loss: 4.295698642730713


 39%|███████████████████████████▋                                           | 39002/100000 [1:01:36<1:37:47, 10.40it/s]

Training loss at epoch 39000: 2.3205726146698 and val loss: 4.367804050445557


 40%|████████████████████████████▍                                          | 40002/100000 [1:03:16<1:35:21, 10.49it/s]

Training loss at epoch 40000: 1.4677668809890747 and val loss: 4.432207107543945


 41%|█████████████████████████████                                          | 41003/100000 [1:04:59<1:39:59,  9.83it/s]

Training loss at epoch 41000: 1.9371857643127441 and val loss: 4.6957244873046875


 42%|█████████████████████████████▊                                         | 42003/100000 [1:06:39<1:31:18, 10.59it/s]

Training loss at epoch 42000: 1.4795321226119995 and val loss: 4.729230880737305


 43%|██████████████████████████████▌                                        | 43002/100000 [1:08:18<1:32:51, 10.23it/s]

Training loss at epoch 43000: 1.6383129358291626 and val loss: 4.7775774002075195


 44%|███████████████████████████████▏                                       | 44002/100000 [1:09:59<1:35:44,  9.75it/s]

Training loss at epoch 44000: 1.432542085647583 and val loss: 4.643391132354736


 45%|███████████████████████████████▉                                       | 45003/100000 [1:11:46<1:32:16,  9.93it/s]

Training loss at epoch 45000: 1.5456115007400513 and val loss: 4.585073947906494


 46%|████████████████████████████████▋                                      | 46002/100000 [1:13:29<1:27:53, 10.24it/s]

Training loss at epoch 46000: 1.9347121715545654 and val loss: 4.649713039398193


 47%|█████████████████████████████████▎                                     | 47002/100000 [1:15:15<1:41:34,  8.70it/s]

Training loss at epoch 47000: 1.319542646408081 and val loss: 4.626134872436523


 48%|██████████████████████████████████                                     | 48001/100000 [1:16:58<1:28:21,  9.81it/s]

Training loss at epoch 48000: 6.570464134216309 and val loss: 7.438607692718506


 49%|██████████████████████████████████▊                                    | 49002/100000 [1:18:42<1:25:43,  9.91it/s]

Training loss at epoch 49000: 1.4896119832992554 and val loss: 4.549562454223633


 50%|███████████████████████████████████▌                                   | 50003/100000 [1:20:25<1:28:08,  9.45it/s]

Training loss at epoch 50000: 1.582906723022461 and val loss: 4.461585998535156


 51%|████████████████████████████████████▏                                  | 51003/100000 [1:22:04<1:03:53, 12.78it/s]

Training loss at epoch 51000: 1.3131351470947266 and val loss: 4.340688228607178


 52%|████████████████████████████████████▉                                  | 52003/100000 [1:23:33<1:11:57, 11.12it/s]

Training loss at epoch 52000: 1.6399059295654297 and val loss: 4.450782775878906


 53%|█████████████████████████████████████▋                                 | 53002/100000 [1:25:26<1:29:45,  8.73it/s]

Training loss at epoch 53000: 1.3692549467086792 and val loss: 4.885272979736328


 54%|██████████████████████████████████████▎                                | 54002/100000 [1:27:23<1:58:48,  6.45it/s]

Training loss at epoch 54000: 1.7065696716308594 and val loss: 4.681830883026123


 55%|███████████████████████████████████████                                | 55002/100000 [1:29:19<1:24:25,  8.88it/s]

Training loss at epoch 55000: 1.1104918718338013 and val loss: 4.41502571105957


 56%|███████████████████████████████████████▊                               | 56002/100000 [1:31:21<1:40:18,  7.31it/s]

Training loss at epoch 56000: 1.4114940166473389 and val loss: 4.409677982330322


 57%|████████████████████████████████████████▍                              | 57002/100000 [1:33:39<1:41:23,  7.07it/s]

Training loss at epoch 57000: 1.1266449689865112 and val loss: 4.284292221069336


 58%|█████████████████████████████████████████▏                             | 58002/100000 [1:35:58<2:25:56,  4.80it/s]

Training loss at epoch 58000: 1.0282299518585205 and val loss: 4.428626537322998


 59%|█████████████████████████████████████████▉                             | 59002/100000 [1:38:15<1:27:00,  7.85it/s]

Training loss at epoch 59000: 1.6332314014434814 and val loss: 4.405017852783203


 60%|██████████████████████████████████████████▌                            | 60002/100000 [1:40:09<1:08:56,  9.67it/s]

Training loss at epoch 60000: 1.1480422019958496 and val loss: 4.487508296966553


 61%|███████████████████████████████████████████▎                           | 61002/100000 [1:41:55<1:06:06,  9.83it/s]

Training loss at epoch 61000: 1.511053204536438 and val loss: 5.052248477935791


 62%|████████████████████████████████████████████                           | 62002/100000 [1:43:38<1:05:37,  9.65it/s]

Training loss at epoch 62000: 1.5050528049468994 and val loss: 4.967014312744141


 63%|████████████████████████████████████████████▋                          | 63002/100000 [1:45:24<1:06:55,  9.21it/s]

Training loss at epoch 63000: 1.3345938920974731 and val loss: 4.79898738861084


 64%|█████████████████████████████████████████████▍                         | 64003/100000 [1:47:09<1:01:30,  9.75it/s]

Training loss at epoch 64000: 1.356967568397522 and val loss: 4.59699010848999


 65%|██████████████████████████████████████████████▏                        | 65002/100000 [1:48:58<1:04:41,  9.02it/s]

Training loss at epoch 65000: 1.032823920249939 and val loss: 4.602235317230225


 66%|██████████████████████████████████████████████▊                        | 66002/100000 [1:50:43<1:01:35,  9.20it/s]

Training loss at epoch 66000: 1.4028972387313843 and val loss: 4.6619553565979


 67%|████████████████████████████████████████████████▉                        | 67002/100000 [1:52:27<57:16,  9.60it/s]

Training loss at epoch 67000: 1.0552603006362915 and val loss: 4.624011993408203


 68%|█████████████████████████████████████████████████▋                       | 68002/100000 [1:54:08<42:44, 12.48it/s]

Training loss at epoch 68000: 1.8868030309677124 and val loss: 4.60003662109375


 69%|██████████████████████████████████████████████████▎                      | 69002/100000 [1:55:26<39:56, 12.94it/s]

Training loss at epoch 69000: 1.8172166347503662 and val loss: 4.807940483093262


 70%|███████████████████████████████████████████████████                      | 70002/100000 [1:56:42<39:55, 12.52it/s]

Training loss at epoch 70000: 2.3117849826812744 and val loss: 4.782109260559082


 71%|███████████████████████████████████████████████████▊                     | 71002/100000 [1:57:56<40:22, 11.97it/s]

Training loss at epoch 71000: 3.2967026233673096 and val loss: 4.945722579956055


 72%|████████████████████████████████████████████████████▌                    | 72002/100000 [1:59:11<33:58, 13.73it/s]

Training loss at epoch 72000: 2.058145523071289 and val loss: 4.860897541046143


 73%|█████████████████████████████████████████████████████▎                   | 73002/100000 [2:00:27<33:08, 13.58it/s]

Training loss at epoch 73000: 1.0008302927017212 and val loss: 4.705759048461914


 74%|██████████████████████████████████████████████████████                   | 74002/100000 [2:01:42<34:00, 12.74it/s]

Training loss at epoch 74000: 1.1779991388320923 and val loss: 4.728811740875244


 75%|██████████████████████████████████████████████████████▊                  | 75002/100000 [2:02:56<29:52, 13.95it/s]

Training loss at epoch 75000: 0.8507617712020874 and val loss: 4.781461238861084


 76%|███████████████████████████████████████████████████████▍                 | 76002/100000 [2:04:11<29:19, 13.64it/s]

Training loss at epoch 76000: 1.5009517669677734 and val loss: 4.745961666107178


 77%|████████████████████████████████████████████████████████▏                | 77002/100000 [2:05:55<41:43,  9.19it/s]

Training loss at epoch 77000: 1.2900325059890747 and val loss: 4.6458868980407715


 78%|████████████████████████████████████████████████████████▉                | 78003/100000 [2:07:48<34:07, 10.74it/s]

Training loss at epoch 78000: 2.159477949142456 and val loss: 4.8414387702941895


 79%|█████████████████████████████████████████████████████████▋               | 79002/100000 [2:09:44<49:53,  7.01it/s]

Training loss at epoch 79000: 1.3522088527679443 and val loss: 4.9247145652771


 80%|██████████████████████████████████████████████████████████▍              | 80002/100000 [2:11:42<38:26,  8.67it/s]

Training loss at epoch 80000: 0.8961162567138672 and val loss: 4.729012489318848


 81%|███████████████████████████████████████████████████████████▏             | 81002/100000 [2:13:36<36:51,  8.59it/s]

Training loss at epoch 81000: 1.0310741662979126 and val loss: 4.721801280975342


 82%|███████████████████████████████████████████████████████████▊             | 82002/100000 [2:15:33<44:19,  6.77it/s]

Training loss at epoch 82000: 0.917149543762207 and val loss: 4.929888725280762


 83%|████████████████████████████████████████████████████████████▌            | 83002/100000 [2:17:45<24:16, 11.67it/s]

Training loss at epoch 83000: 1.1950767040252686 and val loss: 4.575008392333984


 84%|█████████████████████████████████████████████████████████████▎           | 84002/100000 [2:19:31<20:21, 13.09it/s]

Training loss at epoch 84000: 0.9640659689903259 and val loss: 4.84622049331665


 85%|██████████████████████████████████████████████████████████████           | 85001/100000 [2:20:56<24:17, 10.29it/s]

Training loss at epoch 85000: 1.4148632287979126 and val loss: 4.9439826011657715


 86%|██████████████████████████████████████████████████████████████▊          | 86002/100000 [2:22:20<35:07,  6.64it/s]

Training loss at epoch 86000: 1.7057324647903442 and val loss: 5.2923102378845215


 87%|███████████████████████████████████████████████████████████████▌         | 87002/100000 [2:23:53<16:17, 13.30it/s]

Training loss at epoch 87000: 1.3370863199234009 and val loss: 4.739258766174316


 88%|████████████████████████████████████████████████████████████████▏        | 88002/100000 [2:25:21<19:39, 10.18it/s]

Training loss at epoch 88000: 0.8869497179985046 and val loss: 4.835869789123535


 89%|████████████████████████████████████████████████████████████████▉        | 89002/100000 [2:26:41<14:14, 12.87it/s]

Training loss at epoch 89000: 1.4108837842941284 and val loss: 4.964170455932617


 90%|█████████████████████████████████████████████████████████████████▋       | 90002/100000 [2:27:56<13:39, 12.20it/s]

Training loss at epoch 90000: 1.0519366264343262 and val loss: 4.892765998840332


 91%|██████████████████████████████████████████████████████████████████▍      | 91002/100000 [2:29:11<11:21, 13.20it/s]

Training loss at epoch 91000: 1.4551938772201538 and val loss: 4.9179887771606445


 92%|███████████████████████████████████████████████████████████████████▏     | 92002/100000 [2:30:25<10:08, 13.14it/s]

Training loss at epoch 92000: 1.1494158506393433 and val loss: 4.653254508972168


 93%|███████████████████████████████████████████████████████████████████▉     | 93002/100000 [2:31:40<08:23, 13.89it/s]

Training loss at epoch 93000: 1.2769442796707153 and val loss: 4.998660564422607


 94%|████████████████████████████████████████████████████████████████████▌    | 94002/100000 [2:32:54<07:26, 13.42it/s]

Training loss at epoch 94000: 1.2185875177383423 and val loss: 4.945343494415283


 95%|█████████████████████████████████████████████████████████████████████▎   | 95002/100000 [2:34:09<06:10, 13.48it/s]

Training loss at epoch 95000: 1.1068943738937378 and val loss: 4.849132061004639


 96%|██████████████████████████████████████████████████████████████████████   | 96002/100000 [2:35:24<04:53, 13.62it/s]

Training loss at epoch 96000: 1.0154929161071777 and val loss: 4.689638614654541


 97%|██████████████████████████████████████████████████████████████████████▊  | 97002/100000 [2:36:39<03:34, 13.98it/s]

Training loss at epoch 97000: 2.2846689224243164 and val loss: 4.769400596618652


 98%|███████████████████████████████████████████████████████████████████████▌ | 98002/100000 [2:37:54<02:27, 13.52it/s]

Training loss at epoch 98000: 0.8380952477455139 and val loss: 4.599262237548828


 99%|████████████████████████████████████████████████████████████████████████▎| 99002/100000 [2:39:08<01:13, 13.65it/s]

Training loss at epoch 99000: 0.8304106593132019 and val loss: 4.6124491691589355


100%|████████████████████████████████████████████████████████████████████████| 100000/100000 [2:40:23<00:00, 10.39it/s]


In [50]:
### TRAIN SET ###
# Get model's predictions of NN
model_6.eval()
with torch.inference_mode():
    y_pred = model_6(X_train)

# Calculate total (mean) train error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_train))
print(f'Total (mean) train loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on train set
y_pred_naive = naive_forecast(X_train, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_train))
print(f'Total (mean) train loss of naive forecast: {naive_loss.item()}')

### VAL SET ###
# Get model's predictions of NN
model_6.eval()
with torch.inference_mode():
    y_pred = model_6(X_val)

# Calculate total (mean) val error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_val))
print(f'Total (mean) val loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on val set
y_pred_naive = naive_forecast(X_val, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_val))
print(f'Total (mean) val loss of naive forecast: {naive_loss.item()}')

### TEST SET ###
# Get model's predictions of NN
model_6.eval()
with torch.inference_mode():
    y_pred = model_6(X_test)

# Calculate total (mean) test error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_test))
print(f'Total (mean) test loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on test set
y_pred_naive = naive_forecast(X_test, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_test))
print(f'Total (mean) test loss of naive forecast: {naive_loss.item()}')

Total (mean) train loss of NN: 0.8530263900756836
Total (mean) train loss of naive forecast: 12.17895221710205
Total (mean) val loss of NN: 4.64928674697876
Total (mean) val loss of naive forecast: 11.918389320373535
Total (mean) test loss of NN: 5.5459370613098145
Total (mean) test loss of naive forecast: 12.19200611114502


## Experiment 7 - same as experiment 1 but with RNN cell

* file: 1_single
* window size = 10, horizon = 10, max bird speed = 10 m/s, batch size = 64
* Single RNN cell with hidden dim = 32 + Linear Layer
* Time not included in prediction

In [39]:
# Instantiate model, writer, optimizer and loss function
# Model parameters
INPUT_DIM = 3
HIDDEN_DIM = 32
OUTPUT_DIM = 3
NUM_LAYERS = 1

# Instantiate model, optimizer and loss function
model_7 = RNNModel(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM, NUM_LAYERS, horizon = 10)
optimizer = torch.optim.Adam(params = model_7.parameters())
loss_fn = torch.nn.functional.mse_loss
writer = SummaryWriter()

# Display model's summary
summary(model_7)

Layer (type:depth-idx)                   Param #
RNNModel                                 --
├─RNN: 1-1                               1,184
├─Linear: 1-2                            99
Total params: 1,283
Trainable params: 1,283
Non-trainable params: 0

In [40]:
# Train model
BATCH_SIZE = 64
NUM_EPOCHS = 100000
train(model_7, X_train, y_train, BATCH_SIZE, NUM_EPOCHS, optimizer, loss_fn, X_val, y_val, writer)
# Close writer
writer.close()
# Save model
torch.save(model_7, 'model_7.pt')

  0%|                                                                             | 38/100000 [00:00<08:03, 206.89it/s]

Training loss at epoch 0: 137.3274688720703 and val loss: 129.20223999023438


  1%|▊                                                                          | 1056/100000 [00:03<05:37, 293.01it/s]

Training loss at epoch 1000: 100.59381103515625 and val loss: 111.73224639892578


  2%|█▌                                                                         | 2030/100000 [00:07<05:57, 273.85it/s]

Training loss at epoch 2000: 96.8466567993164 and val loss: 98.50167846679688


  3%|██▎                                                                        | 3019/100000 [00:15<15:01, 107.56it/s]

Training loss at epoch 3000: 84.02759552001953 and val loss: 87.67036437988281


  4%|███                                                                        | 4029/100000 [00:24<11:42, 136.68it/s]

Training loss at epoch 4000: 76.64600372314453 and val loss: 78.3804702758789


  5%|███▊                                                                       | 5024/100000 [00:31<11:56, 132.58it/s]

Training loss at epoch 5000: 70.76947784423828 and val loss: 69.89482879638672


  6%|████▌                                                                      | 6015/100000 [00:38<11:03, 141.69it/s]

Training loss at epoch 6000: 67.82173156738281 and val loss: 62.653656005859375


  7%|█████▎                                                                     | 7023/100000 [00:47<10:41, 145.05it/s]

Training loss at epoch 7000: 50.548301696777344 and val loss: 56.57125473022461


  8%|██████                                                                     | 8028/100000 [00:55<11:55, 128.49it/s]

Training loss at epoch 8000: 43.335594177246094 and val loss: 51.43598937988281


  9%|██████▊                                                                    | 9028/100000 [01:02<10:22, 146.17it/s]

Training loss at epoch 9000: 53.13309097290039 and val loss: 46.93874740600586


 10%|███████▍                                                                  | 10019/100000 [01:09<10:42, 139.96it/s]

Training loss at epoch 10000: 39.08697509765625 and val loss: 42.70071029663086


 11%|████████▏                                                                 | 11015/100000 [01:18<12:53, 114.97it/s]

Training loss at epoch 11000: 38.081451416015625 and val loss: 39.249290466308594


 12%|████████▉                                                                 | 12030/100000 [01:25<09:05, 161.21it/s]

Training loss at epoch 12000: 25.997623443603516 and val loss: 35.968929290771484


 13%|█████████▋                                                                | 13023/100000 [01:33<10:28, 138.48it/s]

Training loss at epoch 13000: 34.47702407836914 and val loss: 32.84597396850586


 14%|██████████▎                                                               | 14014/100000 [01:41<13:07, 109.19it/s]

Training loss at epoch 14000: 16.5297908782959 and val loss: 30.462583541870117


 15%|███████████▏                                                              | 15034/100000 [01:48<08:04, 175.25it/s]

Training loss at epoch 15000: 23.79917335510254 and val loss: 28.67914390563965


 16%|███████████▊                                                              | 16022/100000 [01:56<11:15, 124.37it/s]

Training loss at epoch 16000: 26.311569213867188 and val loss: 26.3497314453125


 17%|████████████▌                                                             | 17036/100000 [02:04<08:39, 159.82it/s]

Training loss at epoch 17000: 20.011632919311523 and val loss: 24.42633628845215


 18%|█████████████▎                                                            | 18026/100000 [02:11<10:41, 127.79it/s]

Training loss at epoch 18000: 17.615203857421875 and val loss: 22.805788040161133


 19%|██████████████                                                            | 19014/100000 [02:20<10:58, 123.08it/s]

Training loss at epoch 19000: 18.3958683013916 and val loss: 21.867324829101562


 20%|██████████████▊                                                           | 20027/100000 [02:27<08:38, 154.29it/s]

Training loss at epoch 20000: 15.175789833068848 and val loss: 20.426841735839844


 21%|███████████████▌                                                          | 21032/100000 [02:35<09:59, 131.72it/s]

Training loss at epoch 21000: 18.58514976501465 and val loss: 20.09514808654785


 22%|████████████████▎                                                         | 22009/100000 [02:42<08:56, 145.28it/s]

Training loss at epoch 22000: 13.688844680786133 and val loss: 18.774904251098633


 23%|█████████████████                                                         | 23022/100000 [02:50<08:08, 157.58it/s]

Training loss at epoch 23000: 14.280457496643066 and val loss: 18.08045768737793


 24%|█████████████████▊                                                        | 24018/100000 [02:58<09:44, 130.02it/s]

Training loss at epoch 24000: 19.388063430786133 and val loss: 17.448070526123047


 25%|██████████████████▌                                                       | 25024/100000 [03:06<09:26, 132.43it/s]

Training loss at epoch 25000: 13.246664047241211 and val loss: 17.617523193359375


 26%|███████████████████▌                                                       | 26018/100000 [03:14<12:25, 99.27it/s]

Training loss at epoch 26000: 12.17603874206543 and val loss: 16.90740966796875


 27%|███████████████████▉                                                      | 27025/100000 [03:21<09:24, 129.31it/s]

Training loss at epoch 27000: 11.04577350616455 and val loss: 16.22063636779785


 28%|████████████████████▋                                                     | 28027/100000 [03:28<08:40, 138.16it/s]

Training loss at epoch 28000: 13.235444068908691 and val loss: 15.724400520324707


 29%|█████████████████████▍                                                    | 29010/100000 [03:37<08:36, 137.47it/s]

Training loss at epoch 29000: 12.646672248840332 and val loss: 15.8872652053833


 30%|██████████████████████▏                                                   | 30024/100000 [03:44<08:43, 133.55it/s]

Training loss at epoch 30000: 13.592710494995117 and val loss: 15.51380443572998


 31%|██████████████████████▉                                                   | 31018/100000 [03:52<10:40, 107.78it/s]

Training loss at epoch 31000: 12.17455005645752 and val loss: 15.027888298034668


 32%|███████████████████████▋                                                  | 32012/100000 [03:59<06:44, 168.14it/s]

Training loss at epoch 32000: 13.136086463928223 and val loss: 15.320459365844727


 33%|████████████████████████▍                                                 | 33018/100000 [04:07<09:35, 116.40it/s]

Training loss at epoch 33000: 14.652301788330078 and val loss: 15.2644624710083


 34%|█████████████████████████▏                                                | 34019/100000 [04:16<09:03, 121.47it/s]

Training loss at epoch 34000: 12.42431354522705 and val loss: 14.838239669799805


 35%|█████████████████████████▉                                                | 35023/100000 [04:25<08:17, 130.50it/s]

Training loss at epoch 35000: 11.048524856567383 and val loss: 14.684837341308594


 36%|██████████████████████████▋                                               | 36024/100000 [04:35<09:14, 115.48it/s]

Training loss at epoch 36000: 12.402775764465332 and val loss: 14.222508430480957


 37%|███████████████████████████▍                                              | 37016/100000 [04:45<08:38, 121.39it/s]

Training loss at epoch 37000: 11.34086799621582 and val loss: 15.080581665039062


 38%|████████████████████████████▏                                             | 38021/100000 [04:53<07:02, 146.73it/s]

Training loss at epoch 38000: 10.593944549560547 and val loss: 13.797372817993164


 39%|████████████████████████████▊                                             | 39015/100000 [05:01<08:40, 117.11it/s]

Training loss at epoch 39000: 14.023262023925781 and val loss: 14.525871276855469


 40%|██████████████████████████████                                             | 40015/100000 [05:12<12:56, 77.30it/s]

Training loss at epoch 40000: 10.893782615661621 and val loss: 13.843500137329102


 41%|██████████████████████████████▎                                           | 41021/100000 [05:21<09:25, 104.25it/s]

Training loss at epoch 41000: 11.434021949768066 and val loss: 13.489623069763184


 42%|███████████████████████████████                                           | 42017/100000 [05:30<07:28, 129.23it/s]

Training loss at epoch 42000: 11.624723434448242 and val loss: 13.742464065551758


 43%|███████████████████████████████▊                                          | 43014/100000 [05:40<08:31, 111.42it/s]

Training loss at epoch 43000: 13.966865539550781 and val loss: 13.21178913116455


 44%|████████████████████████████████▌                                         | 44017/100000 [05:48<06:39, 140.19it/s]

Training loss at epoch 44000: 10.609554290771484 and val loss: 13.522159576416016


 45%|█████████████████████████████████▎                                        | 45017/100000 [05:58<06:09, 148.96it/s]

Training loss at epoch 45000: 12.690996170043945 and val loss: 13.575767517089844


 46%|██████████████████████████████████                                        | 46019/100000 [06:06<07:32, 119.23it/s]

Training loss at epoch 46000: 11.992045402526855 and val loss: 13.349139213562012


 47%|██████████████████████████████████▊                                       | 47010/100000 [06:15<08:42, 101.44it/s]

Training loss at epoch 47000: 10.852740287780762 and val loss: 13.42092227935791


 48%|███████████████████████████████████▌                                      | 48016/100000 [06:23<07:51, 110.32it/s]

Training loss at epoch 48000: 10.073498725891113 and val loss: 13.05849838256836


 49%|████████████████████████████████████▎                                     | 49023/100000 [06:32<06:47, 124.96it/s]

Training loss at epoch 49000: 10.179281234741211 and val loss: 12.897873878479004


 50%|█████████████████████████████████████                                     | 50015/100000 [06:40<05:43, 145.53it/s]

Training loss at epoch 50000: 11.061434745788574 and val loss: 13.276043891906738


 51%|█████████████████████████████████████▊                                    | 51017/100000 [06:49<06:40, 122.39it/s]

Training loss at epoch 51000: 11.477823257446289 and val loss: 12.846970558166504


 52%|███████████████████████████████████████                                    | 52014/100000 [06:58<09:56, 80.39it/s]

Training loss at epoch 52000: 13.089669227600098 and val loss: 12.937256813049316


 53%|███████████████████████████████████████▏                                  | 53022/100000 [07:07<07:28, 104.73it/s]

Training loss at epoch 53000: 11.994694709777832 and val loss: 13.017019271850586


 54%|███████████████████████████████████████▉                                  | 54010/100000 [07:15<05:39, 135.37it/s]

Training loss at epoch 54000: 11.394187927246094 and val loss: 13.77646255493164


 55%|████████████████████████████████████████▋                                 | 55023/100000 [07:24<05:25, 138.16it/s]

Training loss at epoch 55000: 11.038025856018066 and val loss: 12.660223007202148


 56%|█████████████████████████████████████████▍                                | 56016/100000 [07:30<05:31, 132.56it/s]

Training loss at epoch 56000: 12.897871017456055 and val loss: 12.830877304077148


 57%|██████████████████████████████████████████▏                               | 57019/100000 [07:38<05:58, 119.83it/s]

Training loss at epoch 57000: 11.899423599243164 and val loss: 12.712605476379395


 58%|██████████████████████████████████████████▉                               | 58023/100000 [07:46<05:32, 126.32it/s]

Training loss at epoch 58000: 11.450024604797363 and val loss: 12.503971099853516


 59%|███████████████████████████████████████████▋                              | 59021/100000 [07:54<05:14, 130.48it/s]

Training loss at epoch 59000: 10.493793487548828 and val loss: 12.556475639343262


 60%|████████████████████████████████████████████▍                             | 60018/100000 [08:02<05:05, 130.86it/s]

Training loss at epoch 60000: 10.300625801086426 and val loss: 12.783614158630371


 61%|█████████████████████████████████████████████▏                            | 61019/100000 [08:09<05:20, 121.78it/s]

Training loss at epoch 61000: 12.921077728271484 and val loss: 13.143132209777832


 62%|█████████████████████████████████████████████▉                            | 62012/100000 [08:18<04:49, 131.01it/s]

Training loss at epoch 62000: 12.804099082946777 and val loss: 13.535818099975586


 63%|██████████████████████████████████████████████▋                           | 63013/100000 [08:26<04:08, 148.73it/s]

Training loss at epoch 63000: 11.934626579284668 and val loss: 13.092979431152344


 64%|███████████████████████████████████████████████▍                          | 64027/100000 [08:35<04:14, 141.34it/s]

Training loss at epoch 64000: 11.225589752197266 and val loss: 12.542888641357422


 65%|████████████████████████████████████████████████                          | 65019/100000 [08:43<04:49, 120.98it/s]

Training loss at epoch 65000: 11.151924133300781 and val loss: 12.218255996704102


 66%|████████████████████████████████████████████████▊                         | 66017/100000 [08:52<05:24, 104.79it/s]

Training loss at epoch 66000: 16.438697814941406 and val loss: 15.58419418334961


 67%|█████████████████████████████████████████████████▌                        | 67022/100000 [08:59<03:58, 138.23it/s]

Training loss at epoch 67000: 10.996216773986816 and val loss: 12.09946346282959


 68%|██████████████████████████████████████████████████▎                       | 68019/100000 [09:06<03:55, 135.94it/s]

Training loss at epoch 68000: 11.842275619506836 and val loss: 13.089942932128906


 69%|███████████████████████████████████████████████████                       | 69013/100000 [09:14<04:17, 120.52it/s]

Training loss at epoch 69000: 8.88129997253418 and val loss: 12.168670654296875


 70%|███████████████████████████████████████████████████▊                      | 70025/100000 [09:21<04:00, 124.80it/s]

Training loss at epoch 70000: 11.121689796447754 and val loss: 11.989982604980469


 71%|████████████████████████████████████████████████████▌                     | 71022/100000 [09:28<02:35, 186.53it/s]

Training loss at epoch 71000: 11.54361629486084 and val loss: 12.126992225646973


 72%|█████████████████████████████████████████████████████▎                    | 72021/100000 [09:36<03:42, 126.02it/s]

Training loss at epoch 72000: 11.146788597106934 and val loss: 12.591891288757324


 73%|██████████████████████████████████████████████████████                    | 73023/100000 [09:43<03:15, 138.31it/s]

Training loss at epoch 73000: 11.12774658203125 and val loss: 12.174415588378906


 74%|██████████████████████████████████████████████████████▊                   | 74019/100000 [09:51<03:09, 136.83it/s]

Training loss at epoch 74000: 12.6432523727417 and val loss: 14.127776145935059


 75%|███████████████████████████████████████████████████████▌                  | 75015/100000 [10:00<03:29, 119.00it/s]

Training loss at epoch 75000: 10.548040390014648 and val loss: 12.17220401763916


 76%|████████████████████████████████████████████████████████▎                 | 76017/100000 [10:07<02:40, 149.49it/s]

Training loss at epoch 76000: 9.563618659973145 and val loss: 11.96389389038086


 77%|████████████████████████████████████████████████████████▉                 | 77025/100000 [10:15<02:31, 151.85it/s]

Training loss at epoch 77000: 10.210723876953125 and val loss: 12.42253589630127


 78%|█████████████████████████████████████████████████████████▋                | 78027/100000 [10:23<02:38, 138.65it/s]

Training loss at epoch 78000: 9.498786926269531 and val loss: 12.327790260314941


 79%|███████████████████████████████████████████████████████████▎               | 79012/100000 [10:30<03:30, 99.62it/s]

Training loss at epoch 79000: 11.93457317352295 and val loss: 12.150471687316895


 80%|███████████████████████████████████████████████████████████▏              | 80026/100000 [10:38<02:26, 136.38it/s]

Training loss at epoch 80000: 10.95189380645752 and val loss: 12.111923217773438


 81%|███████████████████████████████████████████████████████████▉              | 81025/100000 [10:45<02:19, 136.09it/s]

Training loss at epoch 81000: 10.667658805847168 and val loss: 11.957500457763672


 82%|████████████████████████████████████████████████████████████▋             | 82021/100000 [10:53<02:22, 126.33it/s]

Training loss at epoch 82000: 11.40322494506836 and val loss: 11.803181648254395


 83%|█████████████████████████████████████████████████████████████▍            | 83019/100000 [11:01<02:10, 130.00it/s]

Training loss at epoch 83000: 10.800077438354492 and val loss: 11.9969482421875


 84%|██████████████████████████████████████████████████████████████▏           | 84018/100000 [11:08<01:50, 144.19it/s]

Training loss at epoch 84000: 10.669028282165527 and val loss: 11.83810806274414


 85%|██████████████████████████████████████████████████████████████▉           | 85019/100000 [11:16<01:37, 152.98it/s]

Training loss at epoch 85000: 11.72805118560791 and val loss: 11.9487886428833


 86%|███████████████████████████████████████████████████████████████▋          | 86021/100000 [11:23<01:34, 147.76it/s]

Training loss at epoch 86000: 9.885876655578613 and val loss: 12.43182373046875


 87%|████████████████████████████████████████████████████████████████▍         | 87023/100000 [11:31<01:31, 141.21it/s]

Training loss at epoch 87000: 11.96380615234375 and val loss: 11.842769622802734


 88%|█████████████████████████████████████████████████████████████████▏        | 88009/100000 [11:39<01:34, 127.20it/s]

Training loss at epoch 88000: 12.932095527648926 and val loss: 11.98551082611084


 89%|█████████████████████████████████████████████████████████████████▊        | 89016/100000 [11:47<01:23, 131.08it/s]

Training loss at epoch 89000: 9.789331436157227 and val loss: 11.686487197875977


 90%|██████████████████████████████████████████████████████████████████▌       | 90015/100000 [11:55<01:28, 112.79it/s]

Training loss at epoch 90000: 10.961058616638184 and val loss: 12.531319618225098


 91%|███████████████████████████████████████████████████████████████████▎      | 91020/100000 [12:04<01:09, 129.67it/s]

Training loss at epoch 91000: 11.735631942749023 and val loss: 12.65945816040039


 92%|████████████████████████████████████████████████████████████████████      | 92015/100000 [12:13<01:00, 131.41it/s]

Training loss at epoch 92000: 10.886600494384766 and val loss: 12.126733779907227


 93%|████████████████████████████████████████████████████████████████████▊     | 93012/100000 [12:21<00:55, 125.42it/s]

Training loss at epoch 93000: 9.230188369750977 and val loss: 11.636738777160645


 94%|█████████████████████████████████████████████████████████████████████▌    | 94015/100000 [12:29<00:49, 121.41it/s]

Training loss at epoch 94000: 10.242606163024902 and val loss: 12.31676197052002


 95%|███████████████████████████████████████████████████████████████████████▎   | 95013/100000 [12:38<00:51, 97.19it/s]

Training loss at epoch 95000: 12.39012622833252 and val loss: 11.742634773254395


 96%|███████████████████████████████████████████████████████████████████████   | 96019/100000 [12:47<00:27, 142.89it/s]

Training loss at epoch 96000: 10.556772232055664 and val loss: 11.849820137023926


 97%|████████████████████████████████████████████████████████████████████████▊  | 97011/100000 [12:55<00:38, 76.84it/s]

Training loss at epoch 97000: 11.273458480834961 and val loss: 11.715954780578613


 98%|████████████████████████████████████████████████████████████████████████▌ | 98024/100000 [13:06<00:15, 124.66it/s]

Training loss at epoch 98000: 11.48328971862793 and val loss: 12.203266143798828


 99%|█████████████████████████████████████████████████████████████████████████▎| 99023/100000 [13:14<00:07, 138.48it/s]

Training loss at epoch 99000: 11.04752254486084 and val loss: 11.70565128326416


100%|█████████████████████████████████████████████████████████████████████████| 100000/100000 [13:22<00:00, 124.66it/s]


In [41]:
### TRAIN SET ###
# Get model's predictions of NN
model_7.eval()
with torch.inference_mode():
    y_pred = model_7(X_train)

# Calculate total (mean) train error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_train))
print(f'Total (mean) train loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on train set
y_pred_naive = naive_forecast(X_train, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_train))
print(f'Total (mean) train loss of naive forecast: {naive_loss.item()}')

### VAL SET ###
# Get model's predictions of NN
model_7.eval()
with torch.inference_mode():
    y_pred = model_7(X_val)

# Calculate total (mean) val error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_val))
print(f'Total (mean) val loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on val set
y_pred_naive = naive_forecast(X_val, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_val))
print(f'Total (mean) val loss of naive forecast: {naive_loss.item()}')

### TEST SET ###
# Get model's predictions of NN
model_7.eval()
with torch.inference_mode():
    y_pred = model_7(X_test)

# Calculate total (mean) test error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_test))
print(f'Total (mean) test loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on test set
y_pred_naive = naive_forecast(X_test, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_test))
print(f'Total (mean) test loss of naive forecast: {naive_loss.item()}')

Total (mean) train loss of NN: 10.516054153442383
Total (mean) train loss of naive forecast: 12.17895221710205
Total (mean) val loss of NN: 11.543726921081543
Total (mean) val loss of naive forecast: 11.918389320373535
Total (mean) test loss of NN: 12.25177001953125
Total (mean) test loss of naive forecast: 12.19200611114502


## Experiment 8 - same as experiment 7 but with bigger RNN cell

* file: 1_single
* window size = 10, horizon = 10, max bird speed = 10 m/s, batch size = 64
* Single RNN cell with hidden dim = 128 + Linear Layer
* Time not included in prediction

In [42]:
# Instantiate model, writer, optimizer and loss function
# Model parameters
INPUT_DIM = 3
HIDDEN_DIM = 128
OUTPUT_DIM = 3
NUM_LAYERS = 1

# Instantiate model, optimizer and loss function
model_8 = RNNModel(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM, NUM_LAYERS, horizon = 10)
optimizer = torch.optim.Adam(params = model_8.parameters())
loss_fn = torch.nn.functional.mse_loss
writer = SummaryWriter()

# Display model's summary
summary(model_8)

Layer (type:depth-idx)                   Param #
RNNModel                                 --
├─RNN: 1-1                               17,024
├─Linear: 1-2                            387
Total params: 17,411
Trainable params: 17,411
Non-trainable params: 0

In [43]:
# Train model
BATCH_SIZE = 64
NUM_EPOCHS = 100000
train(model_8, X_train, y_train, BATCH_SIZE, NUM_EPOCHS, optimizer, loss_fn, X_val, y_val, writer)
# Close writer
writer.close()
# Save model
torch.save(model_8, 'model_8.pt')

  0%|                                                                             | 18/100000 [00:00<09:40, 172.15it/s]

Training loss at epoch 0: 122.07501983642578 and val loss: 129.0240020751953


  1%|▊                                                                          | 1037/100000 [00:05<08:06, 203.43it/s]

Training loss at epoch 1000: 77.44645690917969 and val loss: 77.30046844482422


  2%|█▌                                                                          | 2013/100000 [00:16<23:41, 68.95it/s]

Training loss at epoch 2000: 62.069190979003906 and val loss: 50.833885192871094


  3%|██▎                                                                         | 3013/100000 [00:31<21:47, 74.16it/s]

Training loss at epoch 3000: 30.89476776123047 and val loss: 37.030792236328125


  4%|███                                                                         | 4013/100000 [00:42<20:00, 79.99it/s]

Training loss at epoch 4000: 20.96791648864746 and val loss: 28.136634826660156


  5%|███▊                                                                        | 5013/100000 [00:54<18:51, 83.95it/s]

Training loss at epoch 5000: 15.365533828735352 and val loss: 22.12713623046875


  6%|████▌                                                                       | 6006/100000 [01:07<31:46, 49.31it/s]

Training loss at epoch 6000: 13.731634140014648 and val loss: 17.809921264648438


  7%|█████▎                                                                      | 7016/100000 [01:17<18:02, 85.87it/s]

Training loss at epoch 7000: 11.953675270080566 and val loss: 15.450582504272461


  8%|██████                                                                      | 8017/100000 [01:28<16:11, 94.66it/s]

Training loss at epoch 8000: 11.75827693939209 and val loss: 13.469171524047852


  9%|██████▊                                                                    | 9014/100000 [01:38<14:51, 102.03it/s]

Training loss at epoch 9000: 9.658620834350586 and val loss: 13.822108268737793


 10%|███████▌                                                                   | 10019/100000 [01:49<16:07, 92.98it/s]

Training loss at epoch 10000: 11.191548347473145 and val loss: 11.516709327697754


 11%|████████▎                                                                  | 11016/100000 [02:00<15:45, 94.08it/s]

Training loss at epoch 11000: 10.512160301208496 and val loss: 11.029131889343262


 12%|█████████                                                                  | 12008/100000 [02:11<16:50, 87.09it/s]

Training loss at epoch 12000: 7.7105889320373535 and val loss: 10.6922025680542


 13%|█████████▊                                                                 | 13009/100000 [02:23<16:06, 89.97it/s]

Training loss at epoch 13000: 8.738866806030273 and val loss: 10.123618125915527


 14%|██████████▎                                                               | 14013/100000 [02:33<13:51, 103.47it/s]

Training loss at epoch 14000: 8.892587661743164 and val loss: 9.24692153930664


 15%|███████████▎                                                               | 15016/100000 [02:44<15:15, 92.81it/s]

Training loss at epoch 15000: 7.301180362701416 and val loss: 9.052444458007812


 16%|████████████                                                               | 16012/100000 [02:55<15:07, 92.53it/s]

Training loss at epoch 16000: 7.195394039154053 and val loss: 9.246868133544922


 17%|████████████▊                                                              | 17020/100000 [03:05<14:28, 95.57it/s]

Training loss at epoch 17000: 7.0355658531188965 and val loss: 8.331521034240723


 18%|█████████████▎                                                            | 18014/100000 [03:16<12:55, 105.67it/s]

Training loss at epoch 18000: 6.17590856552124 and val loss: 8.22901439666748


 19%|██████████████▎                                                            | 19018/100000 [03:26<14:43, 91.62it/s]

Training loss at epoch 19000: 7.175817966461182 and val loss: 8.203858375549316


 20%|███████████████                                                            | 20019/100000 [03:36<14:27, 92.19it/s]

Training loss at epoch 20000: 6.863975524902344 and val loss: 8.72012710571289


 21%|███████████████▊                                                           | 21012/100000 [03:47<14:17, 92.16it/s]

Training loss at epoch 21000: 6.420044422149658 and val loss: 8.139350891113281


 22%|████████████████▌                                                          | 22017/100000 [03:58<14:47, 87.89it/s]

Training loss at epoch 22000: 7.605508804321289 and val loss: 8.13785171508789


 23%|█████████████████▎                                                         | 23017/100000 [04:09<14:29, 88.57it/s]

Training loss at epoch 23000: 7.148571491241455 and val loss: 7.824124336242676


 24%|██████████████████                                                         | 24014/100000 [04:19<14:09, 89.43it/s]

Training loss at epoch 24000: 7.663212776184082 and val loss: 8.710278511047363


 25%|██████████████████▊                                                        | 25012/100000 [04:30<13:31, 92.44it/s]

Training loss at epoch 25000: 6.5460309982299805 and val loss: 7.876307010650635


 26%|███████████████████▌                                                       | 26014/100000 [04:41<13:16, 92.94it/s]

Training loss at epoch 26000: 6.086740493774414 and val loss: 7.490286350250244


 27%|████████████████████▎                                                      | 27011/100000 [04:52<13:23, 90.89it/s]

Training loss at epoch 27000: 6.8450822830200195 and val loss: 7.711605072021484


 28%|█████████████████████                                                      | 28014/100000 [05:04<14:59, 80.03it/s]

Training loss at epoch 28000: 5.736630439758301 and val loss: 7.439798355102539


 29%|█████████████████████▊                                                     | 29015/100000 [05:15<13:28, 87.81it/s]

Training loss at epoch 29000: 6.452190399169922 and val loss: 7.864964485168457


 30%|██████████████████████▌                                                    | 30016/100000 [05:26<13:23, 87.06it/s]

Training loss at epoch 30000: 6.328294277191162 and val loss: 7.360725402832031


 31%|███████████████████████▎                                                   | 31018/100000 [05:38<12:54, 89.06it/s]

Training loss at epoch 31000: 6.456418991088867 and val loss: 7.768488883972168


 32%|████████████████████████                                                   | 32018/100000 [05:48<12:27, 91.00it/s]

Training loss at epoch 32000: 5.220505714416504 and val loss: 7.2857842445373535


 33%|████████████████████████▍                                                 | 33007/100000 [05:58<11:02, 101.08it/s]

Training loss at epoch 33000: 8.99113655090332 and val loss: 9.531622886657715


 34%|█████████████████████████▌                                                 | 34011/100000 [06:09<11:53, 92.43it/s]

Training loss at epoch 34000: 5.998592376708984 and val loss: 7.420225620269775


 35%|██████████████████████████▎                                                | 35014/100000 [06:20<12:08, 89.15it/s]

Training loss at epoch 35000: 5.799821376800537 and val loss: 7.744466781616211


 36%|███████████████████████████                                                | 36013/100000 [06:31<12:08, 87.80it/s]

Training loss at epoch 36000: 5.845208168029785 and val loss: 7.202062129974365


 37%|███████████████████████████▊                                               | 37012/100000 [06:43<12:03, 87.10it/s]

Training loss at epoch 37000: 6.404508113861084 and val loss: 6.94272518157959


 38%|████████████████████████████▌                                              | 38014/100000 [06:53<11:08, 92.78it/s]

Training loss at epoch 38000: 5.262579441070557 and val loss: 7.380575180053711


 39%|█████████████████████████████▎                                             | 39017/100000 [07:04<12:00, 84.66it/s]

Training loss at epoch 39000: 7.690083026885986 and val loss: 10.171442985534668


 40%|██████████████████████████████                                             | 40014/100000 [07:15<10:58, 91.11it/s]

Training loss at epoch 40000: 8.463396072387695 and val loss: 9.71226692199707


 41%|██████████████████████████████▊                                            | 41012/100000 [07:26<10:40, 92.17it/s]

Training loss at epoch 41000: 7.647519588470459 and val loss: 7.998712539672852


 42%|███████████████████████████████▌                                           | 42011/100000 [07:37<10:14, 94.37it/s]

Training loss at epoch 42000: 5.912632465362549 and val loss: 7.64201545715332


 43%|████████████████████████████████▎                                          | 43017/100000 [07:48<10:44, 88.48it/s]

Training loss at epoch 43000: 6.704975605010986 and val loss: 8.303386688232422


 44%|█████████████████████████████████                                          | 44011/100000 [07:58<10:19, 90.39it/s]

Training loss at epoch 44000: 6.107734680175781 and val loss: 7.357726097106934


 45%|█████████████████████████████████▊                                         | 45010/100000 [08:09<10:06, 90.71it/s]

Training loss at epoch 45000: 6.22470760345459 and val loss: 7.145258903503418


 46%|██████████████████████████████████▌                                        | 46017/100000 [08:21<10:03, 89.45it/s]

Training loss at epoch 46000: 5.9441351890563965 and val loss: 8.030867576599121


 47%|███████████████████████████████████▎                                       | 47015/100000 [08:32<09:55, 89.03it/s]

Training loss at epoch 47000: 6.078472137451172 and val loss: 7.216975688934326


 48%|████████████████████████████████████                                       | 48010/100000 [08:42<09:38, 89.81it/s]

Training loss at epoch 48000: 6.034104824066162 and val loss: 7.097808361053467


 49%|████████████████████████████████████▊                                      | 49016/100000 [08:53<09:43, 87.43it/s]

Training loss at epoch 49000: 5.550415515899658 and val loss: 7.112685680389404


 50%|█████████████████████████████████████▌                                     | 50017/100000 [09:03<09:06, 91.46it/s]

Training loss at epoch 50000: 5.344174385070801 and val loss: 7.192797660827637


 51%|██████████████████████████████████████▎                                    | 51019/100000 [09:14<08:55, 91.45it/s]

Training loss at epoch 51000: 8.352010726928711 and val loss: 9.172137260437012


 52%|███████████████████████████████████████                                    | 52011/100000 [09:25<08:45, 91.40it/s]

Training loss at epoch 52000: 4.919055938720703 and val loss: 6.99904203414917


 53%|███████████████████████████████████████▊                                   | 53012/100000 [09:34<08:19, 94.13it/s]

Training loss at epoch 53000: 5.527408599853516 and val loss: 6.673937797546387


 54%|████████████████████████████████████████▌                                  | 54014/100000 [09:45<08:06, 94.57it/s]

Training loss at epoch 54000: 5.0729851722717285 and val loss: 6.956459999084473


 55%|█████████████████████████████████████████▎                                 | 55009/100000 [09:56<09:57, 75.26it/s]

Training loss at epoch 55000: 8.531256675720215 and val loss: 7.515799045562744


 56%|██████████████████████████████████████████                                 | 56014/100000 [10:07<07:44, 94.77it/s]

Training loss at epoch 56000: 7.566579341888428 and val loss: 7.110348224639893


 57%|██████████████████████████████████████████▊                                | 57005/100000 [10:18<07:47, 91.91it/s]

Training loss at epoch 57000: 6.862900257110596 and val loss: 7.741878986358643


 58%|███████████████████████████████████████████▌                               | 58014/100000 [10:28<07:43, 90.65it/s]

Training loss at epoch 58000: 5.803601264953613 and val loss: 6.941990375518799


 59%|████████████████████████████████████████████▎                              | 59014/100000 [10:38<07:23, 92.52it/s]

Training loss at epoch 59000: 5.606768608093262 and val loss: 6.4819464683532715


 60%|█████████████████████████████████████████████                              | 60013/100000 [10:48<07:08, 93.32it/s]

Training loss at epoch 60000: 5.25017786026001 and val loss: 6.686904430389404


 61%|█████████████████████████████████████████████▊                             | 61021/100000 [10:59<06:47, 95.56it/s]

Training loss at epoch 61000: 5.426523685455322 and val loss: 7.001711368560791


 62%|██████████████████████████████████████████████▌                            | 62017/100000 [11:10<07:42, 82.18it/s]

Training loss at epoch 62000: 5.965109348297119 and val loss: 6.795691013336182


 63%|███████████████████████████████████████████████▎                           | 63015/100000 [11:20<06:24, 96.19it/s]

Training loss at epoch 63000: 6.2540435791015625 and val loss: 7.6050801277160645


 64%|████████████████████████████████████████████████                           | 64015/100000 [11:31<06:28, 92.62it/s]

Training loss at epoch 64000: 5.164647102355957 and val loss: 6.513245582580566


 65%|████████████████████████████████████████████████▊                          | 65015/100000 [11:42<06:16, 92.96it/s]

Training loss at epoch 65000: 5.457882404327393 and val loss: 6.955188751220703


 66%|█████████████████████████████████████████████████▌                         | 66012/100000 [11:53<06:07, 92.48it/s]

Training loss at epoch 66000: 4.95746374130249 and val loss: 6.579076290130615


 67%|█████████████████████████████████████████████████▌                        | 67011/100000 [12:03<04:41, 117.04it/s]

Training loss at epoch 67000: 5.225900650024414 and val loss: 6.611635208129883


 68%|███████████████████████████████████████████████████                        | 68010/100000 [12:13<05:39, 94.31it/s]

Training loss at epoch 68000: 5.989940643310547 and val loss: 7.457670211791992


 69%|███████████████████████████████████████████████████▊                       | 69018/100000 [12:24<05:27, 94.65it/s]

Training loss at epoch 69000: 7.401248455047607 and val loss: 7.871275424957275


 70%|████████████████████████████████████████████████████▌                      | 70013/100000 [12:35<05:18, 94.25it/s]

Training loss at epoch 70000: 5.299945831298828 and val loss: 6.4073686599731445


 71%|█████████████████████████████████████████████████████▎                     | 71019/100000 [12:46<05:09, 93.76it/s]

Training loss at epoch 71000: 4.960818290710449 and val loss: 6.46466064453125


 72%|█████████████████████████████████████████████████████▎                    | 72018/100000 [12:56<04:19, 107.86it/s]

Training loss at epoch 72000: 5.048752307891846 and val loss: 6.365825653076172


 73%|██████████████████████████████████████████████████████▊                    | 73012/100000 [13:07<04:38, 96.95it/s]

Training loss at epoch 73000: 5.579172134399414 and val loss: 6.626609802246094


 74%|███████████████████████████████████████████████████████▌                   | 74012/100000 [13:18<04:38, 93.31it/s]

Training loss at epoch 74000: 5.378267765045166 and val loss: 6.7051849365234375


 75%|████████████████████████████████████████████████████████▎                  | 75018/100000 [13:29<04:20, 95.88it/s]

Training loss at epoch 75000: 4.36822509765625 and val loss: 6.338540554046631


 76%|█████████████████████████████████████████████████████████                  | 76013/100000 [13:40<04:17, 93.05it/s]

Training loss at epoch 76000: 5.156713485717773 and val loss: 6.7064690589904785


 77%|█████████████████████████████████████████████████████████▊                 | 77011/100000 [13:49<03:57, 96.70it/s]

Training loss at epoch 77000: 5.002598285675049 and val loss: 6.419958591461182


 78%|██████████████████████████████████████████████████████████▌                | 78017/100000 [14:00<03:50, 95.41it/s]

Training loss at epoch 78000: 4.701401710510254 and val loss: 6.431738376617432


 79%|███████████████████████████████████████████████████████████▎               | 79018/100000 [14:11<03:47, 92.15it/s]

Training loss at epoch 79000: 5.828215599060059 and val loss: 7.708326816558838


 80%|████████████████████████████████████████████████████████████               | 80017/100000 [14:22<03:40, 90.56it/s]

Training loss at epoch 80000: 5.310457229614258 and val loss: 7.402798652648926


 81%|████████████████████████████████████████████████████████████▊              | 81017/100000 [14:33<03:22, 93.57it/s]

Training loss at epoch 81000: 5.504295349121094 and val loss: 6.959420204162598


 82%|█████████████████████████████████████████████████████████████▌             | 82010/100000 [14:42<03:12, 93.30it/s]

Training loss at epoch 82000: 6.470843315124512 and val loss: 7.2418341636657715


 83%|██████████████████████████████████████████████████████████████▎            | 83012/100000 [14:53<03:01, 93.46it/s]

Training loss at epoch 83000: 4.8422322273254395 and val loss: 6.678516864776611


 84%|███████████████████████████████████████████████████████████████            | 84019/100000 [15:04<02:48, 94.57it/s]

Training loss at epoch 84000: 5.0853495597839355 and val loss: 6.745916843414307


 85%|███████████████████████████████████████████████████████████████▊           | 85018/100000 [15:15<02:39, 93.96it/s]

Training loss at epoch 85000: 4.595508098602295 and val loss: 6.859393119812012


 86%|███████████████████████████████████████████████████████████████▋          | 86016/100000 [15:26<02:06, 110.49it/s]

Training loss at epoch 86000: 5.994593620300293 and val loss: 7.027231693267822


 87%|█████████████████████████████████████████████████████████████████▎         | 87011/100000 [15:36<02:13, 97.33it/s]

Training loss at epoch 87000: 5.101283550262451 and val loss: 6.954843044281006


 88%|██████████████████████████████████████████████████████████████████         | 88014/100000 [15:47<02:11, 91.20it/s]

Training loss at epoch 88000: 5.593995571136475 and val loss: 6.779867172241211


 89%|█████████████████████████████████████████████████████████████████▊        | 89015/100000 [15:58<01:45, 104.13it/s]

Training loss at epoch 89000: 12.603652000427246 and val loss: 11.758807182312012


 90%|███████████████████████████████████████████████████████████████████▌       | 90016/100000 [16:07<01:45, 94.92it/s]

Training loss at epoch 90000: 5.445714473724365 and val loss: 6.8293256759643555


 91%|███████████████████████████████████████████████████████████████████▎      | 91020/100000 [16:18<01:14, 119.87it/s]

Training loss at epoch 91000: 5.483041763305664 and val loss: 6.983037948608398


 92%|█████████████████████████████████████████████████████████████████████      | 92016/100000 [16:28<01:27, 91.33it/s]

Training loss at epoch 92000: 5.680384159088135 and val loss: 6.8557448387146


 93%|█████████████████████████████████████████████████████████████████████▊     | 93021/100000 [16:39<01:12, 96.23it/s]

Training loss at epoch 93000: 8.98163890838623 and val loss: 9.489646911621094


 94%|██████████████████████████████████████████████████████████████████████▌    | 94010/100000 [16:49<01:04, 93.34it/s]

Training loss at epoch 94000: 4.506268501281738 and val loss: 6.698818683624268


 95%|███████████████████████████████████████████████████████████████████████▎   | 95015/100000 [17:01<01:01, 80.77it/s]

Training loss at epoch 95000: 4.643514633178711 and val loss: 6.657453536987305


 96%|███████████████████████████████████████████████████████████████████████   | 96021/100000 [17:10<00:35, 111.73it/s]

Training loss at epoch 96000: 5.057742118835449 and val loss: 7.361090183258057


 97%|████████████████████████████████████████████████████████████████████████▊  | 97013/100000 [17:21<00:31, 94.29it/s]

Training loss at epoch 97000: 5.906168460845947 and val loss: 7.006463050842285


 98%|█████████████████████████████████████████████████████████████████████████▌ | 98017/100000 [17:32<00:21, 92.55it/s]

Training loss at epoch 98000: 5.187562465667725 and val loss: 7.190399646759033


 99%|██████████████████████████████████████████████████████████████████████████▎| 99014/100000 [17:44<00:10, 91.46it/s]

Training loss at epoch 99000: 4.692294597625732 and val loss: 6.914438247680664


100%|██████████████████████████████████████████████████████████████████████████| 100000/100000 [17:54<00:00, 93.05it/s]


In [44]:
### TRAIN SET ###
# Get model's predictions of NN
model_8.eval()
with torch.inference_mode():
    y_pred = model_8(X_train)

# Calculate total (mean) train error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_train))
print(f'Total (mean) train loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on train set
y_pred_naive = naive_forecast(X_train, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_train))
print(f'Total (mean) train loss of naive forecast: {naive_loss.item()}')

### VAL SET ###
# Get model's predictions of NN
model_8.eval()
with torch.inference_mode():
    y_pred = model_8(X_val)

# Calculate total (mean) val error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_val))
print(f'Total (mean) val loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on val set
y_pred_naive = naive_forecast(X_val, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_val))
print(f'Total (mean) val loss of naive forecast: {naive_loss.item()}')

### TEST SET ###
# Get model's predictions of NN
model_8.eval()
with torch.inference_mode():
    y_pred = model_8(X_test)

# Calculate total (mean) test error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_test))
print(f'Total (mean) test loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on test set
y_pred_naive = naive_forecast(X_test, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_test))
print(f'Total (mean) test loss of naive forecast: {naive_loss.item()}')

Total (mean) train loss of NN: 5.205631732940674
Total (mean) train loss of naive forecast: 12.17895221710205
Total (mean) val loss of NN: 6.711753845214844
Total (mean) val loss of naive forecast: 11.918389320373535
Total (mean) test loss of NN: 8.766222953796387
Total (mean) test loss of naive forecast: 12.19200611114502


## Experiment 9 - same as experiment 8 but with three RNN cells

* file: 1_single
* window size = 10, horizon = 10, max bird speed = 10 m/s, batch size = 64
* Triple RNN cell with hidden dim = 128 + Linear Layer
* Time not included in prediction

In [45]:
# Instantiate model, writer, optimizer and loss function
# Model parameters
INPUT_DIM = 3
HIDDEN_DIM = 128
OUTPUT_DIM = 3
NUM_LAYERS = 3

# Instantiate model, optimizer and loss function
model_9 = RNNModel(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM, NUM_LAYERS, horizon = 10)
optimizer = torch.optim.Adam(params = model_9.parameters())
loss_fn = torch.nn.functional.mse_loss
writer = SummaryWriter()

# Display model's summary
summary(model_9)

Layer (type:depth-idx)                   Param #
RNNModel                                 --
├─RNN: 1-1                               83,072
├─Linear: 1-2                            387
Total params: 83,459
Trainable params: 83,459
Non-trainable params: 0

In [46]:
# Train model
BATCH_SIZE = 64
NUM_EPOCHS = 100000
train(model_9, X_train, y_train, BATCH_SIZE, NUM_EPOCHS, optimizer, loss_fn, X_val, y_val, writer)
# Close writer
writer.close()
# Save model
torch.save(model_9, 'model_9.pt')

  0%|                                                                             | 3/100000 [00:00<1:04:37, 25.79it/s]

Training loss at epoch 0: 124.99119567871094 and val loss: 129.02215576171875


  1%|▊                                                                           | 1011/100000 [00:11<18:48, 87.68it/s]

Training loss at epoch 1000: 70.39813995361328 and val loss: 76.48888397216797


  2%|█▌                                                                          | 2009/100000 [00:23<18:53, 86.47it/s]

Training loss at epoch 2000: 54.3271484375 and val loss: 50.96574401855469


  3%|██▎                                                                         | 3015/100000 [00:36<23:02, 70.13it/s]

Training loss at epoch 3000: 30.04841423034668 and val loss: 35.793296813964844


  4%|███                                                                         | 4014/100000 [00:49<19:44, 81.05it/s]

Training loss at epoch 4000: 17.593233108520508 and val loss: 23.83771514892578


  5%|███▊                                                                        | 5017/100000 [01:02<19:52, 79.64it/s]

Training loss at epoch 5000: 12.090405464172363 and val loss: 15.813766479492188


  6%|████▌                                                                       | 6007/100000 [01:19<30:43, 50.98it/s]

Training loss at epoch 6000: 7.620858669281006 and val loss: 11.21883773803711


  7%|█████▎                                                                      | 7005/100000 [01:42<35:17, 43.92it/s]

Training loss at epoch 7000: 5.219340801239014 and val loss: 8.37697696685791


  8%|██████                                                                      | 8009/100000 [02:05<32:16, 47.51it/s]

Training loss at epoch 8000: 3.7919113636016846 and val loss: 7.069969177246094


  9%|██████▊                                                                     | 9007/100000 [02:28<33:43, 44.98it/s]

Training loss at epoch 9000: 4.313110828399658 and val loss: 6.898229122161865


 10%|███████▌                                                                   | 10006/100000 [02:50<43:44, 34.29it/s]

Training loss at epoch 10000: 3.737018346786499 and val loss: 6.230563163757324


 11%|████████▎                                                                  | 11005/100000 [03:12<31:46, 46.67it/s]

Training loss at epoch 11000: 4.151206970214844 and val loss: 6.453159332275391


 12%|█████████                                                                  | 12009/100000 [03:35<30:21, 48.30it/s]

Training loss at epoch 12000: 4.15662145614624 and val loss: 5.624088287353516


 13%|█████████▊                                                                 | 13006/100000 [03:57<31:42, 45.73it/s]

Training loss at epoch 13000: 3.609888792037964 and val loss: 6.342059135437012


 14%|██████████▌                                                                | 14009/100000 [04:19<28:53, 49.62it/s]

Training loss at epoch 14000: 2.840486526489258 and val loss: 5.463471412658691


 15%|███████████▎                                                               | 15005/100000 [04:41<30:39, 46.19it/s]

Training loss at epoch 15000: 2.39438533782959 and val loss: 5.200367450714111


 16%|████████████                                                               | 16010/100000 [05:03<31:37, 44.27it/s]

Training loss at epoch 16000: 2.978919744491577 and val loss: 5.487064361572266


 17%|████████████▊                                                              | 17008/100000 [05:25<29:14, 47.30it/s]

Training loss at epoch 17000: 3.1917741298675537 and val loss: 6.079638481140137


 18%|█████████████▌                                                             | 18007/100000 [05:47<30:16, 45.13it/s]

Training loss at epoch 18000: 2.5405213832855225 and val loss: 5.185304641723633


 19%|██████████████▎                                                            | 19009/100000 [06:10<31:44, 42.52it/s]

Training loss at epoch 19000: 3.7677319049835205 and val loss: 5.297518730163574


 20%|███████████████                                                            | 20006/100000 [06:33<35:57, 37.07it/s]

Training loss at epoch 20000: 2.300691604614258 and val loss: 4.942518711090088


 21%|███████████████▊                                                           | 21006/100000 [06:55<27:53, 47.19it/s]

Training loss at epoch 21000: 2.84053635597229 and val loss: 5.072554111480713


 22%|████████████████▌                                                          | 22009/100000 [07:17<26:24, 49.22it/s]

Training loss at epoch 22000: 2.38649845123291 and val loss: 5.253941059112549


 23%|█████████████████▎                                                         | 23009/100000 [07:39<31:29, 40.74it/s]

Training loss at epoch 23000: 2.3748884201049805 and val loss: 5.149511337280273


 24%|██████████████████                                                         | 24009/100000 [08:01<27:07, 46.69it/s]

Training loss at epoch 24000: 2.0780317783355713 and val loss: 4.890212535858154


 25%|██████████████████▊                                                        | 25005/100000 [08:23<25:45, 48.51it/s]

Training loss at epoch 25000: 2.3785507678985596 and val loss: 4.881599426269531


 26%|███████████████████▌                                                       | 26006/100000 [08:45<26:37, 46.32it/s]

Training loss at epoch 26000: 2.3309006690979004 and val loss: 4.8289713859558105


 27%|████████████████████▎                                                      | 27005/100000 [09:08<27:27, 44.32it/s]

Training loss at epoch 27000: 1.6123504638671875 and val loss: 4.782584190368652


 28%|█████████████████████                                                      | 28011/100000 [09:30<24:18, 49.37it/s]

Training loss at epoch 28000: 2.4152681827545166 and val loss: 4.801499843597412


 29%|█████████████████████▊                                                     | 29007/100000 [09:52<25:31, 46.37it/s]

Training loss at epoch 29000: 2.7441842555999756 and val loss: 5.368369102478027


 30%|██████████████████████▌                                                    | 30011/100000 [10:14<24:32, 47.52it/s]

Training loss at epoch 30000: 1.906887412071228 and val loss: 4.892986297607422


 31%|███████████████████████▎                                                   | 31008/100000 [10:36<24:48, 46.36it/s]

Training loss at epoch 31000: 8.284876823425293 and val loss: 7.1246867179870605


 32%|████████████████████████                                                   | 32010/100000 [10:58<27:57, 40.54it/s]

Training loss at epoch 32000: 2.1222620010375977 and val loss: 4.795051574707031


 33%|████████████████████████▊                                                  | 33005/100000 [11:21<24:20, 45.87it/s]

Training loss at epoch 33000: 3.6015512943267822 and val loss: 5.277484893798828


 34%|█████████████████████████▌                                                 | 34006/100000 [11:43<29:00, 37.93it/s]

Training loss at epoch 34000: 2.9348630905151367 and val loss: 5.17624568939209


 35%|██████████████████████████▎                                                | 35005/100000 [12:06<23:17, 46.49it/s]

Training loss at epoch 35000: 2.7837727069854736 and val loss: 5.002110004425049


 36%|███████████████████████████                                                | 36006/100000 [12:28<22:28, 47.47it/s]

Training loss at epoch 36000: 2.0365304946899414 and val loss: 4.7946062088012695


 37%|███████████████████████████▊                                               | 37005/100000 [12:51<22:30, 46.64it/s]

Training loss at epoch 37000: 1.661686658859253 and val loss: 4.888637542724609


 38%|████████████████████████████▌                                              | 38005/100000 [13:14<22:56, 45.04it/s]

Training loss at epoch 38000: 1.9229917526245117 and val loss: 4.661801338195801


 39%|█████████████████████████████▎                                             | 39007/100000 [13:36<20:51, 48.75it/s]

Training loss at epoch 39000: 2.0981218814849854 and val loss: 5.142847537994385


 40%|██████████████████████████████                                             | 40010/100000 [13:58<24:16, 41.18it/s]

Training loss at epoch 40000: 2.620638370513916 and val loss: 4.718924045562744


 41%|██████████████████████████████▊                                            | 41009/100000 [14:21<21:03, 46.68it/s]

Training loss at epoch 41000: 2.615037679672241 and val loss: 5.64762544631958


 42%|███████████████████████████████▌                                           | 42006/100000 [14:43<21:01, 45.98it/s]

Training loss at epoch 42000: 2.4154562950134277 and val loss: 4.858448028564453


 43%|████████████████████████████████▎                                          | 43009/100000 [15:05<22:43, 41.79it/s]

Training loss at epoch 43000: 2.760479688644409 and val loss: 4.740057945251465


 44%|█████████████████████████████████                                          | 44006/100000 [15:27<20:06, 46.42it/s]

Training loss at epoch 44000: 1.864785075187683 and val loss: 4.770495891571045


 45%|█████████████████████████████████▊                                         | 45006/100000 [15:49<19:37, 46.71it/s]

Training loss at epoch 45000: 2.3587074279785156 and val loss: 5.0957489013671875


 46%|██████████████████████████████████▌                                        | 46007/100000 [16:12<19:25, 46.31it/s]

Training loss at epoch 46000: 2.4498934745788574 and val loss: 5.2354536056518555


 47%|███████████████████████████████████▎                                       | 47008/100000 [16:34<23:09, 38.13it/s]

Training loss at epoch 47000: 2.0926828384399414 and val loss: 4.7922282218933105


 48%|████████████████████████████████████                                       | 48009/100000 [16:56<19:49, 43.70it/s]

Training loss at epoch 48000: 2.0017387866973877 and val loss: 4.899932861328125


 49%|████████████████████████████████████▊                                      | 49008/100000 [17:19<17:46, 47.82it/s]

Training loss at epoch 49000: 2.4686639308929443 and val loss: 4.920495986938477


 50%|█████████████████████████████████████▌                                     | 50006/100000 [17:41<17:50, 46.71it/s]

Training loss at epoch 50000: 2.0027897357940674 and val loss: 5.06218147277832


 51%|██████████████████████████████████████▎                                    | 51007/100000 [18:03<16:45, 48.73it/s]

Training loss at epoch 51000: 1.935333490371704 and val loss: 4.687434196472168


 52%|███████████████████████████████████████                                    | 52008/100000 [18:24<16:46, 47.68it/s]

Training loss at epoch 52000: 1.8282768726348877 and val loss: 5.109348773956299


 53%|███████████████████████████████████████▊                                   | 53010/100000 [18:46<16:09, 48.45it/s]

Training loss at epoch 53000: 1.9413354396820068 and val loss: 4.804687023162842


 54%|████████████████████████████████████████▌                                  | 54010/100000 [19:09<16:32, 46.34it/s]

Training loss at epoch 54000: 1.5088211297988892 and val loss: 4.991514205932617


 55%|█████████████████████████████████████████▎                                 | 55005/100000 [19:30<15:41, 47.78it/s]

Training loss at epoch 55000: 2.1907169818878174 and val loss: 4.877886772155762


 56%|██████████████████████████████████████████                                 | 56009/100000 [19:52<16:10, 45.32it/s]

Training loss at epoch 56000: 2.5207037925720215 and val loss: 4.711607456207275


 57%|██████████████████████████████████████████▊                                | 57007/100000 [20:14<14:49, 48.31it/s]

Training loss at epoch 57000: 2.1142094135284424 and val loss: 4.760734558105469


 58%|███████████████████████████████████████████▌                               | 58010/100000 [20:36<14:19, 48.84it/s]

Training loss at epoch 58000: 3.1606807708740234 and val loss: 4.912818908691406


 59%|████████████████████████████████████████████▎                              | 59008/100000 [20:58<16:57, 40.28it/s]

Training loss at epoch 59000: 1.5369760990142822 and val loss: 4.58259916305542


 60%|█████████████████████████████████████████████                              | 60008/100000 [21:20<14:07, 47.20it/s]

Training loss at epoch 60000: 2.2028849124908447 and val loss: 4.881534576416016


 61%|█████████████████████████████████████████████▊                             | 61009/100000 [21:42<14:25, 45.07it/s]

Training loss at epoch 61000: 1.8403328657150269 and val loss: 5.159483909606934


 62%|██████████████████████████████████████████████▌                            | 62007/100000 [22:03<13:20, 47.48it/s]

Training loss at epoch 62000: 2.104445219039917 and val loss: 4.901021957397461


 63%|███████████████████████████████████████████████▎                           | 63007/100000 [22:25<14:42, 41.91it/s]

Training loss at epoch 63000: 2.4658541679382324 and val loss: 5.092087268829346


 64%|████████████████████████████████████████████████                           | 64007/100000 [22:47<12:36, 47.58it/s]

Training loss at epoch 64000: 2.2027390003204346 and val loss: 5.50402307510376


 65%|████████████████████████████████████████████████▊                          | 65008/100000 [23:09<12:32, 46.52it/s]

Training loss at epoch 65000: 2.031451940536499 and val loss: 5.265924453735352


 66%|█████████████████████████████████████████████████▌                         | 66011/100000 [23:32<12:07, 46.73it/s]

Training loss at epoch 66000: 1.7545177936553955 and val loss: 4.941340923309326


 67%|██████████████████████████████████████████████████▎                        | 67009/100000 [23:53<11:14, 48.91it/s]

Training loss at epoch 67000: 1.2321233749389648 and val loss: 4.873576641082764


 68%|███████████████████████████████████████████████████                        | 68008/100000 [24:15<11:46, 45.29it/s]

Training loss at epoch 68000: 2.1946792602539062 and val loss: 5.17979621887207


 69%|███████████████████████████████████████████████████▊                       | 69008/100000 [24:37<11:13, 46.03it/s]

Training loss at epoch 69000: 1.30522882938385 and val loss: 4.853417873382568


 70%|████████████████████████████████████████████████████▌                      | 70009/100000 [24:58<10:21, 48.28it/s]

Training loss at epoch 70000: 1.8199764490127563 and val loss: 5.025327682495117


 71%|█████████████████████████████████████████████████████▎                     | 71009/100000 [25:20<10:10, 47.47it/s]

Training loss at epoch 71000: 1.7053524255752563 and val loss: 4.928705215454102


 72%|██████████████████████████████████████████████████████                     | 72007/100000 [25:42<11:47, 39.56it/s]

Training loss at epoch 72000: 2.100382089614868 and val loss: 4.826327323913574


 73%|██████████████████████████████████████████████████████▊                    | 73011/100000 [26:08<09:00, 49.96it/s]

Training loss at epoch 73000: 1.7284021377563477 and val loss: 5.023065090179443


 74%|███████████████████████████████████████████████████████▌                   | 74005/100000 [26:29<09:18, 46.57it/s]

Training loss at epoch 74000: 1.6460599899291992 and val loss: 4.951632499694824


 75%|████████████████████████████████████████████████████████▎                  | 75008/100000 [26:50<08:18, 50.16it/s]

Training loss at epoch 75000: 1.8892569541931152 and val loss: 5.273003578186035


 76%|█████████████████████████████████████████████████████████                  | 76008/100000 [27:11<09:55, 40.28it/s]

Training loss at epoch 76000: 1.4651380777359009 and val loss: 4.730506896972656


 77%|█████████████████████████████████████████████████████████▊                 | 77007/100000 [27:33<08:24, 45.56it/s]

Training loss at epoch 77000: 1.5799185037612915 and val loss: 4.677750110626221


 78%|██████████████████████████████████████████████████████████▌                | 78011/100000 [27:55<07:41, 47.63it/s]

Training loss at epoch 78000: 2.8751533031463623 and val loss: 5.11214017868042


 79%|███████████████████████████████████████████████████████████▎               | 79006/100000 [28:15<07:00, 49.88it/s]

Training loss at epoch 79000: 1.8998957872390747 and val loss: 4.832386493682861


 80%|████████████████████████████████████████████████████████████               | 80010/100000 [28:37<06:57, 47.88it/s]

Training loss at epoch 80000: 1.525574803352356 and val loss: 4.891465187072754


 81%|████████████████████████████████████████████████████████████▊              | 81008/100000 [28:59<06:39, 47.51it/s]

Training loss at epoch 81000: 3.6258726119995117 and val loss: 5.264370918273926


 82%|█████████████████████████████████████████████████████████████▌             | 82006/100000 [29:20<06:16, 47.78it/s]

Training loss at epoch 82000: 1.8791024684906006 and val loss: 4.8193206787109375


 83%|██████████████████████████████████████████████████████████████▎            | 83006/100000 [29:41<06:44, 41.98it/s]

Training loss at epoch 83000: 1.7131167650222778 and val loss: 4.914595603942871


 84%|███████████████████████████████████████████████████████████████            | 84011/100000 [30:03<05:28, 48.66it/s]

Training loss at epoch 84000: 0.9977328777313232 and val loss: 4.922224998474121


 85%|███████████████████████████████████████████████████████████████▊           | 85007/100000 [30:25<05:14, 47.67it/s]

Training loss at epoch 85000: 1.5235785245895386 and val loss: 4.941267967224121


 86%|████████████████████████████████████████████████████████████████▌          | 86006/100000 [30:46<04:52, 47.83it/s]

Training loss at epoch 86000: 1.7214980125427246 and val loss: 5.0701904296875


 87%|█████████████████████████████████████████████████████████████████▎         | 87006/100000 [31:08<04:18, 50.23it/s]

Training loss at epoch 87000: 1.499436616897583 and val loss: 4.793976306915283


 88%|██████████████████████████████████████████████████████████████████         | 88010/100000 [31:29<04:14, 47.20it/s]

Training loss at epoch 88000: 3.33591890335083 and val loss: 4.869194030761719


 89%|██████████████████████████████████████████████████████████████████▊        | 89010/100000 [31:50<03:52, 47.29it/s]

Training loss at epoch 89000: 1.3192660808563232 and val loss: 4.871693134307861


 90%|███████████████████████████████████████████████████████████████████▌       | 90006/100000 [32:11<03:20, 49.85it/s]

Training loss at epoch 90000: 1.4693158864974976 and val loss: 4.868380069732666


 91%|████████████████████████████████████████████████████████████████████▎      | 91007/100000 [32:33<02:54, 51.49it/s]

Training loss at epoch 91000: 1.8551462888717651 and val loss: 5.027215957641602


 92%|█████████████████████████████████████████████████████████████████████      | 92010/100000 [32:54<02:40, 49.79it/s]

Training loss at epoch 92000: 1.4743775129318237 and val loss: 5.0551910400390625


 93%|█████████████████████████████████████████████████████████████████████▊     | 93010/100000 [33:16<02:32, 45.98it/s]

Training loss at epoch 93000: 1.1210954189300537 and val loss: 4.928496837615967


 94%|██████████████████████████████████████████████████████████████████████▌    | 94008/100000 [33:38<01:56, 51.22it/s]

Training loss at epoch 94000: 1.7032679319381714 and val loss: 4.5123186111450195


 95%|███████████████████████████████████████████████████████████████████████▎   | 95009/100000 [34:00<01:40, 49.80it/s]

Training loss at epoch 95000: 1.416208028793335 and val loss: 4.797597885131836


 96%|████████████████████████████████████████████████████████████████████████   | 96008/100000 [34:21<01:38, 40.42it/s]

Training loss at epoch 96000: 3.1444132328033447 and val loss: 4.996552467346191


 97%|████████████████████████████████████████████████████████████████████████▊  | 97006/100000 [34:43<01:09, 43.04it/s]

Training loss at epoch 97000: 1.7927377223968506 and val loss: 5.073391437530518


 98%|█████████████████████████████████████████████████████████████████████████▌ | 98005/100000 [35:04<00:41, 48.56it/s]

Training loss at epoch 98000: 1.1924233436584473 and val loss: 4.846593379974365


 99%|██████████████████████████████████████████████████████████████████████████▎| 99007/100000 [35:25<00:19, 50.85it/s]

Training loss at epoch 99000: 2.128781795501709 and val loss: 4.978132247924805


100%|██████████████████████████████████████████████████████████████████████████| 100000/100000 [35:46<00:00, 46.58it/s]


In [47]:
### TRAIN SET ###
# Get model's predictions of NN
model_9.eval()
with torch.inference_mode():
    y_pred = model_9(X_train)

# Calculate total (mean) train error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_train))
print(f'Total (mean) train loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on train set
y_pred_naive = naive_forecast(X_train, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_train))
print(f'Total (mean) train loss of naive forecast: {naive_loss.item()}')

### VAL SET ###
# Get model's predictions of NN
model_9.eval()
with torch.inference_mode():
    y_pred = model_9(X_val)

# Calculate total (mean) val error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_val))
print(f'Total (mean) val loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on val set
y_pred_naive = naive_forecast(X_val, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_val))
print(f'Total (mean) val loss of naive forecast: {naive_loss.item()}')

### TEST SET ###
# Get model's predictions of NN
model_9.eval()
with torch.inference_mode():
    y_pred = model_9(X_test)

# Calculate total (mean) test error
loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred, y_test))
print(f'Total (mean) test loss of NN: {loss.item()}')

# Get naive forecast predictions and loss on test set
y_pred_naive = naive_forecast(X_test, horizon = 10)
naive_loss = torch.sqrt(torch.nn.functional.mse_loss(y_pred_naive, y_test))
print(f'Total (mean) test loss of naive forecast: {naive_loss.item()}')

Total (mean) train loss of NN: 1.7667797803878784
Total (mean) train loss of naive forecast: 12.17895221710205
Total (mean) val loss of NN: 4.845053195953369
Total (mean) val loss of naive forecast: 11.918389320373535
Total (mean) test loss of NN: 8.705182075500488
Total (mean) test loss of naive forecast: 12.19200611114502
