Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

# Training a 3D Diffusion Model for Generating 3D Images with Various Sizes and Spacings

## Setup environment

In [1]:
!python -c "import monai" || pip install -q "monai-weekly[pillow, tqdm]"

## Setup imports

In [2]:
import copy
import os
import json
import numpy as np
import nibabel as nib
import subprocess

from monai.data import create_test_image_3d
from monai.config import print_config

print_config()

IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


MONAI version: 1.3.1+27.g8cfbcbab
Numpy version: 1.26.4
Pytorch version: 2.3.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 8cfbcbabd1529ef4090fb6f7ffbeef47d6b70cc2
MONAI __file__: /localhome/<username>/miniconda3/envs/monai-dev/lib/python3.11/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.11
ITK version: 5.4.0
Nibabel version: 5.2.1
scikit-image version: 0.24.0
scipy version: 1.13.1
Pillow version: 10.3.0
Tensorboard version: 2.17.0
gdown version: 5.2.0
TorchVision version: 0.18.1+cu121
tqdm version: 4.66.4
lmdb version: 1.4.1
psutil version: 6.0.0
pandas version: 2.2.2
einops version: 0.8.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: 2.14.1
pynrrd version: 1.0.0
clearml version: 1.16.2

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



## Simulate a special dataset

It is well known that AI takes time to train. We will simulate a small dataset and run training only for multiple epochs. Due to the nature of AI, the performance shouldn't be highly expected, but the entire pipeline will be completed within minutes!

`sim_datalist` provides the information of the simulated datasets. It lists 2 training images. The size of the dimension is defined by the `sim_dim`.

In [3]:
sim_datalist = {"training": [{"image": "tr_image_001.nii.gz"}, {"image": "tr_image_002.nii.gz"}]}

sim_dim = (128, 160, 96)

## Generate images

Now we can use MONAI `create_test_image_3d` and `nib.Nifti1Image` functions to generate the 3D simulated images under the work_dir

In [4]:
work_dir = "./helloworld_work_dir"
if not os.path.isdir(work_dir):
    os.makedirs(work_dir)

dataroot_dir = os.path.join(work_dir, "sim_dataroot")
if not os.path.isdir(dataroot_dir):
    os.makedirs(dataroot_dir)

datalist_file = os.path.join(work_dir, "sim_datalist.json")
with open(datalist_file, "w") as f:
    json.dump(sim_datalist, f)

for d in sim_datalist["training"]:
    im, _ = create_test_image_3d(
        sim_dim[0], sim_dim[1], sim_dim[2], rad_max=10, num_seg_classes=1, random_state=np.random.RandomState(42)
    )
    image_fpath = os.path.join(dataroot_dir, d["image"])
    nib.save(nib.Nifti1Image(im, affine=np.eye(4)), image_fpath)

print("Generated simulated images.")

Generated simulated images.


## Set up directories and configurations

In [5]:
env_config_path = "./configs/environment_maisi_diff_model_train.json"
model_config_path = "./configs/config_maisi_diff_model_train.json"

# Load environment and model configurations
with open(env_config_path, "r") as f:
    env_config = json.load(f)

with open(model_config_path, "r") as f:
    model_config = json.load(f)

env_config_out = copy.deepcopy(env_config)
model_config_out = copy.deepcopy(model_config)

# Set up directories based on configurations
env_config_out["data_base_dir"] = dataroot_dir
env_config_out["embedding_base_dir"] = os.path.join(work_dir, env_config_out["embedding_base_dir"])
env_config_out["json_data_list"] = datalist_file
env_config_out["model_dir"] = os.path.join(work_dir, env_config_out["model_dir"])
env_config_out["output_dir"] = os.path.join(work_dir, env_config_out["output_dir"])
env_config_out["trained_autoencoder_path"] = None

# Create necessary directories
os.makedirs(env_config_out["embedding_base_dir"], exist_ok=True)
os.makedirs(env_config_out["model_dir"], exist_ok=True)
os.makedirs(env_config_out["output_dir"], exist_ok=True)

env_config_filepath = os.path.join(work_dir, "environment_maisi_diff_model_train.json")
with open(env_config_filepath, "w") as f:
    json.dump(env_config_out, f, sort_keys=True, indent=4)

