# Medical Image Generation with Generative AI

This guide is designed to help you navigate the process of generating medical image on the NVIDIA DGX Cloud, focusing on leveraging the powerful capabilities of DGX systems for medical imaging applications.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/monai-cloud-api/blob/main/notebooks/Medical%20Image%20Generation%20with%20MAISI.ipynb)

## Table of Contents

- Introduction
- Setup
- Generative AI Experiment Creation
- Generating Medical Image
- Download the Job Log
- Download the Generated Medical Images
- Visualize the Generated Medical Images
- Conclusion

## Introduction

Synthetic medical image generation using generative AI has emerged as a powerful technique in the field of medical imaging. It allows researchers, healthcare professionals, and developers to generate realistic and high-fidelity medical images, such as CT scans, without the need for extensive data collection or patient involvement.

CT (Computed Tomography) scans are widely used in medical diagnosis, treatment planning, and research. They provide detailed cross-sectional images of the body, allowing healthcare professionals to visualize internal structures and identify abnormalities.

Traditionally, obtaining a large dataset of CT scans for research or training purposes can be challenging due to privacy concerns, limited access to patient data, and the time-consuming process of acquiring scans. Synthetic medical image generation addresses these challenges by leveraging generative AI models to generate synthetic CT scans that closely resemble real patient scans.

Generative AI models, such as generative adversarial networks (GANs) and variational autoencoders (VAEs), learn the underlying patterns and structures of real CT scans from a limited dataset. They then generate new CT scans that exhibit similar characteristics, including anatomical structures, tissue densities, and noise patterns.

In the end of this guide, you will be able to generate synthetic CT scans using generative AI models on the NVIDIA DGX Cloud. These synthetic CT scans can be used for a variety of applications, including medical imaging research, algorithm development, education, and training.

![image.png](attachment:image.png)

By generating synthetic CT scans, researchers and developers can:

- Augment limited datasets: Synthetic CT scans can be used to augment small or imbalanced datasets, improving the performance and generalization of machine learning models.
- Privacy-preserving research: Synthetic CT scans eliminate the need for accessing sensitive patient data, ensuring privacy compliance while enabling collaborative research.
- Simulation and testing: Synthetic CT scans can be used for simulating different clinical scenarios, testing algorithms, and evaluating the performance of medical imaging systems.
- Education and training: Synthetic CT scans provide a valuable resource for medical education and training, allowing students and healthcare professionals to practice interpreting and analyzing scans.

To get started, make sure you have generated your credentials by following the step-by-step guide on [Generating and Managing Your Credentials](./Generating%20and%20Managing%20Your%20Credentials.ipynb). These credentials will be required for accessing the NVIDIA DGX Cloud and running the generative AI experiments.

In this guide, we will explore the process of synthetic medical image generation using generative AI, specifically focusing on CT scans. We will cover the setup, training of generative AI models, and the generation of synthetic CT scans using NVIDIA DGX Cloud.

Let's embark on this exciting journey of synthetic medical image generation and unlock new possibilities in medical imaging research and applications!

## Setup

In [1]:
import json
import os
import requests
import time

from libcloud.storage.providers import get_driver
from libcloud.storage.types import Provider
import nibabel as nib
import matplotlib.pyplot as plt

#### Required Parameters

In [2]:
# API Endpoint and Credentials
host_url = "https://api.monai.ngc.nvidia.com"
ngc_api_key = os.environ.get("MONAI_API_KEY", "<YOUR_API_KEY>")  # we recommend using environment variables for API keys, but you can also hardcode them here

# The cloud storage type used in this notebook. Currently only support `aws` and `azure`.
cloud_type = "azure"  # cloud storage provider: aws or azure
cloud_account = "account_name"  # if cloud_type == "aws"  should be "access_key"
cloud_secret = "access_key"  # if cloud_type == "aws" should be "secret_key"

# Cloud storage credentials. Needed for storing the data and results of the experiments.
access_id = "<user name for the remote storage object>"  # Please fill it with the actual Access ID
access_secret = "<secret for the remote storage object>"  # Please fill it with the actual Access Secret
cs_bucket = "<bucket/container name to push the experiment job data to>"  # Please fill it with the actual bucket name

