The benchmark is tested with Python 3.10+ on Linux/macOS. (For Python 3.10, use scipy==1.15.3 from requirements.txt.)
conda create -n spatialepibench python=3.10 -y
conda activate spatialepibench
python -m pip install --upgrade pip setuptools wheelpip install -r requirements.txtIf you need a CUDA build or a different torch build than the default resolver chooses, install PyTorch first from the official index, then install the rest:
# Example (CUDA 12.1)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip install -r requirements.txtpython -c "import torch; print('torch', torch.__version__, 'cuda?', torch.cuda.is_available())"
python run_retrain.py --dataset JHUcase --model Dlinear --epochs 1 --device cpu- If pip reports resolver conflicts, re-run with a clean environment and upgraded build tools:
conda deactivate 2>/dev/null || true conda remove -n spatialepibench --all -y conda create -n spatialepibench python=3.10 -y conda activate spatialepibench python -m pip install --upgrade pip setuptools wheel pip install -r requirements.txt
- If a package build fails, confirm system compilers are available (for Linux:
build-essential,python3-dev). - If
--device cudafails, check that your NVIDIA driver supports the installed CUDA runtime and verify with:nvidia-smi python -c "import torch; print(torch.cuda.is_available())"
SpatialEpiBench/
├── README.md
├── requirements.txt
├── run_retrain.py
├── models/
│ ├── AGCRN.py
│ ├── DCRNN.py
│ └── ...
└── rawData/
└── processed/
├── JHUcase.csv
├── JHUcase_adj.csv
├── ILI2019.csv
├── ILI2019_adj.csv
└── ...
Each dataset needs two CSV files:
rawData/processed/<dataset>.csv
rawData/processed/<dataset>_adj.csv
Example for --dataset JHUcase:
rawData/processed/JHUcase.csv
rawData/processed/JHUcase_adj.csv
Run with default settings:
python run_retrain.py --dataset JHUcaseBy default, this uses:
model = AGCRN
device = cpu
epochs = 50
lookback = 28
horizon = 7
train_rate = 0.6
val_rate = 0.2
loss = mse
By default, build_splits() scales each node's positive observations by that node's standard deviation (zeros stay zero). This affects both y_true and y_pred in output CSVs. To keep raw counts/values: add --no-scale
python run_retrain.py --dataset JHUcase --model ILI2019 --no-scaleRun on GPU:
python run_retrain.py \
--dataset JHUcase \
--model AGCRN \
--device cudaRun other models (examples):
python run_retrain.py \
--dataset JHUcase \
--model DCRNN \
--device cuda \
--rnn-units 32 \
--num-rnn-layers 2 \
--max-diffusion-step 2 \
--dropout 0.1python run_retrain.py \
--dataset JHUcase \
--model STGCN \
--device cuda \
--nhids 32python run_retrain.py \
--dataset JHUcase \
--model GraphWaveNet \
--device cuda \
--epochs 30 \
--blocks 2 \
--nlayers 4 \
--residual-channels 4 \
--dilation-channels 4These arguments are available for all models.
| Argument | Default | Description |
|---|---|---|
--dataset |
JHUcase |
Dataset name under rawData/processed/. |
--model |
AGCRN |
Model name to run. |
--device |
cpu |
Device, for example cpu, cuda, or cuda:0. |
--epochs |
50 |
Number of training epochs per retraining window. |
--lookback |
28 |
Number of historical time steps used as input. |
--horizon |
7 |
Forecast horizon used when generating windows. |
--train_rate |
0.6 |
Initial train split ratio. |
--val_rate |
0.2 |
Initial validation split ratio. |
--loss |
mse |
Loss function. Choices: mse, mse_filtered. |
--retrain-every |
90 |
Number of target time steps predicted before retraining again. |
--retrain-train-length |
180 |
Number of previous time steps used for each retraining window. |
--no-scale |
off | Disable per-node scaling by std of positive observations (default is scaled with zeros preserved). |
--use-future-ti |
off | Use future time-index information if supported by the model. |
--epi-mode |
none |
Epidemiological mode. Choices: none, sir_incidence, sir_percent, ngm. |
--use-einn |
off | Enable EINN alignment. Requires --epi-mode sir_incidence, --epi-mode sir_percent or --epi-mode ngm. |
--plot |
off | Plot predictions for one selected state/location. |
--state2plot |
None |
State/location name to plot. Used with --plot. |
--model-kwargs-json |
None |
Extra model-specific kwargs as a JSON object. |
The runner currently supports:
AGCRN
ARIMA
ColaGNN
DCRNN
Dlinear
EARTH
EpiGNN
GraphWaveNet
GTS
MTGNN
STGCN
STNorm
StemGNN
repeat_last
ARIMA and repeat_last are baseline models.
Only the selected model's hyperparameters are added to the command-line parser.
For example, when you run:
python run_retrain.py --model GraphWaveNet --helpGraphWaveNet-specific arguments will appear.
The runner creates an output folder named:
retrain_<dataset>/
Example:
retrain_JHUcase/
The main prediction CSV is saved as:
retrain_<dataset>/retrain_<dataset>_<tag>.csv
The CSV contains rows with fields such as:
| Column | Meaning |
|---|---|
retrain_id |
Retraining window index. |
timestamp |
Target timestamp. |
state_idx |
Node/region index. |
state |
Node/region name. |
train_start |
First timestamp used in the retraining window. |
train_end |
Last timestamp used in the retraining window. |
eval_start |
First timestamp evaluated after this retrain. |
eval_end |
Last timestamp evaluated after this retrain. |
y_true |
Ground-truth target value. |
y_pred |
Model prediction. |
If --plot and --state2plot are provided, a PNG file is also saved beside the CSV.
To add a new model, edit MODEL_REGISTRY in run_retrain.py.
Example:
MODEL_REGISTRY["NewModel"] = {
"class_path": "models.NewModel:NewModel",
"defaults": {
"hidden_dim": 32,
"dropout": 0.1,
},
}Then you can run:
python run_retrain.py \
--dataset JHUcase \
--model NewModel \
--hidden-dim 64 \
--dropout 0.2Rules:
- The file should be importable from Python.
- The class path format is
module_path:ClassName. - Keys in
defaultsbecome command-line arguments. - Underscores in parameter names become hyphens in the CLI.
Example:
"hidden_dim": 32becomes:
--hidden-dim 32