### Define Parameters

In [1]:
seq_len = 100
pred_len = 200

batch_size = 1
max_epochs = 5

n_layers = 5

### Import libraries and dataset

In [2]:
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

from tqdm import tqdm

from model import MICN

from weather.dataset import weather_dataset

dataset = weather_dataset(seq_len=seq_len, pred_len=pred_len)

train_set, test_set, valid_set = dataset.split()

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True)

Found files: ['mpi_roof_2023b.csv', 'mpi_roof_2023a.csv']
Found a total of 52700 samples.


  dates_df['month'] = dates_df["Date Time"].apply(lambda row:row.month,1)
  dates_df['day'] = dates_df["Date Time"].apply(lambda row:row.day,1)
  dates_df['weekday'] = dates_df["Date Time"].apply(lambda row:row.weekday(),1)
  dates_df['hour'] = dates_df["Date Time"].apply(lambda row:row.hour,1)


In [3]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

model = MICN(seq_len, pred_len, device=device)

In [4]:
for seq_times, sequence, true_times, true in train_loader:
    seq_times, sequence, true_times, true = seq_times.to(device), sequence.to(device), true_times.to(device), true.to(device)
    print(model(sequence, seq_times, true_times).shape)
    break

torch.Size([1, 200, 21])


## Train the model

In [5]:
criterion = nn.MSELoss()
criterion = criterion.to(device)
optim = torch.optim.Adam(model.parameters(), lr=1e-6)

train_losses = []
val_losses = []

patience = 3
epochs_without_improving = 0

best_epoch = -1
best_val_loss = 0

for epoch in range(max_epochs):
    print("Starting epoch", epoch+1)
    model.train()
    losses = []
    for seq_times, sequence, true_times, true in tqdm(train_loader):
        seq_times, sequence, true_times, true = seq_times.to(device), sequence.to(device), true_times.to(device), true.to(device)
        out = model(sequence, seq_times, true_times)
        
        loss = criterion(out, true)
        losses.append(loss.detach().item())

        loss.backward()
        optim.step()
    avg_train_loss = np.mean(losses)
    print("Average train loss:", avg_train_loss)
    train_losses.append(avg_train_loss)
    
    model.eval()
    losses = []
    with torch.no_grad():
        for seq_times, sequence, true_times, true in tqdm(valid_loader):
            seq_times, sequence, true_times, true = seq_times.to(device), sequence.to(device), true_times.to(device), true.to(device)
            out = model(sequence, seq_times, true_times)
            
            loss = criterion(out, true)
            losses.append(loss.detach().item())
    avg_val_loss = np.mean(losses)
    print("Average validation loss:", avg_val_loss)

    if best_epoch == -1 or avg_val_loss < best_val_loss:
        best_epoch = epoch
        best_val_loss = avg_val_loss
        epochs_without_improving = 0

        save_name = "MICN_" + str(n_layers) + "_" + str(seq_len) + "_" + str(pred_len) + "_" + str(epoch+1) + ".pt"
        torch.save(model.state_dict(), save_name)
    else:
        epochs_without_improving += 1
    
    if epochs_without_improving > patience:
        print("Stopping at epoch", epoch+1)
        print("Best validation loss reached at epoch", best_epoch)
        model.load_state_dict(torch.load(save_name))
        epoch = best_epoch
        break

Starting epoch 1


100%|██████████| 41860/41860 [16:09<00:00, 43.17it/s]


Average train loss: 54323.79489847576


100%|██████████| 5120/5120 [00:23<00:00, 215.12it/s]


Average validation loss: 67168.76493148804
Starting epoch 2


100%|██████████| 41860/41860 [18:34<00:00, 37.56it/s]


Average train loss: 42066.68485292306


100%|██████████| 5120/5120 [00:28<00:00, 180.66it/s]


Average validation loss: 57284.89844932556
Starting epoch 3


100%|██████████| 41860/41860 [18:31<00:00, 37.67it/s]


Average train loss: 35233.41366925823


100%|██████████| 5120/5120 [00:25<00:00, 203.41it/s]


Average validation loss: 52545.34154955546
Starting epoch 4


100%|██████████| 41860/41860 [17:40<00:00, 39.46it/s]


Average train loss: 31071.596215788963


100%|██████████| 5120/5120 [00:22<00:00, 228.83it/s]


Average validation loss: 49219.14912185669
Starting epoch 5


100%|██████████| 41860/41860 [17:38<00:00, 39.55it/s]


Average train loss: 27952.07534170006


100%|██████████| 5120/5120 [00:26<00:00, 192.01it/s]


Average validation loss: 45738.79900600433


### Plot losses

In [7]:
train_losses

