# Multimodal Models from NVIDIA AI Catalog and LangChain Agent 

## Prerequisites

To run this notebook, you need to [follow the steps from here](https://python.langchain.com/docs/integrations/text_embedding/nvidia_ai_endpoints#setup) and generate an API key from [NVIDIA API Catalog](https://build.nvidia.com/).

This notebook covers the following custom plug-in components:

- LLM using [**meta/llama-3.1-405b-instruct**](https://build.nvidia.com/meta/llama-3_1-405b-instruct)
    
- A NVIDIA AI Catalog [**Deplot**](https://build.nvidia.com/google/google-deplot) as one of the tool

- A NVIDIA AI Catalog [**Fuyu**](https://build.nvidia.com/adept/fuyu-8b) as one of the tool
    
- Gradio as the simply User Interface where we will upload a few images

At the end of the day, as below illustrated, we would like to have a UI which allow user to upload image of their choice and have the agent choose tools to do visual reasoning. 

![interactive UI](./data/imgs/visual_reasoning.png)    
Note: As one can see, since we are using NVIDIA AI Catalog as an API, there is no further requirement in the prerequisites about GPUs as compute hardware


In [None]:
# Install python packages.
!pip install gradio==3.48.0

## Step 1  - Export the NVIDIA_API_KEY
You can supply the NVIDIA_API_KEY directly in this notebook when you run the cell below

In [None]:
import getpass
import os

# del os.environ['NVIDIA_API_KEY']  ## delete key and reset
if os.environ.get("NVIDIA_API_KEY", "").startswith("nvapi-"):
    print("Valid NVIDIA_API_KEY already in environment. Delete to reset")
else:
    nvapi_key = getpass.getpass("NVAPI Key (starts with nvapi-): ")
    assert nvapi_key.startswith("nvapi-"), f"{nvapi_key[:5]}... is not a valid key"
    os.environ["NVIDIA_API_KEY"] = nvapi_key
global nvapi_key

## Step 2 - Wrap the Fuyu API call into a function and verify by supplying an image to get a respond

In [None]:
import requests
import base64, io
from PIL import Image
import requests, json

def img2base64_string(img_path):
    image = Image.open(img_path)
    if image.width > 800 or image.height > 800:
        image.thumbnail((800, 800))
    buffered = io.BytesIO()
    image.convert("RGB").save(buffered, format="JPEG", quality=85)
    image_base64 = base64.b64encode(buffered.getvalue()).decode()
    return image_base64
    
def fetch_outputs(output):    
    result=output['choices'][0]['message']['content'] 
    result = result.replace('\\','').replace('"','')
    return result

def ImageCaptionTool( img_path :str) -> str :
    """
    describe an image and return text
    Args:
        prompt : user input query
        img_path : path to image location
    """
    
    invoke_url = "https://ai.api.nvidia.com/v1/vlm/adept/fuyu-8b"
    
    image_b64 = img2base64_string(img_path)
    
    headers = {
      "Authorization": f"Bearer {nvapi_key}",
      "Accept": "application/json"
    }
    
    payload = {
      "messages": [
        {
          "role": "user",
          "content": f'describe this image <img src="data:image/png;base64,{image_b64}" />'
        }
      ],
      "max_tokens": 1024,
      "temperature": 0.20,
      "top_p": 0.70,
      "seed": 0,
      "stream": False
    }
    
    response = requests.post(invoke_url, headers=headers, json=payload)

    if response.status_code == 200 :    
        output=response.json()
        result = fetch_outputs(output)
    else:
        result = 'something went wrong, please try again !'        
    return result


Fetch a test image of a pair of white sneakers and verify the function works

In [None]:
!wget "https://docs.google.com/uc?export=download&id=12ZpBBFkYu-jzz1iz356U5kMikn4uN9ww" -O ./data/imgs/jordan.png

In [None]:
img_path="./data/imgs/jordan.png"

out=ImageCaptionTool(img_path)
out

## Step 3 - Wrap the Deplot API call into a function and verify by supplying an image to get a respond

In [None]:

def Tabular2TextTool(img_path : str) -> str :
    """
    understand tabular image and return text
    Args:
        img_path : path to image location
    """
    
    
    image_b64 = img2base64_string(img_path)
    
    invoke_url = "https://ai.api.nvidia.com/v1/vlm/google/deplot"
   
  
    headers = {
      "Authorization": f"Bearer {nvapi_key}",
      "Accept": "application/json"
    }
    
    payload = {
      "messages": [
        {
          "role": "user",
          "content": f'Generate underlying data table of the figure below: <img src="data:image/png;base64,{image_b64}" />'
        }
      ],
      "max_tokens": 1024,
      "temperature": 0.20,
      "top_p": 0.20,
      "stream": False
    }
    
    response = requests.post(invoke_url, headers=headers, json=payload)
    

    if response.status_code == 200 :    
        output=response.json()
        result = fetch_outputs(output)
    else:
        result = 'something went wrong, please try again !'        
    return result


In [None]:
!wget https://developer-blogs.nvidia.com/wp-content/uploads/2024/01/DePlot-bar-chart-example.png -O ./data/imgs/chart_example.png

In [None]:
img_path="./data/imgs/chart_example.png"

out=Tabular2TextTool(img_path=img_path)
out

---
## Step 4 - Construct the agent via [LCEL]() parallel chain

Let's review the below conceptual flow on how the **lcel_agent_chain** is constructed :

![parallel chain](./data/imgs/parallel_chains.png)


- We will use [meta/llama-3.1-405b-instruct](https://build.nvidia.com/meta/llama-3_1-405b-instruct) model as main LLM for the agent
- We will use _**with_structured_output**_ to format user input query and form **format_chain**
- We will use _**bind_tools**_ to bind ImageCaptionTool and TabularPlotTool as tools to our llm and form **tool_chain**
- Write an output parser that combine all the 3 branches and construct out agent via LCEL **lcel_agent_chain**


### Initiate [meta/llama-3.1-405b-instruct](https://build.nvidia.com/meta/llama-3_1-405b-instruct) as the main LLM 

In [None]:
# test run and see that you can genreate a respond successfully
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate


## use meta/llama-3.1-405b-instruct model as our main LLM
llm = ChatNVIDIA(model="meta/llama-3.1-405b-instruct", max_tokens=1024)


###  Make **format_chain** via __**with_structured_output**__ for formatting user input query

In [None]:
from langchain_core.pydantic_v1 import BaseModel, Field, validator
from langchain_core.runnables import RunnableParallel

## structural output  
class StructureOutput(BaseModel):     
    img_path: str = Field(description="path to the input img")

## use .with_structured_output to format input user query
llm_with_structured_output = llm.with_structured_output(StructureOutput)     
format_chain = ChatPromptTemplate.from_template("format the user input query : {input}") |llm_with_structured_output


In [None]:
out1=format_chain.invoke("describe this image, ./data/imgs/jordan.png")
out1

### Make **tool_chain** via __**bind_tools**__ to bind ImageCaptionTool and TabularPlotTool as tools to our llm

In [None]:
llm_with_tools=llm.bind_tools([ImageCaptionTool ,Tabular2TextTool])

In [None]:
tool_chain = ChatPromptTemplate.from_template("Select appropriate tool for the input user query : {input}") | llm_with_tools


In [None]:
out2=tool_chain.invoke("describe this image, ./data/imgs/jordan.png")
out2.tool_calls

### Conbined the 2 chains to form an **lcel_agent_chain **

In [None]:

def custom_output_parser(agent_output):    
    if len(agent_output['tools'].tool_calls) > 0 :
        tool_selected=agent_output['tools'].tool_calls[0]['name']
        tool_input = agent_output['format'].img_path
        if tool_selected=='ImageCaptionTool':
            output=ImageCaptionTool(img_path=tool_input)
        elif tool_selected=='Tabular2TextTool':
            output=Tabular2TextTool(img_path=tool_input)
        else:
            output=f"the selected tool :{tool_selected} does not exist in available tool, please check that the tools are binded correctly"
    else :
        output = f"No tool selected, please check that the tools are binded correctly"
    return output
    
lcel_agent_chain = RunnableParallel( format=format_chain, tools=tool_chain) | custom_output_parser


In [None]:
## test run
agent_output=lcel_agent_chain.invoke("describe this image, ./data/imgs/jordan.png")
agent_output

---
### Step 5 - Wrap the **lcel_agent_chain** into a python function to prepare for Gradio UI integration 

In [None]:
import os

def interface(img_path):
    if type(img_path) == None  :        
        output="Did you forgot to upload image?"
    else :
        output=lcel_agent_chain.invoke(f"Describe this image located here : {img_path}")
    
    print(output)
    return output

### Step 6 -  A simple gradio UI so we can interactively upload arbitrary image

In [None]:
import gradio as gr
ImageCaptionApp = gr.Interface(fn=interface ,
                    inputs=[ gr.Image(label="Upload image", type="filepath")],
                    outputs=[gr.Textbox(label="Agent Output")],
                    title="langchain LCEL agent",
                    description="combine langchain agent using tools for image reasoning",
                    allow_flagging="never")

ImageCaptionApp.launch(share=True)