# Job configuration
timeout = 3600  # Time (in seconds) to wait for a job to be completed
num_inference_steps = None  # [advanced parameter] Number of inference steps to run. If None, it uses the default value of 1000.

#### Login into NGC and API Setup

In [None]:
# Exchange NGC_API_KEY for JWT
api_url = f"{host_url}/api/v1"
response = requests.post(f"{api_url}/login", json={"ngc_api_key": ngc_api_key})
response.raise_for_status()
assert "user_id" in response.json(), "user_id is not in response."
assert "token" in response.json(), "token is not in response."
user_id = response.json()["user_id"]
token = response.json()["token"]

# Construct the URL and Headers
ngc_org = "iasixjqzw1hj"
base_url = f"{api_url}/orgs/{ngc_org}"
headers = {"Authorization": f"Bearer {token}"}
print("API Calls will be forwarded to", base_url)

## Generative AI Experiment Creation


#### Find the base experiment for Generative AI

In [None]:
endpoint = f"{base_url}/experiments"
response = requests.get(endpoint, headers=headers)
assert response.status_code == 200, f"List experiment failed, got {response.json()}."
res = response.json()

gen_ai_base_exps = [p for p in res["experiments"] if p["network_arch"] == "monai_genai" and p["name"] == "MONAI GenerativeAI"]
assert len(gen_ai_base_exps) > 0, "No base experiment found for Generative AI"
print("List of available base experiments for MAISI:")
for exp in gen_ai_base_exps:
    print(f"  {exp['id']}: {exp['name']} v{exp['version']}")
base_experiment = sorted(gen_ai_base_exps, key=lambda x: x["version"])[-1]  # Take the latest version
version = base_experiment["version"]
base_exp_maisi = base_experiment["id"]
print("-----------------------------------------------------------------------------------------")
print(f"Base experiment ID for '{base_experiment['name']}' v{base_experiment['version']}: {base_exp_maisi}")
print("-----------------------------------------------------------------------------------------")

#### Create Generative AI Experiment

In [None]:
experiment_cloud_details = {
    "cloud_type": cloud_type,
    "cloud_file_type": "folder",  # If the file is tar.gz key in "file", else "folder"
    "cloud_specific_details": {
        "cloud_bucket_name": cs_bucket,  # Bucket link to upload results to
        cloud_account: access_id,  # Access and Secret for Azure
        cloud_secret: access_secret,  # Access and Secret for Azure
    }
}

data = {
    "name": "generative_ai_experiment",
    "description": "MONAI Generative AI Experiment",
    "type": "medical",
    "base_experiment": [base_exp_maisi],
    "network_arch": "monai_genai",
    "cloud_details": experiment_cloud_details,
}

endpoint = f"{base_url}/experiments"
response = requests.post(endpoint, json=data, headers=headers)
assert response.status_code == 201, f"Experiment creation failed, got {response.json()}."
res = response.json()
experiment_id = res["id"]
print("Experiment creation succeeded with experiment ID:", experiment_id)
print("--------------------------------------------------------------------------------------")
print(json.dumps(res, indent=2))
print("--------------------------------------------------------------------------------------")

## Generate Medical Images

