<a href="https://colab.research.google.com/github/RyuMinHo/GAI_project/blob/main/GAI_project_merge.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install gradio vllm transformers triton PyPDF2 Pillow sentence_transformers numpy typing faiss-gpu spacy pymupdf4llm fitz frontend tools semchunk



In [2]:
import gradio as gr
import faiss
import numpy as np
import spacy
from sentence_transformers import SentenceTransformer
import os
import time
import semchunk
import pymupdf as fitz
import pymupdf4llm
from vllm import LLM, SamplingParams
from typing import List, Tuple, Dict, Optional
from PIL import Image
import hashlib
import logging
import torch
import gc

Exception in thread Thread-5 (attachment_entry):
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/debugpy/server/api.py", line 237, in listen
    sock, _ = endpoints_listener.accept()
  File "/usr/lib/python3.10/socket.py", line 293, in accept
    fd, addr = self._accept()
TimeoutError: timed out

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.10/dist-packages/google/colab/_debugpy.py", line 52, in attachment_entry
    debugpy.listen(_dap_port)
  File "/usr/local/lib/python3.10/dist-packages/debugpy/public_api.py", line 31, in wrapper
    return wrapped(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/debugpy/server/api.py", line 143, in debug
    log.reraise

In [3]:
# 전역 변수 초기화
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", dtype='half', max_model_len=8192)
sampling_params = SamplingParams(temperature=0.7, max_tokens=512)
logging.basicConfig(level=logging.INFO)

"""
PDF 파일 RAG를 위한 Pipeline class
"""
class RAGPipeline:
    def __init__(self):
        # lava-hf/llava-v1.6-mistral-7b-hf를 사용
        self.llm = llm

        # Sampling parameters 설정
        self.sampling_params = sampling_params

        # embedding
        self.embedder = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
        self.chunker = semchunk.chunkerify('gpt-4', 200)
        self.index = faiss.IndexFlatL2(self.embedder.get_sentence_embedding_dimension())
        self.chunks = []
        self.processed_files = {} # {file_hash: file_path}

    def get_file_hash(self, file_path: str) -> str:
        with open(file_path, "rb") as f:
            return hashlib.md5(f.read()).hexdigest()

    def indexing_pdf(self, pdf_path: List[str]):
        for pdf in pdf_path:
            try:
                file_hash = self.get_file_hash(pdf)
                if file_hash in self.processed_files:
                    logging.info(f"{pdf} has already been processed before")
                    continue

                self.processed_files[file_hash] = pdf
                logging.info(f"Processing new file: {pdf}")

                doc = fitz.open(pdf)
                markdown_text = pymupdf4llm.to_markdown(doc)
                doc.close()

                chunks = self.chunker(markdown_text)
                self.chunks.extend(chunks)
                chunks_embeddings = self.embedder.encode(chunks)
                self.index.add(chunks_embeddings)
            except Exception as e:
                logging.error(f"Error in indexing {pdf_path}: {e}")

        logging.info(f"Processed {len(pdf_path)} files. Total unique files: {len(self.processed_files)}")

    def process_query(self, query: str, top_k: int = 5) -> List[str]:
        query_embedding = self.embedder.encode([query])
        distances, indices = self.index.search(query_embedding, top_k)
        return [self.chunks[i] for i in indices[0]]

    def prompt_template(self, query: str, context: List[str]) -> str:
        system_message = """You are an AI assistant tasked with answering questions based on provided context. Your role is to:
                            1. Carefully analyze the given context
                            2. Provide accurate and relevant information
                            3. Synthesize a coherent response
                            4. Maintain objectivity and clarity
                            If the context doesn't contain sufficient information, state so clearly."""

        context_str = "\n".join([f"Context {i+1}: {ctx}" for i, ctx in enumerate(context)])

        prompt = f"""[INST] {system_message}

            Relevant information:
            {context_str}

            User's Quetion: {query}

            Instructions:
            - Answer the query using only the information provided in the context.
            - If the context doesn't contain enough information to fully answer the query, acknowledge this limitation in your response.
            - Provide a concise yet comprehensive answer.
            - Do not introduce information not present in the given context.
            - Privide in complete sentences in English always.
            - Check once again your response so that the user can be provided precise information.

            Please provide your response below:
            [/INST]"""

        return prompt

    def generate_response(self, query: str, context: List[str]) -> str:
        prompt = self.prompt_template(query, context)
        output = self.llm.generate([prompt], self.sampling_params)
        return output[0].outputs[0].text

    def answer_query(self, query: str, top_k: int = 5) -> str:
        retrieved_contexts = self.process_query(query, top_k)
        return self.generate_response(query, retrieved_contexts)

class LLaVAImageQAProcessor:
    def __init__(self):
        self.llm = llm
        self.sampling_params = sampling_params

    def get_prompt(self, question: str):
        return f"""[INST]
                    Explain me about this image precisely in bullet points.
                    Your response should be in complete sentences.
                    [/INST]"""

    def process_image(self, image: Image.Image, question: str) -> str:
        prompt = self.get_prompt(question)
        try:
            inputs = {"prompt": prompt, "multi_modal_data": {"image": image}}
            outputs = self.llm.generate(inputs, self.sampling_params)
            return outputs[0].outputs[0].text.strip() if outputs else "Failed to generate response."
        finally:
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()




The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

INFO 12-03 09:49:49 config.py:350] This model supports multiple tasks: {'embedding', 'generate'}. Defaulting to 'generate'.
INFO 12-03 09:49:49 llm_engine.py:249] Initializing an LLM engine (v0.6.4.post1) with config: model='llava-hf/llava-v1.6-mistral-7b-hf', speculative_config=None, tokenizer='llava-hf/llava-v1.6-mistral-7b-hf', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=llava-hf/llav

