## 1. Setup Environment

Before diving into the core functionalities, let's set up our environment by importing the necessary libraries and configuring essential settings.

In [1]:
import os
import uuid
import pandas as pd
import re
import base64
import htmltabletomd
import logging
import requests
import time
import ipywidgets as widgets
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

In [2]:
from io import BytesIO, StringIO
from PIL import Image as PILImage
from IPython.display import display, Markdown, HTML, clear_output
from openai import OpenAIError

In [3]:
# LangChain and related libraries
from langchain_community.document_loaders import UnstructuredPDFLoader
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain_community.storage import RedisStore
from langchain_community.utilities.redis import get_client
from langchain_chroma import Chroma
from langchain_core.documents import Document

## 2. Data Loading and Preprocessing

In this section, we'll load multimodal data (PDFs containing text, tables, and images) and preprocess it for further analysis.

### 2.1. Load Multimodal Data

We'll start by locating all PDF files within the specified directory and its subdirectories. This setup ensures that we process all relevant documents while excluding hidden files and directories.

In [4]:
# Remove existing figures to ensure a clean workspace
!rm -rf ./figures

In [5]:
# Directory containing the PDFs
pdf_dir = './references'

# Collect all PDF files from the directory and subdirectories, excluding hidden ones
pdf_files = []

for root, dirs, files in os.walk(pdf_dir):
    # Exclude hidden directories
    dirs[:] = [d for d in dirs if not d.startswith('.')]
    for file in files:
        # Exclude hidden files and ensure the file has a .pdf extension
        if file.lower().endswith('.pdf') and not file.startswith('.'):
            pdf_files.append(os.path.join(root, file))

### 2.2. Extract and Partition Text, Tables, and Images

Next, we'll extract the content from each PDF using UnstructuredPDFLoader. The loader is configured to extract text, tables, and images, and to partition the content into manageable chunks based on titles.

In [6]:
# Initialize an empty list to hold data from all PDFs
data = []

# Loop through each PDF file and load its content
for pdf_file in pdf_files:
    print(f'Loading {pdf_file}')
    loader = UnstructuredPDFLoader(
        file_path=pdf_file,
        strategy='hi_res',
        extract_images_in_pdf=True,
        infer_table_structure=True,
        skip_infer_table_types = [],
        chunking_strategy="by_title",     # Section-based chunking
        max_characters=8000,              # Max size of chunks
        new_after_n_chars=4000,           # Preferred size of chunks
        combine_text_under_n_chars=2000,  # Combine smaller chunks
        mode='elements',
        image_output_dir_path='./figures'
    )
    data.extend(loader.load())

Loading ./references/Table with grids.pdf


In [7]:
# Separate documents and tables based on metadata
docs = []
tables = []

for doc in data:
    if doc.metadata['category'] == 'Table':
        tables.append(doc)
    elif doc.metadata['category'] == 'CompositeElement':
        docs.append(doc)

# Display the number of documents and tables extracted
len(docs), len(tables)

(6, 2)

In [8]:
# Convert HTML tables to Markdown for easier readability and processing
for table in tables:
    table.page_content = htmltabletomd.convert_table(table.metadata['text_as_html'])

## 3. Connecting to the Language Model

To interact with the OpenAI language models, we'll establish a connection using the OpenAI API. You'll be prompted to enter your API key securely.

In [9]:
from getpass import getpass

# Prompt the user to enter their OpenAI API Key securely
OPENAI_KEY = getpass('Enter Open AI API Key: ')
os.environ['OPENAI_API_KEY'] = OPENAI_KEY

Enter Open AI API Key:  ········


In [10]:
# Initialize the ChatOpenAI model with desired parameters
chatgpt = ChatOpenAI(model_name='gpt-4o', temperature=0)

## 4. Generating Summaries for Multimodal Data

Summarizing the extracted data is crucial for efficient retrieval. We'll generate summaries for texts, tables, and images to optimize them for semantic retrieval.

