Skip to content

Commit

Permalink
Compute feature mean and std by default
Browse files Browse the repository at this point in the history
  • Loading branch information
JonathanCrabbe committed Dec 19, 2023
1 parent 19b8471 commit 0c66073
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 67 deletions.
Empty file removed __init__.py
Empty file.
22 changes: 12 additions & 10 deletions src/fdiff/dataloaders/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing import Any, Optional

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
Expand All @@ -12,8 +13,6 @@
from fdiff.utils.dataclasses import collate_batch
from fdiff.utils.fourier import dft

import numpy as np


class DiffusionDataset(Dataset):
def __init__(
Expand All @@ -24,17 +23,25 @@ def __init__(
standardize: bool = False,
X_ref: Optional[torch.Tensor] = None,
) -> None:
"""Dataset for diffusion models.
Args:
X (torch.Tensor): Time series that are fed to the model.
y (Optional[torch.Tensor], optional): Potential labels. Defaults to None.
fourier_transform (bool, optional): Performs a Fourier transform on the time series. Defaults to False.
standardize (bool, optional): Standardize each feature in the dataset. Defaults to False.
X_ref (Optional[torch.Tensor], optional): Features used to compute the mean and std. Defaults to None.
"""
super().__init__()
if fourier_transform:
X = dft(X).detach()
self.X = X
self.y = y
self.standardize = standardize
self.feature_mean = torch.empty(size=(self.X.size(1), self.X.size(2)))
self.feature_std = torch.empty(size=(self.X.size(1), self.X.size(2)))
if X_ref is None:
X_ref = X
self.compute_feature_statistics(X_ref)
self.feature_mean = X_ref.mean(dim=0)
self.feature_std = X_ref.std(dim=0)

def __len__(self) -> int:
return len(self.X)
Expand All @@ -48,11 +55,6 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
data["y"] = self.y[index]
return data

def compute_feature_statistics(self, X_ref: torch.Tensor) -> None:
"""Compute the mean and standard deviation of the features."""
self.feature_mean = X_ref.mean(dim=0)
self.feature_std = X_ref.std(dim=0)


class Datamodule(pl.LightningDataModule, ABC):
def __init__(
Expand Down
Empty file removed test.ipynb
Empty file.
6 changes: 0 additions & 6 deletions tests/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,10 @@ def setup(self, stage: str = "fit") -> None:
)
self.X_test = torch.randn_like(self.X_train)
self.y_test = torch.randint_like(self.y_train, low=low, high=high)
self.compute_feature_statistics()

def download_data(self) -> None:
...

def compute_feature_statistics(self) -> None:
"""Compute the mean and standard deviation of the features, along the batch dimension."""
self.feature_mean = self.X_train.mean(dim=0)
self.feature_std = self.X_train.std(dim=0)

@property
def dataset_name(self) -> str:
return "dummy"
Expand Down
53 changes: 2 additions & 51 deletions tests/test_schedulers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from copy import deepcopy
from pathlib import Path

import pytorch_lightning as pl
import torch

from fdiff.dataloaders.datamodules import Datamodule
from fdiff.models.score_models import ScoreModule
from fdiff.sampling.sampler import DiffusionSampler
from fdiff.schedulers.vpsde_scheduler import VPScheduler
from fdiff.utils.dataclasses import DiffusableBatch

from .test_datamodules import DummyDatamodule

n_head = 4
d_model = 8
n_channels = 3
Expand Down Expand Up @@ -119,55 +119,6 @@ def test_score_module_with_vpsde():
# Check the shape of the samples
assert samples.shape == (num_samples, max_len, n_channels)


class DummyDatamodule(Datamodule):
def __init__(
self,
data_dir: Path = Path.cwd() / "data",
random_seed: int = 42,
batch_size: int = batch_size,
max_len: int = max_len,
n_channels: int = n_channels,
fourier_transform: bool = False,
standardize: bool = False,
) -> None:
super().__init__(
data_dir=data_dir,
random_seed=random_seed,
batch_size=batch_size,
fourier_transform=fourier_transform,
)
self.max_len = max_len
self.n_channels = n_channels
self.batch_size = batch_size

def setup(self, stage: str = "fit") -> None:
torch.manual_seed(self.random_seed)
self.X_train = torch.randn(
(10 * self.batch_size, self.max_len, self.n_channels),
dtype=torch.float32,
)
self.y_train = torch.randint(
low=low, high=high, size=(10 * self.batch_size,), dtype=torch.long
)
self.X_test = torch.randn_like(self.X_train)
self.y_test = torch.randint_like(self.y_train, low=low, high=high)
self.compute_feature_statistics()

def download_data(self) -> None:
...

def compute_feature_statistics(self) -> None:
"""Compute the mean and standard deviation
of the features, along the batch dimension."""
self.feature_mean = self.X_train.mean(dim=0)
self.feature_std = self.X_train.std(dim=0)

@property
def dataset_name(self) -> str:
return "dummy"


if __name__ == "__main__":
# test_noise_adder()
test_score_module_with_vpsde()

0 comments on commit 0c66073

Please sign in to comment.