<a href="https://colab.research.google.com/github/RyuMinHo/GAI_project/blob/main/gui_gai.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import gradio as gr
from vllm import LLM, SamplingParams
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from PIL import Image
import pymupdf as fitz
import pymupdf4llm
import hashlib
import logging

# LLM 초기화
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", dtype='half', max_model_len=8192)
sampling_params = SamplingParams(temperature=0.7, max_tokens=512)
logging.basicConfig(level=logging.INFO)

# RAGPipeline 클래스
class RAGPipeline:
    def __init__(self):
        self.llm = llm
        self.sampling_params = sampling_params
        self.embedder = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
        self.index = faiss.IndexFlatL2(self.embedder.get_sentence_embedding_dimension())
        self.chunks = []
        self.processed_files = {}

    def get_file_hash(self, file_path: str) -> str:
        with open(file_path, "rb") as f:
            return hashlib.md5(f.read()).hexdigest()

    def indexing_pdf(self, pdf_path: str):
        file_hash = self.get_file_hash(pdf_path)
        if file_hash in self.processed_files:
            logging.info(f"{pdf_path} has already been processed.")
            return
        self.processed_files[file_hash] = pdf_path

        doc = fitz.open(pdf_path)
        markdown_text = pymupdf4llm.to_markdown(doc)
        doc.close()
        chunks = markdown_text.split('\n')
        self.chunks.extend(chunks)
        embeddings = self.embedder.encode(chunks)
        self.index.add(embeddings)

    def process_query(self, query: str, top_k: int = 5):
        query_embedding = self.embedder.encode([query])
        distances, indices = self.index.search(query_embedding, top_k)
        return [self.chunks[i] for i in indices[0]]

    def prompt_template(self, query: str, context: list):
        prompt = f"""
        [INST]
        Answer the question based on the following context:
        {context}

        Question: {query}
        [/INST]
        """
        return prompt

    def generate_response(self, query: str, context: list):
        prompt = self.prompt_template(query, context)
        output = self.llm.generate([prompt], self.sampling_params)
        return output[0].outputs[0].text

    def answer_query(self, query: str, top_k: int = 5):
        context = self.process_query(query, top_k)
        return self.generate_response(query, context)


# LLaVAImageQAProcessor 클래스
class LLaVAImageQAProcessor:
    def __init__(self):
        self.llm = llm
        self.sampling_params = sampling_params

    def process_image(self, image_path, question):
        prompt = f"""[INST] <image>
        Explain this image in detail:

        Question: {question}
        [/INST]"""

        image = Image.open(image_path)
        inputs = {"prompt": prompt, "multi_modal_data": {"image": image}}
        outputs = self.llm.generate(inputs, sampling_params=self.sampling_params)
        return outputs[0].outputs[0].text


# UI 통합
def create_ui():
    pipeline = RAGPipeline()
    image_processor = LLaVAImageQAProcessor()

    with gr.Blocks() as demo:
        with gr.Row():
            chat_mode = gr.Radio(
                choices=["General Chat", "RAG Chat"],
                value="General Chat",
                label="Mode"
            )

        with gr.Column():
            chatbot = gr.Chatbot()

            with gr.Row():
                file_upload = gr.File(label="Upload File")
                msg_input = gr.Textbox(placeholder="Type your message here...")
                send_btn = gr.Button("Send")

        def process_input(chat_mode, file_upload, msg_input):
            if chat_mode == "General Chat" and file_upload and file_upload.name.endswith(('.jpg', '.jpeg', '.png')):
                response = image_processor.process_image(file_upload.name, msg_input)
                return chatbot.update([(msg_input, response)])
            elif chat_mode == "RAG Chat" and file_upload and file_upload.name.endswith('.pdf'):
                pipeline.indexing_pdf(file_upload.name)
                response = pipeline.answer_query(msg_input)
                return chatbot.update([(msg_input, response)])
            else:
                return chatbot.update([(msg_input, "Unsupported input or mode.")])

        send_btn.click(
            process_input,
            inputs=[chat_mode, file_upload, msg_input],
            outputs=chatbot
        )

    return demo


if __name__ == "__main__":
    demo = create_ui()
    demo.launch()
