Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 28 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ Current repository contains the following models:
2. [MOMENT](https://arxiv.org/abs/2402.03885)
3. [TimesFM](https://arxiv.org/html/2310.10688v2)
4. [Chronos](https://arxiv.org/abs/2403.07815)
5. [MOIRAI](https://arxiv.org/abs/2402.02592)

More models will be added soon...

Expand All @@ -20,6 +19,12 @@ You can add the package to your project by running the following command:
pip install git+https://github.com/AdityaLab/Samay.git
```

For linux users with CUDA installed, you can install the package with GPU support by running:

```bash
pip install https://github.com/SamayAI/Samay/releases/download/v0.1.0/samay-0.1.0-cp311-cp311-linux_x86_64.whl
```

**Note:** If the installation fails because rust is missing run:

For MacOS:
Expand Down Expand Up @@ -53,35 +58,41 @@ curl -LsSf https://astral.sh/uv/install.sh | sh
uv sync --reinstall
```

## Usage Example
## Usage Examples

Check out example notebooks at `examples/` for more detailed examples. We also have google colab notebooks at `examples/colab/`.

### LPTM

#### Loading Model

```python
from samay.model import LPTMModel
from samay.dataset import LPTMDataset

repo = "lptm"
config = {
"context_len": 512,
"horizon_len": 192,
"backend": "gpu",
"per_core_batch_size": 32,
"domain": "electricity",
"task_name": "forecasting",
"forecast_horizon": 192,
"freeze_encoder": True, # Freeze the patch embedding layer
"freeze_embedder": True, # Freeze the transformer encoder
"freeze_head": False, # The linear forecasting head must be trained
}

lptm = LPTMModel(config=config, repo=repo)
model = LPTMModel(config)
```

#### Loading Dataset

```python
train_dataset = LPTMDataset(name="electricity", datetime_col='date', path='data/ETTh1.csv',
mode='train', context_len=config["context_len"], horizon_len=128)
val_dataset = LPTMDataset(name="electricity", datetime_col='date', path='data/ETTh1.csv',
mode='test', context_len=config["context_len"], horizon_len=config["horizon_len"])
from samay.dataset import LPTMDataset

train_dataset = LPTMDataset(
name="ett",
datetime_col="date",
path="./data/data/ETTh1.csv",
mode="train",
horizon=192,
)

finetuned_model = model.finetune(train_dataset)
```

#### Zero-Forecasting
Expand All @@ -92,8 +103,6 @@ avg_loss, trues, preds, histories = lptm.evaluate(val_dataset)

### TimesFM

Install the package: `pip install git+https://github.com/AdityaLab/Samay.git`.

#### Loading Model

```python
Expand Down Expand Up @@ -133,7 +142,7 @@ avg_loss, trues, preds, histories = tfm.evaluate(val_dataset)

### Support

Tested on Python 3.12, 3.13 on Linux (CPU + GPU) and MacOS (CPU). Supports NVIDIA GPUs.
Tested on Python 3.11-3.13 on Linux (CPU + GPU) and MacOS (CPU). Supports NVIDIA GPUs.
Support for Windows and Apple Silicon GPUs is planned.

## Citation
Expand All @@ -153,4 +162,4 @@ url={https://openreview.net/forum?id=vMMzjCr5Zj}

## Contact

If you have any feedback or questions, you can contact us via email: <hkamarthi3@gatech.edu>.
If you have any feedback or questions, you can contact us via email: <hkamarthi3@gatech.edu>, <badityap@cc.gatech.edu>.