# Demo 4: Accelerating ML inference using Triton inference servers

In this demo we show how analysis workflows can be accelerated by outsourcing the ML inference to Triton servers with GPUs.

## 1. Loading events

In [None]:
import numpy as np
import pandas as pd
import torch

from submodule.event_selection import load_events
from submodule.dnn_model import NeuralNet

sources = ["data", "ttbar", "dy"]
server = "file:/depot/cms/purdue-af/demos/"
model_dir = "/depot/cms/purdue-af/demos/"
dfs = {}

features = ['mu1_pt', 'mu1_eta', 'mu2_pt', 'mu2_eta', 'dimuon_mass', 'met']

# load datasets for inference
for src in sources:
    dfs[src] = load_events(f"{server}/{src}.root")[features]


## 2. Outsourcing ML inference to remote GPUs via Triton servers
Machine learning inference is known to run much faster on GPUs as compared to CPUs. However, computing clusters are usually limited in number of GPUs, therefore it is not possible to ensure full access to GPUs for all users at all times.

An approach allowing to use the power of GPUs to accelerate inference without blocking the GPU nodes is to use dedicated inference servers which are always connected to GPUs.

In order to be able to evaluate a model via a Triton server, the model has to be saved in a special way: [see example how to do that in PyTorch](https://medium.com/@furcifer/deploying-triton-inference-server-in-5-minutes-67aa09a84ca6).

The saved models must be put into a repository visible to the Triton server(s). At the moment, we host the repository at Purdue shared project storage (Depot): `/depot/cms/purdue-af/triton/models/`. In the future, a repository with write access for non-Purdue users will be set up as well.

At the moment, we provide several Triton servers corresponding to different GPUs / GPU partitions. To select a particular server, simply uncomment the corresponding address:

In [None]:
triton_address = 'triton-10gb.cms.geddes.rcac.purdue.edu:8001'
# Triton load balancer running at the partition of A100 GPU with 10gb RAM

import tritonclient.grpc as grpcclient

print(f"Connecting to Triton inference sever at {triton_address}")

keepalive_options = grpcclient.KeepAliveOptions(
    keepalive_time_ms=2**31 - 1,
    keepalive_timeout_ms=20000,
    keepalive_permit_without_calls=False,
    http2_max_pings_without_data=2
)

# Create Triton client
try:
    triton_client = grpcclient.InferenceServerClient(
        url=triton_address,
        verbose=False,
        keepalive_options=keepalive_options
    )
except Exception as e:
    print("Channel creation failed: " + str(e))
    sys.exit()

In [None]:
def inference_triton(inp):
    label= inp[0]
    df = inp[1]
    
    # Inputs and outputs should be compatible with model metadata
    # stored in /depot/cms/purdue-af/triton/models/test-model/config.pbtxt
    inputs = [grpcclient.InferInput('INPUT__0', df.shape, "FP64")]
    outputs = [grpcclient.InferRequestedOutput('OUTPUT__0')]
    
    # Load input data
    inputs[0].set_data_from_numpy(df.values)
    
    # Run inference on Triton server.
    # Models are stored in /depot/cms/purdue-af/triton/models/
    results = triton_client.infer(
        model_name="test-model",
        inputs=inputs,
        outputs=outputs,
        headers={'test': '1'},
    )
    scores = results.as_numpy('OUTPUT__0').flatten()
    
    # Save DNN outputs to a file
    save_path = f"/depot/cms/users/dkondra/dnn_outputs_triton/{label}.npy"
    np.save(save_path, scores, allow_pickle=True)
    print(label, scores)


print("\nDatasets:", list(dfs.keys()))
results = []
for label, df in dfs.items():
    inference_triton([label, df])

## Plotting DNN outputs
Run this cell after either Dask parallelization example or after Triton example to plot the DNN outputs (note that the models are different in these examples, so the outputs will not look the same). The models are generic and not meant to provide any physics meaning.

In [None]:
import matplotlib.pyplot as plt
bins = np.linspace(0, 1, 100)
plt.figure(figsize=(5,4))

dnn_outputs = {}

for src in sources:
    load_path = f"/depot/cms/users/dkondra/dnn_outputs_triton/{src}.npy"
    dnn_outputs[src] = np.load(load_path)

plt.hist(dnn_outputs["dy"], bins, alpha=0.3, label='dy', density=True)
plt.hist(dnn_outputs["ttbar"], bins, alpha=0.3, label='ttbar', density=True)
plt.hist(dnn_outputs["data"], bins, alpha=0.3, label='data', density=True)
plt.xlabel('DNN Score')
plt.ylabel('Events')
leg = plt.legend(loc='upper left')