# Legal Document Analysis - Llama 3 Fine-tuning & DPO Training

This notebook trains Llama 3 for legal document analysis with fine-tuning and DPO.

## Setup


In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️ No GPU detected! Training will be very slow on CPU.")


CUDA available: True
GPU: Tesla T4
VRAM: 15.83 GB


In [None]:
# Install dependencies
!pip install -q torch>=2.0.0 transformers>=4.35.0 accelerate>=0.24.0
!pip install -q peft>=0.6.0 bitsandbytes>=0.41.0
!pip install -q datasets>=2.14.0 trl>=0.7.0
!pip install -q sentencepiece protobuf pandas numpy scikit-learn tqdm
!pip install -q pyyaml python-dotenv pypdf2 python-docx nltk spacy
!pip install -q evaluate

# Download spaCy model
!python -m spacy download en_core_web_sm -q


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/232.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m232.6/232.6 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m253.0/253.0 kB[0m [31m25.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m145.5 MB/s[0m eta [36m0:00:00[0m
[?25h[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.


In [None]:
# Mount Google Drive (optional - to save models and data)
from google.colab import drive
drive.mount('/content/drive')

# Set working directory
import os
WORK_DIR = '/content/legal-document-analysis'
os.makedirs(WORK_DIR, exist_ok=True)
os.chdir(WORK_DIR)
print(f"Working directory: {os.getcwd()}")


MessageError: Error: credential propagation was unsuccessful

## Upload Project Files

You have two options:
1. **Upload from GitHub** (recommended): Clone the repository
2. **Upload manually**: Upload project files using the file browser on the left


In [None]:
# Option 1: Clone from GitHub (if you have a repo)
!git clone https://github.com/yourname/legal_opensource.git


Cloning into 'legal_opensource'...
remote: Enumerating objects: 61, done.[K
remote: Counting objects: 100% (61/61), done.[K
remote: Compressing objects: 100% (56/56), done.[K
remote: Total 61 (delta 18), reused 27 (delta 2), pack-reused 0 (from 0)[K
Receiving objects: 100% (61/61), 39.35 KiB | 13.12 MiB/s, done.
Resolving deltas: 100% (18/18), done.


In [None]:
# Set your Hugging Face token
import os

# Option 1: Set directly (less secure, but convenient for Colab)
os.environ['HF_TOKEN'] = 'hf'  # ⚠️ Replace with your token!

# Option 2: Use Colab secrets (more secure - recommended)
# from google.colab import userdata
# os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')

# Optional: Weights & Biases
# os.environ['WANDB_API_KEY'] = 'your_wandb_key_here'

print("Environment variables set")
print(f"HF_TOKEN set: {bool(os.getenv('HF_TOKEN') and os.getenv('HF_TOKEN') != 'your_huggingface_token_here')}")
print("\n⚠️ Make sure to set your actual HF_TOKEN above!")


Environment variables set
HF_TOKEN set: True

⚠️ Make sure to set your actual HF_TOKEN above!


## Generate All Source Code Files

Creating all necessary Python files automatically...


In [None]:
# Create __init__.py files
with open('src/__init__.py', 'w') as f:
    f.write('# Legal Document Analysis Tool\n')

with open('src/models/__init__.py', 'w') as f:
    f.write('from .llama3_model import Llama3Model\n\n__all__ = [\'Llama3Model\']\n')

with open('src/data_processing/__init__.py', 'w') as f:
    f.write('''from .legal_document_processor import LegalDocumentProcessor
from .dataset_utils import (
    load_jsonl, save_jsonl, create_finetune_dataset,
    create_dpo_dataset, format_instruction_prompt
)
__all__ = ['LegalDocumentProcessor', 'load_jsonl', 'save_jsonl',
           'create_finetune_dataset', 'create_dpo_dataset', 'format_instruction_prompt']
''')

with open('src/evaluation/__init__.py', 'w') as f:
    f.write('from .citation_evaluator import CitationEvaluator\n\n__all__ = [\'CitationEvaluator\']\n')

print("✓ __init__.py files created")


✓ __init__.py files created


In [None]:
# Create llama3_model.py
llama3_model_code = '''"""
Llama 3 Model Integration and Loading
"""
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import os
from typing import Optional, Dict, Any


class Llama3Model:
    """Wrapper for Llama 3 model with quantization and LoRA support"""

    def __init__(
        self,
        model_name: str = "meta-llama/Meta-Llama-3-8B-Instruct",
        use_4bit: bool = True,
        bnb_4bit_compute_dtype: str = "float16",
        bnb_4bit_quant_type: str = "nf4",
        bnb_4bit_use_double_quant: bool = True,
        device_map: str = "auto",
        trust_remote_code: bool = True
    ):
        self.model_name = model_name
        self.use_4bit = use_4bit
        self.device_map = device_map

        # Configure quantization
        if use_4bit:
            self.bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type=bnb_4bit_quant_type,
                bnb_4bit_compute_dtype=getattr(torch, bnb_4bit_compute_dtype),
                bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
            )
        else:
            self.bnb_config = None

        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            trust_remote_code=trust_remote_code,
            token=os.getenv("HF_TOKEN")
        )

        # Set pad token if not present
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        # Load model
        self.model = None
        self._load_model()

    def _load_model(self):
        """Load the model with quantization if specified"""
        print(f"Loading model: {self.model_name}")

        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            quantization_config=self.bnb_config if self.use_4bit else None,
            device_map=self.device_map,
            trust_remote_code=True,
            token=os.getenv("HF_TOKEN"),
            torch_dtype=torch.float16 if not self.use_4bit else None,
        )

        print("Model loaded successfully")

    def prepare_for_training(self, lora_config: LoraConfig):
        """Prepare model for LoRA fine-tuning"""
        if self.model is None:
            raise ValueError("Model not loaded. Call _load_model() first.")

        # Enable gradient checkpointing
        self.model.gradient_checkpointing_enable()

        # Prepare model for k-bit training
        self.model = prepare_model_for_kbit_training(self.model)

        # Apply LoRA
        self.model = get_peft_model(self.model, lora_config)

        # Enable trainable parameters
        self.model.print_trainable_parameters()

        return self.model

    def get_model(self):
        """Get the underlying model"""
        return self.model

    def get_tokenizer(self):
        """Get the tokenizer"""
        return self.tokenizer

    def generate(
        self,
        prompt: str,
        max_length: int = 512,
        temperature: float = 0.7,
        top_p: float = 0.9,
        do_sample: bool = True,
        **kwargs
    ):
        """Generate text from prompt"""
        if self.model is None:
            raise ValueError("Model not loaded")

        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_length=max_length,
                temperature=temperature,
                top_p=top_p,
                do_sample=do_sample,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
                **kwargs
            )

        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return generated_text
'''

with open('src/models/llama3_model.py', 'w') as f:
    f.write(llama3_model_code)

print("✓ llama3_model.py created")


✓ llama3_model.py created


In [None]:
# Create dataset_utils.py (simplified version)
dataset_utils_code = '''"""
Dataset utilities for fine-tuning and DPO training
"""
import json
from typing import List, Dict, Any
from datasets import Dataset
from transformers import PreTrainedTokenizer


def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
    """Load data from JSONL file"""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line))
    return data


def format_instruction_prompt(instruction: str, input_text: str = None) -> str:
    """Format instruction following Llama 3 chat template"""
    if input_text:
        prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful legal document analyst assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>

{instruction}

Document:
{input_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""
    else:
        prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful legal document analyst assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>

{instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""
    return prompt


def preprocess_function_finetune(examples: Dict[str, List], tokenizer: PreTrainedTokenizer, max_length: int = 2048) -> Dict[str, Any]:
    """Preprocess function for fine-tuning"""
    prompts = []
    completions = []

    for i in range(len(examples['prompt'])):
        prompt = examples['prompt'][i]
        completion = examples['completion'][i] if 'completion' in examples else examples.get('summary', [''])[i]
        full_text = format_instruction_prompt(prompt, None) + completion
        prompts.append(full_text)
        completions.append(completion)

    model_inputs = tokenizer(prompts, max_length=max_length, truncation=True, padding="max_length", return_tensors="pt")
    labels = model_inputs["input_ids"].clone()

    for i, prompt in enumerate(prompts):
        prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
        prompt_len = len(prompt_tokens)
        if prompt_len < max_length:
            labels[i][:prompt_len] = -100

    model_inputs["labels"] = labels
    return model_inputs


def preprocess_function_dpo(examples: Dict[str, List], tokenizer: PreTrainedTokenizer, max_length: int = 2048, max_prompt_length: int = 512) -> Dict[str, Any]:
    """Preprocess function for DPO training"""
    prompts = []
    chosen = []
    rejected = []

    for i in range(len(examples['prompt'])):
        prompt = examples['prompt'][i]
        chosen_text = examples['chosen'][i]
        rejected_text = examples['rejected'][i]
        chosen_prompt = format_instruction_prompt(prompt, None) + chosen_text
        rejected_prompt = format_instruction_prompt(prompt, None) + rejected_text
        prompts.append(prompt)
        chosen.append(chosen_prompt)
        rejected.append(rejected_prompt)

    tokenized_prompts = tokenizer(prompts, max_length=max_prompt_length, truncation=True, padding="max_length", return_tensors="pt")
    tokenized_chosen = tokenizer(chosen, max_length=max_length, truncation=True, padding="max_length", return_tensors="pt")
    tokenized_rejected = tokenizer(rejected, max_length=max_length, truncation=True, padding="max_length", return_tensors="pt")

    return {
        "input_ids": tokenized_prompts["input_ids"],
        "attention_mask": tokenized_prompts["attention_mask"],
        "chosen_input_ids": tokenized_chosen["input_ids"],
        "chosen_attention_mask": tokenized_chosen["attention_mask"],
        "rejected_input_ids": tokenized_rejected["input_ids"],
        "rejected_attention_mask": tokenized_rejected["attention_mask"],
    }


def create_finetune_dataset(examples: List[Dict[str, Any]], tokenizer: PreTrainedTokenizer) -> Dataset:
    """Create fine-tuning dataset"""
    dataset_dict = {
        'prompt': [ex['prompt'] for ex in examples],
        'completion': [ex.get('completion', ex.get('summary', '')) for ex in examples]
    }
    dataset = Dataset.from_dict(dataset_dict)
    return dataset.map(lambda x: preprocess_function_finetune(x, tokenizer), batched=True, remove_columns=dataset.column_names)


def create_dpo_dataset(examples: List[Dict[str, Any]], tokenizer: PreTrainedTokenizer) -> Dataset:
    """Create DPO dataset"""
    dataset_dict = {
        'prompt': [ex['prompt'] for ex in examples],
        'chosen': [ex['chosen'] for ex in examples],
        'rejected': [ex['rejected'] for ex in examples]
    }
    dataset = Dataset.from_dict(dataset_dict)
    return dataset.map(lambda x: preprocess_function_dpo(x, tokenizer), batched=True, remove_columns=dataset.column_names)
'''

with open('src/data_processing/dataset_utils.py', 'w') as f:
    f.write(dataset_utils_code)

print("✓ dataset_utils.py created")


✓ dataset_utils.py created


In [None]:
# Create citation_evaluator.py (simplified)
citation_evaluator_code = '''"""
Evaluation metrics for citation accuracy
"""
import re
from typing import List, Dict, Any


class CitationEvaluator:
    """Evaluate citation accuracy"""

    def __init__(self):
        self.citation_patterns = [
            r'\\d+\\s+[A-Z][a-z]+\\s+\\d+',
            r'[A-Z][a-z]+\\s+v\\.\\s+[A-Z][a-z]+',
            r'\\d+\\s+F\\.\\d+d\\s+\\d+',
            r'\\d+\\s+F\\.\\s+Supp\\.\\s+\\d+',
        ]

    def extract_citations(self, text: str) -> List[str]:
        """Extract all citations from text"""
        citations = []
        for pattern in self.citation_patterns:
            matches = re.findall(pattern, text)
            citations.extend(matches)
        return list(set(citations))

    def citation_precision(self, generated_text: str, source_document: str) -> float:
        """Calculate precision"""
        generated_citations = self.extract_citations(generated_text)
        source_citations = self.extract_citations(source_document)
        if len(generated_citations) == 0:
            return 1.0 if len(source_citations) == 0 else 0.0
        correct_citations = sum(1 for cit in generated_citations if cit in source_citations)
        return correct_citations / len(generated_citations)

    def evaluate_summary(self, generated_summary: str, source_document: str, ground_truth_citations: List[str] = None) -> Dict[str, float]:
        """Comprehensive evaluation"""
        return {
            'citation_precision': self.citation_precision(generated_summary, source_document),
            'num_citations_generated': len(self.extract_citations(generated_summary)),
        }
'''

with open('src/evaluation/citation_evaluator.py', 'w') as f:
    f.write(citation_evaluator_code)

print("✓ citation_evaluator.py created")


✓ citation_evaluator.py created


In [None]:
# Create train_finetune.py
train_finetune_code = '''"""
Fine-tuning script for Llama 3
"""
import os
import yaml
import argparse
from pathlib import Path
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
from peft import LoraConfig
from src.models.llama3_model import Llama3Model
from src.data_processing.dataset_utils import load_jsonl, create_finetune_dataset


def load_config(config_path: str) -> dict:
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="configs/finetune_config.yaml")
    args = parser.parse_args()

    config = load_config(args.config)
    model_config = config['model']
    lora_config_dict = config['lora']
    training_config = config['training']
    data_config = config['data']

    print("Loading Llama 3 model...")
    llama_model = Llama3Model(
        model_name=model_config['name'],
        use_4bit=model_config['use_4bit'],
        bnb_4bit_compute_dtype=model_config['bnb_4bit_compute_dtype'],
        bnb_4bit_quant_type=model_config['bnb_4bit_quant_type'],
        bnb_4bit_use_double_quant=model_config['bnb_4bit_use_double_quant']
    )

    lora_config = LoraConfig(
        r=lora_config_dict['r'],
        lora_alpha=lora_config_dict['lora_alpha'],
        target_modules=lora_config_dict['target_modules'],
        lora_dropout=lora_config_dict['lora_dropout'],
        bias=lora_config_dict['bias'],
        task_type=lora_config_dict['task_type']
    )

    model = llama_model.prepare_for_training(lora_config)
    tokenizer = llama_model.get_tokenizer()

    print(f"Loading training data from {data_config['train_path']}...")
    train_examples = load_jsonl(data_config['train_path'])
    if data_config.get('max_samples'):
        train_examples = train_examples[:data_config['max_samples']]

    val_examples = []
    if data_config.get('val_path') and Path(data_config['val_path']).exists():
        val_examples = load_jsonl(data_config['val_path'])

    print("Creating datasets...")
    train_dataset = create_finetune_dataset(train_examples, tokenizer)
    eval_dataset = create_finetune_dataset(val_examples, tokenizer) if val_examples else None

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    training_args = TrainingArguments(
        output_dir=training_config['output_dir'],
        num_train_epochs=training_config['num_train_epochs'],
        per_device_train_batch_size=training_config['per_device_train_batch_size'],
        per_device_eval_batch_size=training_config['per_device_eval_batch_size'],
        gradient_accumulation_steps=training_config['gradient_accumulation_steps'],
        learning_rate=training_config['learning_rate'],
        lr_scheduler_type=training_config['lr_scheduler_type'],
        warmup_steps=training_config['warmup_steps'],
        logging_steps=training_config['logging_steps'],
        save_steps=training_config['save_steps'],
        eval_steps=training_config['eval_steps'] if eval_dataset else None,
        save_total_limit=training_config['save_total_limit'],
        fp16=training_config['fp16'],
        gradient_checkpointing=training_config['gradient_checkpointing'],
        optim=training_config['optim'],
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
    )

    print("Starting training...")
    trainer.train()
    trainer.save_model()
    tokenizer.save_pretrained(training_config['output_dir'])
    print("Training completed!")


if __name__ == "__main__":
    main()
'''

with open('train_finetune.py', 'w') as f:
    f.write(train_finetune_code)

print("✓ train_finetune.py created")


✓ train_finetune.py created


In [None]:
# Create train_dpo.py
train_dpo_code = '''"""
DPO training script
"""
import os
import yaml
import argparse
from pathlib import Path
from transformers import TrainingArguments, AutoModelForCausalLM
from trl import DPOTrainer, DPOConfig
from peft import PeftModel
import torch
from src.models.llama3_model import Llama3Model
from src.data_processing.dataset_utils import load_jsonl, create_dpo_dataset


def load_config(config_path: str) -> dict:
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="configs/dpo_config.yaml")
    args = parser.parse_args()

    config = load_config(args.config)
    model_config = config['model']
    training_config = config['training']
    dpo_config = config['dpo']
    data_config = config['data']

    print(f"Loading base model from {model_config['base_model_path']}...")
    base_model_path = model_config['base_model_path']

    llama_model = Llama3Model(
        model_name=base_model_path if Path(base_model_path).exists() else model_config.get('name', "meta-llama/Meta-Llama-3-8B-Instruct"),
        use_4bit=model_config['use_4bit'],
        bnb_4bit_compute_dtype=model_config['bnb_4bit_compute_dtype'],
        bnb_4bit_quant_type=model_config['bnb_4bit_quant_type'],
        bnb_4bit_use_double_quant=model_config['bnb_4bit_use_double_quant']
    )

    if Path(base_model_path).exists():
        try:
            model = PeftModel.from_pretrained(llama_model.get_model(), base_model_path, device_map="auto")
            print("Loaded fine-tuned PEFT weights")
        except:
            model = llama_model.get_model()
            print("Using base model")
    else:
        model = llama_model.get_model()

    ref_model = AutoModelForCausalLM.from_pretrained(
        model_config.get('name', "meta-llama/Meta-Llama-3-8B-Instruct"),
        quantization_config=llama_model.bnb_config if model_config['use_4bit'] else None,
        device_map="auto",
        trust_remote_code=True,
        token=os.getenv("HF_TOKEN"),
        torch_dtype=torch.float16 if not model_config['use_4bit'] else None,
    )

    if Path(base_model_path).exists():
        try:
            ref_model = PeftModel.from_pretrained(ref_model, base_model_path, device_map="auto")
        except:
            pass

    tokenizer = llama_model.get_tokenizer()

    print(f"Loading DPO training data...")
    train_examples = load_jsonl(data_config['train_path'])
    val_examples = load_jsonl(data_config['val_path']) if data_config.get('val_path') and Path(data_config['val_path']).exists() else []

    train_dataset = create_dpo_dataset(train_examples, tokenizer)
    eval_dataset = create_dpo_dataset(val_examples, tokenizer) if val_examples else None

    dpo_training_args = DPOConfig(
        output_dir=training_config['output_dir'],
        num_train_epochs=training_config['num_train_epochs'],
        per_device_train_batch_size=training_config['per_device_train_batch_size'],
        per_device_eval_batch_size=training_config['per_device_eval_batch_size'],
        gradient_accumulation_steps=training_config['gradient_accumulation_steps'],
        learning_rate=training_config['learning_rate'],
        lr_scheduler_type=training_config['lr_scheduler_type'],
        warmup_steps=training_config['warmup_steps'],
        logging_steps=training_config['logging_steps'],
        save_steps=training_config['save_steps'],
        eval_steps=training_config['eval_steps'] if eval_dataset else None,
        save_total_limit=training_config['save_total_limit'],
        fp16=training_config['fp16'],
        gradient_checkpointing=training_config['gradient_checkpointing'],
        optim=training_config['optim'],
        max_length=training_config['max_length'],
        max_prompt_length=training_config['max_prompt_length'],
        beta=dpo_config['beta'],
        loss_type=dpo_config['loss_type'],
    )

    dpo_trainer = DPOTrainer(
        model=model,
        ref_model=ref_model,
        args=dpo_training_args,
        beta=dpo_config['beta'],
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        max_length=training_config['max_length'],
        max_prompt_length=training_config['max_prompt_length'],
    )

    print("Starting DPO training...")
    dpo_trainer.train()
    dpo_trainer.save_model()
    tokenizer.save_pretrained(training_config['output_dir'])
    print("DPO training completed!")


if __name__ == "__main__":
    main()
'''

with open('train_dpo.py', 'w') as f:
    f.write(train_dpo_code)

print("✓ train_dpo.py created")
print("\n✅ All source files created!")


✓ train_dpo.py created

✅ All source files created!


## Create Configuration Files

These configs are optimized for Colab's T4 GPU (16GB VRAM).


In [None]:
# Create fine-tuning config (optimized for Colab)
import yaml

finetune_config = {
    'model': {
        'name': 'meta-llama/Meta-Llama-3-8B-Instruct',
        'use_4bit': True,
        'bnb_4bit_compute_dtype': 'float16',
        'bnb_4bit_quant_type': 'nf4',
        'bnb_4bit_use_double_quant': True
    },
    'lora': {
        'r': 16,
        'lora_alpha': 32,
        'target_modules': ['q_proj', 'v_proj', 'k_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
        'lora_dropout': 0.1,
        'bias': 'none',
        'task_type': 'CAUSAL_LM'
    },
    'training': {
        'output_dir': './models/finetuned_llama3',
        'num_train_epochs': 3,
        'per_device_train_batch_size': 2,  # Reduced for Colab T4
        'per_device_eval_batch_size': 2,
        'gradient_accumulation_steps': 8,  # Increased to compensate
        'learning_rate': 2e-4,
        'lr_scheduler_type': 'cosine',
        'warmup_steps': 100,
        'logging_steps': 10,
        'save_steps': 500,
        'eval_steps': 500,
        'save_total_limit': 3,
        'fp16': True,
        'gradient_checkpointing': True,
        'optim': 'paged_adamw_32bit',
        'max_seq_length': 2048
    },
    'data': {
        'train_path': 'data/train.jsonl',
        'val_path': 'data/val.jsonl',
        'max_samples': None
    }
}

os.makedirs('configs', exist_ok=True)
with open('configs/finetune_config.yaml', 'w') as f:
    yaml.dump(finetune_config, f)

print("✓ Fine-tuning config created")


✓ Fine-tuning config created


In [None]:
# Generate DPO training data (chosen vs rejected pairs)
dpo_examples = []

for ex in finetune_examples:
    prompt = ex['prompt']
    chosen = ex['completion']  # Good response with citations

    # Create rejected variants (without citations or with hallucinations)
    rejected_variants = [
        # Variant 1: Remove citations
        re.sub(r'\([^)]*\d+\s+[A-Z][a-z]+\s+\d+[^)]*\)', '', chosen).strip(),
        # Variant 2: Generic response without specific details
        'The court made a decision based on the legal principles and facts presented in the case.',
        # Variant 3: Add fake citation
        chosen.replace('Smith v. Jones', 'Brown v. White, 789 F.2d 123 (2019)') if 'Smith' in chosen else chosen + ' (See Johnson v. Smith, 456 F.3d 789 (2021))'
    ]

    # Use first variant that's different from chosen
    rejected = rejected_variants[0]
    if rejected == chosen or len(rejected) < 20:
        rejected = rejected_variants[1]

    dpo_examples.append({
        'prompt': prompt,
        'chosen': chosen,
        'rejected': rejected
    })

# Save DPO data
with open('data/dpo_train.jsonl', 'w') as f:
    for item in dpo_examples:
        f.write(json.dumps(item, ensure_ascii=False) + '\n')

with open('data/dpo_val.jsonl', 'w') as f:
    for item in dpo_examples[:2]:
        f.write(json.dumps(item, ensure_ascii=False) + '\n')

print(f"✓ Created {len(dpo_examples)} DPO training pairs")
print(f"✓ Created {len(dpo_examples[:2])} DPO validation pairs")
print("\n✅ All training data generated!")


✓ Created 28 DPO training pairs
✓ Created 2 DPO validation pairs

✅ All training data generated!


In [None]:
# Create DPO config
dpo_config = {
    'model': {
        'base_model_path': './models/finetuned_llama3',
        'use_4bit': True,
        'bnb_4bit_compute_dtype': 'float16',
        'bnb_4bit_quant_type': 'nf4',
        'bnb_4bit_use_double_quant': True
    },
    'training': {
        'output_dir': './models/dpo_llama3',
        'num_train_epochs': 2,
        'per_device_train_batch_size': 1,  # Reduced for Colab
        'per_device_eval_batch_size': 1,
        'gradient_accumulation_steps': 8,
        'learning_rate': 1e-5,
        'lr_scheduler_type': 'cosine',
        'warmup_steps': 100,
        'logging_steps': 10,
        'save_steps': 500,
        'eval_steps': 500,
        'save_total_limit': 3,
        'fp16': True,
        'gradient_checkpointing': True,
        'optim': 'paged_adamw_32bit',
        'max_seq_length': 2048,
        'max_prompt_length': 512
    },
    'dpo': {
        'beta': 0.1,
        'loss_type': 'sigmoid',
        'label_smoothing': 0.0,
        'reference_free': False
    },
    'data': {
        'train_path': 'data/dpo_train.jsonl',
        'val_path': 'data/dpo_val.jsonl',
        'max_samples': None
    }
}

with open('configs/dpo_config.yaml', 'w') as f:
    yaml.dump(dpo_config, f)

print("✓ DPO config created")


✓ DPO config created


## Prepare Training Data

Upload your training data files to the `data/` directory, or create sample data for testing below.


In [None]:
# Generate comprehensive training data (25+ examples covering diverse legal areas)
import json
import os
import re

os.makedirs('data', exist_ok=True)

# Extensive fine-tuning data covering multiple legal domains
finetune_examples = [
    # Civil Procedure - Motion to Dismiss
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nThe court held that the defendant\'s motion to dismiss should be granted. See Smith v. Jones, 123 F.3d 456 (2020). The plaintiff failed to state a claim upon which relief can be granted under Rule 12(b)(6) of the Federal Rules of Civil Procedure.\n\nCitations found: Smith v. Jones, 123 F.3d 456, Rule 12(b)(6)\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'The court granted the defendant\'s motion to dismiss (Smith v. Jones, 123 F.3d 456 (2020)) because the plaintiff failed to state a claim under Federal Rule of Civil Procedure 12(b)(6).'
    },
    # Evidence - Hearsay
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nUnder the Federal Rules of Evidence, hearsay is generally inadmissible unless it falls within an exception. See Fed. R. Evid. 802. However, statements made by a party opponent are admissible as non-hearsay. See Fed. R. Evid. 801(d)(2).\n\nCitations found: Fed. R. Evid. 802, Fed. R. Evid. 801(d)(2)\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'Hearsay is generally inadmissible under Federal Rule of Evidence 802, but statements by party opponents are admissible as non-hearsay under Federal Rule of Evidence 801(d)(2).'
    },
    # Contracts - Statute of Limitations
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nThe statute of limitations for breach of contract is four years. See 28 U.S.C. § 1658. The plaintiff filed this action more than four years after the alleged breach occurred. Therefore, the claim is time-barred.\n\nCitations found: 28 U.S.C. § 1658\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'The plaintiff\'s breach of contract claim is time-barred under 28 U.S.C. § 1658, which establishes a four-year statute of limitations, because the action was filed more than four years after the alleged breach.'
    },
    # Civil Procedure - Summary Judgment
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nThe court applied the summary judgment standard from Celotex Corp. v. Catrett, 477 U.S. 317 (1986). Under this standard, the moving party must show there is no genuine issue of material fact. See Fed. R. Civ. P. 56(a).\n\nCitations found: Celotex Corp. v. Catrett, 477 U.S. 317, Fed. R. Civ. P. 56(a)\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'The court applied the summary judgment standard from Celotex Corp. v. Catrett, 477 U.S. 317 (1986), which requires the moving party to show no genuine issue of material fact exists under Federal Rule of Civil Procedure 56(a).'
    },
    # Criminal Procedure - Exclusionary Rule
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nThe exclusionary rule prohibits the use of evidence obtained in violation of the Fourth Amendment. See Mapp v. Ohio, 367 U.S. 643 (1961). However, the good faith exception applies when officers act in reasonable reliance on a warrant. See United States v. Leon, 468 U.S. 897 (1984).\n\nCitations found: Mapp v. Ohio, 367 U.S. 643, United States v. Leon, 468 U.S. 897\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'The exclusionary rule from Mapp v. Ohio, 367 U.S. 643 (1961) prohibits use of evidence obtained in violation of the Fourth Amendment, but the good faith exception from United States v. Leon, 468 U.S. 897 (1984) applies when officers act in reasonable reliance on a warrant.'
    },
    # Civil Procedure - Personal Jurisdiction
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nFor a court to exercise personal jurisdiction over a non-resident defendant, the defendant must have minimum contacts with the forum state. See International Shoe Co. v. Washington, 326 U.S. 310 (1945). The contacts must be such that maintenance of the suit does not offend traditional notions of fair play and substantial justice. See Burger King Corp. v. Rudzewicz, 471 U.S. 462 (1985).\n\nCitations found: International Shoe Co. v. Washington, 326 U.S. 310, Burger King Corp. v. Rudzewicz, 471 U.S. 462\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'Personal jurisdiction over a non-resident defendant requires minimum contacts with the forum state under International Shoe Co. v. Washington, 326 U.S. 310 (1945), and the contacts must satisfy traditional notions of fair play and substantial justice under Burger King Corp. v. Rudzewicz, 471 U.S. 462 (1985).'
    },
    # Civil Procedure - Class Actions
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nTo certify a class action, the plaintiff must satisfy the requirements of Federal Rule of Civil Procedure 23(a) and at least one subsection of Rule 23(b). See Fed. R. Civ. P. 23. The class must be so numerous that joinder is impracticable. See Fed. R. Civ. P. 23(a)(1).\n\nCitations found: Fed. R. Civ. P. 23, Fed. R. Civ. P. 23(a)(1)\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'Class action certification requires satisfying Federal Rule of Civil Procedure 23(a) and at least one subsection of Rule 23(b), including the numerosity requirement under Rule 23(a)(1) that the class be so numerous that joinder is impracticable.'
    },
    # Criminal Procedure - Miranda Rights
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nBefore custodial interrogation, law enforcement must inform suspects of their Miranda rights. See Miranda v. Arizona, 384 U.S. 436 (1966). These rights include the right to remain silent and the right to an attorney. See Miranda v. Arizona, 384 U.S. 436 (1966).\n\nCitations found: Miranda v. Arizona, 384 U.S. 436\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'Under Miranda v. Arizona, 384 U.S. 436 (1966), law enforcement must inform suspects of their Miranda rights before custodial interrogation, including the right to remain silent and the right to an attorney.'
    },
    # Constitutional Law - Qualified Immunity
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nGovernment officials are entitled to qualified immunity unless their conduct violates clearly established statutory or constitutional rights. See Harlow v. Fitzgerald, 457 U.S. 800 (1982). The right must be sufficiently clear that a reasonable official would understand that what he is doing violates that right. See Anderson v. Creighton, 483 U.S. 635 (1987).\n\nCitations found: Harlow v. Fitzgerald, 457 U.S. 800, Anderson v. Creighton, 483 U.S. 635\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'Qualified immunity protects government officials unless their conduct violates clearly established rights under Harlow v. Fitzgerald, 457 U.S. 800 (1982), and the right must be sufficiently clear under Anderson v. Creighton, 483 U.S. 635 (1987).'
    },
    # Evidence - Attorney-Client Privilege
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nThe attorney-client privilege protects confidential communications between a client and attorney made for the purpose of obtaining legal advice. See Upjohn Co. v. United States, 449 U.S. 383 (1981). The privilege belongs to the client and may be waived only by the client. See Swidler & Berlin v. United States, 524 U.S. 399 (1998).\n\nCitations found: Upjohn Co. v. United States, 449 U.S. 383, Swidler & Berlin v. United States, 524 U.S. 399\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'The attorney-client privilege protects confidential communications for legal advice under Upjohn Co. v. United States, 449 U.S. 383 (1981), and the privilege belongs to the client and may only be waived by the client under Swidler & Berlin v. United States, 524 U.S. 399 (1998).'
    },
    # Constitutional Law - Standing
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nTo establish Article III standing, a plaintiff must show injury in fact, causation, and redressability. See Lujan v. Defenders of Wildlife, 504 U.S. 555 (1992). The injury must be concrete and particularized, and actual or imminent. See Lujan v. Defenders of Wildlife, 504 U.S. 555 (1992).\n\nCitations found: Lujan v. Defenders of Wildlife, 504 U.S. 555\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'Article III standing requires injury in fact, causation, and redressability under Lujan v. Defenders of Wildlife, 504 U.S. 555 (1992), and the injury must be concrete, particularized, and actual or imminent.'
    },
    # Civil Procedure - Discovery Sanctions
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nA court may impose sanctions for failure to comply with discovery orders under Federal Rule of Civil Procedure 37(b). See Fed. R. Civ. P. 37(b). Sanctions may include dismissal of the action or default judgment. See Fed. R. Civ. P. 37(b)(2)(A).\n\nCitations found: Fed. R. Civ. P. 37(b), Fed. R. Civ. P. 37(b)(2)(A)\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'Courts may impose sanctions for discovery violations under Federal Rule of Civil Procedure 37(b), including dismissal or default judgment under Rule 37(b)(2)(A).'
    },
    # Constitutional Law - Equal Protection
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nThe Equal Protection Clause requires that similarly situated persons be treated alike. See City of Cleburne v. Cleburne Living Center, 473 U.S. 432 (1985). Strict scrutiny applies to classifications based on race. See Loving v. Virginia, 388 U.S. 1 (1967).\n\nCitations found: City of Cleburne v. Cleburne Living Center, 473 U.S. 432, Loving v. Virginia, 388 U.S. 1\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'The Equal Protection Clause requires similar treatment of similarly situated persons under City of Cleburne v. Cleburne Living Center, 473 U.S. 432 (1985), and strict scrutiny applies to racial classifications under Loving v. Virginia, 388 U.S. 1 (1967).'
    },
    # Evidence - Work Product
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nThe work product doctrine protects materials prepared in anticipation of litigation. See Hickman v. Taylor, 329 U.S. 495 (1947). This protection is codified in Federal Rule of Civil Procedure 26(b)(3). See Fed. R. Civ. P. 26(b)(3).\n\nCitations found: Hickman v. Taylor, 329 U.S. 495, Fed. R. Civ. P. 26(b)(3)\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'The work product doctrine protects litigation materials under Hickman v. Taylor, 329 U.S. 495 (1947), and is codified in Federal Rule of Civil Procedure 26(b)(3).'
    },
    # Constitutional Law - Due Process
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nProcedural due process requires notice and an opportunity to be heard. See Mullane v. Central Hanover Bank & Trust Co., 339 U.S. 306 (1950). The notice must be reasonably calculated to inform interested parties. See Mullane v. Central Hanover Bank & Trust Co., 339 U.S. 306 (1950).\n\nCitations found: Mullane v. Central Hanover Bank & Trust Co., 339 U.S. 306\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'Procedural due process requires notice and opportunity to be heard under Mullane v. Central Hanover Bank & Trust Co., 339 U.S. 306 (1950), and notice must be reasonably calculated to inform interested parties.'
    },
    # Antitrust Law
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nSection 1 of the Sherman Act prohibits contracts, combinations, or conspiracies in restraint of trade. See 15 U.S.C. § 1. Section 2 prohibits monopolization or attempts to monopolize. See 15 U.S.C. § 2.\n\nCitations found: 15 U.S.C. § 1, 15 U.S.C. § 2\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'Section 1 of the Sherman Act (15 U.S.C. § 1) prohibits contracts in restraint of trade, while Section 2 (15 U.S.C. § 2) prohibits monopolization or attempts to monopolize.'
    },
    # Employment Law
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nTitle VII of the Civil Rights Act prohibits employment discrimination based on race, color, religion, sex, or national origin. See 42 U.S.C. § 2000e-2. A plaintiff may establish discrimination through direct evidence or the McDonnell Douglas burden-shifting framework. See McDonnell Douglas Corp. v. Green, 411 U.S. 792 (1973).\n\nCitations found: 42 U.S.C. § 2000e-2, McDonnell Douglas Corp. v. Green, 411 U.S. 792\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'Title VII (42 U.S.C. § 2000e-2) prohibits employment discrimination, and plaintiffs may establish discrimination through direct evidence or the McDonnell Douglas framework from McDonnell Douglas Corp. v. Green, 411 U.S. 792 (1973).'
    },
    # Intellectual Property - Patents
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nTo establish patent infringement, the patentee must show that the accused product or process meets each element of at least one claim of the patent. See 35 U.S.C. § 271. The doctrine of equivalents allows finding infringement even when not literally infringing. See Warner-Jenkinson Co. v. Hilton Davis Chemical Co., 520 U.S. 17 (1997).\n\nCitations found: 35 U.S.C. § 271, Warner-Jenkinson Co. v. Hilton Davis Chemical Co., 520 U.S. 17\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'Patent infringement requires showing the accused product meets each element of a claim under 35 U.S.C. § 271, and the doctrine of equivalents from Warner-Jenkinson Co. v. Hilton Davis Chemical Co., 520 U.S. 17 (1997) allows finding infringement beyond literal infringement.'
    },
    # Securities Law
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nSection 10(b) of the Securities Exchange Act prohibits fraudulent practices in connection with the purchase or sale of securities. See 15 U.S.C. § 78j(b). Rule 10b-5 implements this prohibition. See 17 C.F.R. § 240.10b-5.\n\nCitations found: 15 U.S.C. § 78j(b), 17 C.F.R. § 240.10b-5\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'Section 10(b) of the Securities Exchange Act (15 U.S.C. § 78j(b)) prohibits securities fraud, and Rule 10b-5 (17 C.F.R. § 240.10b-5) implements this prohibition.'
    },
    # Contracts
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nWhen interpreting contracts, courts give effect to the plain meaning of unambiguous terms. See Restatement (Second) of Contracts § 202. Parol evidence is inadmissible to contradict unambiguous written terms. See Restatement (Second) of Contracts § 213.\n\nCitations found: Restatement (Second) of Contracts § 202, Restatement (Second) of Contracts § 213\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'Courts give effect to plain meaning of unambiguous contract terms under Restatement (Second) of Contracts § 202, and parol evidence is inadmissible to contradict unambiguous terms under § 213.'
    },
    # Bankruptcy Law
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nA bankruptcy discharge releases the debtor from personal liability for dischargeable debts. See 11 U.S.C. § 524. Certain debts are excepted from discharge, including student loans unless repayment would impose undue hardship. See 11 U.S.C. § 523(a)(8).\n\nCitations found: 11 U.S.C. § 524, 11 U.S.C. § 523(a)(8)\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'Bankruptcy discharge releases debtors from liability under 11 U.S.C. § 524, but certain debts are excepted from discharge, including student loans under 11 U.S.C. § 523(a)(8) unless repayment would cause undue hardship.'
    },
    # Criminal Procedure - Fourth Amendment
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nA search violates the Fourth Amendment unless it is reasonable. See Katz v. United States, 389 U.S. 347 (1967). Warrantless searches are per se unreasonable unless they fall within a recognized exception. See Coolidge v. New Hampshire, 403 U.S. 443 (1971).\n\nCitations found: Katz v. United States, 389 U.S. 347, Coolidge v. New Hampshire, 403 U.S. 443\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'Searches must be reasonable under the Fourth Amendment per Katz v. United States, 389 U.S. 347 (1967), and warrantless searches are per se unreasonable unless falling within an exception under Coolidge v. New Hampshire, 403 U.S. 443 (1971).'
    },
    # ERISA
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nERISA fiduciaries must act solely in the interest of plan participants and beneficiaries. See 29 U.S.C. § 1104(a)(1). This duty of loyalty requires fiduciaries to avoid conflicts of interest. See Donovan v. Bierwirth, 680 F.2d 263 (2d Cir. 1982).\n\nCitations found: 29 U.S.C. § 1104(a)(1), Donovan v. Bierwirth, 680 F.2d 263\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'ERISA fiduciaries must act solely in participants\' interests under 29 U.S.C. § 1104(a)(1), and the duty of loyalty requires avoiding conflicts of interest under Donovan v. Bierwirth, 680 F.2d 263 (2d Cir. 1982).'
    },
    # Intellectual Property - Trademarks
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nTrademark infringement requires showing a likelihood of confusion. See 15 U.S.C. § 1114. Courts consider factors such as similarity of marks, similarity of goods, and strength of the mark. See Polaroid Corp. v. Polarad Electronics Corp., 287 F.2d 492 (2d Cir. 1961).\n\nCitations found: 15 U.S.C. § 1114, Polaroid Corp. v. Polarad Electronics Corp., 287 F.2d 492\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'Trademark infringement requires likelihood of confusion under 15 U.S.C. § 1114, and courts consider factors including mark similarity and strength under Polaroid Corp. v. Polarad Electronics Corp., 287 F.2d 492 (2d Cir. 1961).'
    },
    # Intellectual Property - Copyright
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nFair use is a defense to copyright infringement. See 17 U.S.C. § 107. Courts consider the purpose and character of use, nature of the work, amount used, and effect on the market. See Campbell v. Acuff-Rose Music, Inc., 510 U.S. 569 (1994).\n\nCitations found: 17 U.S.C. § 107, Campbell v. Acuff-Rose Music, Inc., 510 U.S. 569\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'Fair use is a defense to copyright infringement under 17 U.S.C. § 107, and courts consider purpose, nature, amount, and market effect under Campbell v. Acuff-Rose Music, Inc., 510 U.S. 569 (1994).'
    },
    # Employment Law - ADA
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nThe Americans with Disabilities Act requires employers to provide reasonable accommodations to qualified individuals with disabilities. See 42 U.S.C. § 12112(b)(5)(A). An accommodation is reasonable if it does not impose an undue hardship. See 42 U.S.C. § 12111(10).\n\nCitations found: 42 U.S.C. § 12112(b)(5)(A), 42 U.S.C. § 12111(10)\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'The ADA requires reasonable accommodations for qualified individuals with disabilities under 42 U.S.C. § 12112(b)(5)(A), and accommodations must not impose undue hardship under 42 U.S.C. § 12111(10).'
    },
    # RICO
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nRICO prohibits conducting an enterprise through a pattern of racketeering activity. See 18 U.S.C. § 1962(c). A pattern requires at least two predicate acts within ten years. See 18 U.S.C. § 1961(5).\n\nCitations found: 18 U.S.C. § 1962(c), 18 U.S.C. § 1961(5)\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'RICO prohibits conducting an enterprise through racketeering activity under 18 U.S.C. § 1962(c), and a pattern requires at least two predicate acts within ten years under 18 U.S.C. § 1961(5).'
    },
    # Torts - Product Liability
    {
        'prompt': 'Analyze the following legal document excerpt and create a summary with accurate citations.\n\nDocument Excerpt:\nStrict product liability requires showing the product was defective and unreasonably dangerous. See Restatement (Second) of Torts § 402A. A product may be defective in design, manufacture, or warning. See Restatement (Third) of Torts: Products Liability § 2.\n\nCitations found: Restatement (Second) of Torts § 402A, Restatement (Third) of Torts: Products Liability § 2\n\nInstructions:\n1. Create a concise summary of the key points\n2. Include all relevant citations in your summary\n3. Ensure every factual claim is tied to a specific citation\n4. Do not make up information or citations\n\nSummary:',
        'completion': 'Strict product liability requires a defective and unreasonably dangerous product under Restatement (Second) of Torts § 402A, and defects may be in design, manufacture, or warning under Restatement (Third) of Torts: Products Liability § 2.'
    }
]

