
    PyMuPDF: For extracting text and images from PDFs.
    Gemini 1.5-flash model: To summarize images and tables.
    Cohere Embeddings: For embedding document splits.
    Chroma Vectorstore: To store and retrieve document embeddings.
    LangChain: To orchestrate the retrieval and generation pipeline.


In [45]:
# %pip install langchain langchain-community pillow pymupdf python-dotenv

In [46]:
import fitz  # PyMuPDF
from PIL import Image
import io
import os

from dotenv import load_dotenv
import google.generativeai as genai
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.documents import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.output_parsers import StrOutputParser
# from langchain_google_genai import GoogleGenerativeAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.llms import HuggingFaceHub
from sentence_transformers import SentenceTransformer

load_dotenv()

True

In [47]:
text_data = []
img_data = []

In [49]:
with fitz.open('training_documents/Transfer Learning.pdf') as pdf_file:
    # Create a directory to store the images
    if not os.path.exists("extracted_images"):
        os.makedirs("extracted_images")
    
     # Loop through every page in the PDF
    for page_number in range(len(pdf_file)):
        page = pdf_file[page_number]
        
        # Get the text on page
        text = page.get_text().strip()
        text_data.append({"response": text, "name": page_number+1})
        # Get the list of images on the page
        images = page.get_images(full=True)

        # Loop through all images found on the page
        for image_index, img in enumerate(images, start=0):
            xref = img[0]  # Get the XREF of the image
            base_image = pdf_file.extract_image(xref)  # Extract the image
            image_bytes = base_image["image"]  # Get the image bytes
            image_ext = base_image["ext"]  # Get the image extension
            
            # Load the image using PIL and save it
            image = Image.open(io.BytesIO(image_bytes))
            image.save(f"extracted_images/image_{page_number+1}_{image_index+1}.{image_ext}")    
        

In [50]:
api_key = os.getenv('GOOGLE_API_KEY')

genai.configure(api_key=api_key)
model = genai.GenerativeModel(model_name="gemini-2.0-flash")

In [51]:
for img in os.listdir("extracted_images"):
    image = Image.open(f"extracted_images/{img}")
    response = model.generate_content([image, "You are an AI assistant helping build a retrieval system from academic papers. The input is a table or figure image extracted from a paper. \
                                            Summarize the image with reference to the core topic or claim being visualized. Include comparisons, axes, legends, and what this visual proves or supports in context of the paper. \
                                            Your summary will be embedded and must serve as a high-quality retrieval chunk. Be specific, concise, and factually grounded."
])
    img_data.append({"response": response.text, "name": img})

In [52]:
model_name = "sentence-transformers/all-MiniLM-L6-v2"
model_kwargs = {'device': 'cuda'}
encode_kwargs = {'normalize_embeddings': False}
embeddings = HuggingFaceEmbeddings(
    model_name=model_name,
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs
)

# Load the document
docs_list = [Document(page_content=text['response'], metadata={"name": text['name']}) for text in text_data]
img_list = [Document(page_content=img['response'], metadata={"name": img['name']}) for img in img_data]

# Split
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=400, chunk_overlap=50
)

doc_splits = text_splitter.split_documents(docs_list)
img_splits = text_splitter.split_documents(img_list)

In [53]:
# Add to vectorstore
vectorstore = Chroma.from_documents(
    documents=doc_splits + img_splits, # adding the both text and image splits
    collection_name="multi_model_rag",
    embedding=embeddings,
)

retriever = vectorstore.as_retriever(
                search_type="similarity",
                search_kwargs={'k': 1}, # number of documents to retrieve
            )

In [57]:
query = (
    "Interpret Figure 3 from the Transfer Learning PDF: "
    "what do the boxes and arrows represent, and how do they illustrate knowledge transfer "
    "from the source to the target domain?"
)  
docs = retriever.invoke(query)

In [58]:
print(docs[0].page_content)
print(docs[0].metadata)

A PREPRINT - JANUARY 14, 2025
Figure 3: Text prompt template for transfer learning using DistilGPT2 [22].
(a) All LLM weights are trainable
(b) All LLM weights are frozen
Figure 4: Convergence plots of proposed transfer learning of tabular data using a large language model (LLM)
Although in-context learning via FeatLLM uses one of the most up-to-date LLMs, its performance may be significantly
limited by the maximum token size. In-context learning via LLM API does not involve any finetuning or transfer
learning. Without transfer learning, text input prompts (0.733 (0.021)) are able to achieve a better performance than the
best overall GBT model (0.711(0.039)) on the blood transfusion data set. The shot and sample ratio (ssr) is 60% for this
data set. The performance of in-context learning is on par with GBT on the dermatology (ssr: 11.5%), breast cancer
(ssr: 8.4%), and diabetes (31%) data sets. The LLM may have some prior knowledge about these medical domains or
data sets, which may ha

In [59]:
from langchain_core.output_parsers import StrOutputParser
from langchain_cohere import ChatCohere

# Prompt
system = """You are an assistant for question-answering tasks based on academic papers. 
Always ground your answer strictly in the retrieved evidence below—do not use outside knowledge.
If evidence is from an image or table summary, reference the figure or table number if present."""
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Retrieved documents: \n\n <docs>{documents}</docs> \n\n User question: <question>{question}</question>"),
    ]
)

# LLM
llm = ChatCohere(
    model="command-r-plus",
    temperature=0,
    
)

# Chain
rag_chain = prompt | llm | StrOutputParser()

# Run
generation = rag_chain.invoke({"documents":docs[0].page_content, "question": query})
print(generation)

Figure 3 illustrates the process of transfer learning using a Large Language Model (LLM). The boxes represent different components or steps in the process, and the arrows indicate the direction of knowledge transfer and the flow of information.

Here's a breakdown of the elements in the figure:

- The leftmost box represents the "Source Domain," which contains the knowledge and information from the pre-trained LLM. This box has an arrow pointing to the right, indicating the transfer of knowledge to the target domain.
- The rightmost box is the "Target Domain," which is the specific task or domain to which we want to apply the knowledge from the source domain.
- The arrows connecting the two domains represent the knowledge transfer process. The arrow from the source to the target indicates that the knowledge learned from the source domain is being applied to the target domain.
- The arrow from the target domain back to the source domain indicates that the target task also influences the