# Deploy Machine Learning Model using Gradio

![gradio-icon](https://github.com/gradio-app/gradio/raw/main/readme_files/gradio.svg)

[Gradio](https://gradio.app/) is a Python library that simplifies the process of creating customizable and interactive interfaces for machine learning models. It provides a user-friendly way to build web-based UIs, allowing users to input data, visualize results, and interact with models in real-time. With Gradio, you can create interfaces for a wide range of machine learning tasks, such as image classification, text generation, sentiment analysis, and more. It supports popular deep learning frameworks like TensorFlow and PyTorch, as well as other machine learning libraries.

Gradio handles the integration with the model and provides an easy-to-use API to link the interface components with the underlying machine learning code. It takes care of processing user inputs, passing them to the model, and displaying the results back to the user. The resulting interface is automatically hosted as a web application that can be accessed locally or deployed to a server for wider use. Gradio simplifies the process of sharing and showcasing machine learning models, making it easier for users to understand and interact with the underlying AI technology.

Once you understand with Gradio, you could deploy and share your own app from a cloud hosting or use [HuggingFace Space](https://huggingface.co/spaces)!

Read the full [documentation](https://gradio.app/docs/) and [examples](https://gradio.app/quickstart) on how to use Gradio with your app.

In [2]:
!pip install -q gradio

You should consider upgrading via the '/usr/bin/python -m pip install --upgrade pip' command.[0m


In [4]:
import os

ROOT_DIR = os.path.dirname(os.path.abspath(''))
PRETRAINED_MODEL = os.path.join(ROOT_DIR, 'pretrained/simple-lightning-epoch100/resnet18_epoch99.ckpt')


## Launching Simple Interface

Before we integrate our model to the system, we need to understand how it works. In the example below, we're building a simple application to classify image. The [`Interface`](https://gradio.app/docs/#interface) will take `image_classifier` as a `Callable` function that takes an input type of [`Image`](https://gradio.app/docs/#image) and will return an output type of [`Label`](https://gradio.app/docs/#label) generated from the `dict`.

You could also open the local URL to view it from a browser.

In [3]:
import gradio as gr

def image_classifier(inp):
    return {'cat': 0.3, 'dog': 0.7}

demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label")
demo.launch()

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




## Integrating the Model

### Loading Checkpoint

We've trained the ResNet18 model from the previous section using PyTorch Lightning. The checkpoint will contain all of the information from the training process, so we have to extract the weights for the model.

In [40]:
import torch
from src.models import ResNet18, BasicBlock

model = ResNet18(3, 10)

checkpoint = torch.load(PRETRAINED_MODEL)
# print(checkpoint.keys())

# The state dict will contains net.layer_name
# Our model doesn't contains `net.` so we have to rename it
state_dict = checkpoint['state_dict']
for key in list(state_dict.keys()):
    if 'net.' in key:
        state_dict[key.replace('net.', '')] = state_dict[key]
        del state_dict[key]

model.load_state_dict(state_dict)
model.eval()

class_names = ['apple_pie', 'bibimbap', 'cannoli', 'edamame', 'falafel', 'french_toast', 'ice_cream', 'ramen', 'sushi', 'tiramisu']
class_names.sort()

### Launching a Classification Model

Now we integrate the application with our model by adding the preprocessing, inference, and postprocessing steps.

In [45]:
import gradio as gr

import numpy as np
import cv2
from src.dataset import RGB_MEAN, RGB_STD, INPUT_SIZE
from torchvision.transforms import transforms

transformation_pipeline = transforms.Compose([
    transforms.ToPILImage(),
    transforms.CenterCrop(INPUT_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=RGB_MEAN, std=RGB_STD)
])


def preprocess_image(image: np.ndarray):
    """Preprocess the input image.

    Note that the input image is in RGB mode.

    Parameters
    ----------
    image: np.ndarray
        Input image from callback.
    """

    image = transformation_pipeline(image)
    image = torch.unsqueeze(image, 0)

    return image
    

def image_classifier(inp):
    """Image Classifier Function.

    Parameters
    ----------
    inp: Optional[np.ndarray] = None
        Input image from callback
    
    Returns
    -------
    Dict
        A dictionary class names and its probability
    """

    # If input not valid, return dummy data or raise error
    if inp is None:
        return {'cat': 0.3, 'dog': 0.7}

    # preprocess
    image = preprocess_image(inp)
    image = image.to(dtype=torch.float32)

    # inference
    result = model(image)

    # postprocess
    result = torch.nn.functional.softmax(result, dim=1) # apply softmax
    result = result[0].detach().numpy().tolist() # take the first batch
    labeled_result = {name:score for name, score in zip(class_names, result)}

    return labeled_result

demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label")
demo.launch()

Running on local URL:  http://127.0.0.1:7886

To create a public link, set `share=True` in `launch()`.




## Custom Gradio App Block

[Blocks](https://gradio.app/docs/#blocks) is Gradio's low-level API that allows you to create more custom web applications and demos than Interfaces (yet still entirely in Python).

Let's build the same pipeline as before but using `Block` instead of an `Interface`. This application also contains example of images to quickly try the app without uploading an image.

In [58]:
import os
sample_files = os.listdir('samples')
sample_files = [os.path.join('samples', path) for path in sample_files]

In [63]:
import gradio as gr
def update(name):
    return f"Welcome to Gradio, {name}!"

with gr.Blocks() as demo:
    with gr.Row(): # Build a row
        with gr.Column(): # build a column section as the first item
            inp = gr.Image(label="image", image_mode="RGB") # build an image as the first column item
            with gr.Row(): # build a row section as the second item
                clear_btn = gr.Button("Clear")
                submit_btn = gr.Button("Submit")

        # build a label as the second item
        out = gr.Label(label="prediction", num_top_classes=3)

        # Define buttons functionalities
        submit_btn.click(fn=image_classifier, inputs=inp, outputs=out)
        clear_btn.click(
            lambda: (
                gr.update(value=None),
                gr.update(value=None),
            ),
            inputs=None,
            outputs=[inp, out]
        )

    # Add examples
    gr.Markdown("## Image Examples")
    gr.Examples(sample_files, inputs=[inp], label="Image Examples")

demo.launch()


Running on local URL:  http://127.0.0.1:7901

To create a public link, set `share=True` in `launch()`.