### 4.1. Create Text and Table Summaries

Using a tailored prompt, we'll instruct the language model to generate detailed summaries of text and tables. These summaries are designed to be easily embedded and retrieved later.

In [11]:
# Define the prompt template for summarization
prompt_text = """
You are an assistant tasked with summarizing tables and text particularly for semantic retrieval.
These summaries will be embedded and used to retrieve the raw text or table elements.
Give a detailed summary of the table or text below that is well optimized for retrieval.
For any tables also add in a one line description of what the table is about besides the summary.
Do not add additional words like Summary: etc.

Table or text chunk:
{element}
"""
prompt = ChatPromptTemplate.from_template(prompt_text)

# Define the summarization chain
summarize_chain = (
    {"element": RunnablePassthrough()}
      |
    prompt
      |
    chatgpt
      |
    StrOutputParser()  # Extracts the response as text and returns it as a string
)

In [12]:
def summarize_with_retry(docs, chain, max_retries=5):
    summaries = []
    for idx, doc in enumerate(docs):
        retries = 0
        while retries < max_retries:
            try:
                # Attempt to summarize the document
                summary = chain.invoke(doc)
                summaries.append(summary)
                break  # Break out of the retry loop if successful
            except RateLimitError as e:
                # Extract recommended wait time from error message if available
                wait_time = 5  # Default wait time in seconds
                error_message = str(e)
                if 'Please try again in' in error_message:
                    try:
                        wait_time = float(re.search(r'Please try again in (\d+(\.\d+)?)s', error_message).group(1))
                    except (AttributeError, ValueError):
                        pass
                print(f"Rate limit hit when processing document {idx + 1}. Waiting for {wait_time} seconds before retrying.")
                time.sleep(wait_time)
                retries += 1
            except Exception as e:
                # Handle other exceptions if necessary
                print(f"An error occurred when processing document {idx + 1}: {e}")
                summaries.append(None)
                break  # Break out of the retry loop on other exceptions
        else:
            print(f"Failed to process document {idx + 1} after {max_retries} retries.")
            summaries.append(None)
    return summaries

In [13]:
# Prepare documents for summarization
text_docs = [doc.page_content for doc in docs]
table_docs = [table.page_content for table in tables]

# Generate text summaries with retry logic
text_summaries = summarize_with_retry(text_docs, summarize_chain)

# Generate table summaries with retry logic
table_summaries = summarize_with_retry(table_docs, summarize_chain)

# Display the number of summaries generated
print(f"Generated {len(text_summaries)} text summaries and {len(table_summaries)} table summaries.")

Generated 6 text summaries and 2 table summaries.


### 4.2. Create Image Summaries

Images require special handling. We'll encode images to Base64 and generate summaries that describe their content, making them suitable for retrieval-based tasks.

In [14]:
# Function to encode images to Base64
def encode_image(image_path):
    # Read and encode the image file
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")

In [15]:
# Function to generate image summaries using the language model
def image_summarize(img_base64, prompt):
    chat = ChatOpenAI(model="gpt-4o", temperature=0)

    msg = chat.invoke(
        [
            HumanMessage(
                content=[
                    {"type": "text", "text": prompt},
                    {
                        "type": "image_url",
                        "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
                    },
                ]
            )
        ]
    )
    return msg.content

In [16]:
# Function to generate summaries for all images in a directory
def generate_img_summaries(path):
    """
    Generate summaries and base64 encoded strings for images
    path: Path to list of .jpg files extracted by Unstructured
    """

    # Lists to store Base64 encoded images and their summaries
    img_base64_list = []
    image_summaries = []

    # Define the prompt for image summarization
    prompt = """You are an assistant tasked with summarizing images for retrieval.
                Remember these images could potentially contain graphs, charts or tables also.
                These summaries will be embedded and used to retrieve the raw image for question answering.
                Give a detailed summary of the image that is well optimized for retrieval.
                Do not add additional words like Summary: etc.
             """

    # Process each image file in the directory
    for img_file in sorted(os.listdir(path)):
        if img_file.endswith(".jpg"):
            img_path = os.path.join(path, img_file)
            base64_image = encode_image(img_path)
            img_base64_list.append(base64_image)
            image_summaries.append(image_summarize(base64_image, prompt))

    return img_base64_list, image_summaries

