Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

# MAISI VAE Training Tutorial

This tutorial illustrates how to train the VAE model in MAISI. The VAE model is used for latent feature compression, which significantly reduce the memory usage of the diffusion model.

## Setup environment

To run this tutorial, please install `xformers` by following the [official guide](https://github.com/facebookresearch/xformers#installing-xformers) 

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel, tqdm]"
# TODO: remove print statement after generative dir is renamed to generation
!python -c "import xformers" || pip install -q xformers --index-url https://download.pytorch.org/whl/cu121
!python -c "import generative; print(generative.__version__)" || pip install -q "monai-generative"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

## Setup imports

In [None]:
import argparse
import json
import os
import tempfile
import glob
from pathlib import Path

import torch
from torch.nn import L1Loss, MSELoss
from torch.optim import lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast

import monai
from monai.apps import download_url
from monai.config import print_config
from monai.transforms import LoadImage, Orientation
from monai.utils import set_determinism
from monai.data import load_decathlon_datalist
from monai.apps import download_and_extract
from monai.data import CacheDataset, DataLoader
from monai.losses.perceptual import PerceptualLoss
from monai.losses.adversarial_loss import PatchAdversarialLoss
from monai.inferers.inferer import SlidingWindowInferer, SimpleInferer

from generative.networks.nets import PatchDiscriminator

from scripts.utils import define_instance, load_autoencoder_ckpt, KL_loss, dynamic_infer
from scripts.utils_plot import find_label_center_loc, get_xyz_plot, show_image
from scripts.transforms import VAE_Transform

print_config()

## Setup data directory
You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.
This allows you to save results and reuse downloads.
If not specified a temporary directory will be used.

In [None]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
if directory is not None:
    os.makedirs(directory, exist_ok=True)
root_dir = tempfile.mkdtemp() if directory is None else directory
root_dir = '/mnt/drive1/canz/datasets'
print(root_dir)

## Download dataset

This tutorial shows how to train a VAE with both CT and MRI data. We use MSD 09 Spleen segmentation and MSD 01 Brats16&17 Brain Tumor segmentation as examples. Users can choose their own training datasets.

Downloads and extracts the dataset.
The dataset comes from http://medicaldecathlon.com/.

In [None]:
# MSD Spleen CT data
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
md5 = "410d4a301da4e5b2f6f86ec3ddba524e"

compressed_file = os.path.join(root_dir, "Task09_Spleen.tar")
data_path_1 = os.path.join(root_dir, "Task09_Spleen")
if not os.path.exists(data_path_1):
    download_and_extract(resource, compressed_file, root_dir, md5)
    
train_images_1 = sorted(glob.glob(os.path.join(data_path_1, "imagesTr", "*.nii.gz")))
data_dicts_1 = [{"image": image_name} for image_name in train_images_1]
len_train= int(0.8*len(data_dicts_1))
train_files_1, val_files_1 = data_dicts_1[:len_train], data_dicts_1[len_train:]

# MSD Brats MRI data
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task01_BrainTumour.tar"
md5 = "240a19d752f0d9e9101544901065d872"

compressed_file = os.path.join(root_dir, "Task01_BrainTumour.tar")
data_path_2 = os.path.join(root_dir, "Task01_BrainTumour")
if not os.path.exists(data_path_2):
    download_and_extract(resource, compressed_file, root_dir, md5)
    
train_images_2 = sorted(glob.glob(os.path.join(data_path_2, "imagesTr", "*.nii.gz")))
data_dicts_2 = [{"image": image_name} for image_name in train_images_2]
len_train= int(0.8*len(data_dicts_2))
train_files_2, val_files_2 = data_dicts_2[:len_train], data_dicts_2[len_train:]

