Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Diffusion Autoencoder tutorial #361

Merged
merged 3 commits into from
Apr 6, 2023
Merged

Conversation

SANCHES-Pedro
Copy link
Contributor

Create a Diffusion Autoencoder tutorial using the components that we already have for image manipulation.

@marksgraham marksgraham self-assigned this Apr 5, 2023
Copy link
Collaborator

@marksgraham marksgraham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @SANCHES-Pedro

Really nice tutorial, the results at the end are v cool! I've left a few comments, and could you also run ./runtests.sh --autofix to fix some formatting issues?

Feel free to merge once the changes are made.

# ## Setup imports

# %% jupyter={"outputs_hidden": false}
import os
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused import

Comment on lines 64 to 65
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused imports

# %% jupyter={"outputs_hidden": false}
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
root_dir = '/home/s2086085/pedro_idcom/experiment_data'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove this line specific to your system

Comment on lines 117 to 118
transforms.Lambdad(keys=["image"], func=lambda x: x[channel, :, :, :]),
transforms.AddChanneld(keys=["image"]),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gives the warning <class 'monai.transforms.utility.array.AddChannel'>: Class AddChannel has been deprecated since version 0.8. please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.
Could get rid of this by deleted AddChannel and adding. channel during the Lambda like:
transforms.Lambdad(keys=["image"], func=lambda x: x[channel, None, :, :, :]),

# 1. `LoadImaged` loads the brain images from files.
# 2. `EnsureChannelFirstd` ensures the original data to construct "channel first" shape.
# 3. The first `Lambdad` transform chooses the first channel of the image, which is the T1-weighted image.
# 4. `Spacingd` resamples the image to the specified voxel spacing, we use 3,3,2 mm to match the original paper.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm what do you mean by the original paper here? The CVPR paper only uses computer vision data as far as I can see

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh, sorry, this was copied from another tutorial

# 6. The last `Lambdad` transform obtains `slice_label` by summing up the label to have a single scalar value (healthy `=1` or not `=2` ).

# %%
channel = 0 # 0 = Flair
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Says T1-weighted in step 3. of the transforms list but FLAIR here

section="training",
cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise
num_workers=4,
download=False, # Set download to True if the dataset hasnt been downloaded yet
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could set this to True so it just runs out the box for users without having to change anything - I think it doesn't actually do the download if it detects the file is already there

noise_pred = self.unet(x=xt, timesteps=t, context=latent.unsqueeze(2))
return noise_pred, latent

device = torch.device("cuda:2")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it might be best for users if we default to device 0

# ## Training a diffusion model and semantic encoder

# %%
n_iterations = 1e4 # training for longer helps a lot with reconstruction quality, even if the loss is already low
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you just give an estimate of training time in the comments? useful so users know how to change this variable if they just want 'the best results they can get in 10 mins of training' or something similar

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just noticed it says at the end of training, ~3 hours

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's not super quick...

# get latent space of training set
latents_train = []
classes_train = []
for i in range(15): # 15 slices from each volume
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i stared at this a while working out why i was not being used! could we replace i with _ to make clear we don't actually use the variable, its just a way of running the transform 15 times per batch?

@SANCHES-Pedro SANCHES-Pedro merged commit 02de27c into main Apr 6, 2023
@Warvito Warvito deleted the 346-diffusion-autoencoder branch May 4, 2023 19:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants