# 06. 統合演習: RAG + Gradio UI

このノートブックでは、学習なしで **RAG + Gradio** を統合し、応答品質と根拠表示を確認します。


In [None]:
# 編集禁止セル
import os
import sys

from google.colab import drive

if not os.path.isdir('/content/drive'):
    drive.mount('/content/drive')

repo_path = '/content/llm_lab'
if os.path.exists(repo_path):
    !rm -rf {repo_path}
!git clone -b stable-base https://github.com/akio-kobayashi/llm_lab.git {repo_path}
os.chdir(repo_path)

!pip install -q -U transformers accelerate bitsandbytes sentence-transformers faiss-cpu gradio

if 'src' not in sys.path:
    sys.path.append(os.path.abspath('src'))

from src.common import load_llm
from src.rag import FaissRAGPipeline
from src.ui import create_gradio_ui

print('セットアップが完了しました。')


In [None]:
# 編集禁止セル
base_model, tokenizer, rag_pipeline = None, None, None

base_model, tokenizer = load_llm(use_4bit=True)

rag_pipeline = FaissRAGPipeline()
rag_pipeline.build_index('data/docs/anime_docs_sample.jsonl')

print('モデルとRAGインデックスの準備が完了しました。')


In [None]:
# 編集禁止セル
import torch

def safe_generate_local(model, tokenizer, prompt, max_new_tokens=128, repetition_penalty=1.05):
    try:
        inputs = tokenizer(prompt, return_tensors='pt')
        model_device = next(model.parameters()).device
        inputs = {k: v.to(model_device) for k, v in inputs.items()}

        pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id

        with torch.inference_mode():
            gen_ids = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                repetition_penalty=repetition_penalty,
                do_sample=False,
                use_cache=False,
                pad_token_id=pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
                remove_invalid_values=True,
                renormalize_logits=True,
            )

        return tokenizer.decode(gen_ids[0], skip_special_tokens=True).strip()
    except Exception as e:
        return f"Error: {e}"

def generate_plain(q):
    return safe_generate_local(base_model, tokenizer, f"### 指示:\n{q}\n\n### 応答:\n").strip()

def generate_rag(q):
    docs = rag_pipeline.search(q, top_k=2)
    prompt = rag_pipeline.create_prompt_with_context(q, docs)
    ans = safe_generate_local(base_model, tokenizer, prompt).strip()
    return ans, "\n\n".join([d['text'] for d in docs])


In [None]:
# 編集禁止セル
demo = create_gradio_ui(
    generate_func_plain=generate_plain,
    generate_func_rag=generate_rag,
)
demo.launch(share=True, debug=True)
