-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Diffusion Autoencoder tutorial #361
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice tutorial, the results at the end are v cool! I've left a few comments, and could you also run ./runtests.sh --autofix
to fix some formatting issues?
Feel free to merge once the changes are made.
# ## Setup imports | ||
|
||
# %% jupyter={"outputs_hidden": false} | ||
import os |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused import
from torch.cuda.amp import GradScaler, autocast | ||
from tqdm import tqdm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused imports
# %% jupyter={"outputs_hidden": false} | ||
directory = os.environ.get("MONAI_DATA_DIRECTORY") | ||
root_dir = tempfile.mkdtemp() if directory is None else directory | ||
root_dir = '/home/s2086085/pedro_idcom/experiment_data' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove this line specific to your system
transforms.Lambdad(keys=["image"], func=lambda x: x[channel, :, :, :]), | ||
transforms.AddChanneld(keys=["image"]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gives the warning <class 'monai.transforms.utility.array.AddChannel'>: Class
AddChannel has been deprecated since version 0.8. please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.
Could get rid of this by deleted AddChannel and adding. channel during the Lambda like:
transforms.Lambdad(keys=["image"], func=lambda x: x[channel, None, :, :, :]),
# 1. `LoadImaged` loads the brain images from files. | ||
# 2. `EnsureChannelFirstd` ensures the original data to construct "channel first" shape. | ||
# 3. The first `Lambdad` transform chooses the first channel of the image, which is the T1-weighted image. | ||
# 4. `Spacingd` resamples the image to the specified voxel spacing, we use 3,3,2 mm to match the original paper. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm what do you mean by the original paper here? The CVPR paper only uses computer vision data as far as I can see
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh, sorry, this was copied from another tutorial
# 6. The last `Lambdad` transform obtains `slice_label` by summing up the label to have a single scalar value (healthy `=1` or not `=2` ). | ||
|
||
# %% | ||
channel = 0 # 0 = Flair |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Says T1-weighted in step 3. of the transforms list but FLAIR here
section="training", | ||
cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise | ||
num_workers=4, | ||
download=False, # Set download to True if the dataset hasnt been downloaded yet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could set this to True so it just runs out the box for users without having to change anything - I think it doesn't actually do the download if it detects the file is already there
noise_pred = self.unet(x=xt, timesteps=t, context=latent.unsqueeze(2)) | ||
return noise_pred, latent | ||
|
||
device = torch.device("cuda:2") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think it might be best for users if we default to device 0
# ## Training a diffusion model and semantic encoder | ||
|
||
# %% | ||
n_iterations = 1e4 # training for longer helps a lot with reconstruction quality, even if the loss is already low |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you just give an estimate of training time in the comments? useful so users know how to change this variable if they just want 'the best results they can get in 10 mins of training' or something similar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just noticed it says at the end of training, ~3 hours
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's not super quick...
# get latent space of training set | ||
latents_train = [] | ||
classes_train = [] | ||
for i in range(15): # 15 slices from each volume |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i stared at this a while working out why i
was not being used! could we replace i
with _
to make clear we don't actually use the variable, its just a way of running the transform 15 times per batch?
Create a Diffusion Autoencoder tutorial using the components that we already have for image manipulation.