# Save fine-tuning data
with open('data/train.jsonl', 'w') as f:
    for item in finetune_examples:
        f.write(json.dumps(item, ensure_ascii=False) + '\n')

# Validation data (20% split)
val_size = max(1, len(finetune_examples) // 5)
with open('data/val.jsonl', 'w') as f:
    for item in finetune_examples[:val_size]:
        f.write(json.dumps(item, ensure_ascii=False) + '\n')

print(f"✓ Created {len(finetune_examples)} fine-tuning examples")
print(f"✓ Created {val_size} validation examples")
print(f"✓ Training examples cover: Civil Procedure, Evidence, Constitutional Law, Employment, IP, Securities, Antitrust, Bankruptcy, and more")


✓ Created 28 fine-tuning examples
✓ Created 5 validation examples
✓ Training examples cover: Civil Procedure, Evidence, Constitutional Law, Employment, IP, Securities, Antitrust, Bankruptcy, and more


## Fine-tune Llama 3

**⚠️ Make sure you've uploaded all project source files before running this!**


In [None]:
# Generate DPO training data (chosen vs rejected pairs)
dpo_examples = []

for ex in finetune_examples:
    prompt = ex['prompt']
    chosen = ex['completion']  # Good response with citations

    # Create rejected variants (without citations or with hallucinations)
    rejected_variants = [
        # Variant 1: Remove citations
        re.sub(r'\([^)]*\d+\s+[A-Z][a-z]+\s+\d+[^)]*\)', '', chosen).strip(),
        # Variant 2: Generic response without specific details
        'The court made a decision based on the legal principles and facts presented in the case.',
        # Variant 3: Add fake citation
        chosen.replace('Smith v. Jones', 'Brown v. White, 789 F.2d 123 (2019)') if 'Smith' in chosen else chosen + ' (See Johnson v. Smith, 456 F.3d 789 (2021))'
    ]

    # Use first variant that's different from chosen
    rejected = rejected_variants[0]
    if rejected == chosen or len(rejected) < 20:
        rejected = rejected_variants[1]

    dpo_examples.append({
        'prompt': prompt,
        'chosen': chosen,
        'rejected': rejected
    })

# Save DPO data
with open('data/dpo_train.jsonl', 'w') as f:
    for item in dpo_examples:
        f.write(json.dumps(item, ensure_ascii=False) + '\n')

with open('data/dpo_val.jsonl', 'w') as f:
    for item in dpo_examples[:2]:
        f.write(json.dumps(item, ensure_ascii=False) + '\n')

print(f"✓ Created {len(dpo_examples)} DPO training pairs")
print(f"✓ Created {len(dpo_examples[:2])} DPO validation pairs")
print("\n✅ All training data generated!")


✓ Created 28 DPO training pairs
✓ Created 2 DPO validation pairs

✅ All training data generated!


In [None]:
# Run fine-tuning
!python train_finetune.py --config configs/finetune_config.yaml


## DPO Training

Run this after fine-tuning is complete. First, create DPO pairs if needed.


In [None]:
# DPO data already generated above - ready to train!
print("✓ DPO data ready")


✓ DPO data ready


In [None]:
# Run DPO training
!python train_dpo.py --config configs/dpo_config.yaml


2025-12-09 06:05:37.393221: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765260337.412938    2645 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765260337.418825    2645 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1765260337.433559    2645 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1765260337.433584    2645 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1765260337.433588    2645 computation_placer.cc:177] computation placer alr

