In [None]:
import os
import uuid
import base64
import fitz  # PyMuPDF
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
from PyPDF2 import PdfReader
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import HumanMessage

# --- 1. MMD and PDF Processing Functions ---

def read_mmd_file(mmd_path):
    """
    Extracts \begin{table} ... \end{table} blocks from an MMD file.
    (Your original function)
    """
    try:
        with open(mmd_path, encoding='utf-8') as f:
            lines = f.readlines()
    except Exception as e:
        print(f"Error reading MMD file {mmd_path}: {e}")
        return []
        
    res = []
    tmp = []
    flag = ""
    for line in lines:
        if line == "\\begin{table}\n":
            flag = "BEGINTABLE"
        elif line == "\\end{table}\n":
            flag = "ENDTABLE"
        
        if flag == "BEGINTABLE":
            tmp.append(line)
        elif flag == "ENDTABLE":
            tmp.append(line)
            flag = "CAPTION"
        elif flag == "CAPTION":
            tmp.append(line)
            flag = "MARKDOWN"
            res.append(''.join(tmp))
            tmp = []
    
    return res

def process_mmd_pdf_pairs(pdf_path, mmd_path):
    """
    Process PDF-MMD pairs:
    1. Extract text from PDF (using PyPDF2)
    2. Extract tables from MMD (using read_mmd_file)
    3. (New Feature) Extract images from PDF (using PyMuPDF/fitz) and encode to base64
    """
    pdf_id = str(uuid.uuid4())
    texts_with_metadata = []
    tables_with_metadata = []
    images_with_metadata = []

    # --- 1. Process PDF file, extract text ---
    try:
        reader = PdfReader(pdf_path)
        pages = [page.extract_text() for page in reader.pages if page.extract_text()]
        texts_with_metadata.extend([{"content": page, "metadata": {"pdf_id": pdf_id, "page_num": i+1}} for i, page in enumerate(pages)])
    except Exception as e:
        print(f"Error processing PDF text for {pdf_path} with PyPDF2: {e}")

    # --- 2. Process MMD file, extract tables ---
    tables = read_mmd_file(mmd_path)
    tables_with_metadata.extend([{"content": table, "metadata": {"pdf_id": pdf_id}} for table in tables])

    # --- 3. (New Feature) Process PDF file, extract images ---
    try:
        doc = fitz.open(pdf_path)
        for page_num in range(len(doc)):
            page = doc.load_page(page_num)
            image_list = page.get_images(full=True)
            
            for img_index, img in enumerate(image_list):
                xref = img[0]
                try:
                    base_image = doc.extract_image(xref)
                    image_bytes = base_image["image"]
                    image_ext = base_image["ext"]
                    
                    # Encode image bytes to base64 string
                    base64_image = base64.b64encode(image_bytes).decode('utf-8')
                    
                    images_with_metadata.append({
                        "content": base64_image, 
                        "metadata": {
                            "pdf_id": pdf_id, 
                            "page_num": page_num + 1,
                            "image_index": img_index,
                            "format": image_ext
                        }
                    })
                except Exception as e:
                    print(f"Error extracting image (xref={xref}) from {pdf_path}: {e}")
        doc.close()
    except Exception as e:
        print(f"Error opening or processing PDF with PyMuPDF for {pdf_path}: {e}")

    return texts_with_metadata, tables_with_metadata, images_with_metadata, pdf_id

def process_folder(input_dir, output_dir, max_workers=4):
    """
    (Updated)
    Process all PDF-MMD pairs in the folder in parallel, extracting text, tables, and images.
    """
    pdf_files = [f for f in os.listdir(input_dir) if f.endswith('.pdf')]
    mmd_files = [f for f in os.listdir(output_dir) if f.endswith('.mmd')]

    print(f"Found {len(pdf_files)} PDF files and {len(mmd_files)} MMD files in the folders.")

    all_texts = []
    all_tables = []
    all_images = [] # Added
    pdf_ids = []

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = []
        for pdf_file in pdf_files:
            pdf_name = os.path.splitext(pdf_file)[0]
            mmd_file = f"{pdf_name}.mmd"
            
            if mmd_file in mmd_files:
                pdf_path = os.path.join(input_dir, pdf_file)
                mmd_path = os.path.join(output_dir, mmd_file)
                # print(f"Matching pair found: {pdf_file} and {mmd_file}")
                futures.append(executor.submit(process_mmd_pdf_pairs, pdf_path, mmd_path))
            else:
                print(f"No matching MMD file found for: {pdf_file}")

        for future in tqdm(futures, desc="Processing PDF-MMD-Image pairs"):
            try:
                result = future.result()
                if result:
                    texts, tables, images, pdf_id = result # Updated
                    all_texts.extend(texts)
                    all_tables.extend(tables)
                    all_images.extend(images) # Added
                    pdf_ids.append(pdf_id)
            except Exception as e:
                print(f"Error in processing future: {e}")

    return all_texts, all_tables, all_images, pdf_ids # Updated

