In [1]:
!pip install bitsandbytes
!pip install trl
!pip install peft
!pip install accelerate



In [2]:
pip install python-dotenv

Note: you may need to restart the kernel to use updated packages.


In [1]:
import os
import transformers
import torch
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig, GemmaTokenizer

In [2]:
from dotenv import load_dotenv
load_dotenv()

os.environ["HF_TOKEN_REPO"] = os.getenv("HF_TOKEN_REPO")
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")

In [3]:
model_id = "jaimadhukar/lawvista"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0}, token=os.environ['HF_TOKEN'])



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
GEMMA_CHAT_TEMPLATE = (
    "{% for message in messages %}"
    "{% if message['role'] == 'user' %}"
    "{{ '<start_of_turn>user\n' + message['content'] | trim + '<end_of_turn>' }}"
    "{% elif message['role'] == 'model' %}"
    "{{ '<start_of_turn>model\n' + message['content'] | trim + '<end_of_turn>' }}"
    "{% else %}"
    "{{ raise('Only user and model roles are supported!') }}"
    "{% endif %}"
    "{% endfor %}"
)
tokenizer.chat_template = GEMMA_CHAT_TEMPLATE

In [6]:
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token 

In [7]:
print("TOK:", tokenizer)
print("SPECIAL TOKENS:", tokenizer.special_tokens_map)

