In [1]:
%%capture
!pip install unsloth_zoo unsloth vllm
!pip install bitsandbytes datasets rank_bm25 datasets scikit-learn gradio

In [2]:
from unsloth import FastLanguageModel
import pandas as pd
from datasets import load_dataset
from rank_bm25 import BM25Okapi
from sklearn.feature_extraction.text import TfidfVectorizer
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
import gradio as gr
from sentence_transformers import SentenceTransformer
import numpy as np

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 05-20 08:36:21 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-20 08:36:21 [__init__.py:239] Automatically detected platform cuda.


# Load Document Corpus

In [3]:
lb_all = load_dataset("ShoAnn/indonesian-criminal-law-statutory")

corpus = [f"{lb_all['train']['text'][i]} >> {lb_all['train']['fulltext'][i]}" for i in range(len(lb_all['train']))]

In [7]:
def build_sparse_index():
    print("Starting TF-IDF vectorization (indexing)...")
    # Initialize TF-IDF Vectorizer
    tfidf_vectorizer = TfidfVectorizer(lowercase=True)

    tfidf_matrix = tfidf_vectorizer.fit_transform(corpus)
    print(f"TF-IDF matrix shape: {tfidf_matrix.shape}")
    print("TF-IDF indexing finished.")

    print("Starting BM25 indexing...")
    tokenized_corpus_bm25 = [doc.lower().split() for doc in corpus]
    bm25 = BM25Okapi(tokenized_corpus_bm25)
    print("BM25 index built.")
    return tfidf_vectorizer, tfidf_matrix, bm25

def prepare_model_and_embedding():

    base_model = SentenceTransformer("indobenchmark/indobert-base-p1")
    model = SentenceTransformer("ShoAnn/indobert-base-p1-legalqa-retriever")
    corpus_embeddings_base = base_model.encode(corpus, convert_to_tensor=True)
    corpus_embeddings = model.encode(corpus, convert_to_tensor=True)
    return base_model, model, corpus_embeddings_base, corpus_embeddings

def load_unsloth_model(model_name, max_seq_length=2048, load_in_4bit=True, load_in_8bit=False, full_finetuning=False):
    """Loads a model using the unsloth library."""
    if "gemma" in model_name:
        from unsloth import FastModel
        model_loader = FastModel.from_pretrained
    else:
        model_loader = FastLanguageModel.from_pretrained

    model, tokenizer = model_loader(
        model_name = model_name,
        max_seq_length = max_seq_length,
        load_in_4bit = load_in_4bit,
        load_in_8bit = load_in_8bit,
        full_finetuning = full_finetuning,
    )
    from unsloth.chat_templates import get_chat_template
    if "gemma" in model_name:
        tokenizer = get_chat_template(
            tokenizer,
            chat_template = "gemma-3",
        )
    return model, tokenizer

def load_transformers_model(base_model_id, model_id="", peft=False):
    """Loads a model using the transformers library, with optional PEFT."""
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    base_model = AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config)

    if peft:
        model = PeftModel.from_pretrained(base_model, model_id)
    else:
        model = base_model

    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    return model, tokenizer

def load_model(name):
    if "unsloth" in name:
        if "legalqa" in name:
            model, tokenizer = load_unsloth_model(f"ShoAnn/{name}")
        else:
            model, tokenizer = load_unsloth_model(f"unsloth/{name}")
    else:
        if "finetuned" in name:
            base_model_list = ["SeaLLMs/SeaLLMs-v3-7B-Chat", "aisingapore/Llama-SEA-LION-v3.5-8B-R"]
            for base_model in base_model_list:
                if name in base_model:
                    model, tokenizer = load_transformers_model(base_model, "ShoAnn/"+name, peft=True)
        else:
            model, tokenizer = load_transformers_model(name, peft=True)
    return model, tokenizer

instruction_prompt = "You are an AI Legal Assistant. Your task is to carefully analyze the provided **Legal Basis** below and answer given **Question** based *solely* and *exclusively* on the information contained within that context in Indonesian Language."
def format_to_message(question, input):
    message = []
    message.append(
        {'content': [{
            'type': 'text',
            'text': f'{instruction_prompt} \n **Question** \n{question} \n**Legal Basis** \n{input}'
        }], 'role': 'user'},
    )
    return message

retriever_model_list = [
    "TF-IDF",
    "BM25",
    "indobenchmark/indobert-base-p1",
    "ShoAnn/indobert-base-p1-legalqa-retriever",
]

generator_model_list = [
    "SeaLLMs-v3-7B-Chat",
    "Llama-SEA-LION-v3.5-8B-R",
    "gemma-3-4b-it-unsloth-bnb-4bit",
    "Qwen3-8B-unsloth-bnb-4bit",
    "SeaLLMs-v3-7B-Chat-legalqa",
    "Llama-SEA-LION-v3.5-8B-R-legalqa",
    "gemma-3-4b-it-unsloth-bnb-4bit-legalqa",
    "Qwen3-8B-unsloth-bnb-4bit-legalqa"
]



In [5]:
base_model, model, corpus_embeddings_base, corpus_embeddings = prepare_model_and_embedding()
tfidf_vectorizer, tfidf_matrix, bm25 = build_sparse_index()



Starting TF-IDF vectorization (indexing)...
TF-IDF matrix shape: (5474, 13320)
TF-IDF indexing finished.
Starting BM25 indexing...
BM25 index built.


# Load Models

In [8]:
# --- 2. RAG Core Logic ---

