In [None]:
import gradio as gr
import faiss
import numpy as np
import spacy
from sentence_transformers import SentenceTransformer
import os
import time
import semchunk
import pymupdf as fitz
import pymupdf4llm
from vllm import LLM, SamplingParams
from typing import List, Tuple, Dict, Optional
from PIL import Image
import hashlib
import logging

In [None]:
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", dtype='half', max_model_len=8192)
sampling_params = SamplingParams(temperature=0.7, max_tokens=512)
logging.basicConfig(level=logging.INFO)
"""
PDF 파일 RAG를 위한 Pipeline class
"""
class RAGPipeline:
    def __init__(self):
        # lava-hf/llava-v1.6-mistral-7b-hf를 사용
        self.llm = llm

        # Sampling parameters 설정
        self.sampling_params = sampling_params

        # embedding
        self.embedder = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
        self.chunker = semchunk.chunkerify('gpt-4', 200)
        self.index = faiss.IndexFlatL2(self.embedder.get_sentence_embedding_dimension())
        self.chunks = []
        self.processed_files = {} # {file_hash: file_path}

    # Gradio는 file을 업로드할 때 임시 경로를 사용하므로, 파일의 hash를 통해 중복을 확인하기 위함
    def get_file_hash(self, file_path: str) -> str:
        with open(file_path, "rb") as f:
            return hashlib.md5(f.read()).hexdigest()
    
    """
    PDF indexing
    - PDF를 입력받아, pymupdf4llm을 통해 Markdown으로 변환
    - Markdown을 semchunk를 통해 Chunk로 나눔
    - Chunk를 SentenceTransformer를 통해 embedding
    - embedding 결과를 faiss index에 추가
    """
    
    def indexing_pdf(self, pdf_path: List[str]):

        # 이미 indexing을 진행했던 파일인지 확인
        for pdf in pdf_path:
            try:
                file_hash = self.get_file_hash(pdf)
                if file_hash in self.processed_files:
                    logging.info(f"{pdf} has already been processed before")
                    continue
                
                # indexing한 적이 없는 파일인 경우, {file_hash : pdf경로} 를 self.processed_files에 추가
                self.processed_files[file_hash] = pdf
                logging.info(f"Processing new file: {pdf}")

                # PDF를 로드하여 Markdown으로 변환
                doc = fitz.open(pdf)
                markdown_text = pymupdf4llm.to_markdown(doc)
                doc.close()
                
                # 변환된 markdown_text를 Chunk로 나누기
                chunks = self.chunker(markdown_text)

                # chunks를 list에 추가
                self.chunks.extend(chunks)

                # chunks를 embedding
                chunks_embeddings = self.embedder.encode(chunks)

                # chunks를 embedding한 결과를 faiss index에 추가
                self.index.add(chunks_embeddings)
            except Exception as e:
                logging.error(f"Error in indexing {pdf_path}: {e}")
        
        logging.info(f"Processed {len(pdf_path)} files. Total unique files: {len(self.processed_files)}")

    """
    query 처리 및 검색
    - query를 SentenceTransformer를 통해 embedding
    - embedding 결과를 faiss index를 통해 top_k개 검색
    - 검색 결과를 반환
    """
    def process_query(self, query: str, top_k: int = 5) -> List[str]:
        # 쿼리 임베딩 및 관련 컨텍스트 검색
        query_embedding = self.embedder.encode([query])
        distances, indices = self.index.search(query_embedding, top_k)
        return [self.chunks[i] for i in indices[0]]
    
    """
    Prompt template 함수
    - query와 context를 입력받아, LLM에 입력할 prompt를 생성
    - prompt에는 system message, context, query, instruction이 포함
    - 생성된 prompt를 반환
    """
    def prompt_template(self, query: str, context: List[str]) -> str:
        system_message = """You are an AI assistant tasked with answering questions based on provided context. Your role is to:
                            1. Carefully analyze the given context
                            2. Provide accurate and relevant information
                            3. Synthesize a coherent response
                            4. Maintain objectivity and clarity
                            If the context doesn't contain sufficient information, state so clearly."""

        context_str = "\n".join([f"Context {i+1}: {ctx}" for i, ctx in enumerate(context)])
        
        prompt = f"""[INST] {system_message}

            Relevant information:
            {context_str}

            User's Quetion: {query}

            Instructions:
            - Answer the query using only the information provided in the context.
            - If the context doesn't contain enough information to fully answer the query, acknowledge this limitation in your response.
            - Provide a concise yet comprehensive answer.
            - Do not introduce information not present in the given context.
            - Privide in complete sentences in English always.
            - Check once again your response so that the user can be provided precise information.

            Please provide your response below:
            [/INST]"""
            
        return prompt
    
    """
    query와 context를 입력받아, LLM을 통해 답변 생성
    """
    def generate_response(self, query: str, context: List[str]) -> str:
        prompt = self.prompt_template(query, context)
        output = self.llm.generate([prompt], self.sampling_params)
        return output[0].outputs[0].text
    
    def answer_query(self, query: str, top_k: int = 5) -> str:
        retrieved_contexts = self.process_query(query, top_k)
        return self.generate_response(query, retrieved_contexts)
    
class LLaVAImageQAProcessor:
    def __init__(self):
        self.llm = llm
        self.sampling_params = sampling_params

    def get_prompt(self, question: str):
        return f"""[INST] 
                    Explain me about this image precisely in bullet points. 
                    Your response shuold be in complete sentences.
                    <image>\n{question} [/INST]"""

    def process_image(self, image_path, question):
        prompt = self.get_prompt(question)
        image = Image.open(image_path)

        inputs = {
            "prompt": prompt,
            "multi_modal_data": {"image": image},
        }

        outputs = self.llm.generate(inputs, sampling_params=self.sampling_params)
        return outputs[0].outputs[0].text
