diff --git a/README.md b/README.md index 7269d80..9eb40d2 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,6 @@ Current repository contains the following models: 2. [MOMENT](https://arxiv.org/abs/2402.03885) 3. [TimesFM](https://arxiv.org/html/2310.10688v2) 4. [Chronos](https://arxiv.org/abs/2403.07815) -5. [MOIRAI](https://arxiv.org/abs/2402.02592) More models will be added soon... @@ -20,6 +19,12 @@ You can add the package to your project by running the following command: pip install git+https://github.com/AdityaLab/Samay.git ``` +For linux users with CUDA installed, you can install the package with GPU support by running: + +```bash +pip install https://github.com/SamayAI/Samay/releases/download/v0.1.0/samay-0.1.0-cp311-cp311-linux_x86_64.whl +``` + **Note:** If the installation fails because rust is missing run: For MacOS: @@ -53,7 +58,9 @@ curl -LsSf https://astral.sh/uv/install.sh | sh uv sync --reinstall ``` -## Usage Example +## Usage Examples + +Check out example notebooks at `examples/` for more detailed examples. We also have google colab notebooks at `examples/colab/`. ### LPTM @@ -61,27 +68,31 @@ uv sync --reinstall ```python from samay.model import LPTMModel -from samay.dataset import LPTMDataset -repo = "lptm" config = { - "context_len": 512, - "horizon_len": 192, - "backend": "gpu", - "per_core_batch_size": 32, - "domain": "electricity", + "task_name": "forecasting", + "forecast_horizon": 192, + "freeze_encoder": True, # Freeze the patch embedding layer + "freeze_embedder": True, # Freeze the transformer encoder + "freeze_head": False, # The linear forecasting head must be trained } - -lptm = LPTMModel(config=config, repo=repo) +model = LPTMModel(config) ``` #### Loading Dataset ```python -train_dataset = LPTMDataset(name="electricity", datetime_col='date', path='data/ETTh1.csv', - mode='train', context_len=config["context_len"], horizon_len=128) -val_dataset = LPTMDataset(name="electricity", datetime_col='date', path='data/ETTh1.csv', - mode='test', context_len=config["context_len"], horizon_len=config["horizon_len"]) +from samay.dataset import LPTMDataset + +train_dataset = LPTMDataset( + name="ett", + datetime_col="date", + path="./data/data/ETTh1.csv", + mode="train", + horizon=192, +) + +finetuned_model = model.finetune(train_dataset) ``` #### Zero-Forecasting @@ -92,8 +103,6 @@ avg_loss, trues, preds, histories = lptm.evaluate(val_dataset) ### TimesFM -Install the package: `pip install git+https://github.com/AdityaLab/Samay.git`. - #### Loading Model ```python @@ -133,7 +142,7 @@ avg_loss, trues, preds, histories = tfm.evaluate(val_dataset) ### Support -Tested on Python 3.12, 3.13 on Linux (CPU + GPU) and MacOS (CPU). Supports NVIDIA GPUs. +Tested on Python 3.11-3.13 on Linux (CPU + GPU) and MacOS (CPU). Supports NVIDIA GPUs. Support for Windows and Apple Silicon GPUs is planned. ## Citation @@ -153,4 +162,4 @@ url={https://openreview.net/forum?id=vMMzjCr5Zj} ## Contact -If you have any feedback or questions, you can contact us via email: . +If you have any feedback or questions, you can contact us via email: , .