### Notebook to demonstrate Image Classification workflow

Transfer learning is the process of transferring learned features from one application to another. It is a commonly used training technique where you use a model trained on one task and re-train to use it on a different task. Train Adapt Optimize (TAO) Toolkit  is a simple and easy-to-use Python based AI toolkit for taking purpose-built AI models and customizing them with users' own data.

![image](https://d29g4g2dyqv443.cloudfront.net/sites/default/files/akamai/TAO/tlt-tao-toolkit-bring-your-own-model-diagram.png)

### Sample prediction for an Image Classification model
<img align="center" src="../example_images/sample_image_classification.jpg">

### The workflow in a nutshell

- Pulling datasets from cloud
- Getting a PTM from NGC
- Model Actions
    - Train (Normal/AutoML)
    - Evaluate
    - Prune, retrain
    - Export
    - TAO-Deploy
    - Inference on TAO, TRT
    - Delete experiments/dataset

### Table of contents

1. [FIXME's](#head-1)
1. [Login](#head-2)
1. [Set dataset formats](#head-3)
1. [Create and pull train dataset](#head-4)
1. [Create and pull val dataset](#head-5)
1. [Create and pull test dataset](#head-6)
1. [List the created datasets](#head-7)
1. [Create an experiment](#head-8)
1. [List experiments](#head-9)
1. [Assign train, eval datasets](#head-10)
1. [Assign PTM](#head-11)
1. [Actions](#head-14)
1. [Train](#head-14)
1. [Delete experiment](#head-21)
1. [Delete dataset](#head-22)

### Requirements
Please find the server requirements [here](https://docs.nvidia.com/tao/tao-toolkit/text/tao_toolkit_api/api_setup.html#)

In [None]:
import json
import os
import requests
import time
from IPython.display import clear_output
import glob
from remove_corrupted_images import remove_corrupted_images_workflow

### To see the dataset folder structure required for the models supported in this notebook, visit the notebooks under dataset_prepare like for [this notebook](../dataset_prepare/classification.ipynb)

### FIXME's <a class="anchor" id="head-1"></a>

1. Assign a model_name in FIXME 1
1. Assign the functionID of the helm chart function in FIXME 2
1. Assign the versionID of the helm chart function in FIXME 3
1. Assign the ngc_key variable in FIXME 4
1. Assign the ngc_org_name variable in FIXME 5
1. Set cloud storage details in FIXME 6
1. Assign path of datasets relative to the bucket in FIXME 7

#### Choose a classification model

In [None]:
# Define model_name workspaces and other variables
# Available models (#FIXME 1):
# 1. classification_pyt - https://docs.nvidia.com/tao/tao-toolkit/text/image_classification.html
# 2. classification_tf2 - https://docs.nvidia.com/tao/tao-toolkit/text/image_classification_tf2.html

model_name = "classification_pyt" # FIXME1 (Add the model name from the above mentioned list)

#### Set API service's host information

In [None]:
functionId = "9c252c9c-6559-4b16-b464-cbc87fc4ab7a" # FIXME2
version_id = "4d0faf19-1443-42b2-8abd-40ab8297ef8a" # FIXME3

#### Set NGC Personal key for authentication and NGC org to access API services

In [None]:
ngc_key = "" # FIXME4 example: (Add NGC Personal key)

In [None]:
ngc_org_name = "ea-tlt" # FIXME5 your NGC ORG

In [None]:
# Invoke NVCF helm chart deployment
def invoke_function(request_body):
    url = f"https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions/{functionId}/versions/{version_id}"

    headers = {
        'accept': 'application/json',
        'Content-Type': 'application/json',
        "Authorization": f"Bearer {ngc_key}",
    }

    response = requests.post(url, headers=headers, json=request_body)

    if response.ok:
        return response
    else:
        print("Request failed.")
        print("Response status code:", response.status_code)
        print("Response status code:", response.text)
        print("Response content:", response.json())


### Login <a class="anchor" id="head-2"></a>

In [None]:
# Validate NGC_PERSONAL_KEY
login_metadata = {"ngc_org_name": ngc_org_name,
                   "ngc_key": ngc_key}
super_data = {
    "api_endpoint": "login",
    "request_body": json.dumps(login_metadata),
    "ngc_key": ngc_key
}
response = invoke_function(super_data)
print(response)
print(response.json())
assert response.status_code in (200, 201)
assert "token" in response.json().keys()
token = response.json()["token"]
print("JWT",token)

### Get NVCF gpu details <a class="anchor" id="head-2"></a>

 One of the keys of the response json are to be used as platform_id when you run each job

In [None]:
# # Valid only for NVCF backend during TAO-API helm deployment currently
# endpoint = f"{base_url}:gpu_types"
# response = requests.get(super_endpoint, headers=headers)

# assert response.ok
# print(response)
# print((json.dumps(response.json(), indent=4)))

### Create cloud workspace
This workspace will be the place where your datasets reside and your results of TAO API jobs will be pushed to.

If you want to have different workspaces for dataset and experiment, duplocate the workspace creation part and adjust the metadata accordingly.

In [None]:
#FIXME7 Dataset Cloud bucket details to download dataset for experiments (Can be read only)
cloud_metadata = {
    "name": "AWS workspace info",  # A Representative name for this cloud info
    "cloud_type": "aws",  # If it's AWS, HuggingFace or Azure
    "cloud_specific_details": {
        "cloud_region": "us-west-1",
        "cloud_bucket_name": "",  # FIXME 6
        "access_key": "", # FIXME 6
        "secret_key": "", # FIXME 6
    }
}

In [None]:
super_data = {
    "api_endpoint": "create",
    "kind": "workspaces",
    "request_body": json.dumps(cloud_metadata), 
    "ngc_key": ngc_key
}

response = invoke_function(super_data)
print(response)
print(json.dumps(response.json(), indent=4))

assert response.status_code in (200, 201)
assert "id" in response.json().keys()
workspace_id = response.json()["id"]

#### Set dataset path (path within cloud bucket)

In [None]:
# FIXME7 : Set paths relative to cloud bucket
train_dataset_path =  "/data/classification_train"
eval_dataset_path = "/data/classification_val"
test_dataset_path = "/data/classification_test"

### Set dataset formats <a class="anchor" id="head-3"></a>

In [None]:
# Create train dataset
ds_type = "image_classification"
ds_format = model_name

### Create and pull train dataset <a class="anchor" id="head-4"></a>

In [None]:
train_dataset_metadata = {"name":"Train dataset",
                          "description":"My train dataset",
                          "type":ds_type,
                          "format":ds_format,
                          "workspace":workspace_id,
                          "cloud_file_path": train_dataset_path,
                          "use_for": ["training"]
                         }
super_data = {
    "api_endpoint": "create",
    "kind": "datasets",
    "request_body": json.dumps(train_dataset_metadata),
    "ngc_key": ngc_key,
}
response = invoke_function(super_data)
print(response)
print(json.dumps(response.json(), indent=4))

assert response.status_code in (200, 201)
assert "id" in response.json().keys()
train_dataset_id = response.json()["id"]

In [None]:
# Check progress

while True:
    clear_output(wait=True)
    super_data = {
        "api_endpoint": "retrieve",
        "kind": "datasets",
        "handler_id": train_dataset_id,
        "ngc_key": ngc_key,
    }
    response = invoke_function(super_data)
    assert response.status_code in (200, 201)

    print(json.dumps(response.json(), indent=4))
    if response.json().get("status") == "invalid_pull":
        raise ValueError("Dataset pull failed")
    if response.json().get("status") == "pull_complete":
        break
    time.sleep(5)

#### Uncomment if you want to remove corrupted images in your dataset

In [None]:
# # This packages data-services experiments create and running the job of removing corrupted images
# try:
#     from remove_corrupted_images import remove_corrupted_images_workflow
#     train_dataset_id = remove_corrupted_images_workflow(base_url, headers, workspace_id, train_dataset_id)
# except Exception as e:
#     raise e

### Create and pull val dataset <a class="anchor" id="head-5"></a>

In [None]:
val_dataset_metadata = {"name":"Val dataset",
                        "description":"My val dataset",
                        "type":ds_type,
                        "format":ds_format,
                        "workspace":workspace_id,
                        "cloud_file_path": eval_dataset_path,
                        "use_for": ["evaluation"]
                   }
super_data = {
    "api_endpoint": "create",
    "kind": "datasets",
    "request_body": json.dumps(val_dataset_metadata),
    "ngc_key": ngc_key,
}
response = invoke_function(super_data)
print(response)
print(json.dumps(response.json(), indent=4))

assert response.status_code in (200, 201)
assert "id" in response.json().keys()
eval_dataset_id = response.json()["id"]

In [None]:
# Check progress

while True:
    clear_output(wait=True)
    super_data = {
        "api_endpoint": "retrieve",
        "kind": "datasets",
        "handler_id": eval_dataset_id,
        "ngc_key": ngc_key,
    }
    response = invoke_function(super_data)
    assert response.status_code in (200, 201)

    print(json.dumps(response.json(), indent=4))
    if response.json().get("status") == "invalid_pull":
        raise ValueError("Dataset pull failed")
    if response.json().get("status") == "pull_complete":
        break
    time.sleep(5)

#### Uncomment if you want to remove corrupted images in your dataset

In [None]:
# # This packages data-services experiments create and running the job of removing corrupted images
# try:
#     from remove_corrupted_images import remove_corrupted_images_workflow
#     eval_dataset_id = remove_corrupted_images_workflow(base_url, headers, workspace_id, eval_dataset_id)
# except Exception as e:
#     raise e

### Create and pull test dataset <a class="anchor" id="head-6"></a>

In [None]:
 # Create testing dataset for inference
ds_type = "image_classification"
ds_format = "raw"
test_dataset_metadata = {"name":"Test dataset",
                        "description":"My test dataset",
                        "type":ds_type,
                        "format":ds_format,
                        "workspace":workspace_id,
                        "cloud_file_path": test_dataset_path,
                        "use_for": ["testing"]
                        }
super_data = {
    "api_endpoint": "create",
    "kind": "datasets",
    "request_body": json.dumps(test_dataset_metadata),
    "ngc_key": ngc_key,
}
response = invoke_function(super_data)
print(response)
print(json.dumps(response.json(), indent=4))

assert response.status_code in (200, 201)
assert "id" in response.json().keys()
test_dataset_id = response.json()["id"]

In [None]:
# Check progress

while True:
    clear_output(wait=True)
    super_data = {
        "api_endpoint": "retrieve",
        "kind": "datasets",
        "handler_id": test_dataset_id,
        "ngc_key": ngc_key,
    }
    response = invoke_function(super_data)
    assert response.status_code in (200, 201)

    print(json.dumps(response.json(), indent=4))
    if response.json().get("status") == "invalid_pull":
        raise ValueError("Dataset pull failed")
    if response.json().get("status") == "pull_complete":
        break
    time.sleep(5)

#### Uncomment if you want to remove corrupted images in your dataset

In [None]:
# # This packages data-services experiments create and running the job of removing corrupted images
# try:
#     from remove_corrupted_images import remove_corrupted_images_workflow
#     test_dataset_id = remove_corrupted_images_workflow(base_url, headers, workspace_id, test_dataset_id)
# except Exception as e:
#     raise e

### List the created datasets <a class="anchor" id="head-7"></a>

In [None]:
super_data = {
    "api_endpoint": "retrieve",
    "kind": "datasets",
    "ngc_key": ngc_key,
}
response = invoke_function(super_data)
print(response)
assert response.status_code in (200, 201)

# print(json.dumps(response.json(), indent=4)) ## Uncomment for verbose list output
print("id\t\t\t\t\t type\t\t\t format\t\t name")
for rsp in response.json()["datasets"]:
    rsp_keys = rsp.keys()
    assert "id" in rsp_keys
    assert "type" in rsp_keys
    assert "format" in rsp_keys
    assert "name" in rsp_keys
    print(rsp["id"],"\t",rsp["type"],"\t",rsp["format"],"\t\t",rsp["name"])

### Create an experiment <a class="anchor" id="head-8"></a>

In [None]:
encode_key = "nvidia_tlt"
checkpoint_choose_method = "best_model"

experiment_metadata = {"network_arch":model_name,
                       "encryption_key":encode_key,
                       "checkpoint_choose_method":checkpoint_choose_method,
                       "workspace": workspace_id}
super_data = {
    "api_endpoint": "create",
    "kind": "experiments",
    "request_body": json.dumps(experiment_metadata),
    "ngc_key": ngc_key,    
}
response = invoke_function(super_data)
assert response.status_code in (200, 201)

print(response)
print(json.dumps(response.json(), indent=4))
assert "id" in response.json().keys()
experiment_id = response.json()["id"]

### List experiments <a class="anchor" id="head-9"></a>

In [None]:
params = {"network_arch": model_name}
super_data = {
    "api_endpoint": "retrieve",
    "kind": "experiments",
    "request_body": params,
    "ngc_key": ngc_key,
}
response = invoke_function(super_data)

print(response)
assert response.status_code in (200, 201)
# print(json.dumps(response.json(), indent=4)) ## Uncomment for verbose list output
print("model id\t\t\t     network architecture")
for rsp in response.json()["experiments"]:
    rsp_keys = rsp.keys()
    assert "id" in rsp_keys and "network_arch" in rsp_keys
    print(rsp["name"], rsp["id"],rsp["network_arch"])

### Assign train, eval datasets <a class="anchor" id="head-10"></a>

In [None]:
dataset_information = {"train_datasets":[train_dataset_id],
                       "eval_dataset":eval_dataset_id,
                       "inference_dataset":test_dataset_id,
                       "calibration_dataset":train_dataset_id}
super_data = {
    "api_endpoint": "partial_update",
    "kind": "experiments",
    "handler_id": experiment_id,
    "request_body": json.dumps(dataset_information),
    "ngc_key": ngc_key,
}
response = invoke_function(super_data)
assert response.status_code in (200, 201)

print(response)
print(json.dumps(response.json(), indent=4))

### Assign PTM <a class="anchor" id="head-11"></a>

Search for the PTM on NGC for the Classification model chosen

In [None]:
# List all pretrained models for the chosen network architecture
params = {"network_arch": model_name}
super_data = {
    "api_endpoint": "retrieve",
    "kind": "experiments",
    "request_body": params,
    "is_base_experiment": True,
    "ngc_key": ngc_key,
}
response = invoke_function(super_data)
assert response.status_code in (200, 201)

response_json = response.json()["experiments"]

for rsp in response_json:
    rsp_keys = rsp.keys()
    if "encryption_key" not in rsp.keys():
        assert "name" in rsp_keys and "version" in rsp_keys and "ngc_path" in rsp_keys
        print(f'PTM Name: {rsp["name"]}; PTM version: {rsp["version"]}; NGC PATH: {rsp["ngc_path"]}')

In [None]:
# Assigning pretrained models to different classification models
# From the output of previous cell make the appropriate changes to this map if you want to change the default PTM backbone.
# Changing the default backbone here requires changing default spec/config during train/eval etc like for example
# If you are changing the ptm to resnet34, then you have to modify the config key num_layers if it exists to 34 manually
pretrained_map = {"classification_tf2" : "pretrained_classification_tf2:efficientnet_b0",
                  "classification_pyt" : "pretrained_fan_classification_imagenet:fan_hybrid_tiny",
                  }
no_ptm_models = set([])

In [None]:
# Get pretrained model for classification
if model_name not in no_ptm_models:
    params = {"network_arch": model_name}
    super_data = {
        "api_endpoint": "retrieve",
        "kind": "experiments",
        "request_body": params,
        "is_base_experiment": True,
        "ngc_key": ngc_key,
    }
    response = invoke_function(super_data)
    assert response.status_code in (200, 201)

    response_json = response.json()["experiments"]

    # Search for ptm with given ngc path
    ptm = []
    for rsp in response_json:
        assert "ngc_path" in rsp_keys
        if rsp["ngc_path"].endswith(pretrained_map[model_name]):
            assert "id" in rsp_keys
            ptm_id = rsp["id"]
            ptm = [ptm_id]
            print("Metadata for model with requested NGC Path")
            print(rsp)
            break

In [None]:
if model_name not in no_ptm_models:
    ptm_information = {"base_experiment":ptm}
    super_data = {
        "api_endpoint": "partial_update",
        "kind": "experiments",
        "handler_id": experiment_id,
        "request_body": json.dumps(ptm_information),
        "ngc_key": ngc_key,
    }
    
    response = invoke_function(super_data)

    assert response.status_code in (200, 201)
    print(response)
    print(json.dumps(response.json(), indent=4))

### Actions <a class="anchor" id="head-13"></a>

For all actions:
1. Get default spec schema and derive the default values
2. Modify defaults if needed
3. Post spec dictionary to the service
4. Run model action
5. Monitor job using retrieve
6. Download results using job download endpoint (if needed)

In [None]:
job_map = {}

### Train <a class="anchor" id="head-14"></a>

In [None]:
# Get default spec schema

while True:
    super_data = {
        "api_endpoint": "specs_schema",
        "kind": "experiments",
        "handler_id": experiment_id,
        "action": "train",
        "ngc_key": ngc_key,
    }
    response = invoke_function(super_data)
    if response.status_code == 404:
        if "Base spec file download state is " in response.json()["error_desc"]:
            print("Base experiment spec file is being downloaded")
            time.sleep(2)
            continue
        else:
            break
    else:
        break
assert response.status_code in (200, 201)

print(response)
train_specs = response.json()["default"]
print(json.dumps(train_specs, sort_keys=True, indent=4))

In [None]:
# Override any of the parameters listed in the previous cell as required
# Example for classification_pyt
if model_name == "classification_pyt":
    train_specs["train"]["train_config"]["runner"]["max_epochs"] = 10
    train_specs["train"]["num_gpus"] = 1
# Example for classification_tf2
elif model_name == "classification_tf2":
    train_specs["train"]["num_epochs"] = 80

print(json.dumps(train_specs, sort_keys=True, indent=4))

In [None]:
# Run action
parent = None
action = "train"
train_request_body = {"parent_job_id":parent,"action":action,"specs":train_specs,
        "platform_id": '9af1aa90-8ea5-5a11-98d9-3879cd0da92c',  # Pick a platform_from output of {base_url}:gpu_types depending on GPU_type and instance_type
        }
super_data = {
    "api_endpoint": "job_run",
    "kind": "experiments",
    "handler_id": experiment_id,
    "action": action,
    "request_body": json.dumps(train_request_body),
    "ngc_key": ngc_key,
}
	
response = invoke_function(super_data)
assert response.status_code in (200, 201)
assert response.json()

print(response)
print(json.dumps(response.json(), indent=4))

job_map["train_" + model_name] = response.json()
print(job_map)

In [None]:
# Monitor job status by repeatedly running this cell
# For automl: Training times for different models benchmarked on 1 GPU V100 machine can be found here: https://docs.nvidia.com/tao/tao-toolkit/text/automl/automl.html#results-of-automl-experiments

job_id = job_map["train_" + model_name]

while True:    
    clear_output(wait=True)
    super_data = {
        "api_endpoint": "retrieve",
        "kind": "experiments",
        "handler_id": experiment_id,
        "is_job": True,
        "job_id": job_id,
        "ngc_key": ngc_key,
    }
    response = invoke_function(super_data)

    if "error_desc" in response.json().keys() and response.json()["error_desc"] in ("Job trying to retrieve not found", "No AutoML run found"):
        print("Job is being created")
        time.sleep(5)
        continue
    print(response)
    print(json.dumps(response.json(), sort_keys=True, indent=4))
    assert "status" in response.json().keys() and response.json().get("status") != "Error"
    if response.json().get("status") in ["Done","Error", "Canceled", "Paused"] or response.status_code not in (200,201):
        break
    time.sleep(15)

### Delete experiment <a class="anchor" id="head-21"></a>

In [None]:
super_data = {
    "api_endpoint": "delete",
    "kind": "experiments",
    "handler_id": experiment_id,
    "ngc_key": ngc_key,
}
	
response = invoke_function(super_data)
assert response.status_code in (200, 201)

print(response)
print(json.dumps(response.json(), indent=4))

### Delete dataset <a class="anchor" id="head-22"></a>

#### Delete train dataset

In [None]:
super_data = {
    "api_endpoint": "delete",
    "kind": "datasets",
    "handler_id": train_dataset_id,
    "ngc_key": ngc_key,
}
	
response = invoke_function(super_data)
assert response.status_code in (200, 201)

print(response)
print(json.dumps(response.json(), indent=4))

#### Delete val dataset

In [None]:
super_data = {
    "api_endpoint": "delete",
    "kind": "datasets",
    "handler_id": eval_dataset_id,
    "ngc_key": ngc_key,
}
	
response = invoke_function(super_data)
assert response.status_code in (200, 201)

print(response)
print(json.dumps(response.json(), indent=4))

#### Delete test dataset

In [None]:
super_data = {
    "api_endpoint": "delete",
    "kind": "datasets",
    "handler_id": test_dataset_id,
    "ngc_key": ngc_key,
}
	
response = invoke_function(super_data)
assert response.status_code in (200, 201)

print(response)
print(json.dumps(response.json(), indent=4))