Skip to content

Commit

Permalink
Add detailed reproducibility instructions.
Browse files Browse the repository at this point in the history
  • Loading branch information
JonathanCrabbe committed Jan 31, 2024
1 parent 0be30a2 commit 8616b7d
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@ In order to train models, you can simply run the following command:
python cmd/train.py
```

By default, this command will train a score model in the time domain with the `ecg` dataset. In order to modify this behaviour, you can use [hydra override syntax](https://hydra.cc/docs/advanced/override_grammar/basic/). The following hyperparameters are useful to modify:
By default, this command will train a score model in the time domain with the `ecg` dataset. In order to modify this behaviour, you can use [hydra override syntax](https://hydra.cc/docs/advanced/override_grammar/basic/). The following hyperparameters can be modified to retrain all the models appearing in the paper:

| Hyperparameter | Description | Values |
|----------------|-------------|---------------|
| datamodule | Name of the dataset to use. | ecg |
| score_model.lr_max | Max learning rate of the score model reached by the cosine scheduling with warmup. | $\mathbb{R^+}$ |
|fourier_transform | Whether or not to train a diffusion model in the frequency domain. | true, false |
| datamodule | Name of the dataset to use. | ecg, mimiciii, nasa, nasdaq, usdroughts|
| datamodule.subdataset | For the NASA dataset only. Selects between the charge and discharge subsets. | charge, discharge |
| datamodule.smoother_width | For the ECG dataset only. Width of the Gaussian kernel smoother applied in the frequency domain. | $\mathbb{R}^+$
| score_model | The backbone to use for the score model. | default, lstm |

At the end of training, your model is stored in the `lightning_logs` directory, in a folder named after the current `run_id`. You can find the `run_id` in the logs of the training and in the [wandb dashboard](https://wandb.ai/) if you have correctly configured wandb.

Expand All @@ -50,7 +53,9 @@ In order to sample from a trained model, you can simply run the following comman
python cmd/sample.py model_id=XYZ
```

where `XYZ` is the `run_id` of the model you want to sample from. At the end of sampling, the samples are stored in the `lightning_logs` directory, in a folder named after the current `run_id`.
where `XYZ` is the `run_id` of the model you want to sample from. At the end of sampling, the samples are stored in the `lightning_logs` directory, in a folder named after the current `run_id`.

One can then reproduce the plots in the paper by including the `run_id` to the `run_list` list appearing in [this notebook](notebooks/results.ipynb) and running all cells.

# 3. Contribute

Expand Down

0 comments on commit 8616b7d

Please sign in to comment.