In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
%%capture
# !pip install unsloth_zoo unsloth vllm
!pip install -U bitsandbytes datasets

In [3]:
# from unsloth import FastModel
# from unsloth.chat_templates import get_chat_template
# model, tokenizer = FastModel.from_pretrained(
#     model_name = "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
#     max_seq_length = 2048, # Choose any for long context!
#     load_in_4bit = True,  # 4 bit quantization to reduce memory
#     load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
#     full_finetuning = False, # [NEW!] We have full finetuning now!
#     # token = "hf_...", # use one if using gated models
# )

In [5]:
#@title for non unsloth models
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig
from peft import PeftModel
import torch

# base
model_id = "aisingapore/Llama-SEA-LION-v3.5-8B-R"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)

# fine-tuned
model_base = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
peft_model_id = "ShoAnn/Llama-SEA-LION-v3.5-8B-R-legalqa"
model = PeftModel.from_pretrained(model_base, peft_model_id)
model.merge_and_unload()

model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



# **Load data**

In [6]:
#@title Load ground truth

In [7]:
# from unsloth.chat_templates import get_chat_template
# tokenizer = get_chat_template(
#     tokenizer,
#     chat_template = "gemma-3",
#     # enable_thinking = False
# )

In [None]:
from datasets import DatasetDict, load_dataset
hub_dataset = load_dataset("ShoAnn/legalqa_klinik_hukumonline")
test = hub_dataset["test"]

# Gather the splits into a DatasetDict
dataset = DatasetDict({
    "test": test,
})

instruction_prompt = "You are an AI Legal Assistant. Your task is to carefully analyze the provided **Legal Basis** below and answer given **Question** based *solely* and *exclusively* on the information contained within that context in Indonesian Language."
def format_to_conv(examples):
    questions = examples["question"]
    inputs       = examples["context"]
    outputs      = examples["answer"]
    conversation = []
    for question, input, output in zip(questions, inputs, outputs):
        conversation.append([ # This line has been changed
            {'content': f'{instruction_prompt} \n **Question** \n{question} \n**Legal Basis** \n{input}', 'role': 'user'},
        ])
    return { "conversations" : conversation, }

def formatting_prompts_func(examples):
   convos = examples["conversations"]
   texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = True, thinking_mode="off") for convo in convos]
   return { "text" : texts, }

conv_test = dataset['test'].map(format_to_conv, batched = True,)
test = conv_test.map(formatting_prompts_func, batched = True)

README.md:   0%|          | 0.00/1.12k [00:00<?, ?B/s]

In [None]:
test['text'][0]

In [None]:
#@title Load inference result from my model

In [None]:
import os
import json
from tqdm import tqdm
import pandas as pd
import torch

def save_checkpoint(inference_results, checkpoint_file):
    """Save the current state of inference"""
    inference_results.to_csv(checkpoint_file, index=False)

def load_checkpoint(checkpoint_file):
    """Load the previous state of inference"""
    if os.path.exists(checkpoint_file):
        return pd.read_csv(checkpoint_file)
    return pd.DataFrame()

