# Training a MONAI Segmentation Bundle

This tutorial is designed to show how to train a typical MONAI segmentation bundle onthe NVIDIA DGX Cloud, focusing on leveraging the powerful capabilities of DGX systems for medical imaging applications. We will use a MONAI vista bundle to showcase this example.

## Table of Contents

1. [Login with NGC Key](#Login-with-NGC-Key)
2. [Datasets Creation](#Datasets-Creation)
3. [Experiment Creation](#Experiment-Creation)
4. [Monitoring Job Status](#Monitoring-Job-Status)
5. [Bundle Download](#Bundle-Download)
6. [Clean Up](#Clean-Up)

<a id='Setup'></a>

## Setup

In [None]:
import requests
import json
import time

In [None]:
# Provided the following parameters to start this notebook.
host_url = "<monai service API address>"
ngc_api_key = "<your ngc keys>"
# Object storage info
access_id = "<user name for the object storage>"
access_secret = "<secret for the object storage>"
train_manifest_url = "<train manifest url>"
val_manifest_url = "<validation manifest url>"
# training parameters
train_epochs = 10


## Login with NGC Key

In [None]:
# Exchange NGC_API_KEY for JWT
data = json.dumps({"ngc_api_key": ngc_api_key})
response = requests.post(f"{host_url}/api/v1/login", data=data)
print(response.status_code)
assert response.status_code == 201, f"Login failed, got status code: {response.status_code}."
assert "user_id" in response.json().keys(), "user_id is not in response."
user_id = response.json()["user_id"]
print("User ID",user_id)
assert "token" in response.json().keys(), "token is not in response."
token = response.json()["token"]
print("JWT",token)

# Set base URL
base_url = f"{host_url}/api/v1/users/{user_id}"
print("API Calls will be forwarded to",base_url)

headers = {"Authorization": f"Bearer {token}"}


## Datasets Creation

### **1. Remote Object as Data Sources**

MONAI Cloud platform supports a range of other cloud storage solutions, including Azure Blob Storage and Amazon S3, providing you with the flexibility to choose the service that best fits your project's needs. Below is an example of Azure:

**Steps:**
1. [Creating a Storage Account and Container](https://learn.microsoft.com/en-us/azure/storage/blobs/storage-blobs-introduction)
   - **Storage Account**: Start by creating a new storage account in your Azure portal. This account will host your blob storage containers.
   - **Container Creation**: Within your storage account, create a new container. This container will hold your datasets.

2. [Container URL](https://learn.microsoft.com/en-us/rest/api/storageservices/naming-and-referencing-containers--blobs--and-metadata)
   - Once the container is created, you will be provided with a unique URL that can be used to access it. This URL will be essential for accessing your data.

#### Obtaining Credentials

- **Access Keys**: Access your storage account and navigate to the [Access keys](https://learn.microsoft.com/en-us/azure/storage/common/storage-account-keys-manage?tabs=azure-portal) section. Here, you will find the necessary credentials to access your Blob Storage programmatically.
- **Shared Access Signature (SAS)**: Alternatively, you can create a SAS for more granular control over permissions and access duration.

#### Creating a Manifest JSON File

In the root of your Azure container, create a manifest JSON file to keep track of your datasets. The file format is as follows:

```json
{
    "root_path": "https://[your-storage-account-name].blob.core.windows.net/[your-container-name]",
    "data": [
        {
            "image": {
                "path": ["path/to/your/image_1"],
                "id": "unique-uuid-1"
            },
            "label": {
                "path": ["path/to/your/label_1"],
                "id": "unique-uuid-2"
            }
        },
        // Additional data objects follow the same format
    ]
}
````

- Each dataset (training, testing, etc.) should have their own root directory
- All the data should be under a root directory
- The root directory should contain a `manifest.json` file
- The `manifest.json` file should contain "data" field, which is a list of all the data entries
- Each data entry should contain "image" and "label" fields
- Each "image"/"label" field should contain "path" field, which is the list of relative path to the image/label files
- Please provide the "id" field of the "image"/"label", if there is not one please provide a random uuid  


After preparing your dataset, please modify the following variables in [Setup](#Setup):

```python
access_id = ...
access_secret = ...
train_manifest_url = ...
val_manifest_url = ...
```

### **2. Create the training dataset and the validation dataset**

In [None]:
# Training dataset
data = {
    "name": "MONAI_seg_train",
    "description":"Object storage dataset for training",
    "type": "semantic_segmentation",
    "format": "monai",
    "client_url": train_manifest_url,
    "client_id": access_id,
    "client_secret": access_secret,
}
data=json.dumps(data)

endpoint = f"{base_url}/datasets"
print(endpoint)
print(headers)
response = requests.post(endpoint, data=data, headers=headers)
print(response.json())

assert response.status_code == 201, f"Create train dataset failed, got {response.json()}."
res = response.json()
train_dataset_id = res["id"]
print("Train dataset creation succeeded with dataset ID:", train_dataset_id)
print("---------------------------------\n")
print(json.dumps(res, indent=2))

# Validation dataset
data = {
    "name": "MONAI_seg_val",
    "description":"Object storage dataset for validation",
    "type": "semantic_segmentation",
    "format": "monai",
    "client_url": val_manifest_url,
    "client_id": access_id,
    "client_secret": access_secret,
}
data=json.dumps(data)

endpoint = f"{base_url}/datasets"
print(endpoint)
print(headers)
response = requests.post(endpoint, data=data, headers=headers)
print(response.json())

assert response.status_code == 201, f"Create val dataset failed, got {response.json()}."
res = response.json()
val_dataset_id = res["id"]
print("Validation dataset creation succeeded with dataset ID:", val_dataset_id)
print("---------------------------------\n")
print(json.dumps(res, indent=2))

## Experiment Creation

Create an experiment based on a MONAI segmentation bundle. In this notebook, we will use the vista3d bundle.

### **1.List Available Base Experiments**

In [None]:
endpoint = f"{base_url}/experiments"
response = requests.get(endpoint, headers=headers)
assert response.status_code == 200, f"List Base Experiments failed, got {response.json()}."
res = response.json()

# VISTA-3D
ptm_vista = [p for p in res if p["network_arch"] == "monai_vista3d" and not len(p["base_experiment"])][0]["id"]
print(f"Base Experiment ID for VISTA Experiment: {ptm_vista}")

# DeepEdit
ptm_annotation = [p for p in res if p["network_arch"] == "monai_annotation" and not len(p["base_experiment"])][0]["id"]
print(f"Base Experiment ID for DeepEdit(Annotation) Experiment: {ptm_annotation}")


### **2.Create Experiment**

In [None]:
data = {
  "name": "my_vista",
  "description": "based on vista",
  "network_arch": "monai_vista3d",
  "type": "medical",
  "base_experiment": [ ptm_vista ],
  "eval_dataset": val_dataset_id,
  "train_datasets": [ train_dataset_id ],
}

endpoint = f"{base_url}/experiments"
response = requests.post(endpoint, json=data, headers=headers)
assert response.status_code == 201, f"Create experiment failed, got {response.json()}."
res = response.json()
experiment_id = res["id"]
model_network = res["network_arch"]
print("Experiment creation succeeded with experiment ID: ", experiment_id)
print("---------------------------------\n")
print(json.dumps(res, indent=2))


### **3.Run a DGX Train Job**

In [None]:
train_spec = {
    "train#trainer#max_epochs": train_epochs,
    "val_interval": 1,
}

data = {"action": "train", "specs": train_spec}
endpoint = f"{base_url}/experiments/{experiment_id}/jobs"
response = requests.post(endpoint, json=data, headers=headers)

assert response.status_code == 201, f"Run dgx train job failed, got {response.json()}."
job_id = response.json()
print("Job creation succeeded with job ID： ", job_id)


## Monitoring Job Status

In [None]:
# Helper functions for running jobs
def wait_for_job(endpoint, headers, timeout):
    start_time = time.time()
    response = requests.get(endpoint, headers=headers)
    assert response.status_code == 200, f"Failed to get job status, got {response.json()}."
    status = response.json()["status"].title()
    print("Waiting for job to complete...")
    print(status, end="", flush=True)
    while True:
        if status not in ["Pending", "Running"]:
            assert status == "Done", f"Job failed with status: {status}"
            break
        time.sleep(5)
        response = requests.get(endpoint, headers=headers)
        assert response.status_code == 200, f"Failed to get job status, got {response.json()}."
        status_new = response.json()["status"].title()
        if status_new != status:
            status = status_new
            print(f"\n{status}", end="", flush=True)
        else:
            print(".", end="", flush=True)
        if time.time() - start_time > timeout:
            assert False, f"Job timeout after {timeout} seconds."
    print("\nJob completed successfully!")

# During the Job is Running 
endpoint = f"{base_url}/experiments/{experiment_id}/jobs/{job_id}"
response = requests.get(endpoint, headers=headers)

assert response.status_code == 200, f"Failed to get job status, got {response.json()}."
for k, v in response.json().items():
    if k != "result":
        print(f"{k}: {v}")
    else:
        print("result:")
        for k1, v1 in v.items():
            print(f"    {k1}: {v1}")

print("------------------------------------------------------------------------")
wait_for_job(endpoint, headers, 600)

## Bundle Download

Download the trained bundle from the job.

In [None]:
endpoint = f"{base_url}/experiments/{experiment_id}/jobs/{job_id}:download"
response = requests.get(endpoint, data=json.dumps({"export_type": "monai_bundle"}), headers=headers)
assert response.status_code == 200, f"Failed to download bundle, got {response.json()}."
with open(f"{job_id}.tar.gz", "wb") as fp:
    fp.write(response.content)
print("Downloaded!")

## Cleaning Up

Delete the experiment after all jobs are done.

In [None]:
endpoint = f"{base_url}/experiments/{experiment_id}"
response = requests.delete(endpoint, headers=headers)
assert response.status_code == 200, f"Delete experiment failed, got {response.json()}."
print(response)

Delete datasets after the experiment is done.

In [None]:
# train dataset
endpoint = f"{base_url}/datasets/{train_dataset_id}"
response = requests.delete(endpoint, headers=headers)
assert response.status_code == 200, f"Delete train dataset failed, got {response.json()}."
print(response)

# validation dataset
endpoint = f"{base_url}/datasets/{val_dataset_id}"
response = requests.delete(endpoint, headers=headers)
assert response.status_code == 200, f"Delete val dataset failed, got {response.json()}."
print(response)