# Expandable to more datasets
datasets = {
    1: {
        'data_name': "Dataset 1 MSD09 Spleen Abdomen CT",
        'train_files': train_files_1,
        'val_files': val_files_1,
        "modality": 'ct'
    },
    2: {
        'data_name': "Dataset 2 MSD01 Brats Brain MRI",
        'train_files': train_files_2,
        'val_files': val_files_2,
        "modality": 'mri'
    }
}

## Read in environment setting, including data directory, model directory, and output directory

The information for data directory, model directory, and output directory are saved in ./configs/environment.json

In [None]:
args = argparse.Namespace()

environment_file = "./configs/environment_maisi_vae_train.json"
env_dict = json.load(open(environment_file, "r"))
for k, v in env_dict.items():
    setattr(args, k, v)
    print(f"{k}: {v}")

# model path
Path(args.model_dir).mkdir(parents=True, exist_ok=True)
trained_g_path = os.path.join(args.model_dir,"autoencoder.pt")
trained_d_path = os.path.join(args.model_dir,"discriminator.pt")
print(f"Trained model will be saved as {trained_g_path} and {trained_d_path}.")

# initialize tensorboard writer
Path(args.tfevent_path).mkdir(parents=True, exist_ok=True)
tensorboard_path = os.path.join(args.tfevent_path, "autoencoder")
Path(tensorboard_path).mkdir(parents=True, exist_ok=True)
tensorboard_writer = SummaryWriter(tensorboard_path)
print(f"Tensorboard event will be saved as {tensorboard_path}.")

## Read in configuration setting, including network definition, body region and anatomy to generate, etc.

The information used for both training and inference, like network definition, is stored in `"./configs/config_maisi.json"`. Training and inference should use the same "./configs/config_maisi.json".

The information for the training hyperparameters and data processing pparameters, like learning rate and patch size, are stored in `"./configs/config_maisi_vae_train"`. Please feel free to play with it.

Dataset preprocessing:
- `"random_aug"`: bool, whether to add random data augmentation for training data.
- `"spacing_type"`: choose from `"original"` (no resampling involved), `"fixed"` (all images resampled to same voxel size), and `"rand"` (images randomly resampled).
- `"spacing"`: None or list of three floats. If `"spacing_type"` is `"fixed"`, all the images will be interpolated to the voxel size of `"spacing"`.
- `"select_channel"`: int, if multi-channel MRI, which chennel it will select.

Training configs:
- `"patch_size"`: training patch size. For the released model, we first trained the autoencoder with small patch size ([128,128,128]), then increased to larger patch size ([256,256,128]).
- `"val_patch_size"`: Size of validation patches. If None, will use the whole volume for validation. If given, will central crop a patch for validation.
- `"val_sliding_window_patch_size"`: if the validation patch is too large, will use sliding window inference. Choose small value if GPU memory is small.
- `"perceptual_weight"`: perceptual loss weight.
- `"kl_weight"`: KL loss weight, important hyper-parameter. If too large, decoder cannot recon good results from latent space. If too small, latent space will not be regularized enough for the diffusion model.
- `"adv_weight"`: adversavarial loss weight.
- `"recon_loss"`: choose from 'l1' and l2'.
- `"val_interval"`:int, do validation every `"val_interval"` epoches.
- `"cache"`: float between 0 and 1, dataloader cache, choose small value if CPU memory is small.

In [None]:
config_file = "./configs/config_maisi.json"
config_dict = json.load(open(config_file, "r"))
for k, v in config_dict.items():
    setattr(args, k, v)

# check the format of inference inputs
config_train_file = "./configs/config_maisi_vae_train.json"
config_train_dict = json.load(open(config_train_file, "r"))
for k, v in config_train_dict["data_option"].items():
    setattr(args, k, v)
    print(f"{k}: {v}")
for k, v in config_train_dict["autoencoder_train"].items():
    setattr(args, k, v)
    print(f"{k}: {v}")

print("Network definition and training hyperparameters have been loaded.")

## Set deterministic training for reproducibility

In [None]:
set_determinism(seed=0)

## Initialize networks

