In [None]:
!pip install bitsandbytes
!pip install trl
!pip install peft
!pip install accelerate

Collecting bitsandbytes
  Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl (59.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m19.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.48.2
Collecting trl
  Downloading trl-0.25.1-py3-none-any.whl.metadata (11 kB)
Downloading trl-0.25.1-py3-none-any.whl (465 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m465.5/465.5 kB[0m [31m38.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: trl
Successfully installed trl-0.25.1


In [None]:
import os
import transformers
import torch
from google.colab import userdata
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig, GemmaTokenizer

In [None]:
os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")

In [None]:
model_id = "google/gemma-7b"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0}, token=os.environ['HF_TOKEN'])

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
text = "What are legal lawsuits,"
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = model.generate(**inputs, max_new_tokens=250)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

What are legal lawsuits, and what are the different types of legal lawsuits?

Legal lawsuits are legal actions that are taken against a person or company for a wrong or injury that has been caused.

The different types of legal lawsuits are:

1. Breach of contract
2. Breach of warranty
3. Negligence
4. Fraud
5. Product liability
6. Personal injury
7. Wrongful death
8. Medical malpractice
9. Employment discrimination
10. Consumer protection

<h2><strong>Breach of contract</strong></h2>

A breach of contract is a legal action that is taken against a person or company for failing to fulfill their obligations under a contract.

<h2><strong>Breach of warranty</strong></h2>

A breach of warranty is a legal action that is taken against a person or company for failing to fulfill their obligations under a warranty.

<h2><strong>Negligence</strong></h2>

Negligence is a legal action that is taken against a person or company for failing to exercise reasonable care in their actions.

<h2><strong>F

In [None]:
os.environ["WANDB_DISABLED"] = "false"

In [None]:
lora_config = LoraConfig(
    r = 8,
    target_modules = ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type = "CAUSAL_LM"
)

In [None]:
from datasets import load_dataset

In [None]:
stream = load_dataset("AnuragB/Indian-legal", split="train", streaming=True)
small = list(stream.take(15000))

Resolving data files:   0%|          | 0/7130 [00:00<?, ?it/s]

In [None]:
print(len(small))

15000


In [None]:
small[0]

{'text': 'Appeal No. LXVI of 1949.'}

In [None]:
small[14999]

{'text': 'Sec tions 83A and 83 B of the Indian Companies Act can only be supported as valid on the ground that they regulate the management of companies and are, therefore, within the said entry.'}

In [None]:
def tokenize_fn(examples):
    outputs = tokenizer(
        examples["text"],
        truncation=True,
        max_length=1024,
        return_overflowing_tokens=False
    )
    outputs["labels"] = outputs["input_ids"].copy()
    return outputs

In [None]:
tokenized_dataset = [tokenize_fn(x) for x in small]
print(len(tokenized_dataset))

15000


In [None]:
from transformers import TrainingArguments, Trainer
from peft import get_peft_model

In [None]:
training_args = TrainingArguments(
    output_dir="./gemma-cpt-legal",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    warmup_steps=5,
    max_steps=150,
    learning_rate=2e-4,
    logging_steps=10,
    save_steps=50,
    bf16=True
)


In [None]:
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 25,001,984 || all params: 8,562,682,880 || trainable%: 0.2920


In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
)

In [None]:
trainer.train()

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmadhukarjai9[0m ([33mmadhukarjai9-thapar-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
10,3.112
20,2.8867
30,2.657
40,2.6684
50,2.6136
60,2.7403
70,2.6219
80,2.4514
90,2.6863
100,2.5142


TrainOutput(global_step=150, training_loss=2.598929748535156, metrics={'train_runtime': 1615.1684, 'train_samples_per_second': 0.371, 'train_steps_per_second': 0.093, 'total_flos': 1124959109806080.0, 'train_loss': 2.598929748535156, 'epoch': 0.04})

In [None]:
text = "Appeal from the High Court of judicature, Bombay, in a reference under section,"
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = model.generate(**inputs, max_new_tokens=250)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Appeal from the High Court of judicature, Bombay, in a reference under section, 529 of the Indian Penal Code. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal dismissed. Appeal d

In [None]:
import torch
from datasets import load_dataset
import math

eval_data = small[:100]

def compute_perplexity(model, tokenizer, dataset):
    model.eval()
    total_loss = 0
    count = 0

    for item in dataset:
        enc = tokenizer(item["text"], return_tensors="pt", truncation=True, max_length=512)
        enc = {k: v.to(model.device) for k,v in enc.items()}
        with torch.no_grad():
            loss = model(**enc, labels=enc["input_ids"]).loss
        total_loss += loss.item()
        count += 1

    avg_loss = total_loss / count
    ppl = math.exp(avg_loss)
    return avg_loss, ppl

loss, ppl = compute_perplexity(model, tokenizer, eval_data)
print("Evaluation Loss:", loss)
print("Perplexity:", ppl)


Evaluation Loss: 2.6548352360725405
Perplexity: 14.22264248783265


In [None]:
from peft import PeftModel

model_id = "google/gemma-7b"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

base_model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0}, token=os.environ['HF_TOKEN'])

lora_path = "./gemma-cpt-legal/checkpoint-150"

merged_model = PeftModel.from_pretrained(
    base_model,
    lora_path
)

merged_model = merged_model.merge_and_unload()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



In [None]:
os.environ["HF_TOKEN_WRITE"] = userdata.get("HF_TOKEN_WRITE")
os.environ["HF_TOKEN_REPO"] = userdata.get("HF_TOKEN_REPO")

In [None]:
merged_model.save_pretrained("merged-model")
tokenizer.save_pretrained("merged-model")

('merged-model/tokenizer_config.json',
 'merged-model/special_tokens_map.json',
 'merged-model/tokenizer.model',
 'merged-model/added_tokens.json',
 'merged-model/tokenizer.json')

In [None]:
from huggingface_hub import whoami, HfApi
print(whoami())

{'type': 'user', 'id': '674b62961e1ca74a4121fac0', 'name': 'jaimadhukar', 'fullname': 'Jai Madhukar', 'email': 'madhukarjai9@gmail.com', 'emailVerified': True, 'canPay': False, 'periodEnd': None, 'isPro': False, 'avatarUrl': '/avatars/28b9b9c36b69f49a11c575c8d84ca4be.svg', 'orgs': [{'type': 'org', 'id': '69176f0e5c9c18a331f20174', 'name': 'jaimadhukar007', 'fullname': 'Jai Madhukar', 'email': 'madhukarjai9@gmail.com', 'canPay': False, 'periodEnd': None, 'avatarUrl': 'https://www.gravatar.com/avatar/a4ce9981e44d4b4db84007037e5680ac?d=retro&size=100', 'roleInOrg': 'admin', 'isEnterprise': False}], 'auth': {'type': 'access_token', 'accessToken': {'displayName': 'HF_TOKEN', 'role': 'read', 'createdAt': '2025-11-14T16:00:15.370Z'}}}


In [None]:

merged_model.push_to_hub("jaimadhukar/lawvista", token=os.environ["HF_TOKEN_REPO"])
tokenizer.push_to_hub("jaimadhukar/lawvista", token=os.environ["HF_TOKEN_REPO"])

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...0002-of-00002.safetensors:   0%|          |  610kB /  934MB            

  ...0001-of-00002.safetensors:   0%|          |  623kB / 5.00GB            

README.md: 0.00B [00:00, ?B/s]

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...pg_n4ys0h/tokenizer.model:   0%|          | 16.2kB / 4.24MB            

  ...mpg_n4ys0h/tokenizer.json:  73%|#######3  | 25.1MB / 34.4MB            

CommitInfo(commit_url='https://huggingface.co/jaimadhukar/lawvista/commit/f7faea133ece393101f95c614ff6f70c407f8e7b', commit_message='Upload tokenizer', commit_description='', oid='f7faea133ece393101f95c614ff6f70c407f8e7b', pr_url=None, repo_url=RepoUrl('https://huggingface.co/jaimadhukar/lawvista', endpoint='https://huggingface.co', repo_type='model', repo_id='jaimadhukar/lawvista'), pr_revision=None, pr_num=None)