In [1]:
# Install required packages
!pip install --no-deps peft accelerate bitsandbytes
!pip install py7zr

Collecting bitsandbytes
  Downloading bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl.metadata (3.5 kB)
Downloading bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl (122.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m122.4/122.4 MB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.44.1
Collecting py7zr
  Downloading py7zr-0.22.0-py3-none-any.whl.metadata (16 kB)
Collecting texttable (from py7zr)
  Downloading texttable-1.7.0-py2.py3-none-any.whl.metadata (9.8 kB)
Collecting pycryptodomex>=3.16.0 (from py7zr)
  Downloading pycryptodomex-3.21.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.4 kB)
Collecting pyzstd>=0.15.9 (from py7zr)
  Downloading pyzstd-0.16.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.4 kB)
Collecting pyppmd<1.2.0,>=1.1.0 (from py7zr)
  Downloading pyppmd-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux20

In [2]:
%%capture
!pip install unsloth
# Also get the latest nightly Unsloth!
!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
# Install Flash Attention 2 for softcapping support
import torch
if torch.cuda.get_device_capability()[0] >= 8:
    !pip install --no-deps packaging ninja einops "flash-attn>=2.6.3"

In [3]:

from unsloth import FastLanguageModel
import torch
max_seq_length = 512 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.


model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Llama-3.2-3B-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2024.11.10: Fast Llama patching. Transformers:4.46.2.
   \\   /|    GPU: Tesla T4. Max memory: 14.748 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.1+cu121. CUDA: 7.5. CUDA Toolkit: 12.1. Triton: 3.1.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.28.post3. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/2.24G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/121 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

In [4]:
# Apply PEFT (Parameter Efficient Fine-Tuning) to the loaded model
model = FastLanguageModel.get_peft_model(
    model,
    r=8,  # Reduced LoRA rank for lower VRAM usage
    target_modules=[
        "q_proj", "v_proj", "gate_proj",
    ],  # Minimal modules for task-specific fine-tuning
    lora_alpha=16,  # Scaling factor for LoRA; unchanged
    lora_dropout=0,  # Small dropout for better generalization
    bias="none",  # No additional bias to reduce memory
    use_gradient_checkpointing="unsloth",  # Optimized gradient checkpointing
    random_state=3407,  # Ensure reproducibility
    use_rslora=False,  # Disabling Rank Stabilized LoRA (default)
    loftq_config=None,  # Disabling LoftQ (default)
)


Not an error, but Unsloth cannot patch MLP layers with our manual autograd engine since either LoRA adapters
are not enabled or a bias term (like in Qwen) is used.
Not an error, but Unsloth cannot patch Attention layers with our manual autograd engine since either LoRA adapters
are not enabled or a bias term (like in Qwen) is used.
Not an error, but Unsloth cannot patch O projection layer with our manual autograd engine since either LoRA adapters
are not enabled or a bias term (like in Qwen) is used.
Unsloth 2024.11.10 patched 28 layers with 0 QKV layers, 0 O layers and 0 MLP layers.


In [5]:
from datasets import load_dataset


In [7]:
# Load CaseHOLD dataset
dataset = load_dataset("casehold/casehold", split="train[:1000]")

casehold.py:   0%|          | 0.00/8.68k [00:00<?, ?B/s]

The repository for casehold/casehold contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/casehold/casehold.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


train.csv:   0%|          | 0.00/85.3M [00:00<?, ?B/s]

val.csv:   0%|          | 0.00/10.4M [00:00<?, ?B/s]

test.csv:   0%|          | 0.00/10.6M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [9]:
# Print sample examples
print("\n## Sample Examples")
for i in range(2):
    print(f"\nExample {i+1}:")
    print("Citing Context:")
    print("-" * 50)
    print(dataset[i]['citing_prompt'])
    print("\nCorrect Holding:")
    print("-" * 50)
    # The label indicates which holding is correct (0-4)
    correct_holding_idx = int(dataset[i]['label'])
    print(dataset[i][f'holding_{correct_holding_idx}'])
    print("\nAll Holdings Options:")
    for j in range(5):
        print(f"Option {j}: {dataset[i][f'holding_{j}']}")
    print("=" * 70)


## Sample Examples

Example 1:
Citing Context:
--------------------------------------------------
Drapeau’s cohorts, the cohort would be a “victim” of making the bomb. Further, firebombs are inherently dangerous. There is no peaceful purpose for making a bomb. Felony offenses that involve explosives qualify as “violent crimes” for purposes of enhancing the sentences of career offenders. See 18 U.S.C. § 924(e)(2)(B)(ii) (defining a “violent felony” as: “any crime punishable by imprisonment for a term exceeding one year ... that ... involves use of explosives”). Courts have found possession of a'bomb to be a crime of violence based on the lack of a nonviolent purpose for a bomb and the fact that, by its very nature, there is a substantial risk that the bomb would be used against the person or property of another. See United States v. Newman, 125 F.3d 863 (10th Cir.1997) (unpublished) (<HOLDING>); United States v. Dodge, 846 F.Supp. 181,