In [None]:
supported_body_region = ["head", "chest", "thorax", "abdomen", "pelvis", "lower"]
supported_organs = ["liver", "kidney", "spleen", "pancreas", "right kidney", "aorta", "inferior vena cava", "right adrenal gland", "left adrenal gland", "gallbladder", "esophagus", "stomach", "duodenum", "left kidney", "bladder", "prostate or uterus", "portal vein and splenic vein", "rectum", "small bowel", "lung", "bone", "brain", "lung tumor", "pancreatic tumor", "hepatic vessel", "hepatic tumor", "colon cancer primaries", "left lung upper lobe", "left lung lower lobe", "right lung upper lobe", "right lung middle lobe", "right lung lower lobe", "vertebrae L5", "vertebrae L4", "vertebrae L3", "vertebrae L2", "vertebrae L1", "vertebrae T12", "vertebrae T11", "vertebrae T10", "vertebrae T9", "vertebrae T8", "vertebrae T7", "vertebrae T6", "vertebrae T5", "vertebrae T4", "vertebrae T3", "vertebrae T2", "vertebrae T1", "vertebrae C7", "vertebrae C6", "vertebrae C5", "vertebrae C4", "vertebrae C3", "vertebrae C2", "vertebrae C1", "trachea", "left iliac artery", "right iliac artery", "left iliac vena", "right iliac vena", "colon", "left rib 1", "left rib 2", "left rib 3", "left rib 4", "left rib 5", "left rib 6", "left rib 7", "left rib 8", "left rib 9", "left rib 10", "left rib 11", "left rib 12", "right rib 1", "right rib 2", "right rib 3", "right rib 4", "right rib 5", "right rib 6", "right rib 7", "right rib 8", "right rib 9", "right rib 10", "right rib 11", "right rib 12", "left humerus", "right humerus", "left scapula", "right scapula", "left clavicula", "right clavicula", "left femur", "right femur", "left hip", "right hip", "sacrum", "left gluteus maximus", "right gluteus maximus", "left gluteus medius", "right gluteus medius", "left gluteus minimus", "right gluteus minimus", "left autochthon", "right autochthon", "left iliopsoas", "right iliopsoas", "left atrial appendage", "brachiocephalic trunk", "left brachiocephalic vein", "right brachiocephalic vein", "left common carotid artery", "right common carotid artery", "costal cartilages", "heart", "left kidney cyst", "right kidney cyst", "prostate", "pulmonary vein", "skull", "spinal cord", "sternum", "left subclavian artery", "right subclavian artery", "superior vena cava", "thyroid gland", "vertebrae S1", "bone lesion", "kidney mass", "liver tumor", "vertebrae L6", "airway"]

In [None]:
data = {
    "action": "generate",
    "specs": {
        "num_output_samples": 1,                    # Number of output samples
        "body_region": ["chest"],                   # Body region (“please refer to the list above for the supported nody regions)
        "organ_list": ["liver"],                    # Organs (please refer to the list above for the supported organs)
        "num_inference_steps": num_inference_steps  # Number of inference steps (usually don't need to set)
    },
}
endpoint = f"{base_url}/experiments/{experiment_id}/jobs"
response = requests.post(endpoint, json=data, headers=headers)

assert response.status_code == 201, f"Create job failed, got {response.json()}."
job_id = response.json()
print(f"Job creation succeeded with job ID: {job_id}.")

## Monitoring Job Status

In [None]:
def wait_for_job(endpoint, headers, timeout=1800, interval=5, target_status="Done"):
    """Helper function to wait for job to reach target status."""
    expected = ["Pending", "Running", "Done"]
    assert target_status in expected, f"Invalid target status: {target_status}"
    status_before_target = expected[:expected.index(target_status)]
    start_time = time.time()
    print(f"Waiting for job to reach state {target_status} ...")
    status = None
    while True:
        response = requests.get(endpoint, headers=headers)
        response.raise_for_status()
        status_new = response.json()["status"].title()
        if time.time() - start_time > timeout:
            print(f"\nJob timeout after {timeout} seconds with last status {status_new}.")
            break
        elif status_new not in status_before_target:
            assert status_new == target_status, f"Job failed with status: {status_new}"
            print(f"\nJob reached target status: {status_new}")
            break
        print(f"\n{status_new}", end="", flush=True) if status_new != status else print(".", end="", flush=True)
        status = status_new
        time.sleep(interval)

In [None]:
endpoint = f"{base_url}/experiments/{experiment_id}/jobs/{job_id}"
response = requests.get(endpoint, headers=headers)

assert response.status_code == 200, f"Failed to get job status, got {response.json()}."
for k, v in response.json().items():
    if k != "result":
        print(f"{k}: {v}")
    else:
        print("result:")
        for k1, v1 in v.items():
            print(f"    {k1}: {v1}")

wait_for_job(endpoint, headers, timeout)

## Download the Job Log

Finally, when the jobs are completed, users should be able to check the inference results on the cloud storage. They can also download the job log to examine the outputs.