[171420.0625,
 210978.984375,
 227037.78125,
 180569.5625,
 235697.671875,
 284823.03125,
 170144.578125,
 170994.0,
 176804.9375,
 170747.484375,
 170073.71875,
 180382.140625,
 188756.828125,
 185725.171875,
 170832.3125,
 180478.953125,
 195925.140625,
 215321.9375,
 170670.328125,
 181065.03125,
 241899.921875,
 190004.96875,
 171106.421875,
 253246.78125,
 198145.8125,
 275113.65625,
 234361.796875,
 185111.078125,
 183079.296875,
 200120.40625,
 179170.484375,
 190486.46875,
 180761.265625,
 175456.796875,
 181752.625,
 212060.25,
 181025.640625,
 174842.703125,
 170077.640625,
 191998.90625,
 241314.578125,
 166130.515625,
 181979.84375,
 193614.015625,
 192206.03125,
 175546.6875,
 166088.65625,
 169072.359375,
 174355.875,
 186160.5,
 196808.484375,
 178993.328125,
 173149.125,
 192078.484375,
 169305.65625,
 182523.625,
 229481.390625,
 189528.796875,
 246921.90625,
 279249.6875,
 211669.546875,
 170035.9375,
 171069.953125,
 183142.578125,
 212648.484375,
 243099.0,
 203574.

In [6]:
assert len(train_losses) == len(val_losses)
x = [i+1 for i in range(len(train_losses))]
plt.title("Loss over " + str(len(train_losses)) + " iterations")
plt.plot(x, train_losses, color="red", label="Train loss")
plt.plot(x, val_losses, color="green", label="Validation loss")
plt.legend()

plt.show()

AssertionError: 

## Evaluate on the test set

In [None]:
if 'model' not in locals():
    save_name = "MICN_1_100_200_5.pt" # if model isn't defined change this line with the model you want to evaluate
    assert int(save_name[7:10]) == seq_len and int(save_name[11:14]) == pred_len
    model = MICN(seq_len, pred_len, n_layers=int(save_name[5]))
    model.load_state_dict(torch.load(save_name))
if 'device' not in locals():
    device = torch.device("cpu")
    
criterion = nn.MSELoss()
test_losses = []
model.eval()
with torch.no_grad():
    for seq_times, sequence, true_times, true in tqdm(test_loader):
        seq_times, sequence, true_times, true = seq_times.to(device), sequence.to(device), true_times.to(device), true.to(device)
        out = model(sequence, seq_times, true_times)
            
        loss = criterion(out, true)
        test_losses.append(loss.detach().item())
print("Average test loss:", np.mean(test_losses))
print("Average test loss per time step:", np.mean(test_losses)/pred_len)

100%|██████████| 4820/4820 [00:20<00:00, 230.07it/s]

Average test loss: 31780.643736931892
Average test loss per time step: 158.90321868465946





### Example of comparison between true and predicted values

In [None]:
for seq_times, sequence, true_times, true in valid_loader:
    out = model(sequence, seq_times, true_times).detach()[0]
    print("avg loss per single time step", criterion(out, true)/pred_len)
    print("Predicted True")
    for i in range(pred_len):
        print(out[i][1].numpy(), true[0][i][1].numpy())
    break

avg loss per single time step tensor(138.1893)
Predicted True
16.066483 14.58
15.639151 14.6
15.485911 14.65
16.476917 14.66
16.649399 14.65
17.046452 14.66
15.408783 14.66
15.8881855 14.68
17.05221 14.69
17.14469 14.81
16.090284 14.91
15.047693 14.98
16.276562 15.14
16.29803 15.17
17.005116 15.26
17.332403 15.27
16.652067 15.13
17.117342 15.13
16.780056 15.37
16.942886 15.44
16.598015 15.3
17.657436 15.03
16.726807 14.92
18.56475 14.96
18.490435 15.08
17.463781 15.2
15.393438 15.12
18.313736 15.25
18.125072 15.56
18.049274 15.57
17.641932 15.66
18.334373 15.76
17.96554 15.76
17.736752 15.82
17.23885 16.06
18.580141 16.59
17.97163 16.7
18.760561 16.51
18.156963 16.61
17.892675 16.72
18.27993 16.78
23.367998 16.77
17.000864 16.87
18.23381 17.21
17.823301 17.5
19.327467 17.55
20.297537 17.86
19.359241 17.85
17.087908 17.89
17.829935 18.19
15.714462 18.33
17.708551 18.48
18.569725 18.63
17.852985 18.62
16.474009 18.85
15.714851 19.05
17.456217 19.01
17.535336 19.23
17.45758 19.77
17.82967