In [None]:
from google.colab import userdata
groq_api = userdata.get('groq_api_key')
lc_api = userdata.get('lc_api_key')

In [None]:
import os
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_API_KEY"] = lc_api
os.environ["LANGCHAIN_PROJECT"] = "colab-test"
os.environ["GROQ_API_KEY"] = groq_api

In [None]:
!pip install -Uq "unstructured[all-docs]" pillow lxml
!pip install -Uq chromadb tiktoken
!pip install -Uq langchain langchain-community langchain-groq


In [None]:
# Install Poppler for PDF processing
!apt-get install -y poppler-utils



# Extract the data

Extract the elemets of the PDF that we will be able to use in the retrieval process.

### Partition PDF text, images or tables

In [None]:
from unstructured.partition.pdf import partition_pdf

output_path = "/content/drive/MyDrive/storypdf/"
file_path = output_path + 'short-story.pdf'

chunks = partition_pdf(
    filename = file_path,
    #infer_table_structure=True,               # extract tables
    #strategy='hi_res',                        # mandatory to infer tables

    extract_image_block_types=["Image"],      # Add 'Table' to list to extract image of tables
    image_output_dir_path="/content/drive/MyDrive/storypdf/",        # if None, images and tables will saved in base64

    extract_image_block_to_payload=True,      # if true, will extract base64 for API usage

    chunking_strategy="by_title",             # or'basic'
    max_characters=10000,                     # defaults to 500
    combine_text_under_n_chars=2000,          # defaults to 0
    new_after_n_chars=6000,

)

In [None]:
# we got 1 types of elements from the partition_pdf function
set([str(type(el)) for el in chunks])

In [None]:
# Each CompositeElement containes a bunch of related elements.
# This makes it easy to use these elements together in a RAG pipeline.

chunks[1].metadata.orig_elements

In [None]:
# This is what an extracted image looks like.
# It contains the base64 representation only because we set the param extract_image_block_to_payload=True

elements = chunks[1].metadata.orig_elements
chunk_images = [el for el in elements if 'Image' in str(type(el))]
chunk_images[0].to_dict()

### Separate extracted elements into text, images and tables

In [None]:
# separate tables from texts
tables = []
texts = []

for chunk in chunks:
    if "Table" in str(type(chunk)):
      tables.append(chunk)

    if "CompositeElement" in str(type(chunk)):
      texts.append(chunk)

In [None]:
# Get the images from the CompositeElement objects

def get_images_base64(chunks):
  images_b64 = []
  for chunk in chunks:
    if "CompositeElement" in str(type(chunk)):
      chunk_els = chunk.metadata.orig_elements
      for el in chunk_els:
        if "Image" in str(type(el)):
          images_b64.append(el.metadata.image_base64)
  return images_b64

images = get_images_base64(chunks)

### Check what the images look like

In [None]:
import base64
from IPython.display import Image, display

def display_base64_image(base64_code):
  # Decode the base64 string to binary
  image_data = base64.b64decode(base64_code)
  # Display the image
  display(Image(data=image_data))

display_base64_image(images[0])

# Summarize the data

### text and table summaries

In [None]:
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

In [None]:
# prompt
prompt_text = """
You are an assistant tasked with summarizing tables and text.
Give a concise summary of the table or text.

Respond only with the summary, no additionnal comment.
Do not start your message by saying "Here is a summary" or anything like that.
Just give the summary as it is.

Table or text chunk: {element}

"""

prompt = ChatPromptTemplate.from_template(prompt_text)

# Summary Chain
model = ChatGroq(temperature=0.5, model="llama-3.1-8b-instant")
summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()

In [None]:
# Summarize text
text_summaries = summarize_chain.batch(texts, {"max_concurrency": 3})

# Summarize tables
tables_html = [table.metadata.text_as_html for table in tables]
table_summaries = summarize_chain.batch(tables_html, {"max_concurrency": 3})

In [None]:
table_summaries

### Image summaries

In [None]:
prompt_template = """Describe the image in detail. For context,
              the image part of the story for kids."""
messages = [
    (
        "user",
        [
            {"type": "text", "text": prompt_template},
            {
                "type": "image_url",
                "image_url": {"url": "data:image/jpeg;base64,{image}"},
            },
        ],
    )
]

prompt = ChatPromptTemplate.from_messages(messages)

model = ChatGroq(model="meta-llama/llama-4-scout-17b-16e-instruct")
chain = prompt | model | StrOutputParser()


image_summaries = chain.batch(images)

In [None]:

# 1. Create a prompt template with a placeholder for the image
prompt = ChatPromptTemplate.from_messages([
    (
        "user",
        [
            {"type": "text", "text": """Describe the image in detail. For context,
              the image part of the story for kids."""},
            {
                "type": "image_url",
                "image_url": {"url": "data:image/jpeg;base64,{image}"},
            },
        ],
    )
])

# 2. Your list of base64 images (strip whitespace if needed / img.strip())
# Example: images = ["/9j/4AAQSkZJRgABAQ...", "iVBORw0KGgoAAAANSUhEUg..."]
image_dicts = [{"image": img} for img in images]

# 3. Setup the chain
model = ChatGroq(model="meta-llama/llama-4-scout-17b-16e-instruct")
chain = prompt | model | StrOutputParser()