def retrieve_documents(query_text, retriever="indobert", k=5):
    """
    Retrieves the top k most relevant documents for a given query
    using simple cosine similarity search over a list of embeddings.
    """
    if not query_text or not corpus:
        return []
    if "bert" in retriever.lower():
        encoder = model if "legalqa" in retriever else base_model
        doc_embeddings = corpus_embeddings if "legalqa" in retriever else corpus_embeddings_base
        query_embedding = encoder.encode([query_text], convert_to_tensor=True) # Shape: (1, embedding_dim)
    else:
        doc_embeddings = tfidf_matrix
        query_embedding = tfidf_vectorizer.transform([query_text]) # Shape: (1, num_features)

    # Calculate cosine similarities between query embedding and all document embeddings
    similarity_scores = model.similarity(query_embedding, corpus_embeddings)[0]
    scores, indices = torch.topk(similarity_scores, k=k)
    retrieved_docs = []
    for score, idx in zip(scores, indices):
        retrieved_docs.append((corpus[idx], score.detach().cpu().item()))
    return retrieved_docs


def rag_app_interface(user_query, selected_retriever_model, selected_generator_model_name, num_docs_to_retrieve):
    if not user_query:
        return "Please enter a query.", "", "" # Matches 3 outputs

    # 1. Retrieve relevant documents
    retrieved_docs_with_scores = retrieve_documents(user_query, selected_retriever_model, k=int(num_docs_to_retrieve))

    # Format retrieved documents for display and for prompt
    formatted_retrieved_docs_display = "## Retrieved Documents:\n\n"
    context_for_prompt = ""
    if retrieved_docs_with_scores:
        for i, (doc_text, doc_score) in enumerate(retrieved_docs_with_scores):
            formatted_retrieved_docs_display += f"**Document {i+1} (Score: {doc_score:.4f}):**\n{doc_text.strip()}\n\n"
            context_for_prompt += f"Document {i+1}:\n{doc_text.strip()}\n\n"
    else:
        formatted_retrieved_docs_display += "No relevant documents found."
        context_for_prompt = "No relevant documents found."


    generator_llm, generator_tokenizer = load_model(selected_generator_model_name)

    # Example of creating a prompt string (adjust as needed for your specific model/template)
    # This is often model-specific. Some models use special tokens.
    # Using apply_chat_template is generally preferred if available and correctly used.
    # For this example, assuming format_to_message gives a list of dicts for chat template:
    # chat_messages = [
    #    {"role": "user", "content": f"Based on the following documents, answer the question.\n\nDocuments:\n{context_for_prompt}\n\nQuestion: {user_query}"}
    # ]
    # If format_to_message creates this list:
    chat_messages = format_to_message(user_query, context_for_prompt)


    # `apply_chat_template` typically returns a string if tokenize=False
    formatted_prompt_str = generator_tokenizer.apply_chat_template(
        chat_messages,
        tokenize=False,
        add_generation_prompt=True # Adds the prompt for the assistant to start generating
    )

    # 3. Generate an answer
    inputs = generator_tokenizer([formatted_prompt_str], return_tensors="pt").to(generator_llm.device)
    input_ids_length = inputs.input_ids.shape[1]

    # Ensure pad_token_id is set for open-ended generation if not already handled in load_model
    if generator_tokenizer.pad_token_id is None:
        generator_tokenizer.pad_token_id = generator_tokenizer.eos_token_id

    with torch.no_grad(): # Important for inference
        outputs_tokens = generator_llm.generate(
                    **inputs,
                    max_new_tokens=512,
                    temperature=1.0,
                    top_p=0.95,
                    top_k=64,
                    pad_token_id=generator_tokenizer.pad_token_id
                )

    # Decode only the newly generated tokens
    generated_ids = outputs_tokens[0][input_ids_length:]
    final_answer_text = generator_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()

    # The Gradio interface expects 3 outputs
    return formatted_retrieved_docs_display, formatted_prompt_str, final_answer_text

# --- 4. Gradio UI Definition ---

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# RAG Legal Chatbot")

    with gr.Row():
        with gr.Column(scale=2):
            query_input = gr.Textbox(label="Pertanyaan:")
            retriever_selector = gr.Dropdown(
                label="Pilih Model Retriever:",
                choices=retriever_model_list,
                value=retriever_model_list[0]
            )
            generator_selector = gr.Dropdown(
                label="Pilih Model Generator:",
                choices=generator_model_list,
                value=generator_model_list[0]
            )
            docs_slider = gr.Slider(
                minimum=1, maximum=5, value=3, step=1,
                label="Jumlah Dokumen yang Ingin Diambil:"
            )
            submit_button = gr.Button("Submit", variant="primary")

        with gr.Column(scale=2):
            retrieved_docs_output = gr.Markdown(label="Retrieved Documents")
            full_prompt_output = gr.Markdown(label="Full Prompt to Generator")
            answer_output = gr.Markdown(label="Generated Answer")

    submit_button.click(
        fn=rag_app_interface,
        inputs=[query_input, retriever_selector, generator_selector, docs_slider],
        outputs=[retrieved_docs_output, full_prompt_output, answer_output]
    )

# --- 5. Launch the App ---
# For Colab, share=True is often useful to get a public link.
# For local use, you can omit share=True or set it to False.
print("Launching Gradio app...")
demo.launch(share=True, debug=True) # debug=True can be helpful during development

Launching Gradio app...
Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://6fb7a49d0dafa29e37.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


==((====))==  Unsloth 2025.5.6: Fast Gemma3 patching. Transformers: 4.51.3. vLLM: 0.8.5.post1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Using float16 precision for gemma3 won't work! Using float32.


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://6fb7a49d0dafa29e37.gradio.live


