In [None]:
import copy
import os
import json
import numpy as np
import nibabel as nib
import subprocess

from monai.data import create_test_image_3d
from monai.config import print_config

print_config()

## Simulate a special dataset

It is well known that AI takes time to train. To provide the "Hello World!" experience of Auto3D in this notebook, we will simulate a small dataset and run training only for multiple epochs. Due to the nature of AI, the performance shouldn't be highly expected, but the entire pipeline will be completed within minutes!

`sim_datalist` provides the information of the simulated datasets. It lists 12 training and 2 testing images and labels. The training data are split into 3 folds. Each fold will use 8 images to train and 4 images to validate. The size of the dimension is defined by the `sim_dim`.

In [None]:
sim_datalist = {
    "training": [
        {"image": "tr_image_001.nii.gz"},
        {"image": "tr_image_002.nii.gz"}
    ]
}

sim_dim = (128, 160, 96)

## Generate images and labels

Now we can use MONAI `create_test_image_3d` and `nib.Nifti1Image` functions to generate the 3D simulated images under the work_dir

In [None]:
work_dir = "./helloworld_work_dir"
if not os.path.isdir(work_dir):
    os.makedirs(work_dir)

dataroot_dir = os.path.join(work_dir, "sim_dataroot")
if not os.path.isdir(dataroot_dir):
    os.makedirs(dataroot_dir)

datalist_file = os.path.join(work_dir, "sim_datalist.json")
with open(datalist_file, "w") as f:
    json.dump(sim_datalist, f)

for d in sim_datalist["training"]:
    im, _ = create_test_image_3d(
        sim_dim[0], sim_dim[1], sim_dim[2], rad_max=10, num_seg_classes=1, random_state=np.random.RandomState(42)
    )
    image_fpath = os.path.join(dataroot_dir, d["image"])
    nib.save(nib.Nifti1Image(im, affine=np.eye(4)), image_fpath)

print("Generated simulated images.")

In [None]:
# Set up directories and configurations
env_config_path = "./configs/environment_maisi_diff_model_train.json"
model_config_path = "./configs/config_maisi_diff_model_train.json"

# Load environment and model configurations
with open(env_config_path, "r") as f:
    env_config = json.load(f)

with open(model_config_path, "r") as f:
    model_config = json.load(f)

env_config_out = copy.deepcopy(env_config)
model_config_out = copy.deepcopy(model_config)

# Set up directories based on configurations
env_config_out["data_base_dir"] = dataroot_dir
env_config_out["embedding_base_dir"] = os.path.join(work_dir, env_config_out["embedding_base_dir"])
env_config_out["json_data_list"] = datalist_file
env_config_out["model_dir"] = os.path.join(work_dir, env_config_out["model_dir"])
env_config_out["output_dir"] = os.path.join(work_dir, env_config_out["output_dir"])
env_config_out["trained_autoencoder_path"] = None

# Create necessary directories
os.makedirs(env_config_out["embedding_base_dir"], exist_ok=True)
os.makedirs(env_config_out["model_dir"], exist_ok=True)
os.makedirs(env_config_out["output_dir"], exist_ok=True)

env_config_filepath = os.path.join(work_dir, "environment_maisi_diff_model_train.json")
with open(env_config_filepath, "w") as f:
    json.dump(env_config_out, f, sort_keys=True, indent=4)

# Update model configuration for demo
model_config_out["autoencoder_def"]["num_splits"] = 4

model_config_filepath = os.path.join(work_dir, "config_maisi_diff_model_train.json")
with open(model_config_filepath, "w") as f:
    json.dump(model_config_out, f, sort_keys=True, indent=4)

# Print files and folders under work_dir
print(os.listdir(work_dir))

In [None]:
# Step 1: Create Training Data
print("Creating training data...")

# Define the arguments for torchrun
num_nodes = 1
num_gpus = 2  # Adjust based on the number of GPUs you want to use
script = "scripts/diff_model_create_training_data.py"  # Replace with your script
script_args = [
    "--env_config", env_config_filepath,
    "--model_config", model_config_filepath
]

# Build the torchrun command
torchrun_command = [
    "torchrun",
    "--nproc_per_node", str(num_gpus),
    "--nnodes", str(num_nodes),
    script
] + script_args

# Execute the command
result = subprocess.run(torchrun_command, capture_output=True, text=True)

# Print the output and errors
print("Output:\n", result.stdout)
print("Errors:\n", result.stderr)

In [None]:
# Step 2: Train the Model
print("Training the model...")
diff_model_train(env_config_path, model_config_path)

In [None]:
# Step 3: Infer using the Trained Model
print("Running inference...")
diff_model_infer(env_config_path, model_config_path, ckpt_filepath, amp)

print("Completed all steps.")