Skip to content
This repository was archived by the owner on Aug 28, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ Lightning-Sandbox documentation
:caption: Start here
:glob:

notebooks/*
notebooks/**/*

.. raw:: html
Expand Down
5 changes: 3 additions & 2 deletions flash_tutorials/electricity_forecasting/.meta.yml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
title: Electricity Price Forecasting with N-BEATS
author: Ethan Harris (ethan@pytorchlightning.ai)
created: 2021-11-23
updated: 2021-11-23
updated: 2021-12-16
license: CC BY-SA
build: 3
tags:
- Tabular
- Forecasting
- Timeseries
description: |
This tutorial covers using Lightning Flash and it's integration with PyTorch Forecasting to train an autoregressive
model (N-BEATS) on hourly electricity pricing data. We show how the built-in interpretability tools from PyTorch
Expand All @@ -15,7 +16,7 @@ description: |
bonus, we show hat we can resample daily observations from the data to discover weekly trends instead.
requirements:
- pandas==1.1.5
- lightning-flash[tabular]>=0.5.2
- lightning-flash[tabular]>=0.6.0
accelerator:
- GPU
- CPU
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# %%

import os
from typing import Any, Dict

import flash
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -196,9 +197,15 @@ def preprocess(df: pd.DataFrame, frequency: str = "1H") -> pd.DataFrame:
# %%


def plot_interpretation(model_path: str, predict_df: pd.DataFrame):
def plot_interpretation(model_path: str, predict_df: pd.DataFrame, parameters: Dict[str, Any]):
model = TabularForecaster.load_from_checkpoint(model_path)
predictions = model.predict(predict_df)
datamodule = TabularForecastingData.from_data_frame(
parameters=parameters,
predict_data_frame=predict_df,
batch_size=256,
)
trainer = flash.Trainer(gpus=int(torch.cuda.is_available()))
predictions = trainer.predict(model, datamodule=datamodule)
predictions, inputs = convert_predictions(predictions)
model.pytorch_forecasting_model.plot_interpretation(inputs, predictions, idx=0)
plt.show()
Expand All @@ -208,7 +215,7 @@ def plot_interpretation(model_path: str, predict_df: pd.DataFrame):
# And now we run the function to plot the trend and seasonality curves:

# %%
plot_interpretation(trainer.checkpoint_callback.best_model_path, df_energy_hourly)
plot_interpretation(trainer.checkpoint_callback.best_model_path, df_energy_hourly, datamodule.parameters)

# %% [markdown]
# It worked! The plot shows that the `TabularForecaster` does a reasonable job of modelling the time series and also
Expand Down Expand Up @@ -281,7 +288,7 @@ def plot_interpretation(model_path: str, predict_df: pd.DataFrame):
# Now let's look at what it learned:

# %%
plot_interpretation(trainer.checkpoint_callback.best_model_path, df_energy_daily)
plot_interpretation(trainer.checkpoint_callback.best_model_path, df_energy_daily, datamodule.parameters)

# %% [markdown]
# Success! We can now also see weekly trends / seasonality uncovered by our new model.
Expand Down