In [None]:
device = torch.device("cuda")

args.autoencoder_def["num_splits"] = 1
autoencoder = define_instance(args, "autoencoder_def").to(device)
discriminator_norm = "INSTANCE"
discriminator = PatchDiscriminator(
    spatial_dims=args.spatial_dims,
    num_layers_d=3,
    num_channels=32,
    in_channels=1,
    out_channels=1,
    norm=discriminator_norm,
).to(device)

## Build training dataset

In [None]:
# Initialize file lists
train_files = {'ct': [], 'mri': []}
val_files = {'ct': [], 'mri': []}

# Function to add assigned class to datalist
def add_assigned_class_to_datalist(datalist, classname):
    for item in datalist:
        item["class"] = classname
    return datalist

# Process datasets
for data_id, dataset in datasets.items(): 
    train_files_i = dataset["train_files"]
    val_files_i = dataset["val_files"]
    print(f"{dataset['data_name']}: number of training data is {len(train_files_i)}.")
    print(f"{dataset['data_name']}: number of val data is {len(val_files_i)}.")
    
    # attach modality to each file
    modality = dataset['modality']
    train_files[modality] += add_assigned_class_to_datalist(train_files_i, modality)
    val_files[modality] += add_assigned_class_to_datalist(val_files_i, modality)

# Print total numbers for each modality
for modality in train_files.keys():
    print(f"Total number of training data for {modality} is {len(train_files[modality])}.")
    print(f"Total number of val data for {modality} is {len(val_files[modality])}.")

# Combine the data
train_files_combined = train_files['ct'] + train_files['mri']
val_files_combined = val_files['ct'] + val_files['mri']

## Define data transforms

In [None]:
train_transform = VAE_Transform(
    is_train=True, 
    random_aug=args.random_aug,  # whether apply random data augmentation for training
    k=4,  # patches should be divisible by k
    patch_size=args.patch_size,
    val_patch_size=args.val_patch_size,
    output_dtype=torch.float16,  # final data type
    spacing_type=args.spacing_type,
    spacing=args.spacing,
    image_keys=["image"],
    label_keys=[],
    additional_keys=[],
    select_channel=0)
val_transform = VAE_Transform(
    is_train=False, 
    random_aug=False,
    k=4,  # patches should be divisible by k
    val_patch_size=args.val_patch_size, # if None, will validate on whole image volume
    output_dtype=torch.float16,  # final data type
    image_keys=["image"],
    label_keys=[],
    additional_keys=[],
    select_channel=0)

## Build data loader

In [None]:
print(f"Total number of training data is {len(train_files_combined)}.")
dataset_train = CacheDataset(
    data=train_files_combined, transform=train_transform, cache_rate=args.cache, num_workers=8
)
dataloader_train = DataLoader(
    dataset_train,
    batch_size=args.batch_size,
    num_workers=4,
    shuffle=True,
    drop_last=True
)

print(f"Total number of validation data is {len(val_files_combined)}.")
dataset_val = CacheDataset(
    data=val_files_combined, transform=val_transform, cache_rate=args.cache, num_workers=8
)
dataloader_val = DataLoader(
    dataset_val,
    batch_size=args.val_batch_size,
    num_workers=4,
    shuffle=False
    )

## Visualize the data

In [None]:
example_vis_img = dataset_train[0]["image"]
print(f"Train image shape {example_vis_img.shape}")
center_loc_axis = find_label_center_loc(example_vis_img.squeeze(0))
vis_image = get_xyz_plot(example_vis_img, center_loc_axis, mask_bool=False)
show_image(vis_image, title="training image")

example_vis_img = dataset_val[0]["image"]
print(f"Val image shape {example_vis_img.shape}")
center_loc_axis = find_label_center_loc(example_vis_img.squeeze(0))
vis_image = get_xyz_plot(example_vis_img, center_loc_axis, mask_bool=False)
show_image(vis_image, title="validation image")

