!pip install vllm transformers triton PyPDF2 Pillow sentence_transformers numpy typing faiss-gpu semchunk gradio docling pymupdf4llm fitz frontend tools

In [None]:
import gradio as gr
from PyPDF2 import PdfReader
import semchunk
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from vllm import LLM, SamplingParams
from typing import List, Tuple
import time
#from docling.document_converter import DocumentConverter
#import fitz
import pymupdf4llm  # PyMuPDF4LLM
import pymupdf as fitz
import os

if not os.path.exists('static'):
    os.makedirs('static')

In [None]:
# 모델 및 토크나이저 초기화
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", dtype='half', max_model_len=8192)
embedder = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')

In [None]:
# FAISS 인덱스 초기화
dim = embedder.get_sentence_embedding_dimension()
print(dim)
index = faiss.IndexFlatL2(dim)
"""
def process_pdf(file_path: str, query: str = None) -> List[str]:
    # PDF를 Markdown으로 변환
    converter = DocumentConverter()
    result = converter.convert(file_path)
    markdown_text = result.document.export_to_markdown()

    # 청킹
    chunk_size = 200
    chunker = semchunk.chunkerify('umarbutler/emubert', chunk_size)
    chunks = chunker(markdown_text)
    print(chunks)

    # 임베딩 및 인덱싱
    chunks_embeddings = embedder.encode(chunks)
    index.add(chunks_embeddings)

    if query:
        # 쿼리 처리
        query_embedding = embedder.encode([query])
        top_k = 5
        distances, indices = index.search(query_embedding, top_k)
        results = [chunks[idx] for idx in indices[0] if idx < len(chunks)]
        return results

    return chunks
"""
def process_pdf(file_path: str, query: str = None) -> List[str]:
    # PDF를 Markdown으로 변환
    doc = fitz.open(file_path)
    markdown_text = pymupdf4llm.to_markdown(doc)
    doc.close()

    # 청킹
    chunk_size = 200
    chunker = semchunk.chunkerify('umarbutler/emubert', chunk_size)
    chunks = chunker(markdown_text)
    print(chunks)

    # 임베딩 및 인덱싱
    chunks_embeddings = embedder.encode(chunks, batch_size=32)
    index.add(chunks_embeddings)

    if query:
        # 쿼리 처리
        query_embedding = embedder.encode([query])
        top_k = 5
        distances, indices = index.search(query_embedding, top_k)
        results = [chunks[idx] for idx in indices[0] if idx < len(chunks)]
        return results

    return chunks

In [None]:
# 대화 히스토리 저장
conversation_history = []

def generate_answer(question: str, context: str = "") -> str:
    global conversation_history

    if context:
        prompt = f"""[INST] You are an AI assistant specialized in analyzing documents. Your task is to answer the following question based solely on the provided context. Follow these guidelines:

1. Carefully analyze the given context.
2. If the context contains relevant information, provide a clear and concise answer.
3. If the context lacks relevant information, explicitly state: "The provided context does not contain information to answer this question."
4. Do not use any external knowledge or make assumptions beyond the given context.
5. If the question requires clarification, ask for more specific details.

Previous conversation:
{' '.join(conversation_history)}

Context: {context}

Question: {question}

Answer: [/INST]"""
    else:
        prompt = f"""[INST] You are an AI assistant with expertise in various fields. Please follow these steps to respond:

1. Analyze the user's question carefully.
2. Identify the main topic and any subtopics in the question.
3. Provide a clear, concise, and informative answer.
4. If the question is ambiguous, ask for clarification before answering.
5. If you're unsure about any part of the answer, explicitly state your uncertainty.

Previous conversation:
{' '.join(conversation_history)}

User's question: {question}

Your response: [/INST]"""

    inputs = {"prompt": prompt}
    sampling_params = SamplingParams(temperature=0.2, max_tokens=256)
    outputs = llm.generate(inputs, sampling_params=sampling_params)
    return outputs[0].outputs[0].text

def add_message(history, message):
    global conversation_history

    for file in message["files"]:
        try:
            history.append({"role": "user", "content": {"path": file}})
            text = message.get("text", None)
            if text:
                history.append({"role": "user", "content": text})
                conversation_history.append(f"User: {text}")

            chunks = process_pdf(file, text)
            print(chunks)
            #context = " ".join(chunks) if chunks else ""
            #result = generate_answer(text, context)
            #history.append({"role": "assistant", "content": result})
            #conversation_history.append(f"Assistant: {result}")

        except Exception as e:
            history.append({"role": "assistant", "content": f"Invalid PDF file: {str(e)}"})

    if not message["files"] and message["text"]:
        history.append({"role": "user", "content": message["text"]})
        conversation_history.append(f"User: {message['text']}")
        #result = generate_answer(message["text"])
        #history.append({"role": "assistant", "content": result})
        #conversation_history.append(f"Assistant: {result}")

    return history, gr.MultimodalTextbox(value=None, interactive=False)

def bot(history: list):
    response = history[-1]["content"]
    rsps = generate_answer(response)
    history.append({"role": "assistant", "content": ""})
    for character in rsps:
        history[-1]["content"] += character
        time.sleep(0.05)
        yield history

In [None]:
# Gradio 인터페이스 설정
with gr.Blocks() as demo:
    chatbot = gr.Chatbot(elem_id="chatbot", bubble_full_width=False, type="messages")
    chat_input = gr.MultimodalTextbox(
        interactive=True,
        file_count="multiple",
        placeholder="Enter message or upload file...",
        show_label=False,
    )
    chat_msg = chat_input.submit(
        add_message, [chatbot, chat_input], [chatbot, chat_input]
    )
    bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name="bot_response")
    bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])

demo.launch()