From 0232c640f43a80a65bead3d6c72c68db90b81c9b Mon Sep 17 00:00:00 2001 From: nicolas Date: Mon, 18 Dec 2023 16:23:48 +0000 Subject: [PATCH] cleaned typing --- src/fdiff/dataloaders/datamodules.py | 4 ++-- src/fdiff/models/score_models.py | 7 ++++--- src/fdiff/models/transformer.py | 5 +++-- src/fdiff/schedulers/custom_ddpm_scheduler.py | 2 ++ src/fdiff/schedulers/vpsde_scheduler.py | 4 +--- src/fdiff/utils/losses.py | 4 ++-- 6 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/fdiff/dataloaders/datamodules.py b/src/fdiff/dataloaders/datamodules.py index 4a591e5..6ee53e0 100644 --- a/src/fdiff/dataloaders/datamodules.py +++ b/src/fdiff/dataloaders/datamodules.py @@ -30,8 +30,8 @@ def __init__( self.X = X self.y = y self.standardize = standardize - self.feature_mean = None - self.feature_std = None + self.feature_mean = torch.empty(size=(self.X.size(1), self.X.size(2))) + self.feature_std = torch.empty(size=(self.X.size(1), self.X.size(2))) if X_ref is None: X_ref = X self.compute_feature_statistics(X_ref) diff --git a/src/fdiff/models/score_models.py b/src/fdiff/models/score_models.py index f7f2edf..fc27851 100644 --- a/src/fdiff/models/score_models.py +++ b/src/fdiff/models/score_models.py @@ -99,7 +99,7 @@ def forward(self, batch: DiffusableBatch) -> torch.Tensor: def training_step( self, batch: DiffusableBatch, batch_idx: int, dataloader_idx: int = 0 - ) -> torch.Tensor: + ): loss = self.training_loss_fn(self, batch) self.log_dict( @@ -135,8 +135,9 @@ def configure_optimizers(self) -> OptimizerLRScheduler: def set_loss_fn(self) -> tuple[Callable, Callable]: # depending on the scheduler, get the right loss function + scheduler_config = self.noise_scheduler.config # type: ignore if isinstance(self.noise_scheduler, DDPMScheduler): - self.max_time = self.noise_scheduler.config.num_train_timesteps + self.max_time = scheduler_config.num_train_timesteps training_loss_fn = get_ddpm_loss( scheduler=self.noise_scheduler, train=True, max_time=self.max_time @@ -167,7 +168,7 @@ def set_loss_fn(self) -> tuple[Callable, Callable]: "Scheduler not implemented yet, cannot set loss function" ) - def set_time_encoder(self) -> nn.Module: + def set_time_encoder(self): if isinstance(self.noise_scheduler, DDPMScheduler): return TimeEncoding(d_model=self.d_model, max_time=self.max_time) diff --git a/src/fdiff/models/transformer.py b/src/fdiff/models/transformer.py index 65833db..9279db9 100644 --- a/src/fdiff/models/transformer.py +++ b/src/fdiff/models/transformer.py @@ -68,7 +68,6 @@ def __init__(self, d_model: int, scale=30.0): torch.randn((d_model + 1) // 2) * scale, requires_grad=False ) self.dense = nn.Linear(d_model, d_model) - print(self.d_model) def forward(self, x: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: time_proj = timesteps[:, None] * self.W[None, :] * 2 * np.pi @@ -78,4 +77,6 @@ def forward(self, x: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: t_emb = t_emb.unsqueeze(1) - return x + self.dense(t_emb) + projected_emb: torch.Tensor = self.dense(t_emb) + + return x + projected_emb diff --git a/src/fdiff/schedulers/custom_ddpm_scheduler.py b/src/fdiff/schedulers/custom_ddpm_scheduler.py index 6bafa24..f82f341 100644 --- a/src/fdiff/schedulers/custom_ddpm_scheduler.py +++ b/src/fdiff/schedulers/custom_ddpm_scheduler.py @@ -1,3 +1,5 @@ +# mypy: ignore-errors + # Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/src/fdiff/schedulers/vpsde_scheduler.py b/src/fdiff/schedulers/vpsde_scheduler.py index c532b4a..de2aeee 100644 --- a/src/fdiff/schedulers/vpsde_scheduler.py +++ b/src/fdiff/schedulers/vpsde_scheduler.py @@ -71,9 +71,7 @@ def add_noise( sample = mean + noise return sample - def get_beta( - self, timestep: torch.Tensor | float | np.ndarray - ) -> torch.Tensor | float | np.ndarray: + def get_beta(self, timestep: torch.Tensor | float | np.ndarray) -> torch.Tensor: return torch.tensor( self.beta_0 + timestep * (self.beta_1 - self.beta_0), device=self.device ) diff --git a/src/fdiff/utils/losses.py b/src/fdiff/utils/losses.py index 1a335ee..105dcf7 100644 --- a/src/fdiff/utils/losses.py +++ b/src/fdiff/utils/losses.py @@ -33,7 +33,7 @@ def get_sde_loss_fn( else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs) ) - def loss_fn(model, batch): + def loss_fn(model, batch) -> torch.Tensor: """Compute the loss function. Args: @@ -118,7 +118,7 @@ def loss_fn(model, batch): def get_ddpm_loss(scheduler, train, max_time): - def loss_fn(model, batch): + def loss_fn(model, batch) -> torch.Tensor: if train: model.train() else: