Skip to content

Commit

Permalink
update to diffusion to work with more than one channel
Browse files Browse the repository at this point in the history
  • Loading branch information
josegcpa committed May 10, 2024
1 parent e4b65ea commit 148d99b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
3 changes: 2 additions & 1 deletion adell_mri/entrypoints/generative/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def main(arguments):
)
network_config["with_conditioning"] = with_conditioning
network_config["cross_attention_dim"] = 256 if with_conditioning else None
network_config["in_channels"] = len(keys)

all_pids = [k for k in data_dict]

Expand Down Expand Up @@ -320,7 +321,7 @@ def train_loader_call():
size = return_first_not_none(args.pad_size, args.crop_size)
callbacks.append(
LogImageFromDiffusionProcess(
n_images=2,
n_images=1,
size=[int(x) for x in size][: network_config["spatial_dims"]],
)
)
Expand Down
21 changes: 18 additions & 3 deletions adell_mri/utils/pl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ def on_validation_epoch_end(
) -> None:
ep = pl_module.current_epoch
if ep % self.every_n_epochs == 0 and ep > 0:
images = pl_module.generate_image(size=self.size, n=self.n_images)
with torch.no_grad():
images = pl_module.generate_image(
size=self.size, n=self.n_images
)
log_image(
trainer,
key="Generated images",
Expand Down Expand Up @@ -351,6 +354,16 @@ def coerce_to_uint8(x: np.ndarray):
return x.astype(np.uint8)


def split_and_cat(x: np.ndarray, split_dim: int, cat_dim: int) -> np.ndarray:
print(x.shape)
arrays = np.split(x, x.shape[split_dim], axis=split_dim)
arrays = np.concatenate(
[arr.squeeze(split_dim) for arr in arrays], cat_dim
)
print(arrays.shape)
return arrays


def log_image(
trainer: Trainer,
key: str,
Expand All @@ -377,15 +390,17 @@ def log_image(
images = images.detach().to("cpu")
if len(images.shape) == 5:
n_slices = images.shape[slice_dim]
slice_idxs = np.arange(0, n_slices, n_slices_out + 2)[1:-1]
slice_idxs = np.linspace(
0, n_slices, num=n_slices_out + 2, dtype=np.int32
)[1:-1]
images = torch.index_select(
images, slice_dim, torch.as_tensor(slice_idxs)
)
images = torch.split(images, 1, dim=slice_dim)
images = torch.cat(images, -2).squeeze(-1)
images = torch.split(images, 1, 0)
images = [x.squeeze(0).permute(1, 2, 0).numpy() for x in images]
images = [coerce_to_uint8(x).squeeze(-1) for x in images]
images = [coerce_to_uint8(split_and_cat(x, -1, 0)) for x in images]
images = [Image.fromarray(x) for x in images]

if caption is not None:
Expand Down

0 comments on commit 148d99b

Please sign in to comment.