<a href="https://colab.research.google.com/github/RyuMinHo/GAI_project/blob/main/prompt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pymupdf pymupdf4llm gradio faiss-gpu sentence-transformers semchunk vllm triton

In [None]:
import gradio as gr
import faiss
import numpy as np
import spacy
from sentence_transformers import SentenceTransformer
import os
import time
import semchunk
import pymupdf as fitz
import pymupdf4llm
from vllm import LLM, SamplingParams
from typing import List, Tuple, Dict, Optional
from PIL import Image
import hashlib
import logging

In [2]:
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", dtype='half', max_model_len=8192)
sampling_params = SamplingParams(temperature=0.7, max_tokens=512)
logging.basicConfig(level=logging.INFO)
"""
PDF 파일 RAG를 위한 Pipeline class
"""
class RAGPipeline:
    def __init__(self):
        # lava-hf/llava-v1.6-mistral-7b-hf를 사용
        self.llm = llm

        # Sampling parameters 설정
        self.sampling_params = sampling_params

        # embedding
        self.embedder = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
        self.chunker = semchunk.chunkerify('gpt-4', 200)
        self.index = faiss.IndexFlatL2(self.embedder.get_sentence_embedding_dimension())
        self.chunks = []
        self.processed_files = {} # {file_hash: file_path}

    # Gradio는 file을 업로드할 때 임시 경로를 사용하므로, 파일의 hash를 통해 중복을 확인하기 위함
    def get_file_hash(self, file_path: str) -> str:
        with open(file_path, "rb") as f:
            return hashlib.md5(f.read()).hexdigest()


    """
    PDF indexing
    - PDF를 입력받아, pymupdf4llm을 통해 Markdown으로 변환
    - Markdown을 semchunk를 통해 Chunk로 나눔
    - Chunk를 SentenceTransformer를 통해 embedding
    - embedding 결과를 faiss index에 추가
    """

    def indexing_pdf(self, pdf_path: List[str]):

        # 이미 indexing을 진행했던 파일인지 확인
        for pdf in pdf_path:
            try:
                file_hash = self.get_file_hash(pdf)
                if file_hash in self.processed_files:
                    logging.info(f"{pdf} has already been processed before")
                    continue

                # indexing한 적이 없는 파일인 경우, {file_hash : pdf경로} 를 self.processed_files에 추가
                self.processed_files[file_hash] = pdf
                logging.info(f"Processing new file: {pdf}")

                # PDF를 로드하여 Markdown으로 변환
                doc = fitz.open(pdf)
                markdown_text = pymupdf4llm.to_markdown(doc)
                doc.close()

                # 변환된 markdown_text를 Chunk로 나누기
                chunks = self.chunker(markdown_text)

                # chunks를 list에 추가
                self.chunks.extend(chunks)

                # chunks를 embedding
                chunks_embeddings = self.embedder.encode(chunks)

                # chunks를 embedding한 결과를 faiss index에 추가
                self.index.add(chunks_embeddings)
            except Exception as e:
                logging.error(f"Error in indexing {pdf_path}: {e}")

        logging.info(f"Processed {len(pdf_path)} files. Total unique files: {len(self.processed_files)}")

    """
    query 처리 및 검색
    - query를 SentenceTransformer를 통해 embedding
    - embedding 결과를 faiss index를 통해 top_k개 검색
    - 검색 결과를 반환
    """
    def process_query(self, query: str, top_k: int = 5) -> List[str]:
        # 쿼리 임베딩 및 관련 컨텍스트 검색
        query_embedding = self.embedder.encode([query])
        distances, indices = self.index.search(query_embedding, top_k)
        return [self.chunks[i] for i in indices[0]]

    #쿼리를 키워드에 따라 분류
