Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Time-LLM #908

Merged
merged 22 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
2bbe873
TimeLLM initial commit
marcopeix Feb 19, 2024
0d9a5d9
WIP Add TimeLLM
marcopeix Feb 27, 2024
d187e92
Merge branch 'main' of https://github.com/Nixtla/neuralforecast into …
marcopeix Feb 28, 2024
04a16bc
Add TimeLLM to neuralforecast
marcopeix Mar 1, 2024
3f825da
Add elapsed time of training and fitting
marcopeix Mar 1, 2024
f0cf83c
Merge branch 'main' of https://github.com/Nixtla/neuralforecast into …
marcopeix Mar 1, 2024
d0fcb2c
Add GPT-2 as default LLM, add tests to model, and add dependency to t…
marcopeix Mar 4, 2024
fed8e99
Add transformers to dev requirements
marcopeix Mar 5, 2024
4be3457
Remove TimeLLM test file, as script is in notebook
marcopeix Mar 5, 2024
d4aee03
Remove transformers from core dependency and better handling of defau…
marcopeix Mar 5, 2024
e9450c5
Merge branch 'main' of https://github.com/Nixtla/neuralforecast into …
marcopeix Mar 5, 2024
54d8c24
Merge branch 'main' into feature/time-llm
AzulGarza Mar 5, 2024
86b7447
Remove unused imports
marcopeix Mar 7, 2024
13559ec
Merge branch 'main' of https://github.com/Nixtla/neuralforecast into …
marcopeix Mar 7, 2024
bdf0ae6
Merge branch 'feature/time-llm' of https://github.com/Nixtla/neuralfo…
marcopeix Mar 7, 2024
e6f1a58
Merge branch 'main' into feature/time-llm
marcopeix Mar 7, 2024
ea3ec25
Clean notebook
marcopeix Mar 7, 2024
b1faea4
Fix conflicts
marcopeix Mar 7, 2024
d082d25
Fix conflicts
marcopeix Mar 7, 2024
0610431
Merge branch 'main' of https://github.com/Nixtla/neuralforecast into …
marcopeix Mar 7, 2024
6c1a9e9
Replace env filename with environment-cpu.yml
marcopeix Mar 8, 2024
e81ab9e
Merge branch 'main' into feature/time-llm
jmoralez Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies:
- pytorch-lightning>=2.0.0
- pip
- s3fs
- transformers
- pip:
- nbdev
- black
Expand Down
55 changes: 55 additions & 0 deletions experiments/test_timellm/test_timellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
marcopeix marked this conversation as resolved.
Show resolved Hide resolved

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"

import time
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import matplotlib.pyplot as plt

from neuralforecast import NeuralForecast
from neuralforecast.models import TimeLLM
from neuralforecast.losses.pytorch import MAE
from neuralforecast.tsdataset import TimeSeriesDataset
from neuralforecast.utils import AirPassengers, AirPassengersPanel, AirPassengersStatic, augment_calendar_df

from transformers import GPT2Config, GPT2Model, GPT2Tokenizer

AirPassengersPanel, calendar_cols = augment_calendar_df(df=AirPassengersPanel, freq='M')

Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train
Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test

gpt2_config = GPT2Config.from_pretrained('openai-community/gpt2')
gpt2 = GPT2Model.from_pretrained('openai-community/gpt2',config=gpt2_config)
gpt2_tokenizer = GPT2Tokenizer.from_pretrained('openai-community/gpt2')

prompt_prefix = "The dataset contains data on monthly air passengers. There is a yearly seasonality"

timellm = TimeLLM(h=12,
input_size=36,
llm=gpt2,
llm_config=gpt2_config,
llm_tokenizer=gpt2_tokenizer,
prompt_prefix=prompt_prefix,
max_steps=100,
batch_size=24,
windows_batch_size=24)

nf = NeuralForecast(
models=[timellm],
freq='M'
)

start = time.time()

nf.fit(df=Y_train_df, val_size=12)
forecasts = nf.predict(futr_df=Y_test_df)

end = time.time()

print(f'It took {end-start} seconds to fit and predict with TimeLLM')

print(forecasts)

3 changes: 2 additions & 1 deletion nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
" MLP, NHITS, NBEATS, NBEATSx, DLinear, NLinear,\n",
" TFT, VanillaTransformer,\n",
" Informer, Autoformer, FEDformer,\n",
" StemGNN, PatchTST, TimesNet\n",
" StemGNN, PatchTST, TimesNet, TimeLLM\n",
")"
]
},
Expand Down Expand Up @@ -223,6 +223,7 @@
" 'tft': TFT, 'autotft': TFT,\n",
" 'timesnet': TimesNet, 'autotimesnet': TimesNet,\n",
" 'vanillatransformer': VanillaTransformer, 'autovanillatransformer': VanillaTransformer,\n",
" 'timellm': TimeLLM\n",
"}"
]
},
Expand Down
Binary file added nbs/imgs_models/timellm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading