In [None]:
# LexiBrief: AI-Powered Legal Document Summarizer

# This Colab notebook sets up and runs the LexiBrief project, which uses the Mistral-7B-Instruct model for legal document summarization. The notebook will:

# 1. Set up the environment and dependencies
# 2. Clone the project from GitHub
# 3. Configure the model for CPU/GPU training
# 4. Train and evaluate the model
# 5. Launch the Gradio interface for testing


In [None]:
# Check if we're running in Colab
import sys
import os
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Setting up Colab environment...")

    # Pin websockets to match Gradio
    !pip install -q "websockets>=10.0,<12.0"
    
    # Install Gradio first to lock its dependencies
    !pip install -q "gradio==3.40.1"

    # Install remaining project-specific dependencies
    !pip install -q transformers pandas numpy tqdm \
        peft datasets accelerate bitsandbytes evaluate \
        rouge_score nltk wandb python-dotenv requests PyYAML \
        scipy sentencepiece

    # Clone or pull repo
    if not os.path.exists('lexibrief-ai'):
        print("\nCloning LexiBrief repository...")
        !git clone https://github.com/AryanT7/lexibrief-ai.git
    else:
        print("\nRepository already exists, updating...")
        %cd lexibrief-ai
        !git pull
        %cd ..

    # Confirm GPU
    print("\nVerifying GPU setup...")
    import torch
    if torch.cuda.is_available():
        print(f"GPU Device: {torch.cuda.get_device_name(0)}")
        print(f"Number of GPUs: {torch.cuda.device_count()}")
    else:
        print("No GPU found. Using CPU.")
else:
    print("Running in local environment")


In [None]:
# Fix numpy conflict
!pip uninstall -y numpy
!pip install -q numpy==1.26.4

In [None]:
# Clean up broken install
!pip uninstall -y transformers
!rm -rf /usr/local/lib/python3.11/dist-packages/transformers

!pip install -q transformers


In [None]:
# Import required libraries
import os
import sys
import yaml
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import gradio as gr

# Add the project root to Python path
sys.path.append('.')

In [None]:
# Configure training settings based on available hardware
def setup_training_config():
    config_updates = {
        'model': {
            'name': 'mistralai/Mistral-7B-Instruct-v0.1',
            'load_in_4bit': True,
            'use_flash_attention': True
        },
        'training': {
            'per_device_train_batch_size': 4 if torch.cuda.is_available() else 1,
            'per_device_eval_batch_size': 4 if torch.cuda.is_available() else 1,
            'gradient_accumulation_steps': 4,
            'num_train_epochs': 3,
            'learning_rate': 2e-4,
            'max_grad_norm': 0.3,
            'warmup_ratio': 0.03
        },
        'lora': {
            'r': 64,
            'lora_alpha': 16,
            'target_modules': ["q_proj", "k_proj", "v_proj", "o_proj"],
            'bias': "none",
            'task_type': "CAUSAL_LM"
        },
        'hardware': {
            'mixed_precision': 'fp16' if torch.cuda.is_available() else 'no',
            'device': 'cuda' if torch.cuda.is_available() else 'cpu'
        }
    }

    # Create configs directory if it doesn't exist
    os.makedirs('configs', exist_ok=True)
    
    # Save config
    with open('configs/training_config.yaml', 'w') as f:
        yaml.dump(config_updates, f, default_flow_style=False)
    
    return config_updates

# Set up the configuration
config = setup_training_config()
print("\nTraining Configuration:")
print(yaml.dump(config, default_flow_style=False))


In [None]:
# Login to Hugging Face
from huggingface_hub import login
login()
#isithf_nkIwjkghPZIhMVdirYaTxgJQqUEGTpGVzGvJ


In [None]:
!pip install -U datasets

In [None]:
# Load and prepare the dataset
from datasets import load_dataset
from transformers import AutoTokenizer
import torch

def prepare_training_data():
    print("Loading billsum dataset...")
    dataset = load_dataset("billsum", download_mode="force_redownload", keep_in_memory=True)
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config['model']['name'])
    tokenizer.pad_token = tokenizer.eos_token
    
    def preprocess_function(examples):
        # Combine input and target for each example
        model_inputs = []
        for text, summary in zip(examples['text'], examples['summary']):
            # Format the prompt
            prompt = f"Instruction: Summarize the following legal document concisely.\n\nDocument: {text}\n\nSummary: {summary}"
            
            # Tokenize the full sequence
            tokenized = tokenizer(
                prompt,
                truncation=True,
                max_length=1024,
                padding="max_length",
                return_tensors="pt"
            )
            
            # Create labels (same as input_ids, -100 for prompt tokens)
            labels = tokenized.input_ids.clone()
            
            # Find the position of "Summary:" in the tokenized input
            prompt_tokens = tokenizer.encode("Summary:", add_special_tokens=False)
            prompt_end_pos = None
            
            for i in range(len(tokenized.input_ids[0]) - len(prompt_tokens)):
                if tokenized.input_ids[0][i:i+len(prompt_tokens)].tolist() == prompt_tokens:
                    prompt_end_pos = i + len(prompt_tokens)
                    break
            
            if prompt_end_pos is not None:
                # Mask the prompt part with -100
                labels[0, :prompt_end_pos] = -100
            
            model_inputs.append({
                'input_ids': tokenized.input_ids[0],
                'attention_mask': tokenized.attention_mask[0],
                'labels': labels[0]
            })
        
        # Convert to batched format
        batch = {
            'input_ids': torch.stack([x['input_ids'] for x in model_inputs]),
            'attention_mask': torch.stack([x['attention_mask'] for x in model_inputs]),
            'labels': torch.stack([x['labels'] for x in model_inputs])
        }
        
        return batch
    
    # Process datasets
    train_dataset = dataset['train'].map(
        preprocess_function,
        batched=True,
        batch_size=8,  # Process in smaller batches
        remove_columns=dataset['train'].column_names,
        desc="Processing training data"
    )
    
    eval_dataset = dataset['test'].map(
        preprocess_function,
        batched=True,
        batch_size=8,  # Process in smaller batches
        remove_columns=dataset['test'].column_names,
        desc="Processing evaluation data"
    )
    
    print(f"\nDataset Statistics:")
    print(f"Training examples: {len(train_dataset)}")
    print(f"Evaluation examples: {len(eval_dataset)}")
    
    return train_dataset, eval_dataset, tokenizer