Correct Holding:
--------------------------------

In [10]:
def format_instruction(example):
    holdings = [example[f'holding_{i}'] for i in range(5)]
    options = "\n".join([f"Option {idx}: {holding}" for idx, holding in enumerate(holdings)])
    correct_holding = example[f'holding_{int(example["label"])}']
    return {
        "text": f"Given this legal citation context, select the correct holding:\n\nContext: {example['citing_prompt']}\n\nOptions:\n{options}\n\nCorrect holding: {correct_holding}"
    }

In [11]:
# Format dataset
formatted_dataset = dataset.map(format_instruction)
# Print example of formatted data
print("\n## Training Format Example")
print(formatted_dataset[0]['text'])

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]


## Training Format Example
Given this legal citation context, select the correct holding:

Context: Drapeau’s cohorts, the cohort would be a “victim” of making the bomb. Further, firebombs are inherently dangerous. There is no peaceful purpose for making a bomb. Felony offenses that involve explosives qualify as “violent crimes” for purposes of enhancing the sentences of career offenders. See 18 U.S.C. § 924(e)(2)(B)(ii) (defining a “violent felony” as: “any crime punishable by imprisonment for a term exceeding one year ... that ... involves use of explosives”). Courts have found possession of a'bomb to be a crime of violence based on the lack of a nonviolent purpose for a bomb and the fact that, by its very nature, there is a substantial risk that the bomb would be used against the person or property of another. See United States v. Newman, 125 F.3d 863 (10th Cir.1997) (unpublished) (<HOLDING>); United States v. Dodge, 846 F.Supp. 181,

Options:
Option 0: holding that possession of a

In [12]:

from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported


trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=formatted_dataset,
    max_seq_length=max_seq_length,
    dataset_num_proc=2,  # Use 2 processors for dataset preprocessing
    packing=False,  # Packing disabled; useful for variable-length sequences
    args=TrainingArguments(
        per_device_train_batch_size=1,  # Lower batch size to fit within 10GB
        gradient_accumulation_steps=8,  # Maintain effective batch size
        warmup_steps=5,
        max_steps=50,  # Reduced steps for faster completion
        learning_rate=2e-4,  # Learning rate; can be adjusted if needed
        fp16=not is_bfloat16_supported(),  # Enable FP16 if bfloat16 not supported
        bf16=is_bfloat16_supported(),  # Enable bfloat16 if supported
        logging_steps=5,  # Log every 5 steps
        optim="adamw_8bit",  # Optimizer for memory efficiency
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,  # For reproducibility
        output_dir="./legal_holdings_model",  # Directory for model checkpoints
        report_to="none",  # Disable external reporting (e.g., WandB)
    ),
)


Map (num_proc=2):   0%|          | 0/1000 [00:00<?, ? examples/s]

max_steps is given, it will override any value given in num_train_epochs


In [13]:
#@title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = Tesla T4. Max memory = 14.748 GB.
2.695 GB of memory reserved.


In [14]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 1,000 | Num Epochs = 1
O^O/ \_/ \    Batch size per device = 1 | Gradient Accumulation steps = 8
\        /    Total batch size = 8 | Total steps = 50
 "-____-"     Number of trainable parameters = 4,816,896


Step,Training Loss
5,2.3398
10,2.2959
15,2.1964
20,2.0727
25,1.9448
30,1.8413
35,1.8825
40,1.9349
45,1.8542
50,1.879


In [15]:
#@title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory         /max_memory*100, 3)
lora_percentage = round(used_memory_for_lora/max_memory*100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

458.2832 seconds used for training.
7.64 minutes used for training.
Peak reserved memory = 3.902 GB.
Peak reserved memory for training = 1.207 GB.
Peak reserved memory % of max memory = 26.458 %.
Peak reserved memory for training % of max memory = 8.184 %.


In [16]:
# Save the model
trainer.save_model("./legal_holdings_model_final")



In [17]:
# First prepare the model for inference
model = FastLanguageModel.for_inference(model)

In [18]:
def select_holding(context, holdings):
    options = "\n".join([f"Option {idx}: {holding}" for idx, holding in enumerate(holdings)])
    prompt = f"Given this legal citation context, select the correct holding:\n\nContext: {context}\n\nOptions:\n{options}\n\nCorrect holding:"

    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=512
    ).to("cuda")

    outputs = model.generate(
        **inputs,
        max_new_tokens=256,
        temperature=0.3,
        do_sample=True,
        top_p=0.9,
        num_return_sequences=1
    )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text.split("Correct holding:")[-1].strip()

