Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions generative/networks/nets/diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,8 @@ def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_peri
embedding_dim: the dimension of the output.
max_period: controls the minimum frequency of the embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
if timesteps.ndim != 1:
raise ValueError("Timesteps should be a 1d-array")

half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
Expand All @@ -491,7 +492,8 @@ class Downsample(nn.Module):
Args:
spatial_dims: number of spatial dimensions.
num_channels: number of input channels.
use_conv: if True uses Convolution instead of Pool average to perform downsampling.
use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is
False, the number of output channels must be the same as the number of input channels.
out_channels: number of output channels.
padding: controls the amount of implicit zero-paddings on both sides for padding number of points
for each dimension.
Expand All @@ -515,12 +517,17 @@ def __init__(
conv_only=True,
)
else:
assert self.num_channels == self.out_channels
if self.num_channels != self.out_channels:
raise ValueError("num_channels and out_channels must be equal when use_conv=False")
self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2)

def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
del emb
assert x.shape[1] == self.num_channels
if x.shape[1] != self.num_channels:
raise ValueError(
f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels "
f"({self.num_channels})"
)
return self.op(x)


Expand Down Expand Up @@ -559,7 +566,8 @@ def __init__(

def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
del emb
assert x.shape[1] == self.num_channels
if x.shape[1] != self.num_channels:
raise ValueError("Input channels should be equal to num_channels")

# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# https://github.com/pytorch/pytorch/issues/86679
Expand Down
14 changes: 14 additions & 0 deletions tests/test_diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,20 @@ def test_shape_unconditioned_models(self, input_param):
result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long())
self.assertEqual(result.shape, (1, 1, 16, 16))

def test_timestep_with_wrong_shape(self):
net = DiffusionModelUNet(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_res_blocks=1,
num_channels=(8, 8, 8),
attention_levels=(False, False, False),
norm_num_groups=8,
)
with self.assertRaises(ValueError):
with eval_mode(net):
net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long())

def test_shape_with_different_in_channel_out_channel(self):
in_channels = 6
out_channels = 3
Expand Down