# Prepare the datasets
train_dataset, eval_dataset, tokenizer = prepare_training_data()


In [None]:
# Initialize and prepare the model for training
def setup_model():
    print("Loading base model...")
    model = AutoModelForCausalLM.from_pretrained(
        config['model']['name'],
        load_in_4bit=config['model']['load_in_4bit'],
        device_map='auto',
        torch_dtype=torch.float16 if config['hardware']['mixed_precision'] == 'fp16' else torch.float32
    )
    
    # Prepare model for k-bit training
    model = prepare_model_for_kbit_training(model)
    
    # Configure LoRA
    peft_config = LoraConfig(
        r=config['lora']['r'],
        lora_alpha=config['lora']['lora_alpha'],
        target_modules=config['lora']['target_modules'],
        bias=config['lora']['bias'],
        task_type=config['lora']['task_type']
    )
    
    # Get PEFT model
    model = get_peft_model(model, peft_config)
    
    # Print trainable parameters
    model.print_trainable_parameters()
    
    return model

# Set up the model
model = setup_model()


In [None]:
# Set up and run training
def train_model():
    # Create output directory
    os.makedirs("./results", exist_ok=True)
    
    # Set up training arguments
    training_args = TrainingArguments(
        output_dir="./results",
        num_train_epochs=config['training']['num_train_epochs'],
        per_device_train_batch_size=config['training']['per_device_train_batch_size'],
        per_device_eval_batch_size=config['training']['per_device_eval_batch_size'],
        gradient_accumulation_steps=config['training']['gradient_accumulation_steps'],
        learning_rate=config['training']['learning_rate'],
        max_grad_norm=config['training']['max_grad_norm'],
        warmup_ratio=config['training']['warmup_ratio'],
        logging_steps=10,
        # Evaluation and saving settings
        eval_strategy="steps",  # Evaluate every eval_steps
        save_strategy="steps",       # Save every save_steps
        eval_steps=50,              # Evaluate every 50 steps
        save_steps=50,              # Save every 50 steps
        save_total_limit=2,         # Keep only the 2 best checkpoints
        load_best_model_at_end=True,
        metric_for_best_model="loss",
        greater_is_better=False,    # Lower loss is better
        # Logging settings
        report_to="none",          # Disable wandb logging
        # Push to hub settings
        push_to_hub=False,
    )

    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
    )

    print("\nStarting training...")
    trainer.train()
    
    print("\nSaving model...")
    trainer.save_model("./final_model")
    
    return trainer

# Run training
trainer = train_model()


In [None]:
# Test the model with a sample document
def test_model():
    test_text = """
    SECTION 1. SHORT TITLE.
    This Act may be cited as the "Sample Legal Document Act of 2024".
    
    SECTION 2. PURPOSE.
    The purpose of this Act is to demonstrate the capabilities of the LexiBrief model
    in summarizing legal documents effectively and accurately.
    """
    
    # Prepare input
    prompt = f"Instruction: Summarize the following legal document concisely.\n\nDocument: {test_text}\n\nSummary:"
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Generate summary
    print("Generating summary...")
    outputs = model.generate(
        **inputs,
        max_new_tokens=256,
        temperature=0.7,
        num_return_sequences=1,
        pad_token_id=tokenizer.eos_token_id
    )
    
    # Decode and print summary
    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print("\nTest Results:")
    print("Original Text:")
    print(test_text)
    print("\nGenerated Summary:")
    print(summary)

# Run test
test_model()


In [None]:
# Launch the Gradio interface for interactive testing
def create_demo():
    def summarize(text):
        prompt = f"Instruction: Summarize the following legal document concisely.\n\nDocument: {text}\n\nSummary:"
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        
        outputs = model.generate(
            **inputs,
            max_new_tokens=256,
            temperature=0.7,
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id
        )
        
        summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return summary

    # Create Gradio interface
    demo = gr.Interface(
        fn=summarize,
        inputs=gr.Textbox(lines=10, label="Input Legal Document"),
        outputs=gr.Textbox(label="Generated Summary"),
        title="LexiBrief: Legal Document Summarizer",
        description="Enter a legal document and get a concise summary.",
        examples=[
            ["SECTION 1. SHORT TITLE.\nThis Act may be cited as the 'Sample Legal Document Act of 2024'.\n\nSECTION 2. PURPOSE.\nThe purpose of this Act is to demonstrate the capabilities of the LexiBrief model in summarizing legal documents effectively and accurately."]
        ]
    )
    return demo

# Launch the interface
demo = create_demo()
demo.launch(share=True)
