In [None]:
# LexiBrief: AI-Powered Legal Document Summarizer

This Colab notebook sets up and runs the LexiBrief project, which uses the Mistral-7B-Instruct model for legal document summarization. The notebook will:

1. Set up the environment and dependencies
2. Clone the project from GitHub
3. Configure the model for CPU/GPU training
4. Train and evaluate the model
5. Launch the Gradio interface for testing

Note: Make sure you're running this in a GPU-enabled Colab runtime for optimal performance.


In [None]:
# Check if we're running in Colab
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Setting up Colab environment...")
    
    # Install all required packages in one go to minimize installation time
    !pip install -q torch --extra-index-url https://download.pytorch.org/whl/cu118 \
        transformers==4.30.2 \
        numpy==1.24.3 \
        pandas==2.0.3 \
        tqdm==4.65.0 \
        scikit-learn \
        gradio==3.40.1 \
        peft==0.4.0 \
        datasets==2.12.0 \
        accelerate==0.21.0 \
        bitsandbytes==0.41.0 \
        PyYAML==6.0
    
    print("\nRestarting runtime to apply changes...")
    import IPython
    IPython.Application.instance().kernel.do_shutdown(True)
else:
    print("Running in local environment")


In [None]:
# Verify the installation and setup environment
import torch
import os
from tqdm.notebook import tqdm

print("Checking environment setup...")

# Verify GPU availability
print("\nGPU Information:")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    print(f"Current GPU Memory Usage: {torch.cuda.memory_allocated(0)/1024**2:.2f} MB")

# Clone repository if it doesn't exist
if not os.path.exists('lexibrief-ai'):
    print("\nCloning LexiBrief repository...")
    !git clone https://github.com/AryanT7/lexibrief-ai.git
    %cd lexibrief-ai
else:
    print("\nRepository already exists, updating...")
    %cd lexibrief-ai
    !git pull

print("\nSetup completed successfully!")


In [None]:
# Import required libraries
import os
import sys
import yaml
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import gradio as gr

# Add the project root to Python path
sys.path.append('.')


In [None]:
# Configure training settings based on available hardware
def setup_training_config():
    config_updates = {
        'model': {
            'name': 'mistralai/Mistral-7B-Instruct-v0.1',
            'load_in_4bit': True,
            'use_flash_attention': True
        },
        'training': {
            'per_device_train_batch_size': 4 if torch.cuda.is_available() else 1,
            'per_device_eval_batch_size': 4 if torch.cuda.is_available() else 1,
            'gradient_accumulation_steps': 4,
            'num_train_epochs': 3,
            'learning_rate': 2e-4,
            'max_grad_norm': 0.3,
            'warmup_ratio': 0.03
        },
        'lora': {
            'r': 64,
            'lora_alpha': 16,
            'target_modules': ["q_proj", "k_proj", "v_proj", "o_proj"],
            'bias': "none",
            'task_type': "CAUSAL_LM"
        },
        'hardware': {
            'mixed_precision': 'fp16' if torch.cuda.is_available() else 'no',
            'device': 'cuda' if torch.cuda.is_available() else 'cpu'
        }
    }

    # Create configs directory if it doesn't exist
    os.makedirs('configs', exist_ok=True)
    
    # Save config
    with open('configs/training_config.yaml', 'w') as f:
        yaml.dump(config_updates, f, default_flow_style=False)
    
    return config_updates

# Set up the configuration
config = setup_training_config()
print("\nTraining Configuration:")
print(yaml.dump(config, default_flow_style=False))


In [None]:
# Load and prepare the dataset
from datasets import load_dataset
from transformers import AutoTokenizer

def prepare_training_data():
    print("Loading billsum dataset...")
    dataset = load_dataset("billsum")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config['model']['name'])
    tokenizer.pad_token = tokenizer.eos_token
    
    def preprocess_function(examples):
        # Combine summary and text with instruction
        prompts = [
            f"Instruction: Summarize the following legal document concisely.\n\nDocument: {text}\n\nSummary:"
            for text in examples['text']
        ]
        
        # Tokenize inputs and targets
        model_inputs = tokenizer(prompts, truncation=True, padding=True, max_length=1024)
        labels = tokenizer(examples['summary'], truncation=True, padding=True, max_length=256)
        model_inputs['labels'] = labels['input_ids']
        
        return model_inputs
    
    # Process datasets
    train_dataset = dataset['train'].map(
        preprocess_function,
        batched=True,
        remove_columns=dataset['train'].column_names
    )
    
    eval_dataset = dataset['test'].map(
        preprocess_function,
        batched=True,
        remove_columns=dataset['test'].column_names
    )
    
    print(f"\nDataset Statistics:")
    print(f"Training examples: {len(train_dataset)}")
    print(f"Evaluation examples: {len(eval_dataset)}")
    
    return train_dataset, eval_dataset, tokenizer