## Training config

In [None]:
# config loss and loss weight
if args.recon_loss == "l2":
    intensity_loss = MSELoss()
    print("Use l2 loss")
else:
    intensity_loss = L1Loss(reduction='mean')
    print("Use l1 loss")
adv_loss = PatchAdversarialLoss(criterion="least_squares")

loss_perceptual = PerceptualLoss(spatial_dims=3, network_type="squeeze", is_fake_3d=True,
    fake_3d_ratio=0.2).eval().to(device)

# config optimizer and lr scheduler
optimizer_g = torch.optim.Adam(params=autoencoder.parameters(), 
    lr=args.lr, eps=1e-06 if args.amp else 1e-08)
optimizer_d = torch.optim.Adam(params=discriminator.parameters(), 
    lr=args.lr, eps=1e-06 if args.amp else 1e-08)

def warmup_rule(epoch):
    if epoch < 10:
        return 0.01
    elif epoch < 20:
        return 0.1
    else:
        return 1.0

scheduler_g = lr_scheduler.LambdaLR(optimizer_g, lr_lambda=warmup_rule)

# set AMP scaler
if args.amp:
    # test use mean reduction for everything
    scaler_g = GradScaler(init_scale=2. ** 8, growth_factor=1.5)
    scaler_d = GradScaler(init_scale=2. ** 8, growth_factor=1.5)


## Training
**Time cost:** 

**GPU memory:** 

In [None]:
# Initialize variables
val_interval = args.val_interval
best_val_recon_epoch_loss = 10000000.0
total_step = 0
start_epoch = 0

# Setup validation inferer
val_inferer = SlidingWindowInferer(
    roi_size=args.val_sliding_window_patch_size,
    sw_batch_size=1, progress=False, overlap=0.0,
    device=torch.device('cpu'), sw_device=device
) if args.val_sliding_window_patch_size else SimpleInferer()

def loss_weighted_sum(losses):
    return losses['recons_loss'] + args.kl_weight*losses['kl_loss'] + args.perceptual_weight*losses['p_loss']