## Save Models to Google Drive

**Important**: Colab files are temporary. Always save models to Google Drive!


In [None]:
# Copy models to Google Drive
import shutil

DRIVE_MODELS_DIR = '/content/drive/MyDrive/legal_document_models'
os.makedirs(DRIVE_MODELS_DIR, exist_ok=True)

# Copy fine-tuned model
if os.path.exists('models/finetuned_llama3'):
    shutil.copytree('models/finetuned_llama3',
                    f'{DRIVE_MODELS_DIR}/finetuned_llama3',
                    dirs_exist_ok=True)
    print("✓ Fine-tuned model saved to Drive")

# Copy DPO model
if os.path.exists('models/dpo_llama3'):
    shutil.copytree('models/dpo_llama3',
                    f'{DRIVE_MODELS_DIR}/dpo_llama3',
                    dirs_exist_ok=True)
    print("✓ DPO model saved to Drive")

print(f"\nModels saved to: {DRIVE_MODELS_DIR}")


In [None]:
# Test inference
import sys
sys.path.append('/content/legal-document-analysis')

from src.models.llama3_model import Llama3Model

# Load model (use DPO model if available, otherwise fine-tuned)
model_path = 'models/dpo_llama3' if os.path.exists('models/dpo_llama3') else 'models/finetuned_llama3'
if not os.path.exists(model_path):
    model_path = 'meta-llama/Meta-Llama-3-8B-Instruct'
    print("⚠️ Using base model (trained models not found)")

print(f"Loading model from: {model_path}")

model = Llama3Model(
    model_name=model_path,
    use_4bit=True
)

# Test generation
prompt = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful legal document analyst assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>

Analyze the following legal document excerpt and create a summary with accurate citations.

Document Excerpt:
The court held that the defendant's motion to dismiss should be granted. See Smith v. Jones, 123 F.3d 456 (2020). The plaintiff failed to state a claim under Rule 12(b)(6).

Summary:<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""

print("\nGenerating response...")
response = model.generate(prompt, max_length=256, temperature=0.7)
print("\n" + "="*60)
print("Generated Response:")
print("="*60)
print(response)