tokenizer_config.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/41.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/552 [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

INFO 12-03 09:49:50 selector.py:135] Using Flash Attention backend.
INFO 12-03 09:49:50 model_runner.py:1072] Starting to load model llava-hf/llava-v1.6-mistral-7b-hf...
INFO 12-03 09:49:51 weight_utils.py:243] Using model weights format ['*.safetensors']


model-00001-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/380M [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/70.2k [00:00<?, ?B/s]

Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]


INFO 12-03 09:51:53 model_runner.py:1077] Loading model weights took 14.0785 GB
INFO 12-03 09:51:55 worker.py:232] Memory profiling results: total_gpu_memory=39.56GiB initial_memory_usage=14.62GiB peak_torch_memory=14.94GiB memory_usage_post_profile=14.66GiB non_torch_memory=0.57GiB kv_cache_size=20.09GiB gpu_memory_utilization=0.90
INFO 12-03 09:51:56 gpu_executor.py:113] # GPU blocks: 10288, # CPU blocks: 2048
INFO 12-03 09:51:56 gpu_executor.py:117] Maximum concurrency for 8192 tokens per request: 20.09x
INFO 12-03 09:51:58 model_runner.py:1400] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 12-03 09:51:58 model_runner.py:1404] If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
I

In [6]:
class SessionManager:
    def __init__(self):
        self.sessions = {
            "Example": {
                "history": [],
                "mode": "General Chat",
                "mode_locked": False,
                "rag_pipeline": RAGPipeline(),
                "img_processor": LLaVAImageQAProcessor()
            }
        }
        self.current_session = "Example"

    def create_session(self, session_name: str) -> bool:
        if session_name in self.sessions:
            return False

        self.sessions[session_name] = {
            "history": [],
            "mode": "General Chat",
            "mode_locked": False,
            "rag_pipeline": RAGPipeline(),
            "img_processor": LLaVAImageQAProcessor()
        }
        self.current_session = session_name
        return True

    def delete_session(self, session_name: str) -> Tuple[str, dict]:
        if len(self.sessions) <= 1:
            return None, None

        if session_name in self.sessions:
            del self.sessions[session_name]
            next_session = next(iter(self.sessions.keys()))
            self.current_session = next_session
            return next_session, self.sessions[next_session]
        return None, None

    def get_session(self, session_name: str) -> Optional[dict]:
        return self.sessions.get(session_name)