# Training and validation loops
for epoch in range(start_epoch, args.n_epochs):
    print("lr:", scheduler_g.get_lr())
    autoencoder.train()
    discriminator.train()
    train_epoch_losses = {
        'recons_loss': 0,
        'kl_loss': 0,
        'p_loss': 0
    }

    for step, batch in enumerate(dataloader_train):
        images = batch["image"].to(device).contiguous()
        optimizer_g.zero_grad(set_to_none=True)
        optimizer_d.zero_grad(set_to_none=True)
        with autocast(enabled=args.amp):
            # Train Generator
            reconstruction, z_mu, z_sigma = autoencoder(images)
            losses = {
                'recons_loss': intensity_loss(reconstruction, images),
                'kl_loss': KL_loss(z_mu, z_sigma),
                'p_loss': loss_perceptual(reconstruction.float(), images.float())
            }
            logits_fake = discriminator(reconstruction.contiguous().float())[-1]
            generator_loss = adv_loss(logits_fake, target_is_real=True, for_discriminator=False)
            loss_g = loss_weighted_sum(losses)+ args.adv_weight * generator_loss

            if args.amp:
                scaler_g.scale(loss_g).backward()
                scaler_g.unscale_(optimizer_g)
                scaler_g.step(optimizer_g)
                scaler_g.update()
            else:
                loss_g.backward()
                optimizer_g.step()

            # Train Discriminator
            logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
            loss_d_fake = adv_loss(logits_fake, target_is_real=False, for_discriminator=True)
            logits_real = discriminator(images.contiguous().detach())[-1]
            loss_d_real = adv_loss(logits_real, target_is_real=True, for_discriminator=True)
            loss_d = (loss_d_fake + loss_d_real) * 0.5

            if args.amp:
                scaler_d.scale(loss_d).backward()
                scaler_d.step(optimizer_d)
                scaler_d.update()
            else:
                loss_d.backward()
                optimizer_d.step()

        # Log training loss
        total_step += 1
        for loss_name, loss_value in losses.items():
            tensorboard_writer.add_scalar(f"train_{loss_name}_iter", loss_value.item(), total_step)
            train_epoch_losses[loss_name] += loss_value.item()
        tensorboard_writer.add_scalar("train_adv_loss_iter", generator_loss, total_step)
        tensorboard_writer.add_scalar("train_fake_loss_iter", loss_d_fake, total_step)
        tensorboard_writer.add_scalar("train_real_loss_iter", loss_d_real, total_step)

    scheduler_g.step()
    print(f"Epoch {epoch} train_vae_loss {loss_weighted_sum(train_epoch_losses)}: {train_epoch_losses}.")
    for loss_name, loss_value in train_epoch_losses.items():
        tensorboard_writer.add_scalar(f"train_{loss_name}_epoch", loss_value, epoch)

    # Validation
    if epoch % val_interval == 0:
        autoencoder.eval()
        val_epoch_losses = {
            'recons_loss': 0,
            'kl_loss': 0,
            'p_loss': 0
        }
        val_loader_iter = iter(dataloader_val)
        for step in range(len(dataloader_val)):
            try:
                batch = next(val_loader_iter)
            except:
                print("Error for one data in val_loader, skip")
                continue
            with torch.no_grad():
                with autocast(enabled=args.amp):
                    images = batch["image"]
                    reconstruction, _, _ = dynamic_infer(val_inferer, autoencoder, images)
                    reconstruction = reconstruction.to(device)
                    val_epoch_losses['recons_loss'] += intensity_loss(reconstruction, images.to(device)).item()
                    val_epoch_losses['kl_loss'] += KL_loss(z_mu, z_sigma).item()
                    val_epoch_losses['p_loss'] += loss_perceptual(reconstruction, images.to(device)).item()

        for key in val_epoch_losses:
            val_epoch_losses[key] /= (step + 1)
        
        torch.save(autoencoder.state_dict(), trained_g_path)
        torch.save(discriminator.state_dict(), trained_d_path)
        val_loss_g = loss_weighted_sum(val_epoch_losses)
        print(f"Epoch {epoch} val_vae_loss {val_loss_g}: {val_epoch_losses}.")

        if val_loss_g < best_val_recon_epoch_loss:
            best_val_recon_epoch_loss = val_loss_g
            trained_g_path_epoch = f"{trained_g_path[:-3]}_epoch{epoch}.pt"
            torch.save(autoencoder.state_dict(), trained_g_path_epoch)
            print("Got best val vae loss.")
            print("Save trained autoencoder to", trained_g_path)
            print("Save trained discriminator to", trained_d_path)

        for loss_name, loss_value in val_epoch_losses.items():
            tensorboard_writer.add_scalar(loss_name, loss_value, epoch)
        
        # Monitor scale_factor
        # We'd like to tune kl_weights in order to make scale_factor close to 1.
        scale_factor_sample = 1. / z_mu.flatten().std()
        tensorboard_writer.add_scalar("val_one_sample_scale_factor", scale_factor_sample, epoch)
        
        # Monitor reconstruction result 
        center_loc_axis = find_label_center_loc(images[0,0,...])
        vis_image = get_xyz_plot(images[0,...], center_loc_axis, mask_bool=False)
        vis_recon_image = get_xyz_plot(reconstruction[0,...], center_loc_axis, mask_bool=False)
        
        tensorboard_writer.add_image(
            "val_orig_img",
            vis_image.transpose([2, 0, 1]),
            epoch,
        )
        tensorboard_writer.add_image(
            "val_recon_img",
            vis_recon_image.transpose([2, 0, 1]),
            epoch,
        )

        show_image(vis_image, title="val image")
        show_image(vis_recon_image, title="cal recon result")