def classify_query(query):
    query = query.lower()
    categories_keywords = {
        "Summary Request": ["summary", "summarize", "overview", "key points", "main content", "gist"],
        "Information Retrieval": ["definition", "meaning", "explain", "reason", "clarify", "purpose"],
        "Calculation/Analysis Request": ["analysis", "calculate", "steps", "process", "formula", "problem"],
        "Comparison Request": ["compare", "difference", "similarity", "better", "features", "versus"],
        "Task Instruction": ["write", "code", "example", "generate", "task", "perform", "how to", "steps to", "create"]
    }
    for category, keywords in categories_keywords.items():
        if any(keyword in query for keyword in keywords):
            return category
    return "Uncategorized"


    """
    Prompt template 함수
    - query의 유형을 나누고 각 유형마다 다른 intruction을 추가
    - query와 context를 입력받아, LLM에 입력할 prompt를 생성
    - prompt에는 system message, context, query, instruction이 포함
    - 생성된 prompt를 반환
    """
    def prompt_template(self, query: str, context: List[str]) -> str:
        query_type = self.classfiy_query(query)

        if query_type == "Summary Request":
          instructions = (
            "- Summarize the key concepts of the context in a concise and clear manner.\n"
            "- Avoid repeating the same information.\n"
            "- Summarize each paragraph into one sentence to maintain clarity and focus on the main points.\n"
            "- Provide a step-by-step summary of the process or procedure described in the context, if applicable.\n"
            "- You are an AI personal assistant. Respond in a clear and professional manner."
        )
        elif query_type == "Information Retrieval":
          instructions = (
            "- Provide a clear and concise answer to the user's question based on the given context.\n"
            "- Focus on accuracy and omit irrelevant information to provide a precise answer.\n"
            "- If the context lacks sufficient information, respond with 'The information is insufficient.'\n"
            "- Do not fabricate or add information that is not present in the given context.\n"
            "- For example, if the context states 'June 5th' as the date, your response should also use 'June 5th' rather than rephrasing it as 'June 5th, 2001' or '5th of June.'\n"
            "- You are an AI personal assistant. Respond in a clear and professional manner."
        )
        elif query_type == "Calculation/Analysis Request":
          instructions = (
            "- Write the calculation or analysis process step-by-step in the format 'step 1:', 'step 2:', etc.\n"
            "- Ensure that your explanation is logical and easy to understand.\n"
            "- If the context lacks information required for the calculation or analysis, clearly identify which part of the information is missing and explain why it is needed.\n"
            "- Clearly state the final result of the calculation.\n"
            "- Accurately indicate the units (e.g., $, %, km) of the results.\n"
            "- Assume you are a teacher explaining the calculation or analysis to a student, and provide a clear and helpful explanation."
        )
        elif query_type == "Comparison Request":
          instructions = (
            "- Explain the similarities and differences between the items or concepts in a logical and structured way.\n"
            "- If necessary, provide advantages and disadvantages for each item in a table or bullet-point format.\n"
            "- Maintain fairness and objectivity, avoiding any biased statements.\n"
            "- Clearly state the criteria for comparison (e.g., speed, cost, efficiency) to ensure clarity.\n"
            "- Use concise and clear language to make the comparison easy to understand.\n"
            "- Assume you are a teacher explaining the comparison to a student, and provide a clear and insightful explanation."
        )
        elif query_type == "Task Instruction":
          instructions = (
            "- Explain the steps required to perform the requested task in a clear and sequential manner.\n"
            "- Include code snippets, examples, or commands where necessary to make the instructions actionable.\n"
            "- Ensure the explanation is logical and easy to follow, avoiding unnecessary complexity.\n"
            "- If the context lacks sufficient information, specify what additional information is required to proceed.\n"
            "- Present the response as a professional AI assistant, ensuring the instructions are clean and precise.\n"
            "- Clearly describe the expected outcome or result of the task.\n"
            "- Provide additional tips or warnings if applicable."
        )

        system_message = """You are an AI assistant tasked with answering questions based on provided context. Your role is to:
                            1. Carefully analyze the given context
                            2. Provide accurate and relevant information
                            3. Synthesize a coherent response
                            4. Maintain objectivity and clarity
                            If the context doesn't contain sufficient information, state so clearly."""

        context_str = "\n".join([f"Context {i+1}: {ctx}" for i, ctx in enumerate(context)])

        prompt = f"""[INST] {system_message}

            Relevant information:
            {context_str}

            User's Quetion: {query}

            Instructions:
            - Answer the query using only the information provided in the context.
            - If the context doesn't contain enough information to fully answer the query, acknowledge this limitation in your response.
            - Provide a concise yet comprehensive answer.
            - Do not introduce information not present in the given context.
            - Privide in complete sentences in English always.
            - Check once again your response so that the user can be provided precise information.
            {instructions}

            Please provide your response below:
            [/INST]"""

        return prompt

    """
    query와 context를 입력받아, LLM을 통해 답변 생성
    """
    def generate_response(self, query: str, context: List[str]) -> str:
        prompt = self.prompt_template(query, context)
        output = self.llm.generate([prompt], self.sampling_params)
        return output[0].outputs[0].text

    def answer_query(self, query: str, top_k: int = 5) -> str:
        retrieved_contexts = self.process_query(query, top_k)
        return self.generate_response(query, retrieved_contexts)

class LLaVAImageQAProcessor:
    def __init__(self):
        self.llm = llm
        self.sampling_params = sampling_params

    def get_prompt(self, question: str):
        return f"""[INST]
                    Explain me about this image precisely in bullet points.
                    Your response shuold be in complete sentences.
                    <image>\n{question} [/INST]"""

    def process_image(self, image_path, question):
        prompt = self.get_prompt(question)
        image = Image.open(image_path)

        inputs = {
            "prompt": prompt,
            "multi_modal_data": {"image": image},
        }

        outputs = self.llm.generate(inputs, sampling_params=self.sampling_params)
        return outputs[0].outputs[0].text