# Update model configuration for demo
model_config_out["autoencoder_def"]["num_splits"] = 4
model_config_out["diffusion_unet_train"]["n_epochs"] = 2

model_config_filepath = os.path.join(work_dir, "config_maisi_diff_model_train.json")
with open(model_config_filepath, "w") as f:
    json.dump(model_config_out, f, sort_keys=True, indent=4)

# Print files and folders under work_dir
print(os.listdir(work_dir))

['models', 'sim_datalist.json', 'embeddings', 'environment_maisi_diff_model_train.json', 'predictions', 'sim_dataroot', 'config_maisi_diff_model_train.json']


In [6]:
def run_torchrun(script, script_args, num_gpus=1):
    # Define the arguments for torchrun
    num_nodes = 1

    # Build the torchrun command
    torchrun_command = ["torchrun", "--nproc_per_node", str(num_gpus), "--nnodes", str(num_nodes), script] + script_args

    # Set the OMP_NUM_THREADS environment variable
    env = os.environ.copy()
    env["OMP_NUM_THREADS"] = "1"

    # Execute the command
    process = subprocess.Popen(torchrun_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, env=env)

    # Print the output in real-time
    try:
        while True:
            output = process.stdout.readline()
            if output == "" and process.poll() is not None:
                break
            if output:
                print(output.strip())
    except Exception as e:
        print(f"An error occurred: {e}")
    finally:
        # Capture and print any remaining output
        stdout, stderr = process.communicate()
        print(stdout)
        if stderr:
            print(stderr)
    return

## Step 1: Create Training Data

In [7]:
print("Creating training data...")

# Define the arguments for torchrun
num_gpus = 1  # Adjust based on the number of GPUs you want to use
script = "scripts/diff_model_create_training_data.py"  # Replace with your script
script_args = ["--env_config", env_config_filepath, "--model_config", model_config_filepath]

run_torchrun(script, script_args, num_gpus=num_gpus)

Creating training data...
Using device cuda:0
The trained_autoencoder_path does not exist!
filenames_raw: ['tr_image_001.nii.gz', 'tr_image_002.nii.gz']
old [128, 160, 96] [1.0, 1.0, 1.0]
new (128, 128, 128) [[ 1.     0.     0.     0.   ]
[ 0.     1.25   0.     0.125]
[ 0.     0.     0.75  -0.125]
[ 0.     0.     0.     1.   ]]
out_filename: ./helloworld_work_dir/./embeddings/tr_image_001_emb.nii.gz
z: torch.Size([1, 4, 32, 32, 32]), torch.float32
old [128, 160, 96] [1.0, 1.0, 1.0]
new (128, 128, 128) [[ 1.     0.     0.     0.   ]
[ 0.     1.25   0.     0.125]
[ 0.     0.     0.75  -0.125]
[ 0.     0.     0.     1.   ]]
out_filename: ./helloworld_work_dir/./embeddings/tr_image_002_emb.nii.gz
z: torch.Size([1, 4, 32, 32, 32]), torch.float32



## Create .json files for embedding files

In [8]:
def list_gz_files(folder_path):
    """List all .gz files in the folder and its subfolders."""
    gz_files = []
    for root, dirs, files in os.walk(folder_path):
        for file in files:
            if file.endswith(".gz"):
                gz_files.append(os.path.join(root, file))
    return gz_files


def create_json_files(gz_files):
    """Create .json files for each .gz file with the specified keys and values."""
    for gz_file in gz_files:
        # Load the NIfTI image
        img = nib.load(gz_file)

        # Get the dimensions and spacing
        dimensions = img.shape
        dimensions = dimensions[:3]
        spacing = img.header.get_zooms()[:3]
        spacing = spacing[:3]
        spacing = [float(_item) for _item in spacing]

        # Create the dictionary with the specified keys and values
        # The region can be selected from one of four regions from top to bottom.
        # [1,0,0,0] is the head and neck, [0,1,0,0] is the chest region, [0,0,1,0]
        # is the abdomen region, and [0,0,0,1] is the lower body region.
        data = {
            "dim": dimensions,
            "spacing": spacing,
            "top_region_index": [0, 1, 0, 0],  # chest region
            "bottom_region_index": [0, 0, 1, 0],  # abdomen region
        }
        print(f"{data}")

        # Create the .json filename
        json_filename = gz_file + ".json"

        # Write the dictionary to the .json file
        with open(json_filename, "w") as json_file:
            json.dump(data, json_file, indent=4)


