diff --git a/__init__.py b/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/fdiff/dataloaders/datamodules.py b/src/fdiff/dataloaders/datamodules.py index 6ee53e0..b197e0a 100644 --- a/src/fdiff/dataloaders/datamodules.py +++ b/src/fdiff/dataloaders/datamodules.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Any, Optional +import numpy as np import pandas as pd import pytorch_lightning as pl import torch @@ -12,8 +13,6 @@ from fdiff.utils.dataclasses import collate_batch from fdiff.utils.fourier import dft -import numpy as np - class DiffusionDataset(Dataset): def __init__( @@ -24,17 +23,25 @@ def __init__( standardize: bool = False, X_ref: Optional[torch.Tensor] = None, ) -> None: + """Dataset for diffusion models. + + Args: + X (torch.Tensor): Time series that are fed to the model. + y (Optional[torch.Tensor], optional): Potential labels. Defaults to None. + fourier_transform (bool, optional): Performs a Fourier transform on the time series. Defaults to False. + standardize (bool, optional): Standardize each feature in the dataset. Defaults to False. + X_ref (Optional[torch.Tensor], optional): Features used to compute the mean and std. Defaults to None. + """ super().__init__() if fourier_transform: X = dft(X).detach() self.X = X self.y = y self.standardize = standardize - self.feature_mean = torch.empty(size=(self.X.size(1), self.X.size(2))) - self.feature_std = torch.empty(size=(self.X.size(1), self.X.size(2))) if X_ref is None: X_ref = X - self.compute_feature_statistics(X_ref) + self.feature_mean = X_ref.mean(dim=0) + self.feature_std = X_ref.std(dim=0) def __len__(self) -> int: return len(self.X) @@ -48,11 +55,6 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]: data["y"] = self.y[index] return data - def compute_feature_statistics(self, X_ref: torch.Tensor) -> None: - """Compute the mean and standard deviation of the features.""" - self.feature_mean = X_ref.mean(dim=0) - self.feature_std = X_ref.std(dim=0) - class Datamodule(pl.LightningDataModule, ABC): def __init__( diff --git a/test.ipynb b/test.ipynb deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_datamodules.py b/tests/test_datamodules.py index 815d0c7..7344f67 100644 --- a/tests/test_datamodules.py +++ b/tests/test_datamodules.py @@ -45,16 +45,10 @@ def setup(self, stage: str = "fit") -> None: ) self.X_test = torch.randn_like(self.X_train) self.y_test = torch.randint_like(self.y_train, low=low, high=high) - self.compute_feature_statistics() def download_data(self) -> None: ... - def compute_feature_statistics(self) -> None: - """Compute the mean and standard deviation of the features, along the batch dimension.""" - self.feature_mean = self.X_train.mean(dim=0) - self.feature_std = self.X_train.std(dim=0) - @property def dataset_name(self) -> str: return "dummy" diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py index 80628fb..6ca0d4e 100644 --- a/tests/test_schedulers.py +++ b/tests/test_schedulers.py @@ -1,15 +1,15 @@ from copy import deepcopy -from pathlib import Path import pytorch_lightning as pl import torch -from fdiff.dataloaders.datamodules import Datamodule from fdiff.models.score_models import ScoreModule from fdiff.sampling.sampler import DiffusionSampler from fdiff.schedulers.vpsde_scheduler import VPScheduler from fdiff.utils.dataclasses import DiffusableBatch +from .test_datamodules import DummyDatamodule + n_head = 4 d_model = 8 n_channels = 3 @@ -119,55 +119,6 @@ def test_score_module_with_vpsde(): # Check the shape of the samples assert samples.shape == (num_samples, max_len, n_channels) - -class DummyDatamodule(Datamodule): - def __init__( - self, - data_dir: Path = Path.cwd() / "data", - random_seed: int = 42, - batch_size: int = batch_size, - max_len: int = max_len, - n_channels: int = n_channels, - fourier_transform: bool = False, - standardize: bool = False, - ) -> None: - super().__init__( - data_dir=data_dir, - random_seed=random_seed, - batch_size=batch_size, - fourier_transform=fourier_transform, - ) - self.max_len = max_len - self.n_channels = n_channels - self.batch_size = batch_size - - def setup(self, stage: str = "fit") -> None: - torch.manual_seed(self.random_seed) - self.X_train = torch.randn( - (10 * self.batch_size, self.max_len, self.n_channels), - dtype=torch.float32, - ) - self.y_train = torch.randint( - low=low, high=high, size=(10 * self.batch_size,), dtype=torch.long - ) - self.X_test = torch.randn_like(self.X_train) - self.y_test = torch.randint_like(self.y_train, low=low, high=high) - self.compute_feature_statistics() - - def download_data(self) -> None: - ... - - def compute_feature_statistics(self) -> None: - """Compute the mean and standard deviation - of the features, along the batch dimension.""" - self.feature_mean = self.X_train.mean(dim=0) - self.feature_std = self.X_train.std(dim=0) - @property def dataset_name(self) -> str: return "dummy" - - -if __name__ == "__main__": - # test_noise_adder() - test_score_module_with_vpsde()