Skip to content

Commit

Permalink
Merge pull request #2 from JonathanCrabbe/nico
Browse files Browse the repository at this point in the history
Alternative SDE
  • Loading branch information
JonathanCrabbe committed Dec 27, 2023
2 parents 252e2a0 + 6a53ea5 commit c91794a
Show file tree
Hide file tree
Showing 25 changed files with 2,094 additions and 89 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ conda activate fdiff
```shell
pip install -e .
```

4. If you intend to train models, make sure that wandb is correctly configured on your machine by following [this guide](https://docs.wandb.ai/quickstart).
5. Some of the datasets are automatically downloaded by our scripts via kaggle API. Make sure to create a kaggle token as explained [here](https://towardsdatascience.com/downloading-datasets-from-kaggle-for-your-ml-project-b9120d405ea4).

Expand Down
1 change: 1 addition & 0 deletions cmd/conf/datamodule/ecg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ _target_: fdiff.dataloaders.datamodules.ECGDatamodule
data_dir: ${hydra:runtime.cwd}/data
random_seed: ${random_seed}
fourier_transform: ${fourier_transform}
standardize: ${standardize}
batch_size: 64
10 changes: 10 additions & 0 deletions cmd/conf/datamodule/synthetic.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
_target_: fdiff.dataloaders.datamodules.SyntheticDatamodule
data_dir: ${hydra:runtime.cwd}/data
random_seed: ${random_seed}
fourier_transform: ${fourier_transform}
standardize: ${standardize}
batch_size: 64
max_len: 100
num_samples: 1000


4 changes: 3 additions & 1 deletion cmd/conf/score_model/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ d_model: 72
num_layers: 10
n_head: 12
lr_max: 1.0e-3
fourier_noise_scaling: False
likelihood_weighting: False

defaults:
- noise_scheduler: ddpm
- noise_scheduler: vpsde
2 changes: 2 additions & 0 deletions cmd/conf/score_model/noise_scheduler/customddpm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_target_: fdiff.schedulers.ddpm.CustomDDPMScheduler
num_train_timesteps: 1000
5 changes: 5 additions & 0 deletions cmd/conf/score_model/noise_scheduler/vesde.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_target_: fdiff.schedulers.sde.VEScheduler
eps: 1e-5
sigma_min: 0.01
sigma_max: 2
fourier_noise_scaling: ${score_model.fourier_noise_scaling}
5 changes: 5 additions & 0 deletions cmd/conf/score_model/noise_scheduler/vpsde.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_target_: fdiff.schedulers.sde.VPScheduler
eps: 1e-5
beta_min: 0.1
beta_max: 20
fourier_noise_scaling: ${score_model.fourier_noise_scaling}
2 changes: 2 additions & 0 deletions cmd/conf/train.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
random_seed: 42
fourier_transform: false
standardize: true

defaults:
- _self_
- score_model: default
Expand Down
7 changes: 6 additions & 1 deletion cmd/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def __init__(self, cfg: DictConfig) -> None:
self.fourier_transform: bool = self.datamodule.fourier_transform
self.datamodule.prepare_data()
self.datamodule.setup()

# Get number of steps and samples
self.num_samples: int = cfg.num_samples
self.num_diffusion_steps: int = cfg.num_diffusion_steps
Expand All @@ -67,10 +66,16 @@ def __init__(self, cfg: DictConfig) -> None:

def sample(self) -> None:
# Sample from score model

X = self.sampler.sample(
num_samples=self.num_samples, num_diffusion_steps=self.num_diffusion_steps
)

# Map to the original scale if the input was standardized
if self.datamodule.standardize:
feature_mean, feature_std = self.datamodule.feature_mean_and_std
X = X * feature_std + feature_mean

# If sampling in frequency domain, bring back the sample to time domain
if self.fourier_transform:
X = idft(X)
Expand Down
3 changes: 3 additions & 0 deletions cmd/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def __init__(self, cfg: DictConfig) -> None:
self.score_model = self.score_model(**training_params)

def train(self) -> None:
assert not (
self.score_model.scale_noise and not self.datamodule.fourier_transform
), "You cannot use noise scaling without the Fourier transform."
self.trainer.fit(model=self.score_model, datamodule=self.datamodule)


Expand Down
556 changes: 556 additions & 0 deletions notebooks/viz.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,6 @@ exclude = [
'^file1\.py$', # TOML literal string (single-quotes, no escaping necessary)
]
ignore_missing_imports = true

[tool.isort]
profile = "black"
107 changes: 105 additions & 2 deletions src/fdiff/dataloaders/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing import Any, Optional

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
Expand All @@ -19,19 +20,37 @@ def __init__(
X: torch.Tensor,
y: Optional[torch.Tensor] = None,
fourier_transform: bool = False,
standardize: bool = False,
X_ref: Optional[torch.Tensor] = None,
) -> None:
"""Dataset for diffusion models.
Args:
X (torch.Tensor): Time series that are fed to the model.
y (Optional[torch.Tensor], optional): Potential labels. Defaults to None.
fourier_transform (bool, optional): Performs a Fourier transform on the time series. Defaults to False.
standardize (bool, optional): Standardize each feature in the dataset. Defaults to False.
X_ref (Optional[torch.Tensor], optional): Features used to compute the mean and std. Defaults to None.
"""
super().__init__()
if fourier_transform:
X = dft(X).detach()
self.X = X
self.y = y
self.standardize = standardize
if X_ref is None:
X_ref = X
self.feature_mean = X_ref.mean(dim=0)
self.feature_std = X_ref.std(dim=0)

def __len__(self) -> int:
return len(self.X)

def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
data = {}
data["X"] = self.X[index]
if self.standardize:
data["X"] = (data["X"] - self.feature_mean) / self.feature_std
if self.y is not None:
data["y"] = self.y[index]
return data
Expand All @@ -44,6 +63,7 @@ def __init__(
random_seed: int = 42,
batch_size: int = 32,
fourier_transform: bool = False,
standardize: bool = False,
) -> None:
super().__init__()
# Cast data_dir to Path type
Expand All @@ -53,6 +73,7 @@ def __init__(
self.random_seed = random_seed
self.batch_size = batch_size
self.fourier_transform = fourier_transform
self.standardize = standardize
self.X_train = torch.Tensor()
self.y_train: Optional[torch.Tensor] = None
self.X_test = torch.Tensor()
Expand All @@ -72,7 +93,10 @@ def download_data(self) -> None:

def train_dataloader(self) -> DataLoader:
train_set = DiffusionDataset(
X=self.X_train, y=self.y_train, fourier_transform=self.fourier_transform
X=self.X_train,
y=self.y_train,
fourier_transform=self.fourier_transform,
standardize=self.standardize,
)
return DataLoader(
train_set,
Expand All @@ -94,7 +118,11 @@ def test_dataloader(self) -> DataLoader:

def val_dataloader(self) -> DataLoader:
test_set = DiffusionDataset(
X=self.X_test, y=self.y_test, fourier_transform=self.fourier_transform
X=self.X_test,
y=self.y_test,
fourier_transform=self.fourier_transform,
standardize=self.standardize,
X_ref=self.X_train,
)
return DataLoader(
test_set,
Expand All @@ -115,6 +143,16 @@ def dataset_parameters(self) -> dict[str, Any]:
"num_training_steps": len(self.train_dataloader()),
}

@property
def feature_mean_and_std(self) -> tuple[torch.Tensor, torch.Tensor]:
train_set = DiffusionDataset(
X=self.X_train,
y=self.y_train,
fourier_transform=self.fourier_transform,
standardize=self.standardize,
)
return train_set.feature_mean, train_set.feature_std


class ECGDatamodule(Datamodule):
def __init__(
Expand All @@ -123,12 +161,14 @@ def __init__(
random_seed: int = 42,
batch_size: int = 32,
fourier_transform: bool = False,
standardize: bool = False,
) -> None:
super().__init__(
data_dir=data_dir,
random_seed=random_seed,
batch_size=batch_size,
fourier_transform=fourier_transform,
standardize=standardize,
)

def setup(self, stage: str = "fit") -> None:
Expand Down Expand Up @@ -161,3 +201,66 @@ def download_data(self) -> None:
@property
def dataset_name(self) -> str:
return "ecg"


class SyntheticDatamodule(Datamodule):
def __init__(
self,
data_dir: Path | str = Path.cwd() / "data",
random_seed: int = 42,
batch_size: int = 32,
fourier_transform: bool = False,
standardize: bool = False,
max_len: int = 100,
num_samples: int = 1000,
) -> None:
super().__init__(
data_dir=data_dir,
random_seed=random_seed,
batch_size=batch_size,
fourier_transform=fourier_transform,
standardize=standardize,
)
self.max_len = max_len
self.num_samples = num_samples

def setup(self, stage: str = "fit") -> None:
# Read CSV; extract features and labels
path_train = self.data_dir / "train.csv"
path_test = self.data_dir / "test.csv"

# Read data
df_train = pd.read_csv(path_train, header=None)
X_train = df_train.values

df_test = pd.read_csv(path_test, header=None)
X_test = df_test.values

# Convert to tensor
self.X_train = torch.tensor(X_train, dtype=torch.float32).unsqueeze(
2
) # Add a channel dimension
self.y_train = None
self.X_test = torch.tensor(X_test, dtype=torch.float32).unsqueeze(2)
self.y_test = None

def download_data(self) -> None:
# Generate data, same DGP as in Fourier flows

n_generated = 2 * self.num_samples # For train + test
phase = np.random.normal(size=(n_generated)).reshape(-1, 1)
frequency = np.random.beta(a=2, b=2, size=(n_generated)).reshape(-1, 1)
timesteps = np.arange(self.max_len)
X = np.sin(timesteps * frequency + phase)
X_train = X[: self.num_samples]
X_test = X[self.num_samples :]

# Save data
df_train = pd.DataFrame(X_train)
df_test = pd.DataFrame(X_test)
df_train.to_csv(self.data_dir / "train.csv", index=False, header=False)
df_test.to_csv(self.data_dir / "test.csv", index=False, header=False)

@property
def dataset_name(self) -> str:
return "synthetic"
Loading

0 comments on commit c91794a

Please sign in to comment.