# 4. Run in batch
image_summaries = chain.batch(image_dicts)

# 5. Print results
#for i, summary in enumerate(image_summaries):
#    print(f"\n📄 Image {i+1} Summary:\n{summary}")


In [None]:
print(image_summaries[1])


# Load data and summaries to vectorestore

### Create the vectorstore

In [None]:
%pip install --upgrade --quiet  sentence_transformers

In [None]:
from langchain_community.embeddings import HuggingFaceBgeEmbeddings

model_name = "BAAI/bge-small-en"
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": True}
hf = HuggingFaceBgeEmbeddings(
    model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
)

In [None]:
%pip install --upgrade --quiet  langchain-google-genai

In [None]:

import uuid
from langchain.vectorstores import Chroma
from langchain.storage import InMemoryStore
from langchain.schema.document import Document
#from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever


# The vectorstore to use to index the child chunks
#embeddings = GoogleGenerativeAIEmbeddings(model="models/gemini-embedding-exp-03-07")
embeddings = hf
vectorstore = Chroma(collection_name="multi_modal_rag", embedding_function=embeddings)

# The storage layer for the parent documents
store = InMemoryStore()
id_key = "doc_id"

# The retriver ( empty to start )
retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=store,
    id_key=id_key,
)

### Load the sumaries and link them to the original data

In [None]:
# Add texts
doc_ids = [str(uuid.uuid4()) for _ in texts]
summary_texts = [
    Document(page_content=summary, metadata={id_key: doc_ids[i]}) for i, summary in enumerate(text_summaries)
]
retriever.vectorstore.add_documents(summary_texts)
retriever.docstore.mset(list(zip(doc_ids, texts)))

# Add image summaries
img_ids = [str(uuid.uuid4()) for _ in images]
summary_img = [
    Document(page_content=summary, metadata={id_key: img_ids[i]}) for i, summary in enumerate( image_summaries)
]
retriever.vectorstore.add_documents(summary_img)
retriever.docstore.mset(list(zip(img_ids, images)))

# Add tables
#table_ids = [str(uuid.uuid4()) for _ in tables]
#summary_tables = [
#    Document(page_content=summary, metadata={id_key: table_ids[i]}) for i, summary in enumerate(table_summaries)
#]
#retriever.vectorstore.add_documents(summary_tables)
#retriever.docstore.mset(list(zip(table_ids, tables)))

### Check retrieval

In [None]:
# Retrieve
docs = retriever.invoke(
    "Who is the little princess?"
)

In [None]:
for doc in docs:
  print(str(doc) + "\n\n" + "-" * 80)

In [None]:
from langchain_core.documents import Document
def extract_page_numbers_from_chunk(chunk):
  elements = chunk.metadata.orig_elements

  page_numbers = set()
  for element in elements:
    page_numbers.add(element.metadata.page_number)

  return page_numbers

extract_page_numbers_from_chunk(chunks[0])

def display_chunk_pages(chunk):


# RAG pipeline

In [None]:

from base64 import b64decode
from langchain_groq import ChatGroq
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.runnables import RunnablePassthrough, RunnableLambda



In [None]:



def parse_docs(docs):
  """split base64-encoded images and texts"""
  b64 = []
  text = []
  for doc in docs:
    try:
      b64decode(doc)
      b64.append(doc)
    except Exception as e:
      text.append(doc)
  return {"images": b64, "texts": text}

def build_prompt(kwargs):
    docs_by_type = kwargs["context"]
    user_question = kwargs["question"]

    context_text = ""
    if len(docs_by_type["texts"]) > 0 :
      for text_element in docs_by_type["texts"]:
        context_text += text_element.text

    #construct prompt with context (including images)
    prompt_template = f"""
    Answer the question based only on the following context, which can include text, and the below image.
    Context: {context_text}
    Question: {user_question}
    """

    prompt_content = [{"type": "text", "text": prompt_template}]

    if len(docs_by_type["images"]) > 0:
      for image in docs_by_type["images"]:
        prompt_content.append(
            {
                "type": "image_url",
                "image_url": {"url": f"data:image/jpeg;base64,{image}"},
            }
        )

      return ChatPromptTemplate.from_messages(
      [
        HumanMessage(content = prompt_content)
      ]
  )

chain = (
    {
        "context": retriever | RunnableLambda(parse_docs),
        "question": RunnablePassthrough(),
    }
    | RunnableLambda(build_prompt)
    | ChatGroq(model="meta-llama/llama-4-scout-17b-16e-instruct")
    | StrOutputParser()
)

chain_with_sources = {
    "context": retriever | RunnableLambda(parse_docs),
    "question": RunnablePassthrough(),
} | RunnablePassthrough().assign(
    response=(
        RunnableLambda(build_prompt)
        | ChatGroq(model="meta-llama/llama-4-scout-17b-16e-instruct")
        | StrOutputParser()
    )
)



In [None]:
response = chain.invoke(
    " what is the second story about?"
)
print(response)

In [None]:
response = chain_with_sources.invoke(
    "summarise the last story and show the images"
)
print("response:", response['response'])

print("\n\ncontext:")
for text in response['context']['texts']:
    print(text.text)
    print("Page number: ", text.metadata.page_number)
    print("\n" + "-"*50 + "\n")
for image in response['context']['images']:
    display_base64_image(image)