# Prepare the datasets
train_dataset, eval_dataset, tokenizer = prepare_training_data()


In [None]:
# Initialize and prepare the model for training
def setup_model():
    print("Loading base model...")
    model = AutoModelForCausalLM.from_pretrained(
        config['model']['name'],
        load_in_4bit=config['model']['load_in_4bit'],
        device_map='auto',
        torch_dtype=torch.float16 if config['hardware']['mixed_precision'] == 'fp16' else torch.float32
    )
    
    # Prepare model for k-bit training
    model = prepare_model_for_kbit_training(model)
    
    # Configure LoRA
    peft_config = LoraConfig(
        r=config['lora']['r'],
        lora_alpha=config['lora']['lora_alpha'],
        target_modules=config['lora']['target_modules'],
        bias=config['lora']['bias'],
        task_type=config['lora']['task_type']
    )
    
    # Get PEFT model
    model = get_peft_model(model, peft_config)
    
    # Print trainable parameters
    model.print_trainable_parameters()
    
    return model

# Set up the model
model = setup_model()


In [None]:
# Set up and run training
def train_model():
    # Set up training arguments
    training_args = TrainingArguments(
        output_dir="./results",
        num_train_epochs=config['training']['num_train_epochs'],
        per_device_train_batch_size=config['training']['per_device_train_batch_size'],
        per_device_eval_batch_size=config['training']['per_device_eval_batch_size'],
        gradient_accumulation_steps=config['training']['gradient_accumulation_steps'],
        learning_rate=config['training']['learning_rate'],
        max_grad_norm=config['training']['max_grad_norm'],
        warmup_ratio=config['training']['warmup_ratio'],
        logging_steps=10,
        evaluation_strategy="steps",
        eval_steps=50,
        save_strategy="steps",
        save_steps=50,
        load_best_model_at_end=True,
        push_to_hub=False,
    )

    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
    )

    print("\nStarting training...")
    trainer.train()
    
    print("\nSaving model...")
    trainer.save_model("./final_model")
    
    return trainer

# Run training
trainer = train_model()


In [None]:
# Test the model with a sample document
def test_model():
    test_text = """
    SECTION 1. SHORT TITLE.
    This Act may be cited as the "Sample Legal Document Act of 2024".
    
    SECTION 2. PURPOSE.
    The purpose of this Act is to demonstrate the capabilities of the LexiBrief model
    in summarizing legal documents effectively and accurately.
    """
    
    # Prepare input
    prompt = f"Instruction: Summarize the following legal document concisely.\n\nDocument: {test_text}\n\nSummary:"
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Generate summary
    print("Generating summary...")
    outputs = model.generate(
        **inputs,
        max_new_tokens=256,
        temperature=0.7,
        num_return_sequences=1,
        pad_token_id=tokenizer.eos_token_id
    )
    
    # Decode and print summary
    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print("\nTest Results:")
    print("Original Text:")
    print(test_text)
    print("\nGenerated Summary:")
    print(summary)

# Run test
test_model()


In [None]:
# Launch the Gradio interface for interactive testing
def create_demo():
    def summarize(text):
        prompt = f"Instruction: Summarize the following legal document concisely.\n\nDocument: {text}\n\nSummary:"
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        
        outputs = model.generate(
            **inputs,
            max_new_tokens=256,
            temperature=0.7,
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id
        )
        
        summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return summary

    # Create Gradio interface
    demo = gr.Interface(
        fn=summarize,
        inputs=gr.Textbox(lines=10, label="Input Legal Document"),
        outputs=gr.Textbox(label="Generated Summary"),
        title="LexiBrief: Legal Document Summarizer",
        description="Enter a legal document and get a concise summary.",
        examples=[
            ["SECTION 1. SHORT TITLE.\nThis Act may be cited as the 'Sample Legal Document Act of 2024'.\n\nSECTION 2. PURPOSE.\nThe purpose of this Act is to demonstrate the capabilities of the LexiBrief model in summarizing legal documents effectively and accurately."]
        ]
    )
    return demo

# Launch the interface
demo = create_demo()
demo.launch(share=True)