In [17]:
# Path to the directory containing extracted images
IMG_PATH = './figures'

# Generate Base64 encoded images and their summaries
imgs_base64, image_summaries = generate_img_summaries(IMG_PATH)

# Display the number of images processed
print(f"Processed {len(imgs_base64)} images and generated {len(image_summaries)} summaries.")

Processed 4 images and generated 4 summaries.


## 5. Building Vector Retrievers

Vector retrievers play a pivotal role in RAG by enabling efficient and relevant information retrieval. We'll build both multimodal and single-modal retrievers to handle diverse data types.

### 5.1. Access Embedding Model

We'll use OpenAI's embedding model to convert our summaries into vector representations suitable for retrieval.

In [18]:
# Initialize the OpenAI embedding model
openai_embed_model = OpenAIEmbeddings(model='text-embedding-3-large')

### 5.2. Create Utility Functions

Utility functions will assist in managing documents and integrating them with our vector store and document store.

In [19]:
def create_multi_vector_retriever(
    docstore, vectorstore, text_summaries, texts, table_summaries, tables, image_summaries, images
):
    id_key = "doc_id"

    # Initialize the MultiVectorRetriever without search_kwargs
    retriever = MultiVectorRetriever(
        vectorstore=vectorstore,
        docstore=docstore,
        id_key=id_key,
    )

    # Set the number of documents to retrieve directly on the retriever
    retriever.search_kwargs['k'] = 5  # Set 'k' to 5

    # Helper function to add documents to the retriever
    def add_documents(retriever, doc_summaries, doc_contents):
        doc_ids = [str(uuid.uuid4()) for _ in doc_contents]
        summary_docs = [
            Document(page_content=s, metadata={id_key: doc_ids[i]})
            for i, s in enumerate(doc_summaries)
        ]
        retriever.vectorstore.add_documents(summary_docs)
        retriever.docstore.mset(list(zip(doc_ids, doc_contents)))

    # Add text summaries and their contents
    if text_summaries:
        add_documents(retriever, text_summaries, texts)
    # Add table summaries and their contents
    if table_summaries:
        add_documents(retriever, table_summaries, tables)
    # Add image summaries and their contents
    if image_summaries:
        add_documents(retriever, image_summaries, images)

    return retriever

### 5.3. Initiate Vectorstores: Chroma

Chroma serves as our vector store, indexing the summaries and their embeddings for efficient retrieval.

In [20]:
# Initialize the Chroma vectorstore for multimodal data
chroma_db_multimodal = Chroma(
    collection_name="mm_rag",
    embedding_function=openai_embed_model,
    collection_metadata={"hnsw:space": "cosine"},
)

### 5.4. Initiate Docstores: Redis

Docstores store the raw documents corresponding to the summaries. We'll use Redis for the multimodal retriever.

**Note:** Before proceeding, ensure that Redis Stack Server is installed and running. You can set it up by executing the following commands in JupyterLab's terminal:

