In [46]:
import fitz  # PyMuPDF
from PIL import Image
import io
import os
from dotenv import load_dotenv

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.documents import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_cohere import ChatCohere, CohereEmbeddings

load_dotenv()

True

In [47]:
text_data = []
data = 'data/attention_is_all_you_need.pdf'

In [48]:
with fitz.open(data) as pdf_file:
    # Create a directory to store the images
    if not os.path.exists("extracted_images"):
        os.makedirs("extracted_images")

    # Loop through every page in the PDF
    for page_number in range(len(pdf_file)):
        page = pdf_file[page_number]
        
        # Get the text on page
        text = page.get_text().strip()
        text_data.append({"response": text, "name": page_number+1})
        # Get the list of images on the page
        images = page.get_images(full=True)

        # Loop through all images found on the page
        for image_index, img in enumerate(images, start=0):
            xref = img[0]  # Get the XREF of the image
            base_image = pdf_file.extract_image(xref)  # Extract the image
            image_bytes = base_image["image"]  # Get the image bytes
            image_ext = base_image["ext"]  # Get the image extension
            
            # Load the image using PIL and save it
            image = Image.open(io.BytesIO(image_bytes))
            image.save(f"extracted_images/image_{page_number+1}_{image_index+1}.{image_ext}")

In [49]:
llm = ChatGoogleGenerativeAI(
    model="gemini-2.5-flash",
    temperature=0,
    timeout=None,
    max_retries=2,
    # other params...
)

In [50]:
import mimetypes
import base64
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import StrOutputParser

In [51]:
def image_to_base64_uri(file_path: str) -> str:
    mime_type, _ = mimetypes.guess_type(file_path)
    if mime_type is None:
        raise ValueError(f"Unable to determine file type: {file_path}")

    with open(file_path, "rb") as image_file:
        binary_data = image_file.read()

    base64_encoded_data = base64.b64encode(binary_data).decode('utf-8')

    return f"data:{mime_type};base64,{base64_encoded_data}"

In [52]:
prompt_text = "You are an assistant tasked with summarizing tables, images and text for retrieval. \
    These summaries will be embedded and used to retrieve the raw text or table elements \
    Give a concise summary of the table or text that is well optimized for retrieval. Table or text or image:"

In [53]:
img_data = []
for img in os.listdir("extracted_images"):
    image_uri = image_to_base64_uri(f"extracted_images/{img}")

    message = HumanMessage(
        content=[
            {
                "type": "text",
                "text": prompt_text,
            },
            {
                "type": "image_url",
                "image_url": {"url": image_uri},
            },
        ]
    )
    chain = llm | StrOutputParser()
    response = chain.invoke([message])
    img_data.append({"response": response, "name": img})

In [54]:
img_data[0]

{'response': "This image illustrates the Transformer neural network architecture, featuring an encoder-decoder structure. The encoder processes input embeddings combined with positional encodings through N identical layers, each containing a Multi-Head Attention block and a Feed Forward block, both followed by Add & Norm. The decoder, also with N identical layers, takes shifted output embeddings and positional encodings, utilizing a Masked Multi-Head Attention block, a Multi-Head Attention block, and a Feed Forward block, all followed by Add & Norm. The decoder's final output passes through a Linear layer and Softmax to produce output probabilities.",
 'name': 'image_3_1.png'}

In [55]:
# Set embeddings
embedding_model = CohereEmbeddings(model="embed-english-light-v3.0")

# Load the document
docs_list = [Document(page_content=text['response'], metadata={"name": text['name']}) for text in text_data]
img_list = [Document(page_content=img['response'], metadata={"name": img['name']}) for img in img_data]

# Split
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=400, chunk_overlap=50
)

doc_splits = text_splitter.split_documents(docs_list)
img_splits = text_splitter.split_documents(img_list)

In [60]:
# Add to vectorstore
vectorstore = Chroma.from_documents(
    documents=doc_splits + img_splits, # adding the both text and image splits
    collection_name="multi_model_rag",
    embedding=embedding_model,
)

retriever = vectorstore.as_retriever(
    search_type="similarity",
    search_kwargs={'k': 1}, # number of documents to retrieve
)

In [61]:
query = "What is the BLEU score of the Transformer (base model)?"
docs = retriever.invoke(query)

In [62]:
# Prompt
system = """You are an assistant for question-answering tasks. Answer the question based upon your knowledge. 
Use three-to-five sentences maximum and keep the answer concise."""
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Retrieved documents: \n\n <docs>{documents}</docs> \n\n User question: <question>{question}</question>"),
    ]
)

# Chain
rag_chain = prompt | llm | StrOutputParser()

# Run
generation = rag_chain.invoke({"documents":docs[0].page_content, "question": query})
print(generation)

The Transformer (base model) achieved a BLEU score of 27.3 on the English-to-German (EN-DE) newstest2014 test. For the English-to-French (EN-FR) newstest2014 test, its BLEU score was 38.1.