# --- 2. Main Execution Script ---

# Set input and output directories
input_dir = "greenhouse/test"
output_dir = "greenhouse/test"

# Process the entire folder
texts, tables, images, pdf_ids = process_folder(input_dir, output_dir) # Updated

# Print examples of extracted results
print(f"\nProcessed {len(pdf_ids)} PDF-MMD pairs.")
print(f"Total text chunks: {len(texts)}")
print(f"Total tables: {len(tables)}")
print(f"Total images: {len(images)}")

print("\n--- Example Texts (first 3) ---")
for text in texts[:3]:
    print(f"Text (PDF_ID: {text['metadata']['pdf_id']}, Page: {text['metadata']['page_num']}): {text['content'][:100]}...")

print("\n--- Example Tables (first 3) ---")
for table in tables[:3]:
    print(f"Table (PDF_ID: {table['metadata']['pdf_id']}): {table['content'][:100]}...")

print("\n--- Example Images (first 3) ---")
for image in images[:3]:
    print(f"Image (PDF_ID: {image['metadata']['pdf_id']}, Page: {image['metadata']['page_num']}, Format: {image['metadata']['format']}): {image['content'][:50]}... (base64 data)")

# --- 3. LangChain Summarization Section ---

# Check if content was extracted, skip summarization if not
if not texts and not tables and not images:
    print("\nNo content extracted. Skipping summarization.")
    exit()

# Define the model
# Ensure you have set the OPENAI_API_KEY environment variable
# gpt-4o-mini is a multimodal model that can handle text and images
try:
    model = ChatOpenAI(temperature=0, model="gpt-4o-mini")
except Exception as e:
    print(f"Error initializing OpenAI model: {e}")
    print("Please ensure OPENAI_API_KEY is set correctly.")
    exit()


# --- Text Summarization Chain ---
if texts:
    text_prompt_text = """You are an assistant tasked with summarizing text. \
    Give a concise summary of the text chunk. Text chunk: {element}"""
    text_prompt = ChatPromptTemplate.from_template(text_prompt_text)
    text_summarize_chain = {"element": lambda x: x} | text_prompt | model | StrOutputParser()
    
    print("\n--- Generating Text Summaries ---")
    text_summaries = text_summarize_chain.batch([text['content'] for text in texts], {"max_concurrency": 5})
    print("Text Summaries:")
    print(text_summaries)
else:
    print("\nNo texts found to summarize.")

# --- Table Summarization Chain ---
if tables:
    table_prompt_text = """You are an assistant tasked with summarizing tables. \
    Give a concise summary of the table by forming logical and corresponding relationships rather than broad summaries. Table chunk: {element}"""
    table_prompt = ChatPromptTemplate.from_template(table_prompt_text)
    table_summarize_chain = {"element": lambda x: x} | table_prompt | model | StrOutputParser()
    
    print("\n--- Generating Table Summaries ---")
    table_summaries = table_summarize_chain.batch([table['content'] for table in tables], {"max_concurrency": 5})
    print("Table Summaries:")
    print(table_summaries)
else:
    print("\nNo tables found to summarize.")

# --- (New Feature) Image Summarization Chain ---
if images:
    def create_image_message(base64_image):
        """
        Create a message object for the multimodal model
        """
        return HumanMessage(
            content=[
                {
                    "type": "text",
                    "text": "You are an assistant tasked with summarizing images. Give a concise, one-sentence summary of what this image depicts."
                },
                {
                    "type": "image_url",
                    # We use "png" as a safe default for the MIME type,
                    # although the base64 string itself doesn't depend on it.
                    "image_url": f"data:image/png;base64,{base64_image}"
                },
            ]
        )

    # Image summarization chain
    # Input: a base64 string
    # 1. (lambda) Convert the base64 string to a HumanMessage object
    # 2. (model) Send the message to the multimodal model
    # 3. (StrOutputParser) Parse the model's text response
    image_summarize_chain = (
        lambda base64_image: create_image_message(base64_image)
        | model
        | StrOutputParser()
    )

    print("\n--- Generating Image Summaries ---")
    # Note: Image summarization can be slower and consume more tokens than text
    image_summaries = image_summarize_chain.batch([image['content'] for image in images], {"max_concurrency": 2}) # Lower concurrency to avoid rate limits
    print("Image Summaries:")
    print(image_summaries)