In [None]:
endpoint = f"{base_url}/experiments/{experiment_id}/jobs/{job_id}"
response = requests.get(endpoint, headers=headers)
assert response.status_code == 200, f"Failed to get job status, got {response.json()}."
status = response.json()["status"].title()
if status in ["Running", "Done", "Error"]:
    endpoint = f"{base_url}/experiments/{experiment_id}/jobs/{job_id}/logs"
    response = requests.get(endpoint, headers=headers)
    assert response.status_code == 200, f"Failed to get job logs, got {response.json()}."
    print(response.text)
else:
    print(f"Job status: {status}, logs are not available.")

## Download the Generated Medical Images

Download the generated medical images from the cloud storage to your local machine for further analysis, visualization, and integration into medical imaging applications.

In [None]:
folder = f"shared/orgs/{ngc_org}/users/{user_id}/jobs/{job_id}/generative_ai_v{version}/output"

if cloud_type == "aws":
    cs_driver = get_driver(Provider.S3)
elif cloud_type == "azure":
    cs_driver = get_driver(Provider.AZURE_BLOBS)

driver = cs_driver(access_id, access_secret, region="us-west-1")
container = driver.get_container(container_name=cs_bucket)

file_objects = driver.list_container_objects(container=container, ex_prefix=folder)
for obj in file_objects:
    local_destination = obj.name
    print("Downloading object: %s" % obj.name)
    obj.download(os.path.basename(obj.name), overwrite_existing=True)

## Visualize the Generated Medical Images

In [None]:
# find the downloaded file
image_file = sorted([f for f in os.listdir() if f.endswith("_image.nii.gz")])[0]
label_file = sorted([f for f in os.listdir() if f.endswith("_label.nii.gz")])[0]

# Plotting
slice_indices = [192, 208, 224, 240, 256, 272, 288, 304, 320]  # np.linspace(192, 320, 6), 2x3 grid of slices for each image/label.
fig, axes = plt.subplots(nrows=3, ncols=6, figsize=(10, 5))

for idx, slice_index in enumerate(slice_indices):
    for i, file in enumerate([image_file, label_file]):
        # Load the image and label files
        data = nib.load(file).get_fdata()
        axes[idx // 3, idx % 3 + i * 3].imshow(data[:, slice_index, :], cmap='gray' if i == 0 else 'viridis')
        axes[idx // 3, idx % 3 + i * 3].axis('off')  # Hide the axes ticks
    idx += 1
axes[0, 1].set_title("Generated Image Slices")
axes[0, 4].set_title("Generated Label Slices")
plt.tight_layout()
plt.show()

## Cleaning Up

Delete the experiment after all jobs are done.

In [None]:
endpoint = f"{base_url}/experiments/{experiment_id}/jobs/{job_id}"
response = requests.get(endpoint, headers=headers)
# If the job is not done, need to cancel it first
if response.json()["status"] != "Done":
    endpoint = f"{base_url}/experiments/{experiment_id}/jobs/{job_id}:cancel"
    response = requests.post(endpoint, headers=headers)
    assert response.status_code == 200, f"Cancel job failed, got {response.json()}."
    print(response)

endpoint = f"{base_url}/experiments/{experiment_id}"
response = requests.delete(endpoint, headers=headers)
assert response.status_code == 200, f"Delete experiment failed, got {response.json()}."
print(response)

## Conclusion

In this project, we explored the potential of Generative AI in the field of medical imaging. We implemented a generative model that can create new, synthetic medical images. This has vast implications for medical research and training, as it allows for the generation of large datasets without the need for patient involvement or the associated privacy concerns.

However, it's important to note that while the results are promising, the technology is not without its limitations and ethical considerations. The quality of the generated images is highly dependent on the quality and diversity of the training data. Additionally, care must be taken to ensure that the synthetic images do not misrepresent or oversimplify complex medical conditions.

In conclusion, Generative AI holds great promise in the field of medical imaging, offering a powerful tool for research, training, and potentially even diagnosis and treatment planning. However, as with any powerful tool, it must be used responsibly and ethically.