```bash
# 1. Import the GPG key for the Redis repository
curl -fsSL https://packages.redis.io/gpg | sudo gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg

# 2. Add the Redis repository to your sources list
echo "deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] \
https://packages.redis.io/deb $(lsb_release -cs) main" | \
sudo tee /etc/apt/sources.list.d/redis.list

# 3. Update package lists
sudo apt-get update

# 4. Install Redis Stack Server
sudo apt-get install redis-stack-server

# 5. Start Redis Stack Server in the background
redis-stack-server --daemonize yes

In [21]:
from langchain_community.utilities.redis import get_client
from langchain_community.storage import RedisStore

# Initialize Redis client
client = get_client('redis://localhost:6379')

# Initialize RedisStore for multimodal retriever
redis_store = RedisStore(client=client)  # Alternative stores like filestore or memorystore can also be used

### 5.5. Create Retrievers

With our vector stores and document stores set up, we'll create both multimodal and single-modal retrievers.

In [22]:
# Create the multimodal retriever
retriever_multimodal = create_multi_vector_retriever(
    redis_store,
    chroma_db_multimodal,
    text_summaries,
    text_docs,
    table_summaries,
    table_docs,
    image_summaries,
    imgs_base64,
)

# Display the multimodal retriever
retriever_multimodal

MultiVectorRetriever(vectorstore=<langchain_chroma.vectorstores.Chroma object at 0xffc0c898ab00>, docstore=<langchain_community.storage.redis.RedisStore object at 0xffc0cafbebf0>, search_kwargs={'k': 5})

## 6. Interactive Chat Interface for Querying

Now that our data is preprocessed and our retrievers are set up, we'll proceed to create an interactive chat interface within the Jupyter notebook. This interface will allow users to input queries with both text and images and compare responses from the multimodal RAG Agent.

### 6.1. Configure Logging

In [23]:
# Configure logging
logging.basicConfig(level=logging.CRITICAL, format='%(levelname)s: %(message)s')

# Disable specific external library loggers to reduce clutter
logging.getLogger('openai').disabled = True
logging.getLogger('urllib3').disabled = True
logging.getLogger('requests').disabled = True
logging.getLogger('httpx').disabled = True

# Create a dedicated logger for the application
logger = logging.getLogger('Interactive_Chat')
logger.setLevel(logging.WARNING)

### 6.2. Define Helper Functions

In [24]:
def extract_answer(text):
    """
    Extracts the agent's answer.
    Returns the extracted text.
    """
    return text.strip()

In [25]:
def encode_image_file(uploaded_file_content):
    """
    Encodes an uploaded image file to a Base64 string after resizing it to a maximum size.
    """
    try:
        img = PILImage.open(BytesIO(uploaded_file_content))
        # Resize the image while maintaining aspect ratio
        img.thumbnail((400, 400), PILImage.LANCZOS)
        # Save the image to a BytesIO object
        buffered = BytesIO()
        img.save(buffered, format="JPEG", quality=85)
        # Encode the image to Base64 and decode to string
        img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
        return img_str
    except Exception as e:
        logger.warning(f"Error encoding image: {e}")
        return None

In [26]:
def resize_base64_image(img_base64, max_size=(400, 400)):
    """
    Resizes a Base64-encoded image to the specified maximum size.
    """
    try:
        img_data = base64.b64decode(img_base64)
        img = PILImage.open(BytesIO(img_data))
        img.thumbnail(max_size, PILImage.LANCZOS)
        buffered = BytesIO()
        img.save(buffered, format="JPEG", quality=85)
        resized_img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
        return resized_img_str
    except Exception as e:
        # Log any exceptions during resizing
        logger.warning(f"Error resizing image: {e}")
        return img_base64  # Return original if resizing fails

In [27]:
def looks_like_base64(sb):
    """Check if the string looks like base64"""
    return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None

In [28]:
def is_image_data(b64data):
    """
    Check if the base64 data is an image by looking at the start of the data
    """
    image_signatures = {
        b"\xff\xd8\xff": "jpg",
        b"\x89PNG\r\n\x1a\n": "png",
        b"GIF8": "gif",
        b"RIFF": "webp",
    }
    try:
        header = base64.b64decode(b64data)[:8]  # Decode and get the first 8 bytes
        for sig in image_signatures:
            if header.startswith(sig):
                return True
        return False
    except Exception:
        return False

In [29]:
def detect_markdown_table(text):
    """
    Detects if the text contains a Markdown-formatted table.
    """
    lines = text.strip().split('\n')
    if len(lines) >= 2:
        # Check for header separator line (e.g., | --- | --- |)
        header_line = lines[1].strip()
        if re.match(r'^\s*\|?\s*:-{1,}\s*(\|\s*:-{1,}\s*)+\|?\s*$', header_line):
            return True
    return False

In [30]:
def split_docs_into_images_texts_tables(docs):
    """
    Splits documents into images, texts, and tables.
    """
    images = []
    texts = []
    tables = []
    for doc in docs:
        # Extract content and metadata
        if isinstance(doc, Document):
            content = doc.page_content
            metadata = doc.metadata
        else:
            content = doc
            metadata = {}
    
        # Ensure content is a string
        if isinstance(content, bytes):
            content = content.decode('utf-8', errors='ignore')
    
        # Extract category from metadata
        category = metadata.get('category', '').lower()
    
        # Check if the document is a table based on metadata or content
        if category == 'table':
            tables.append({'content': content, 'metadata': metadata})
            continue
        elif '<table' in content.lower():
            tables.append({'content': content, 'metadata': metadata})
            continue
        elif detect_markdown_table(content):
            tables.append({'content': content, 'metadata': metadata})
            continue
    
        # Remove data URL prefix if present
        if content.startswith('data:image'):
            content = content.split(',', 1)[1]
    
        # Check if content is an image
        if looks_like_base64(content) and is_image_data(content):
            images.append(content)
        else:
            texts.append(content)
    return {'images': images, 'texts': texts, 'tables': tables}

In [31]:
def limit_text_length(text, max_words=100):
    """
    Truncates the input text to a maximum number of words.
    If the text is truncated, adds a concise note at the end.
    """
    words = text.split()
    if len(words) > max_words:
        truncated_text = ' '.join(words[:max_words]) + '... (content truncated; see original reference for full text)'
        return truncated_text
    else:
        return text

In [32]:
def display_base64_image(img_base64):
    """
    Displays a Base64-encoded image using matplotlib.
    """
    try:
        img_data = base64.b64decode(img_base64)
        img = PILImage.open(BytesIO(img_data))
        plt.figure(figsize=(6,6))
        plt.imshow(img)
        plt.axis('off')
        plt.show()
    except Exception as e:
        logger.warning(f"Error displaying image: {e}")

In [33]:
def call_openai_api_with_image(messages, max_tokens=500):
    """
    Makes a direct API call to OpenAI's Chat Completion endpoint with structured messages.
    """
    api_key = os.getenv('OPENAI_API_KEY')
    if not api_key:
        logger.error("OpenAI API key not set in environment variables.")
        return None
    
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}"
    }
    
    payload = {
        "model": "gpt-4o",
        "messages": messages,
        "max_tokens": max_tokens,
        "temperature": 0,
        "top_p": 1
    }
    
    try:
        response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
        response.raise_for_status()
        return response.json()
    except requests.exceptions.HTTPError as http_err:
        logger.warning(f"HTTP error occurred: {http_err} - Response: {response.text}")
    except Exception as err:
        logger.warning(f"Other error occurred: {err}")
    return None

### 6.3. Define RAG Functions for Interactive Use

We'll define functions to handle both RAG and non-RAG strategies to work with the interactive interface. These functions manage the retrieval of relevant documents and interact with the language model to generate answers.

In [34]:
# Add a global conversation memory
conversation_history = []

In [35]:
def rag_qa(question, image_base64=None):
    # Ensure we reference the global conversation_history
    global conversation_history
    
    # Build an 'all_history' string of user Q and agent A from conversation_history
    conversation_context = ""
    for turn in conversation_history:
        conversation_context += f"User: {turn['user']}\n"
        conversation_context += f"Agent: {turn['agent']}\n\n"
        
    # Retrieve relevant documents (now retrieving 5 documents)
    retrieved_docs = retriever_multimodal.invoke(question, k = 5)
    
    # Split documents into images, texts, and tables
    sources = split_docs_into_images_texts_tables(retrieved_docs)
    
    # Limit text sources to avoid exceeding token limits
    sources['texts'] = [limit_text_length(text) for text in sources['texts']]
    
    # Build the context text
    formatted_texts = "\n".join(sources['texts'])
    context_text = f"Context documents:\n{formatted_texts}"
    
    # Include conversation_context in the final prompt
    prompt_text = f"""{conversation_context}
You are a friendly and engaging assistant that answers questions based solely on the input provided so far.
Above is the conversation so far, followed by the user's new question.
Use any relevant context documents, images, and tables to answer thoroughly.

New user question:
{question}

{context_text}
"""

    # Prepare the messages
    messages = []
    
    # Include the question image if provided
    if image_base64:
        resized_image_base64 = resize_base64_image(image_base64)
        messages.append({
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "User question includes an image."
                },
                {
                    "type": "image_url",
                    "image_url": {"url": f"data:image/jpeg;base64,{resized_image_base64}"}
                }
            ]
        })
    
    # Add images from retrieved sources
    for image_data in sources['images']:
        resized_img_base64 = resize_base64_image(image_data)
        messages.append({
            "role": "user",
            "content": [
                {
                    "type": "image_url",
                    "image_url": {"url": f"data:image/jpeg;base64,{resized_img_base64}"}
                }
            ]
        })
    
    # Add tables from retrieved sources
    for table_dict in sources['tables']:
        table_content = table_dict['content']
        messages.append({
            "role": "user",
            "content": table_content
        })
    
    # Add the main prompt
    messages.append({
        "role": "user",
        "content": prompt_text
    })
    
    # Make the API call
    response_json = call_openai_api_with_image(messages, max_tokens=500)
    
    if not response_json:
        return 'Error generating response.', sources
    
    # Extract the answer
    try:
        answer_text = response_json['choices'][0]['message']['content'].strip()
    except (KeyError, IndexError) as e:
        logger.warning(f"Error parsing response: {e} - Response: {response_json}")
        answer_text = 'Error generating response.'

# Update conversation_history with the new turn
    conversation_turn = {
        "user": question,
        "agent": answer_text
    }
    if image_base64:
        conversation_turn["user_image"] = image_base64
    conversation_history.append(conversation_turn)

    return answer_text, sources

### 6.4. Create the Interactive Interface

We'll set up the interactive widgets and define a function to handle user input and display the outputs.

In [36]:
def display_retrieved_sources(sources):
    # Display Text Sources
    for i, text in enumerate(sources['texts']):
        display(Markdown(f"**Text Source {i+1}:**"))
        display(Markdown(text))
        print()  # Add a blank line for spacing
    
    # Display Image Sources
    for i, img_base64 in enumerate(sources['images']):
        display(Markdown(f"**Image Source {i+1}:**"))
        display_base64_image(img_base64)
        print()  # Add a blank line for spacing
    
    # Display Table Sources
    for i, table_dict in enumerate(sources['tables']):
        display(Markdown(f"**Table Source {i+1}:**"))
        table_content = table_dict['content']
        # Try to parse and display as HTML table
        try:
            tables = pd.read_html(StringIO(table_content))
            for table in tables:
                display(table)
        except ValueError:
            # If parsing fails, try to render as Markdown table
            try:
                display(Markdown(table_content))
            except Exception:
                # If all else fails, display the raw content
                print(table_content)
        print()

In [37]:
# Added chat history output
chat_history_output = widgets.Output()

def update_chat_history_display():
    global conversation_history
    with chat_history_output:
        clear_output()
        # Add a header for the chat history
        # display(Markdown("### Chat History:"))
        if not conversation_history:
            display(Markdown("_No conversation history yet._"))
            return
        
        # No line between turns, show user image if any
        for idx, turn in enumerate(conversation_history, start=1):
            display(Markdown(f"**User (Turn {idx}):** {turn['user']}"))
            if "user_image" in turn:
                # display the user's uploaded image in chat history
                display_base64_image(turn["user_image"])
            display(Markdown(f"**Agent:** {turn['agent']}"))

In [38]:
# Create text input for the user's question
question_input = widgets.Textarea(
    value='',
    placeholder='Type your question here...',
    description='Question:',
    layout=widgets.Layout(width='80%', height='80px')
)

In [39]:
# Create file upload widget for the user's image
image_upload = widgets.FileUpload(
    accept='image/*',  # Accept images only
    multiple=False,
    description='Upload Image'
)

In [40]:
# Create a button to clear the uploaded image
clear_image_button = widgets.Button(
    description='Clear Image',
    button_style='warning'
)

In [41]:
# Create output widget to display the uploaded image
uploaded_image_output = widgets.Output()

In [42]:
# Function to handle changes in the image upload widget
def on_image_upload_change(change):
    with uploaded_image_output:
        uploaded_image_output.clear_output()
        if image_upload.value:
            uploaded_file = image_upload.value[0]
            image_data = uploaded_file['content']
            image_base64 = encode_image_file(image_data)
            if image_base64:
                display_base64_image(image_base64)
        else:
            # Clear the image if no image is uploaded
            pass

In [43]:
# Observe changes to the image upload widget
image_upload.observe(on_image_upload_change, names='value')

In [44]:
# Function to clear the uploaded image
def on_clear_image_clicked(b):
    image_upload.value = ()
    with uploaded_image_output:
        uploaded_image_output.clear_output()

In [45]:
# Link the clear image button to the function
clear_image_button.on_click(on_clear_image_clicked)

In [46]:
# Create a button to submit the query
submit_button = widgets.Button(
    description='Submit',
    button_style='success'
)

In [47]:
# Create output widget to display the results
output = widgets.Output()

In [48]:
# New button to wipe memory
clear_memory_button = widgets.Button(
    description='Clear Memory',
    button_style='danger'
)

def on_clear_memory_clicked(b):
    global conversation_history
    conversation_history.clear()
    with output:
        clear_output()
        display(Markdown("**Memory has been cleared.**"))
    update_chat_history_display()
clear_memory_button.on_click(on_clear_memory_clicked)

In [49]:
# Define the function to handle the submission
def on_submit(button):
    with output:
        clear_output()
        # Get the user's question
        question_text = question_input.value.strip()
        
        # Check if a question was entered
        if not question_text:
            display(Markdown("**Please enter a question.**"))
            return
        
        # Get the uploaded image if any
        image_base64 = None
        if image_upload.value:
            uploaded_file = image_upload.value[0]
            image_data = uploaded_file['content']
            image_base64 = encode_image_file(image_data)
            if not image_base64:
                image_base64 = None
        else:
            image_base64 = None  # Explicitly set to None if no image

        '''
        # Display the user's question and image
        display(Markdown("**User Question:**"))
        display(Markdown(question_text))
        if image_base64:
            display(Markdown("\n**Uploaded Image:**"))
            display_base64_image(image_base64)
        
        display(Markdown("\n---\n"))
        display(Markdown("**Generating responses...**"))
        '''
        
        # Multimodal RAG agent
        mm_answer, mm_sources = rag_qa(question_text, image_base64)
        
        # Single line separator before Retrieved Sources
        display(Markdown("---"))  
        # display(Markdown("**Retrieved Sources:**"))
        display_retrieved_sources(mm_sources)
        
        # Clear the image upload for the next query
        image_upload.value = ()
        with uploaded_image_output:
            uploaded_image_output.clear_output()
    
    update_chat_history_display()

In [50]:
# Link the submit button to the function
submit_button.on_click(on_submit)

### 6.5. Test the Interactive Chat Interface

In [51]:
# Display the widgets
ui_layout = widgets.VBox([
    question_input,
    widgets.HBox([image_upload, clear_image_button]),
    uploaded_image_output,
    widgets.HBox([submit_button, clear_memory_button]),
    chat_history_output,
    output
])

display(ui_layout)

VBox(children=(Textarea(value='', description='Question:', layout=Layout(height='80px', width='80%'), placehol…