TOK: GemmaTokenizerFast(name_or_path='jaimadhukar/lawvista', vocab_size=256000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<bos>', 'eos_token': '<eos>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<start_of_turn>', '<end_of_turn>']}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<eos>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("<bos>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	4: AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),
	5: AddedToken("<2mass>", rstrip=False, lstrip=False, sin

In [8]:
os.environ["WANDB_DISABLED"] = "false"

In [9]:
lora_config = LoraConfig(
    r = 8,
    target_modules = ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type = "CAUSAL_LM",
    lora_dropout=0.05
)

In [10]:
from datasets import load_dataset

In [11]:
data = load_dataset("jaimadhukar/legal-advisor-gemma-chat-cleaned")

In [12]:
data = data.map(lambda samples: tokenizer(samples["text"], truncation=True), batched=True)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [13]:
def format_func(batch):
    return batch["text"]

In [14]:
trainer = SFTTrainer(
    model=model,
    train_dataset=data["train"],
    formatting_func=format_func,
    peft_config=lora_config,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=5,
        max_steps=750,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=10,
        output_dir="outputs_dir",
        optim="paged_adamw_8bit",
    )
)

You passed a dataset that is already processed (contains an `input_ids` field) together with a formatting function. Therefore `formatting_func` will be ignored. Either remove the `formatting_func` or pass a dataset that is not already processed.


Truncating train dataset:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [15]:
data["train"]

Dataset({
    features: ['text', 'input_ids', 'attention_mask'],
    num_rows: 1000
})

In [16]:
data

DatasetDict({
    train: Dataset({
        features: ['text', 'input_ids', 'attention_mask'],
        num_rows: 1000
    })
})

In [17]:
data["train"]["text"][0]

"<start_of_turn>user\nHello, I am facing some legal issues related to my business in India. I believe there have been violations of Business and Corporate Laws.\n<end_of_turn>\n<start_of_turn>model\nI see. Can you please provide me with more details about the specific violations you are concerned about?  \n<end_of_turn>\n<start_of_turn>user\nSure. I suspect that there have been instances of fraud and mismanagement within my company. I believe some of the directors may have engaged in insider trading and other unethical practices.\n<end_of_turn>\n<start_of_turn>model\nThat's a serious allegation. It's important to address these issues promptly to protect your business and its reputation. Have you gathered any evidence to support your claims?  \n<end_of_turn>\n<start_of_turn>user\nYes, I have some documents and emails that suggest wrongdoing. I also have witness statements from employees who have raised concerns about the conduct of certain directors.\n<end_of_turn>\n<start_of_turn>model

In [18]:
trainer.train()

Step,Training Loss
10,1.4669
20,1.0572
30,0.8755
40,0.8045
50,0.7551
60,0.7522
70,0.7088
80,0.7237
90,0.6639
100,0.6559


TrainOutput(global_step=750, training_loss=0.4868308331171671, metrics={'train_runtime': 1146.3518, 'train_samples_per_second': 2.617, 'train_steps_per_second': 0.654, 'total_flos': 4.277434108806144e+16, 'train_loss': 0.4868308331171671, 'epoch': 3.0})

In [22]:
output_dir = "outputs_dir"
max_steps = 750
api_key = os.environ.get('HF_TOKEN')

def run_inference_test(model, tokenizer):
    """
    Merges LoRA weights and runs a clean test prompt using the correct 
    Gemma chat format and stop sequence.
    """
    from peft import PeftModel
    print("\n--- Running Inference Test ---")
    
    # 1. Merge LoRA weights (use the last checkpoint path)
    final_checkpoint = os.path.join(output_dir, f"checkpoint-{max_steps}")
    print(f"Loading merged model from checkpoint: {final_checkpoint}")

    # Reload the base model and merge the adapter weights
    base_model = AutoModelForCausalLM.from_pretrained(
        model_id, 
        quantization_config=bnb_config, 
        device_map={"":0}, 
        token=api_key
    )
    
    # Check if the checkpoint exists before loading PeftModel
    if not os.path.exists(final_checkpoint):
        print("Checkpoint not found. Skipping inference test.")
        return

    merged_model = PeftModel.from_pretrained(base_model, final_checkpoint)
    merged_model = merged_model.merge_and_unload()
    merged_model.eval()

    # 2. Define the test prompt using the official chat template
    test_conversation = [
        {"role": "user", "content": "I am facing issues with fraud by a business partner in India. What is my first step?"},
    ]

    # Convert conversation to the exact input string format
    # This line now works because tokenizer.chat_template is set!
    prompt = tokenizer.apply_chat_template(test_conversation, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt").to(merged_model.device)

    # 3. Generation parameters
    # CRITICAL: Use the EOS token as the stop sequence to prevent endless generation or junk tokens.
    stop_token_id = tokenizer.eos_token_id
    
    with torch.no_grad():
        outputs = merged_model.generate(
            **inputs, 
            max_new_tokens=256, 
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            eos_token_id=stop_token_id
        )

    # 4. Decode and print the result
    response = tokenizer.decode(outputs[0], skip_special_tokens=False)
    
    # Clean up the output to show only the model's response part
    response_start_tag = "<start_of_turn>model\n"
    if response_start_tag in response:
        response = response.split(response_start_tag, 1)[-1]
    
    response_end_tag = "<end_of_turn>"
    if response_end_tag in response:
        response = response.split(response_end_tag, 1)[0].strip()

    print("\n[Generated Response]")
    print(response)

In [23]:
final_checkpoint = os.path.join(outputs_dir, f"checkpoint-750")
final_checkpoint

'outputs_dir/checkpoint-750'

In [24]:
run_inference_test(model, tokenizer)


--- Running Inference Test ---
Loading merged model from checkpoint: outputs_dir/checkpoint-750




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]




[Generated Response]
<bos><start_of_turn>user
I am facing issues with fraud by a business partner in India. What is my first step?


In [27]:
text = "Hello, I am facing some legal issues related to my business in India. I believe there have been violations of Business and Corporate Laws."
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Hello, I am facing some legal issues related to my business in India. I believe there have been violations of Business and Corporate Laws.
I see. Can you provide me with more details about the specific violations you believe have occurred?


In [28]:
final_checkpoint = os.path.join(output_dir, f"checkpoint-750")
merged_model_dir = "merged_legal_gemma"

In [29]:
model_id

'jaimadhukar/lawvista'

In [30]:
base_model = AutoModelForCausalLM.from_pretrained(
        model_id, 
        torch_dtype=torch.bfloat16, 
        device_map="auto",
        token=api_key
    )

`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [32]:
from peft import PeftModel

In [33]:
merged_model = PeftModel.from_pretrained(base_model, final_checkpoint)
merged_model = merged_model.merge_and_unload()
merged_model.eval()



GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 3072, padding_idx=0)
    (layers): ModuleList(
      (0-27): 28 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (q_proj): Linear4bit(in_features=3072, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=3072, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=3072, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=3072, bias=False)
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear4bit(in_features=3072, out_features=24576, bias=False)
          (up_proj): Linear4bit(in_features=3072, out_features=24576, bias=False)
          (down_proj): Linear4bit(in_features=24576, out_features=3072, bias=False)
          (act_fn): GELUActivation()
        )
        (input_layernorm): GemmaRMSNorm((3072,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((3072,), eps=1e-06)
      )
    )
    

In [34]:
merged_model.save_pretrained(merged_model_dir)
tokenizer.save_pretrained(merged_model_dir)

('merged_legal_gemma/tokenizer_config.json',
 'merged_legal_gemma/special_tokens_map.json',
 'merged_legal_gemma/chat_template.jinja',
 'merged_legal_gemma/tokenizer.json')

In [37]:
repo_id = "jaimadhukar/gemma-7b-legal-advisor-in"

In [38]:
merged_model.push_to_hub(repo_id, token=os.environ["HF_TOKEN_REPO"])
tokenizer.push_to_hub(repo_id, token=os.environ["HF_TOKEN_REPO"])

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

README.md: 0.00B [00:00, ?B/s]

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

CommitInfo(commit_url='https://huggingface.co/jaimadhukar/gemma-7b-legal-advisor-in/commit/74bf5742469d1a04b612d56ec0a85bf68efda158', commit_message='Upload tokenizer', commit_description='', oid='74bf5742469d1a04b612d56ec0a85bf68efda158', pr_url=None, repo_url=RepoUrl('https://huggingface.co/jaimadhukar/gemma-7b-legal-advisor-in', endpoint='https://huggingface.co', repo_type='model', repo_id='jaimadhukar/gemma-7b-legal-advisor-in'), pr_revision=None, pr_num=None)