<a href="https://colab.research.google.com/github/colinmcnamara/austin_langchain/blob/main/labs/LangChain_105/105-streamlit_ollama_llava_auto1111.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install -q langchain streamlit

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.4/8.4 MB[0m [31m26.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m181.5/181.5 kB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.2/48.2 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.6/190.6 kB[0m [31m24.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m49.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m82.1/82.1 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.4/49.4 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━

In [1]:
%%writefile graph.py
import base64
import json
import operator
import os
import requests
from langchain_community.chat_models import ChatOllama
from langchain_community.llms import Ollama
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, FunctionMessage
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import tool
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolInvocation, ToolExecutor
from numpy import random
from typing import Optional, TypedDict, Annotated, Sequence, Dict

# default base urls for automatic 1111 and ollama
a_1111_base_url = "http://localhost:7860"
ollama_base_url = "http://localhost:11434"

# environment variable names
a_1111_env_key = "AUTOMATIC1111_HOST_URL"
ollama_env_key = "OLLAMA_HOST_URL"

# override values for automatic 1111 and ollama from
# environment variables if present
if a_1111_env_key in os.environ:
    a_1111_base_url = os.environ[a_1111_env_key]

if ollama_env_key in os.environ:
    ollama_base_url = os.environ[ollama_env_key]

# bakllava model for image to text
image_llm = Ollama(model="bakllava",
                   base_url=ollama_base_url,
                   num_predict=100)

# llama2 model for text chat
text_llm = ChatOllama(model="llama2", base_url=ollama_base_url)

# mistral 7b functioncall model for function calling
fc_llm = ChatOllama(model="klcoder/mistral-7b-functioncall",
                    format="json", num_predict=100)


# config class for automatic 1111 image generation parameters with defaults
class Config(BaseModel):
    prompt: str
    negative_prompt: str = ''
    sampler_name: str = 'DPM++ 2M Karras'
    checkpoint_name: str = 'dreamshaperXL_v21TurboDPMSDE'
    batch_size: int = 1
    steps: int = 20
    seed: Optional[int] = None
    cfg_scale: int = 7
    width: int = 512
    height: int = 512
    denoising_strength: float = 0.7
    enable_hr: bool = False
    hr_scale: int = 2
    hr_upscaler: str = '4xUltrasharp_4xUltrasharpV10'
    hr_sampler_name: str = 'DPM++ 2M Karras'
    send_images: bool = True
    save_images: bool = True


# argument schema for txt2image tool
class Txt2ImageInput(BaseModel):
    prompt: str = Field(
        description="MidJourney style prompt for image generation")


# txt2image tool
@tool("txt2image", args_schema=Txt2ImageInput)
def txt2image(prompt: str, **kwargs) -> Sequence[str]:
    """An image generation tool that takes in a prompt as string and returns a list of images encoded in base64 string. The prompt is transformed from simple English to a comma separate MidJourney image generation prompt."""
    config = Config(prompt=prompt, **kwargs)
    if config.seed is None:
        config.seed = int(random.normal(scale=2**32))
    response = requests.post(a_1111_base_url + "/sdapi/v1/txt2img",
                             json=config.dict()).json()
    return response


# argument schema for image2text tool
class Image2TxtInput(BaseModel):
    prompt: str = Field(description="Question regarding the image")
    image: str = Field(description="Base64 encoded image")


# image2txt tool
@tool("image2txt", args_schema=Image2TxtInput)
def image2txt(prompt: str, image: str) -> str:
    """An image description tool that takes in a question about an image or a picture as a prompt and returns the answer as string"""
    no_image_error = "No image available within context. Upload an image or generate using prompt to describe it." 
    try:
        if image is None or len(image) == 0:
            return no_image_error
        decoded = base64.b64decode(image)
        del decoded
    except Exception:
        return no_image_error
    bound = image_llm.bind(images=[image])
    response: str = bound.invoke(prompt)
    return response.strip()


# function to transform tool schema for function calling model
def tool_to_definition(tool):
    args = tool.args_schema.schema()
    args = {arg: args[arg] for arg in args if arg != 'title'}
    definition = {
        'name': tool.name,
        'description': tool.description.split(" - ")[-1],
        'parameters': args
    }
    return json.dumps(definition, indent=2)


tools = [image2txt, txt2image]
tool_descriptions = "\n\n".join([tool_to_definition(tool) for tool in tools])
tool_executor = ToolExecutor(tools)


# output format for function calling chain
class OutputFormat(BaseModel):
    function: Optional[str] = Field(description="Name of the function to call")
    arguments: Optional[Dict] = Field(
        description="Arguments or parameters to pass to the function")


# prompt template for function calling chain
fc_prompt = PromptTemplate.from_template("""SYSTEM: You are a helpful assistant with access to the following functions. Use them if required -
{tools}

The output needs to be in the following format:
{{
    'name': <function name>,
    'arguments': <arguments to pass to the function>
}}

For questions not related to image generation or not about an image, respond with an empty json object.

User: {question}
FUNCTION: """, partial_variables={"tools": tool_descriptions})

# function calling chain
fc_chain = fc_prompt | fc_llm | JsonOutputParser(pydantic_object=OutputFormat)

# prompt template for chat chain
text_prompt = PromptTemplate.from_template("""You are a helpful agent. 
Respond to user questions honestly and truthfully.

Human: {question}
AI: """)

# text chat chain
text_chain = text_prompt | text_llm | StrOutputParser()


# state definition for langgraph agent
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    image: Optional[str]


# node function to execute tools
def call_tool(state):
    messages = state['messages']
    last_message = messages[-1]

    if isinstance(last_message, HumanMessage):
        return {"messages": []}

    tool_input = last_message.additional_kwargs

    if "image" in state and state["image"] is not None:
        tool_input["image"] = state["image"]

    action = ToolInvocation(
        tool=last_message.name,
        tool_input=tool_input,
    )

    response = tool_executor.invoke(action)
    function_message = FunctionMessage(
        content=json.dumps(response), name=action.tool)
    return {"messages": [function_message]}


# node function to call function calling model
def call_fc_model(state):
    messages = state["messages"]
    last_message = messages[-1]

    response = fc_chain.invoke({"question": last_message.content})
    if 'name' in response and response['name'] in [tool.name for tool in tools]:
        args = response["arguments"]
        if "image" in state and state["image"] is not None:
            args["image"] = state["image"]
        return {"messages": [AIMessage(name=response["name"], content="function", additional_kwargs=args)]}
    else:
        return {"messages": [last_message]}


# node function for chat model
def call_model(state):
    messages = state["messages"]
    last_message = messages[-1]

    response = text_chain.invoke({"question": last_message.content})
    return {"messages": [AIMessage(content=response)]}


# conditional logic to determine which edge to take based on last message
def is_function_call(state):
    messages = state["messages"]
    last_message = messages[-1]

    if isinstance(last_message, HumanMessage):
        return "human"
    if isinstance(last_message, AIMessage) and last_message.content == "function":
        return "function"
    else:
        return "end"


# langgraph agent graph definition
workflow = StateGraph(AgentState)
workflow.add_node("functions", call_fc_model)
workflow.add_node("model", call_model)
workflow.add_node("tools", call_tool)
workflow.set_entry_point("functions")
workflow.add_conditional_edges(
    "functions",
    is_function_call,
    {
        "human": "model",
        "function": "tools",
        "end": END
    }
)
workflow.add_edge("tools", END)
workflow.add_edge("model", END)
app = workflow.compile()


Overwriting graph.py


In [2]:
%%writefile app.py
import base64
import json
import streamlit as st
from graph import app, ollama_base_url, a_1111_base_url
from langchain_core.messages import AIMessage, HumanMessage, FunctionMessage, BaseMessage
from pandas.io.common import BytesIO
from PIL import Image

st.set_page_config(page_title="LangChain with Automatic 1111 API")
st.title("LangChain with Automatic 1111 API")

if "messages" not in st.session_state:
    st.session_state["messages"] = [AIMessage(content="How can I help you?")]

if "uploaded_file" not in st.session_state:
    st.session_state["uploaded_file"] = None

if "image" not in st.session_state:
    st.session_state["image"] = None

for msg in st.session_state.messages:
    if "image" not in msg.additional_kwargs:
        st.chat_message(msg.type).write(msg.content)
    else:
        st.chat_message(msg.type).image(
            msg.additional_kwargs["image"], width=512
        )
        if "params" in msg.additional_kwargs:
            with st.chat_message(msg.type).expander("Parameters"):
                st.code(msg.additional_kwargs["params"])

state = {}
with st.sidebar:
    st.text(f"Ollama\n{ollama_base_url}")
    st.text(f"Automatic 1111\n{a_1111_base_url}")

if uploaded_file := st.sidebar.file_uploader("Upload an image file",
                                             type=["jpg", "png"]):
    if st.session_state.uploaded_file != uploaded_file:
        st.session_state.uploaded_file = uploaded_file
        st.session_state.image = base64.b64encode(uploaded_file.getvalue()).decode()
        st.session_state.messages.append(
            HumanMessage(
                content=uploaded_file.name,
                additional_kwargs={
                    "image": uploaded_file,
                }
            )
        )
        st.chat_message("user").image(uploaded_file, width=512)

if prompt := st.chat_input():
    human_message = HumanMessage(content=prompt)
    state["messages"] = [human_message]
    st.session_state.messages.append(human_message)
    st.chat_message("human").write(prompt)

    response = ""
    if st.session_state.image is not None:
        image = st.session_state.image
        state["image"] = image

    response = app.invoke(state)
    messages = response["messages"]
    last_message = messages[-1]

    if isinstance(last_message, AIMessage):
        st.chat_message("assistant").write(last_message.content)
        st.session_state.messages.append(last_message)
    else:
        if isinstance(last_message, FunctionMessage):
            if last_message.name == "image2txt":
                last_message.content = str(last_message.content).strip('"')
                st.chat_message(last_message.name).write(last_message.content)
                st.session_state.messages.append(last_message)
            elif last_message.name == "txt2image":
                content = json.loads(str(last_message.content))
                imageb64 = content["images"][0]
                params = content["parameters"]
                params = json.dumps({p: params[p] for p in params if (
                    params[p] is not None and
                    params[p] != 0 and
                    params[p] is not False and
                    params[p] != "" and
                    params[p] != [] and
                    params[p] != {}
                )}, indent=2)
                image = Image.open(BytesIO(base64.b64decode(imageb64)))
                st.chat_message(last_message.name).image(image, width=512)
                with st.chat_message(last_message.name).expander("Parameters"):
                    st.code(params)
                st.session_state.messages.append(
                    BaseMessage(
                        name=last_message.name,
                        type=last_message.name,
                        content="",
                        additional_kwargs={
                            "params": params,
                            "image": image,
                        }
                    )
                )
                st.session_state.image = imageb64


Overwriting app.py


### Download and run ollama

Below, we:
1. download the ollama binary
2. make it executable
3. start ollama in the background
4. download the hosted bakllava, llama2, and function calling model

In [3]:
%%capture
!curl -L https://ollama.ai/download/ollama-linux-amd64 -o ollama
!chmod +x ollama
!./ollama serve &>/content/ollama_logs.txt &
!./ollama pull bakllava
!./ollama pull llama2
!./ollama pull klcoder/mistral-7b-functioncall

### Start and background streamlit app

In [None]:
!streamlit run app.py &>/content/logs.txt &

## Find the IP of your instance

In [5]:
!curl ipv4.icanhazip.com
!echo "Copy this IP into the webpage that opens below"

35.227.3.82
Copy this IP into the webpage that opens below


## Expose the Streamlit app on port 8501

In [None]:
!npx localtunnel --port 8501

[K[?25hnpx: installed 22 in 3.73s
your url is: https://common-mails-doubt.loca.lt
