In [None]:
from torch.utils.data import DataLoader
from copy import copy

from src.data import PeakWeatherTorchDataset, test_model, plot_predictions
from src.model import train_model, MLPModel, GRUModel

In [None]:
# Select window and horizon length
window, horizon = 96, 24
# Create datasets
train_dataset = PeakWeatherTorchDataset(
    window=window, horizon=horizon, parameter="temperature"
)
val_dataset = copy(train_dataset)
val_dataset.mode = "val"
test_dataset = copy(train_dataset)
test_dataset.mode = "test"

In [None]:
# Create DataLoaders
batch_size = 8192
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# model = GRUModel(horizon=horizon, hidden_size=8, num_layers=2, dropout=0.0)
model = MLPModel(
    window=window, horizon=horizon, hidden_size=16, num_layers=2, dropout=0.0
)

In [None]:
model = train_model(
    model=model,
    lr=0.001,
    epochs=5,
    train_loader=train_dataloader,
    val_loader=val_dataloader,
)

In [None]:
mae = test_model(model, test_dataloader)

In [None]:
plot_predictions(model, test_dataset, num_samples=5)