def create_ui():
    session_manager = SessionManager()

    custom_css = """
    .message-box {
        display: flex;
        align-items: center;
        gap: 0.5rem;
    }
    .file-btn {
        max-width: 40px;
    }
    .send-btn {
        max-width: 40px;
    }
    .selected-file {
        margin: 0.5rem 0;
        padding: 0.3rem;
        background: #f0f0f0;
        border-radius: 4px;
        font-size: 0.9em;
    }
    """

    with gr.Blocks(css=custom_css) as demo:
        with gr.Row():
            with gr.Column(scale=1):
                new_session_btn = gr.Button("+ New Session")
                session_title_input = gr.Textbox(
                    label="Session Title",
                    visible=False
                )
                with gr.Column(elem_classes="session-container"):
                    gr.Markdown("Sessions")
                    session_list = gr.Radio(
                        choices=["Example"],
                        value="Example",
                        label=""
                    )
                    delete_btn = gr.Button("🗑️ Delete Session")

            with gr.Column(scale=3):
                current_title = gr.Markdown("## Example")

                with gr.Row():
                    chat_mode = gr.Radio(
                        choices=["General Chat", "RAG Chat"],
                        value="General Chat",
                        label=""
                    )

                chatbot = gr.Chatbot(
                    height=400,
                    render_markdown=True,
                    show_copy_button=True,
                    bubble_full_width=False
                )

                # 파일 업로드, 메시지 입력, 전송 버튼을 한 줄로
                with gr.Row():
                    # General Chat 모드용 이미지 업로드
                    with gr.Column(scale=1, visible=True) as general_upload:
                        file_upload_image = gr.UploadButton(
                            "📎",
                            file_types=[".jpg", ".jpeg", ".png"]
                        )

                    # RAG Chat 모드용 PDF 업로드
                    with gr.Column(scale=1, visible=False) as rag_upload:
                        file_upload_pdf = gr.File(
                            label="PDF",
                            file_types=[".pdf"],
                            file_count="multiple"
                        )

                    # 메시지 입력창
                    msg = gr.Textbox(
                        show_label=False,
                        placeholder="메시지를 입력하세요...",
                        container=False,
                        scale=8
                    )

                    # 전송 버튼
                    send_btn = gr.Button("↑", scale=1)

                with gr.Row():
                    clear_btn = gr.Button("Clear Chat")

                with gr.Row(visible=True) as general_file_info:
                    selected_image = gr.Textbox(
                        label="Selected Image",
                        interactive=False
                    )

                with gr.Row(visible=False) as rag_file_info:
                    selected_pdf = gr.Textbox(
                        label="Selected PDF",
                        interactive=False
                    )

        def process_message(message, file_image, files_pdf, mode, history, session_name):
            try:
                session = session_manager.get_session(session_name)
                if not session:
                    return "세션을 찾을 수 없습니다."

                current_mode = session["mode"]

                if current_mode == "General Chat":
                    if file_image and file_image.name.lower().endswith(('.jpg', '.jpeg', '.png')):
                        with Image.open(file_image) as image:
                            if image.mode != 'RGB':
                                image = image.convert('RGB')
                            question = message if message.strip() else "이 이미지에 대해 설명해주세요."
                            return session["img_processor"].process_image(image, question)
                    else:
                        prompt = f"[INST] {message} [/INST]"
                        output = llm.generate([prompt], sampling_params)
                        return output[0].outputs[0].text.strip()

                else:  # RAG Chat mode
                    if files_pdf:
                        pdf_paths = [f.name for f in files_pdf]
                        session["rag_pipeline"].indexing_pdf(pdf_paths)
                        return f"{len(pdf_paths)}개의 PDF가 성공적으로 처리되었습니다. 이제 문서에 대해 질문할 수 있습니다."

                    return session["rag_pipeline"].answer_query(message)

            except Exception as e:
                logging.error(f"메시지 처리 중 오류: {str(e)}")
                return f"오류가 발생했습니다: {str(e)}"

        def chat_mode_change(mode, session_name):
            session = session_manager.get_session(session_name)
            if not session:
                return [gr.update()] * 5

            if session["mode_locked"]:
                gr.Warning("대화가 시작된 후에는 모드를 변경할 수 없습니다. 새 세션을 만들어주세요.")
                current_mode = session["mode"]
                is_general = current_mode == "General Chat"
                return [
                    gr.update(value=current_mode),
                    gr.update(visible=is_general),
                    gr.update(visible=not is_general),
                    gr.update(visible=is_general),
                    gr.update(visible=not is_general)
                ]

            session["mode"] = mode
            is_general = mode == "General Chat"
            return [
                gr.update(value=mode),
                gr.update(visible=is_general),
                gr.update(visible=not is_general),
                gr.update(visible=is_general),
                gr.update(visible=not is_general)
            ]

        def send_message(message, file_image, files_pdf, session_name, mode, history):
            if not message.strip() and not (file_image or files_pdf):
                return history, "", None, None, "", ""

            try:
                session = session_manager.get_session(session_name)
                if not session:
                    return history, "", None, None, "", ""

                if not session["mode_locked"] and (message.strip() or file_image or files_pdf):
                    session["mode_locked"] = True
                    session["mode"] = mode

                current_mode = session["mode"]
                response = process_message(message, file_image if current_mode == "General Chat" else None,
                                        files_pdf if current_mode == "RAG Chat" else None,
                                        current_mode, history, session_name)

                if current_mode == "General Chat" and file_image:
                    history.append(((file_image.name, file_image), message if message.strip() else None))
                elif current_mode == "RAG Chat" and files_pdf:
                    pdf_names = [f.name for f in files_pdf]
                    history.append((f"Uploaded PDFs: {', '.join(pdf_names)}", None))
                else:
                    history.append((None, message))

                history.append((None, response))
                session["history"] = history

                return history, "", None, None, "", ""

            except Exception as e:
                logging.error(f"메시지 전송 중 오류: {str(e)}")
                return history, "", None, None, "", ""

        def clear_chat():
            return [], "", ""

        # 이벤트 바인딩
        new_session_btn.click(
            lambda: gr.update(visible=True),
            outputs=session_title_input
        )

        session_title_input.submit(
            add_session,
            inputs=[session_title_input],
            outputs=[session_title_input, session_list]
        )

        session_list.change(
            switch_session,
            inputs=[session_list],
            outputs=[current_title, chatbot, chat_mode]
        )

        delete_btn.click(
            delete_session,
            inputs=[session_list],
            outputs=[session_list, current_title, chatbot, chat_mode]
        )

        send_btn.click(
            send_message,
            inputs=[msg, file_upload_image, file_upload_pdf, session_list, chat_mode, chatbot],
            outputs=[chatbot, msg, file_upload_image, file_upload_pdf, selected_image, selected_pdf]
        )

        msg.submit(
            send_message,
            inputs=[msg, file_upload_image, file_upload_pdf, session_list, chat_mode, chatbot],
            outputs=[chatbot, msg, file_upload_image, file_upload_pdf, selected_image, selected_pdf]
        )

        chat_mode.change(
            chat_mode_change,
            inputs=[chat_mode, session_list],
            outputs=[chat_mode, general_upload, rag_upload, general_file_info, rag_file_info]
        )

        clear_btn.click(
            clear_chat,
            outputs=[chatbot, selected_image, selected_pdf]
        )

        return demo

# GUI 실행
demo = create_ui()
demo.launch()

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 39.56 GiB of which 14.81 MiB is free. Process 18308 has 39.53 GiB memory in use. Of the allocated memory 38.14 GiB is allocated by PyTorch, with 58.00 MiB allocated in private pools (e.g., CUDA Graphs), and 141.22 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)