def run_inference_with_checkpointing(dataset, model, tokenizer, checkpoint_file='checkpoint.csv',
                                   save_frequency=100):
    """
    Run inference with periodic checkpointing and text streaming

    Args:
        dataset: The input dataset
        model: The model to use for inference
        tokenizer: The tokenizer to use
        checkpoint_file: Path to save/load checkpoint
        save_frequency: How often to save checkpoints (in number of rows)
    """
    # Load checkpoint if it exists
    inference_results = load_checkpoint(checkpoint_file)

    # Find where to resume from
    processed_texts = set(inference_results['text'].tolist() if not inference_results.empty else [])
    remaining_rows = [row for row in dataset if row['text'] not in processed_texts]

    print(f"Resuming from {len(processed_texts)} previously processed items")

    # Create a text streamer for real-time output
    # streamer = TextStreamer(tokenizer, skip_special_tokens=True)

    try:
        for i, row in enumerate(tqdm(remaining_rows)):

            # llama 3.1 & cendol inference setup
            # inputs = tokenizer([row['text']], return_tensors="pt").to("cuda")
            # outputs = model.generate(**inputs, use_cache=True, max_new_tokens=2048)
            # outputs_text = tokenizer.batch_decode([outputs[0]], skip_special_tokens=True)
            # end llama 3.1 & cendol inference setup

            # gemma3 setup
            outputs = model.generate(
                **tokenizer([row['text']], return_tensors = "pt").to("cuda"),
                max_new_tokens = 1024, # Increase for longer outputs!
                # Recommended Gemma-3 settings!
                # temperature = 1.0, top_p = 0.95, top_k = 64,
                #qwen3 temp
                temperature = 0.7, top_p = 0.8, top_k = 20
            )
            outputs_text = tokenizer.batch_decode(outputs)
            row['model_answer'] = outputs_text

            temp_df = pd.DataFrame([row])
            inference_results = pd.concat([inference_results, temp_df], ignore_index=True)

            # Save checkpoint periodically
            if (i + 1) % save_frequency == 0:
                save_checkpoint(inference_results, checkpoint_file)
                print(f"\nCheckpoint saved at {i + 1} items")

    except KeyboardInterrupt:
        print("\nInterrupted by user. Saving checkpoint...")
        save_checkpoint(inference_results, checkpoint_file)
        raise
    except Exception as e:
        print(f"\nError occurred: {str(e)}. Saving checkpoint...")
        save_checkpoint(inference_results, checkpoint_file)
        raise

    # Save final checkpoint
    save_checkpoint(inference_results, checkpoint_file)
    print("\nInference completed and final results saved")

    return inference_results

In [None]:
model.eval()
inference_results = run_inference_with_checkpointing(
    dataset=test,
    model=model,
    tokenizer=tokenizer,
    checkpoint_file=f'/content/drive/MyDrive/model_outputs/inference_result_{model_id.split("/")[1]}-finetuned.csv',
    save_frequency=10
)

Resuming from 0 previously processed items


  0%|          | 0/112 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  1%|          | 1/112 [00:30<56:08, 30.35s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  2%|▏         | 2/112 [01:10<1:06:01, 36.01s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  3%|▎         | 3/112 [02:06<1:22:01, 45.15s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  4%|▎         | 4/112 [03:07<1:32:40, 51.49s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  4%|▍         | 5/112 [04:33<1:54:02, 63.94s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  5%|▌         | 6/112 [05:45<1:57:27, 66.49s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  6%|▋         | 7/112 [06:32<1:45:45, 60.43s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  7%|▋         | 8/112 [08:19<2:10:2


Checkpoint saved at 10 items


 10%|▉         | 11/112 [11:22<1:46:35, 63.32s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 11%|█         | 12/112 [13:04<2:04:52, 74.93s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 12%|█▏        | 13/112 [13:40<1:44:23, 63.26s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 12%|█▎        | 14/112 [14:36<1:39:59, 61.22s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 13%|█▎        | 15/112 [15:04<1:22:19, 50.93s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 14%|█▍        | 16/112 [16:18<1:32:42, 57.94s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 15%|█▌        | 17/112 [17:29<1:37:54, 61.83s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 16%|█▌        | 18/112 [18:24<1:33:58, 59.98s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 17%|█▋        |


Checkpoint saved at 20 items


 19%|█▉        | 21/112 [22:27<1:54:06, 75.23s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 20%|█▉        | 22/112 [23:06<1:36:42, 64.47s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 21%|██        | 23/112 [25:07<2:00:51, 81.48s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 21%|██▏       | 24/112 [26:39<2:04:13, 84.70s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 22%|██▏       | 25/112 [27:31<1:48:36, 74.90s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 23%|██▎       | 26/112 [28:50<1:49:08, 76.15s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 24%|██▍       | 27/112 [29:42<1:37:22, 68.73s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 25%|██▌       | 28/112 [30:18<1:22:43, 59.09s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 26%|██▌       |


Checkpoint saved at 30 items


 28%|██▊       | 31/112 [33:48<1:30:20, 66.92s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 29%|██▊       | 32/112 [34:12<1:11:57, 53.97s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 29%|██▉       | 33/112 [35:04<1:10:27, 53.51s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 30%|███       | 34/112 [35:38<1:01:46, 47.52s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 31%|███▏      | 35/112 [37:12<1:19:08, 61.67s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 32%|███▏      | 36/112 [38:30<1:24:15, 66.52s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 33%|███▎      | 37/112 [39:28<1:19:48, 63.85s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 34%|███▍      | 38/112 [40:14<1:12:08, 58.49s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
 35%|███▍      |