Skip to content

Commit

Permalink
cleaned typing
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolashuynh committed Dec 18, 2023
1 parent 7ed464d commit 0232c64
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 12 deletions.
4 changes: 2 additions & 2 deletions src/fdiff/dataloaders/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def __init__(
self.X = X
self.y = y
self.standardize = standardize
self.feature_mean = None
self.feature_std = None
self.feature_mean = torch.empty(size=(self.X.size(1), self.X.size(2)))
self.feature_std = torch.empty(size=(self.X.size(1), self.X.size(2)))
if X_ref is None:
X_ref = X
self.compute_feature_statistics(X_ref)
Expand Down
7 changes: 4 additions & 3 deletions src/fdiff/models/score_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def forward(self, batch: DiffusableBatch) -> torch.Tensor:

def training_step(
self, batch: DiffusableBatch, batch_idx: int, dataloader_idx: int = 0
) -> torch.Tensor:
):
loss = self.training_loss_fn(self, batch)

self.log_dict(
Expand Down Expand Up @@ -135,8 +135,9 @@ def configure_optimizers(self) -> OptimizerLRScheduler:

def set_loss_fn(self) -> tuple[Callable, Callable]:
# depending on the scheduler, get the right loss function
scheduler_config = self.noise_scheduler.config # type: ignore
if isinstance(self.noise_scheduler, DDPMScheduler):
self.max_time = self.noise_scheduler.config.num_train_timesteps
self.max_time = scheduler_config.num_train_timesteps

training_loss_fn = get_ddpm_loss(
scheduler=self.noise_scheduler, train=True, max_time=self.max_time
Expand Down Expand Up @@ -167,7 +168,7 @@ def set_loss_fn(self) -> tuple[Callable, Callable]:
"Scheduler not implemented yet, cannot set loss function"
)

def set_time_encoder(self) -> nn.Module:
def set_time_encoder(self):
if isinstance(self.noise_scheduler, DDPMScheduler):
return TimeEncoding(d_model=self.d_model, max_time=self.max_time)

Expand Down
5 changes: 3 additions & 2 deletions src/fdiff/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def __init__(self, d_model: int, scale=30.0):
torch.randn((d_model + 1) // 2) * scale, requires_grad=False
)
self.dense = nn.Linear(d_model, d_model)
print(self.d_model)

def forward(self, x: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
time_proj = timesteps[:, None] * self.W[None, :] * 2 * np.pi
Expand All @@ -78,4 +77,6 @@ def forward(self, x: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:

t_emb = t_emb.unsqueeze(1)

return x + self.dense(t_emb)
projected_emb: torch.Tensor = self.dense(t_emb)

return x + projected_emb
2 changes: 2 additions & 0 deletions src/fdiff/schedulers/custom_ddpm_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# mypy: ignore-errors

# Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
4 changes: 1 addition & 3 deletions src/fdiff/schedulers/vpsde_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def add_noise(
sample = mean + noise
return sample

def get_beta(
self, timestep: torch.Tensor | float | np.ndarray
) -> torch.Tensor | float | np.ndarray:
def get_beta(self, timestep: torch.Tensor | float | np.ndarray) -> torch.Tensor:
return torch.tensor(
self.beta_0 + timestep * (self.beta_1 - self.beta_0), device=self.device
)
Expand Down
4 changes: 2 additions & 2 deletions src/fdiff/utils/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_sde_loss_fn(
else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)
)

def loss_fn(model, batch):
def loss_fn(model, batch) -> torch.Tensor:
"""Compute the loss function.
Args:
Expand Down Expand Up @@ -118,7 +118,7 @@ def loss_fn(model, batch):


def get_ddpm_loss(scheduler, train, max_time):
def loss_fn(model, batch):
def loss_fn(model, batch) -> torch.Tensor:
if train:
model.train()
else:
Expand Down

0 comments on commit 0232c64

Please sign in to comment.