folder_path = env_config_out["embedding_base_dir"]
gz_files = list_gz_files(folder_path)
create_json_files(gz_files)

print("Completed creating .json files for all embedding files.")

{'dim': (32, 32, 32), 'spacing': [1.0, 1.25, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}
{'dim': (32, 32, 32), 'spacing': [1.0, 1.25, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}
Completed creating .json files for all embedding files.


## Step 2: Train the Model

In [9]:
print("Training the model...")

# Define the arguments for torchrun
num_gpus = 1  # Adjust based on the number of GPUs you want to use
script = "scripts/diff_model_train.py"  # Replace with your script
script_args = ["--env_config", env_config_filepath, "--model_config", model_config_filepath]

run_torchrun(script, script_args, num_gpus=num_gpus)

Training the model...
Using cuda:0 of 1
[config] ckpt_folder -> ./helloworld_work_dir/./models.
[config] data_root -> ./helloworld_work_dir/./embeddings.
[config] data_list -> ./helloworld_work_dir/sim_datalist.json.
[config] lr -> 0.0001.
[config] num_epochs -> 2.
[config] num_train_timesteps -> 1000.
num_files_train: 2
cache_rate: 0
num_images_per_batch -> 1.
training from scratch.
Scaling factor set to 0.8950040340423584.
Rank 0: scale_factor -> 0.8950040340423584.
optimizer -> Adam; lr -> 0.0001.
total number of training steps: 4.0.
torch.set_float32_matmul_precision -> highest.
epoch 1/2, lr 0.0001.
[2024-07-12 13:06:50] epoch 1, iter 1/2, loss: 0.7960, lr: 0.000100000000.
[2024-07-12 13:06:50] epoch 1, iter 2/2, loss: 0.7952, lr: 0.000056250000.
epoch 1 average loss: 0.7956.
epoch 2/2, lr 2.5e-05.
[2024-07-12 13:06:51] epoch 2, iter 1/2, loss: 0.7916, lr: 0.000025000000.
[2024-07-12 13:06:51] epoch 2, iter 2/2, loss: 0.7911, lr: 0.000006250000.
epoch 2 average loss: 0.7913.



## Step 3: Infer using the Trained Model

In [10]:
print("Running inference...")

# Define the arguments for torchrun
num_gpus = 1  # Adjust based on the number of GPUs you want to use
script = "scripts/diff_model_infer.py"  # Replace with your script
script_args = ["--env_config", env_config_filepath, "--model_config", model_config_filepath]

run_torchrun(script, script_args, num_gpus=num_gpus)

print("Completed all steps.")

Running inference...
a_min: -1000, a_max: 1000, b_min: 0, b_max: 1.
Using cuda:0 of 1
random seed: 0
[config] ckpt_filepath -> ./helloworld_work_dir/./models/diff_unet_ckpt.pt.
[config] random_seed -> 0.
[config] output_prefix -> unet_3d.
[config] output_size -> (128, 128, 128).
[config] out_spacing -> (1.0, 1.25, 0.75).
The trained_autoencoder_path does not exist!
checkpoints ./helloworld_work_dir/./models/diff_unet_ckpt.pt loaded.
num_downsample_level -> 4.
divisor -> 4.
num_train_timesteps -> 1000.
scale_factor -> 0.8950040340423584.
noise: cuda:0 torch.float32 <class 'torch.Tensor'>
top_region_index_tensor: [  0. 100.   0.   0.].
bottom_region_index_tensor: [  0.   0. 100.   0.].
spacing_tensor: [100. 125.  75.].
Saved ./helloworld_work_dir/./predictions/unet_3d_seed0_size128x128x128_spacing1.00x1.25x0.75_20240712130709.nii.gz.


  0%|                                                                                  | 0/10 [00:00<?, ?it/s]
 10%|███████▍                              