In [20]:
# Custom test cases following CaseHOLD structure
custom_test_cases = [
    {
        "citing_prompt": """
        In Smith v. Tech Corp, the plaintiff alleged copyright infringement of their software code.
        The defendant argued that their use of similar code structures fell under fair use doctrine
        as they only used fundamental programming concepts that were necessary for the function.
        The court must determine whether the use of basic programming patterns can be protected
        under copyright law. Previous cases like Johnson Controls, Inc. v. Phoenix Control Systems
        have addressed similar issues regarding the scope of copyright protection in software.
        """,
        "holding_0": "Basic programming structures and patterns necessary for function are not protected by copyright as they fall under the merger doctrine.",
        "holding_1": "All code structures, regardless of their fundamental nature, are protected by copyright law.",
        "holding_2": "Copyright protection extends only to the creative elements of software, not to functional elements required for operation.",
        "holding_3": "Fair use doctrine does not apply to any form of software code copying.",
        "holding_4": "Software copyright cases must be evaluated on a case-by-case basis without general rules.",
        "label": "2"  # holding_2 is correct
    },

    {
        "citing_prompt": """
        The defendant in Jones v. Social Media Platform was banned from a social media website
        for violating community guidelines. They filed suit claiming First Amendment violations,
        arguing that the platform's status as a major communication channel makes it equivalent
        to a public forum. The platform contends that as a private company, it has the right to
        moderate content on its site. Similar issues were addressed in Manhattan Community Access
        Corp. v. Halleck regarding private entities and First Amendment obligations.
        """,
        "holding_0": "Social media platforms, regardless of size, are subject to First Amendment restrictions.",
        "holding_1": "Private companies operating social media platforms have the right to moderate content without First Amendment constraints.",
        "holding_2": "Only government-operated social media accounts are subject to First Amendment restrictions.",
        "holding_3": "Social media platforms become public forums once they reach a certain size.",
        "holding_4": "Content moderation decisions must follow government guidelines.",
        "label": "1"  # holding_1 is correct
    },

    {
        "citing_prompt": """
        In Healthcare Data Inc. v. Medical Records Corp, the plaintiff seeks a preliminary
        injunction to prevent a former employee from working for a competitor, citing a
        non-compete agreement. The employee argues that the agreement is overly broad and
        prevents them from working in their field of expertise. The court must balance the
        protection of trade secrets with the employee's right to work. Previous rulings in
        BDO Seidman v. Hirshberg addressed similar concerns about the scope of non-compete
        agreements.
        """,
        "holding_0": "Non-compete agreements are always enforceable if signed voluntarily.",
        "holding_1": "Non-compete agreements are never enforceable as they restrict trade.",
        "holding_2": "Non-compete agreements must be narrowly tailored to protect legitimate business interests while not unduly restricting employee rights.",
        "holding_3": "Employers have unlimited rights to restrict former employees' future employment.",
        "holding_4": "Non-compete agreements only apply to senior executives.",
        "label": "2"  # holding_2 is correct
    }
]


In [21]:

def test_legal_model(test_case):
    holdings = [test_case[f'holding_{j}'] for j in range(5)]
    print("Context:")
    print("-" * 50)
    print(test_case['citing_prompt'].strip())
    print("\nAvailable Holdings:")
    print("-" * 50)
    for idx, holding in enumerate(holdings):
        print(f"Option {idx}: {holding}")

    print("\nModel Selection:")
    print("-" * 50)
    model_selection = select_holding(test_case['citing_prompt'], holdings)
    print(model_selection)

    print("\nCorrect Holding:")
    print("-" * 50)
    print(test_case[f'holding_{int(test_case["label"])}'])
    print("=" * 70)

# Test the model with custom cases
print("\n## Testing Legal Holdings Model with Custom Cases")
for i, test_case in enumerate(custom_test_cases, 1):
    print(f"\nCustom Test Case {i}:")
    test_legal_model(test_case)


## Testing Legal Holdings Model with Custom Cases

Custom Test Case 1:
Context:
--------------------------------------------------
In Smith v. Tech Corp, the plaintiff alleged copyright infringement of their software code. 
        The defendant argued that their use of similar code structures fell under fair use doctrine 
        as they only used fundamental programming concepts that were necessary for the function. 
        The court must determine whether the use of basic programming patterns can be protected 
        under copyright law. Previous cases like Johnson Controls, Inc. v. Phoenix Control Systems 
        have addressed similar issues regarding the scope of copyright protection in software.

Available Holdings:
--------------------------------------------------
Option 0: Basic programming structures and patterns necessary for function are not protected by copyright as they fall under the merger doctrine.
Option 1: All code structures, regardless of their fundamental nat

In [25]:
model.push_to_hub("AagamShah08/llama3_3B_LegalQA",token="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") # Online saving
tokenizer.push_to_hub("AagamShah08/llama3_3B_LegalQA",token="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") # Online saving

README.md:   0%|          | 0.00/582 [00:00<?, ?B/s]

  0%|          | 0/1 [00:00<?, ?it/s]

adapter_model.safetensors:   0%|          | 0.00/19.3M [00:00<?, ?B/s]

Saved model to https://huggingface.co/AagamShah08/llama3_3B_LegalQA


  0%|          | 0/1 [00:00<?, ?it/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]