In [None]:
# LexiBrief: AI-Powered Legal Document Summarizer

# This Colab notebook sets up and runs the LexiBrief project, which uses the Mistral-7B-Instruct model for legal document summarization. The notebook will:

# 1. Set up the environment and dependencies
# 2. Clone the project from GitHub
# 3. Configure the model for CPU/GPU training
# 4. Train and evaluate the model
# 5. Launch the Gradio interface for testing


In [None]:
# Set up project structure and paths
import os
import sys
from pathlib import Path

# Check if we're in Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Clone the repository if it doesn't exist
    if not os.path.exists('LexiBrief'):
        !git clone https://github.com/yourusername/LexiBrief.git
    
    # Change to the project directory
    %cd LexiBrief

# Define project paths
PROJECT_ROOT = Path(os.getcwd())
PATHS = {
    'configs': PROJECT_ROOT / 'configs',
    'data': PROJECT_ROOT / 'data',
    'models': PROJECT_ROOT / 'models',
    'outputs': PROJECT_ROOT / 'outputs',
    'logs': PROJECT_ROOT / 'logs',
    'final_model': PROJECT_ROOT / 'models' / 'final_model',
    'results': PROJECT_ROOT / 'outputs' / 'results'
}

# Create all necessary directories
for path in PATHS.values():
    path.mkdir(parents=True, exist_ok=True)

print("\nProject structure created:")
for name, path in PATHS.items():
    print(f"{name}: {path}")

if IN_COLAB:
    # Install project requirements
    if os.path.exists('requirements.txt'):
        !pip install -r requirements.txt
    print("\nProject dependencies installed")


In [None]:
# Clone and set up the project repository
import os

# Check if we're in Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Clone the repository
    !git clone https://github.com/yourusername/LexiBrief.git
    
    # Change to the project directory
    %cd LexiBrief
    
    # Create necessary directories if they don't exist
    os.makedirs("configs", exist_ok=True)
    os.makedirs("data", exist_ok=True)
    os.makedirs("models", exist_ok=True)
    os.makedirs("outputs", exist_ok=True)
    os.makedirs("logs", exist_ok=True)
    
    print("\nProject structure created:")
    !ls -la

    # Install project requirements
    !pip install -r requirements.txt


In [None]:
# Set up the Colab environment
import sys
import os
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Setting up Colab environment...")

    # Clean environment first
    !pip uninstall -y numpy transformers torch torchvision torchaudio datasets accelerate -q
    
    # Install numpy 1.26.4 first
    !pip install -q numpy==1.26.4
    
    # Install PyTorch and related packages
    !pip install -q torch torchvision torchaudio
    
    # Install transformers and its dependencies
    !pip install -q transformers==4.40.2
    !pip install -q accelerate==0.26.0
    !pip install -q datasets
    !pip install -q peft==0.10.0
    
    # Install other dependencies
    !pip install -q "websockets>=10.0,<12.0"
    !pip install -q evaluate rouge_score nltk
    !pip install -q wandb python-dotenv requests PyYAML scipy sentencepiece
    !pip install -q gradio==3.40.1
    
    # Force reinstall numpy to ensure version
    !pip install -q --force-reinstall numpy==1.26.4

    # Confirm GPU
    print("\nVerifying GPU setup...")
    import torch
    if torch.cuda.is_available():
        print(f"GPU Device: {torch.cuda.get_device_name(0)}")
        print(f"Number of GPUs: {torch.cuda.device_count()}")
    else:
        print("No GPU found. Using CPU.")

    torch.cuda.empty_cache()

In [None]:
# Import required libraries
import os
import sys
import yaml
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
from peft import LoraConfig, get_peft_model
import gradio as gr

# Add the project root to Python path
sys.path.append('.')

In [None]:
# Update configuration with project paths
def update_paths_in_config():
    global config
    
    # Update training paths
    config['training'].update({
        'output_dir': str(PATHS['results']),
        'logging_dir': str(PATHS['logs']),
        'model_save_dir': str(PATHS['final_model'])
    })
    
    # Save updated config
    config_path = PATHS['configs'] / 'training_config.yaml'
    with open(config_path, 'w') as f:
        yaml.dump(config, f, default_flow_style=False)
    
    print(f"Updated configuration saved to {config_path}")
    print("\nUpdated paths:")
    print(f"Output directory: {PATHS['results']}")
    print(f"Model save directory: {PATHS['final_model']}")
    print(f"Logs directory: {PATHS['logs']}")

# Update paths in configuration
update_paths_in_config()


In [None]:
# Configure training settings for T4 GPU
def setup_training_config():
    # T4 GPU has 15GB VRAM, we'll reserve 1GB for system
    batch_size = 8  # Increased batch size for faster training
    grad_accum = 8  # Reduced accumulation steps but maintain effective batch size of 64
    
    # Calculate optimal sequence lengths based on dataset
    max_input_length = 384  # Reduced from 512 for faster processing
    max_output_length = 96  # Reduced summary length for faster training
    
    config_updates = {
        'model': {
            'name': 'google/flan-t5-base',
            'max_length': max_input_length,
            'use_flash_attention': False  # FLAN-T5 doesn't support flash attention
        },
        'training': {
            'per_device_train_batch_size': batch_size,
            'per_device_eval_batch_size': batch_size * 2,  # Larger eval batch size for speed
            'gradient_accumulation_steps': grad_accum,
            'num_train_epochs': 2,  # Two epochs for better learning while keeping training time reasonable
            'learning_rate': 2e-3,  # Higher learning rate for fast convergence but still stable learning
            'max_grad_norm': 2.0,  # Allow larger gradients
            'warmup_ratio': 0.01,  # Minimal warmup
            'optim': 'adamw_torch',
            'bf16': True,  # Use bfloat16 for better stability
            'fp16': False,
            'gradient_checkpointing': True,
            'generation_max_length': max_output_length,
            'predict_with_generate': True,
            'gradient_checkpointing_kwargs': {"use_reentrant": False},
            'logging_steps': 50,  # Reduced logging frequency
            'eval_steps': 250,  # Less frequent evaluation
            'save_steps': 250,  # Less frequent saving
            'save_total_limit': 2,  # Keep only the last 2 checkpoints
            'dataloader_num_workers': 4,  # Parallel data loading
            'group_by_length': True,  # Reduces padding, increases speed
        },
        'lora': {
            'r': 8,  # Reduced rank to save memory
            'lora_alpha': 16,
            'target_modules': ["q", "k", "v", "o"],  # T5 attention module names
            'bias': "none",
            'task_type': "SEQ_2_SEQ_LM",  # T5 is a seq2seq model
            'inference_mode': False
        },
        'hardware': {
            'mixed_precision': 'bf16',  # Match model's bfloat16 setting
            'device_type': 'gpu',
            'device': 'cuda',
            'max_memory': {0: "14GB"},  # Reserve 1GB for system
            'device_map': {"": 0},  # Map everything to GPU 0
            'pin_memory': True,  # Pin memory for faster GPU transfer
            'non_blocking': True  # Non-blocking GPU transfers
        }
    }

    # Create configs directory if it doesn't exist
    os.makedirs('configs', exist_ok=True)
    
    # Save config
    with open('configs/training_config.yaml', 'w') as f:
        yaml.dump(config_updates, f, default_flow_style=False)
    
    return config_updates

# Set up the configuration
config = setup_training_config()
print("\nTraining Configuration:")
print(yaml.dump(config, default_flow_style=False))


In [None]:
# Login to Hugging Face
from huggingface_hub import login
login()
#isithf_nkIwjkghPZIhMVdirYaTxgJQqUEGTpGVzGvJ

In [None]:
!pip install -U datasets

In [None]:
# Load and prepare the dataset
from datasets import load_dataset
from transformers import AutoTokenizer
import torch

# Initialize tokenizer globally (will be used by other functions)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")

def prepare_training_data():
    print("Loading and trimming billsum dataset...")
    dataset = load_dataset("billsum", download_mode="force_redownload", keep_in_memory=True)  # Keep in memory for faster processing
    
    # Use a moderate dataset size for better training while maintaining speed
    max_train_samples = 2000  # Using more samples for better generalization
    max_eval_samples = 400   # Increased proportionally with training samples
    
    # Filter out very long documents to speed up training
    dataset = dataset.filter(
        lambda x: len(x['text'].split()) < 1000,  # Only shorter documents
        num_proc=4
    )
    
    # Take subset of dataset
    if len(dataset['train']) > max_train_samples:
        dataset['train'] = dataset['train'].select(range(max_train_samples))
    if len(dataset['test']) > max_eval_samples:
        dataset['test'] = dataset['test'].select(range(max_eval_samples))
    
    def preprocess_function(examples):
        # Format inputs for FLAN-T5 with task prefix
        inputs = [
            f"Summarize the following legal document: {text}"
            for text in examples['text']
        ]

        # Tokenize inputs
        model_inputs = tokenizer(
            inputs,
            max_length=config['model']['max_length'],
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        # Tokenize summaries
        labels = tokenizer(
            examples['summary'],
            max_length=config['training']['generation_max_length'],
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        # Set up the labels
        model_inputs['labels'] = labels['input_ids']
        
        # Replace padding token id with -100 for loss calculation
        model_inputs['labels'][labels['input_ids'] == tokenizer.pad_token_id] = -100

        return model_inputs

    # Process datasets
    train_dataset = dataset['train'].map(
        preprocess_function,
        batched=True,
        batch_size=16,  # Smaller batch size for GPU memory
        remove_columns=dataset['train'].column_names,
        num_proc=4  # Use multiple processes for CPU preprocessing
    )

    eval_dataset = dataset['test'].map(
        preprocess_function,
        batched=True,
        batch_size=16,
        remove_columns=dataset['test'].column_names,
        num_proc=4
    )

    print(f"\nDataset Statistics:")
    print(f"Training examples: {len(train_dataset)}")
    print(f"Evaluation examples: {len(eval_dataset)}")
    print(f"Input sequence length: {config['model']['max_length']}")
    print(f"Output sequence length: {config['model']['max_length'] // 4}")

    return train_dataset, eval_dataset, tokenizer

# Prepare the datasets
train_dataset, eval_dataset, tokenizer = prepare_training_data()

In [None]:
# Initialize and prepare the model for training
from transformers import AutoModelForSeq2SeqLM
from peft import LoraConfig, get_peft_model
import torch

def setup_model():
    print("Loading FLAN-T5 base model...")
    
    if not torch.cuda.is_available():
        raise RuntimeError("This code requires a GPU to run")
    
    # Load model with T4 GPU optimizations
    model = AutoModelForSeq2SeqLM.from_pretrained(
        "google/flan-t5-base",
        torch_dtype=torch.bfloat16,  # Use bfloat16 instead of float16
        device_map={"": 0},  # Map everything to GPU 0, no auto-offloading
        max_memory=config['hardware']['max_memory'],
        load_in_8bit=False,  # We're using bf16, not 8-bit quantization
        use_cache=False  # Disable KV cache for training
    )

    # Enable gradient checkpointing for memory efficiency
    model.gradient_checkpointing_enable()

    # Apply LoRA configuration
    peft_config = LoraConfig(
        r=config['lora']['r'],
        lora_alpha=config['lora']['lora_alpha'],
        target_modules=config['lora']['target_modules'],
        bias=config['lora']['bias'],
        task_type=config['lora']['task_type'],
        inference_mode=False
    )

    model = get_peft_model(model, peft_config)

    # Verify model is on GPU
    if next(model.parameters()).device.type != 'cuda':
        raise RuntimeError("Model failed to load on GPU")

    # Print model statistics
    model.print_trainable_parameters()
    print(f"\nModel loaded on GPU with {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB total VRAM")
    print(f"Current GPU memory used: {torch.cuda.memory_allocated()/1e9:.1f}GB")
    
    return model  # Return only the model   

In [None]:
# Set up the model
model = setup_model()

In [None]:
# Set up and run training
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers.trainer_utils import set_seed
import evaluate
import numpy as np
import nltk

# Download required NLTK data
nltk.download('punkt')
nltk.download('punkt_tab')
nltk.download('tokenizers/punkt/english.pickle', quiet=True)

# Verify NLTK data is downloaded
try:
    nltk.data.find('tokenizers/punkt')
    print("NLTK punkt tokenizer data successfully loaded")
except LookupError:
    raise RuntimeError("Failed to download NLTK data. Please try running the cell again.")

def train_model():
    # Set random seed for reproducibility
    set_seed(42)
    
    # Load ROUGE metric for evaluation
    rouge = evaluate.load('rouge')
    
    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        # Decode predictions
        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        
        # Replace -100 in the labels as we can't decode them
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        
        # ROUGE expects newlines after each sentence
        decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
        decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
        
        # Compute ROUGE scores
        result = rouge.compute(
            predictions=decoded_preds, 
            references=decoded_labels, 
            use_stemmer=True
        )
        
        # Extract scores
        result = {key: value * 100 for key, value in result.items()}
        return {k: round(v, 4) for k, v in result.items()}

    # Set up training arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir=str(PATHS['results']),
        num_train_epochs=config['training']['num_train_epochs'],
        per_device_train_batch_size=config['training']['per_device_train_batch_size'],
        per_device_eval_batch_size=config['training']['per_device_eval_batch_size'],
        gradient_accumulation_steps=config['training']['gradient_accumulation_steps'],
        learning_rate=config['training']['learning_rate'],
        max_grad_norm=config['training']['max_grad_norm'],
        warmup_ratio=config['training']['warmup_ratio'],
        logging_steps=10,
        evaluation_strategy="steps",
        eval_steps=50,
        save_strategy="steps",
        save_steps=50,
        load_best_model_at_end=True,
        metric_for_best_model="rouge2",
        greater_is_better=True,
        push_to_hub=False,
        
        # GPU optimizations
        bf16=config['training']['bf16'],  # Use bfloat16
        fp16=False,  # Disable fp16 since we're using bf16
        gradient_checkpointing=config['training']['gradient_checkpointing'],
        dataloader_pin_memory=True,  # Pin memory for faster GPU transfer
        dataloader_num_workers=4,  # Parallel data loading
        group_by_length=True,  # Reduce padding, optimize GPU memory
        
        # Generation settings
        predict_with_generate=True,
        generation_max_length=config['training']['generation_max_length'],
        generation_num_beams=4
    )

    # Initialize Seq2SeqTrainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
        data_collator=None  # Use default seq2seq collator
    )

    # Print GPU memory usage before training
    print(f"\nGPU memory before training: {torch.cuda.memory_allocated()/1e9:.2f}GB")
    
    print("\nStarting training...")
    trainer.train()

    print("\nSaving model...")
    
    
    # Save the model and verify
    model_save_path = str(PATHS['final_model'])
    trainer.save_model(model_save_path)
    
    # Verify files were saved
    print("\nVerifying saved files:")
    if os.path.exists(model_save_path):
        print(f"Files in {model_save_path}:")
        for file in os.listdir(model_save_path):
            file_size = os.path.getsize(os.path.join(model_save_path, file)) / (1024 * 1024)  # Size in MB
            print(f"- {file} ({file_size:.2f} MB)")
    else:
        print("Warning: Model directory not found!")
    
    # Print final GPU memory usage
    print(f"\nFinal GPU memory usage: {torch.cuda.memory_allocated()/1e9:.2f}GB")
    
    # Force save tokenizer and config
    tokenizer.save_pretrained(model_save_path)
    model.config.save_pretrained(model_save_path)

    return trainer

# Run training
trainer = train_model()

In [None]:
# Verify saved model files and structure
import os
import json

def verify_model_files():
    # Check main directories
    directories = [str(PATHS['final_model']), str(PATHS['results']), str(PATHS['configs'])]
    for dir_path in directories:
        if os.path.exists(dir_path):
            print(f"\n{dir_path}/ directory:")
            files = os.listdir(dir_path)
            for file in files:
                file_size = os.path.getsize(os.path.join(dir_path, file)) / (1024 * 1024)  # Size in MB
                print(f"- {file} ({file_size:.2f} MB)")
        else:
            print(f"\nWarning: {dir_path}/ directory not found!")
    
    # Check if model files exist
    required_files = [
        './final_model/adapter_config.json',
        './final_model/adapter_model.bin',
        './final_model/config.json',
        './final_model/training_config.json',
        './final_model/README.md'
    ]
    
    print("\nChecking required model files:")
    for file_path in required_files:
        if os.path.exists(file_path):
            file_size = os.path.getsize(file_path) / (1024 * 1024)  # Size in MB
            print(f"✓ {file_path} ({file_size:.2f} MB)")
        else:
            print(f"✗ Missing: {file_path}")
    
    # Try to load and verify training config
    try:
        with open('./final_model/training_config.json', 'r') as f:
            training_config = json.load(f)
            print("\nTraining config verified ✓")
    except FileNotFoundError:
        print("\nWarning: training_config.json not found!")
    except json.JSONDecodeError:
        print("\nWarning: training_config.json is not valid JSON!")

# Run verification
verify_model_files()


In [None]:
# Prepare model for publishing
def prepare_model_for_publishing():
    import json
    from datetime import datetime
    
    # 1. Create model card
    model_card = f"""---
language:
- en
tags:
- legal-documents
- summarization
- flan-t5
- lora
license: apache-2.0
datasets:
- billsum
metrics:
- rouge
model-index:
- name: LexiBrief-FLAN-T5-Legal-Summarizer
  results:
  - task:
      type: summarization
      name: Legal Document Summarization
    dataset:
      type: billsum
      name: BillSum
      split: test
      revision: None
    metrics:
    - type: rouge
      value: {trainer.state.best_metric:.4f}
      name: ROUGE-2
---

# LexiBrief: FLAN-T5 Legal Document Summarizer

This model is a fine-tuned version of [google/flan-t5-base](https://huggingface.co/google/flan-t5-base) optimized for legal document summarization using LoRA.

## Model Details

- **Base Model**: FLAN-T5-base
- **Task**: Legal Document Summarization
- **Training Data**: BillSum dataset (filtered to {max_train_samples} samples)
- **Training Time**: {datetime.now().strftime("%Y-%m-%d")}
- **Framework**: PyTorch with 🤗 Transformers and PEFT

## Training Procedure

- **LoRA Configuration**:
  - Rank: {config['lora']['r']}
  - Alpha: {config['lora']['lora_alpha']}
  - Target Modules: {config['lora']['target_modules']}

- **Training Hyperparameters**:
  - Learning Rate: {config['training']['learning_rate']}
  - Batch Size: {config['training']['per_device_train_batch_size']}
  - Gradient Accumulation: {config['training']['gradient_accumulation_steps']}
  - Mixed Precision: bfloat16

## Evaluation Results

The model was evaluated on the BillSum test set ({max_eval_samples} samples):

- ROUGE-2 Score: {trainer.state.best_metric:.4f}

## Usage

```python
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel, PeftConfig

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("your-username/lexibrief-legal-summarizer")
model = AutoModelForSeq2SeqLM.from_pretrained("your-username/lexibrief-legal-summarizer")

# Example usage
text = "Your legal document here..."
inputs = tokenizer(f"Summarize the following legal document: {text}", return_tensors="pt", max_length=384, truncation=True)
outputs = model.generate(**inputs, max_length=96, min_length=30, num_beams=4)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
```

## Limitations

- The model was trained on a subset of the BillSum dataset
- Best suited for documents under 1000 words
- May not perform as well on non-legislative legal documents

## Citation

If you use this model, please cite:

```bibtex
@misc{lexibrief2024,
  title={LexiBrief: FLAN-T5 Legal Document Summarizer},
  author={Your Name},
  year={2024},
  publisher={Hugging Face},
  url={https://huggingface.co/your-username/lexibrief-legal-summarizer}
}
```
"""

    # Save model card
    model_card_path = PATHS['final_model'] / 'README.md'
    with open(model_card_path, 'w') as f:
        f.write(model_card)
    
    # Save training config
    config_path = PATHS['final_model'] / 'training_config.json'
    with open(config_path, 'w') as f:
        json.dump(config, f, indent=2)
    
    # Save evaluation results
    eval_results = {
        'best_rouge2': trainer.state.best_metric,
        'training_steps': trainer.state.global_step,
        'eval_samples': max_eval_samples,
        'train_samples': max_train_samples
    }
    eval_results_path = PATHS['final_model'] / 'eval_results.json'
    with open(eval_results_path, 'w') as f:
        json.dump(eval_results, f, indent=2)
    
    print(f"\nModel artifacts saved in {PATHS['final_model']}:")
    print("- Model weights and configuration")
    print("- README.md (model card)")
    print("- training_config.json")
    print("- eval_results.json")

# Generate publishing materials
prepare_model_for_publishing()


In [None]:
# Test the model with a sample document
def test_model():
    test_text = """
    SECTION 1. SHORT TITLE.
    This Act may be cited as the "Sample Legal Document Act of 2024".

    SECTION 2. PURPOSE.
    The purpose of this Act is to demonstrate the capabilities of the LexiBrief model
    in summarizing legal documents effectively and accurately.
    """

    # Prepare input for T5
    input_text = f"summarize legal document: {test_text}"
    inputs = tokenizer(
        input_text, 
        return_tensors="pt",
        max_length=config['model']['max_length'],
        padding=True,
        truncation=True
    )
    # Move inputs to GPU
    inputs = {k: v.cuda() for k, v in inputs.items()}

    # Generate summary
    print("Generating summary...")
    with torch.cuda.amp.autocast():  # Use mixed precision for inference
        outputs = model.generate(
            **inputs,
            max_length=config['model']['max_length'] // 4,
            min_length=30,  # Ensure summary isn't too short
            num_beams=4,
            length_penalty=2.0,  # Encourage longer summaries
            temperature=0.7,
            no_repeat_ngram_size=3,  # Avoid repetition
            early_stopping=True
        )

    # Decode and print summary
    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print("\nTest Results:")
    print("Original Text:")
    print(test_text)
    print("\nGenerated Summary:")
    print(summary)

    # Print memory usage
    print(f"\nCurrent GPU memory: {torch.cuda.memory_allocated()/1e9:.2f}GB")
    
    # Clear GPU cache
    torch.cuda.empty_cache()

# Run test
test_model()


In [None]:
# Launch the Gradio interface for interactive testing
def create_demo():
    def summarize(text):
        # Prepare input for T5
        input_text = f"summarize legal document: {text}"
        inputs = tokenizer(
            input_text,
            return_tensors="pt",
            max_length=config['model']['max_length'],
            padding=True,
            truncation=True
        )
        # Move inputs to GPU
        inputs = {k: v.cuda() for k, v in inputs.items()}

        # Generate summary with beam search
        with torch.cuda.amp.autocast():  # Use mixed precision for inference
            outputs = model.generate(
                **inputs,
                max_length=config['model']['max_length'] // 4,
                min_length=30,
                num_beams=4,
                length_penalty=2.0,
                temperature=0.7,
                no_repeat_ngram_size=3,
                early_stopping=True
            )

        # Decode summary
        summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Clear GPU cache after generation
        torch.cuda.empty_cache()
        
        return summary

    # Create Gradio interface with improved styling
    demo = gr.Interface(
        fn=summarize,
        inputs=gr.Textbox(
            lines=10,
            label="Input Legal Document",
            placeholder="Paste your legal document here...",
            elem_id="input-box"
        ),
        outputs=gr.Textbox(
            label="Generated Summary",
            elem_id="output-box"
        ),
        title="LexiBrief: Legal Document Summarizer (FLAN-T5)",
        description="""
        This tool uses the FLAN-T5 model fine-tuned on legal documents to generate concise summaries.
        Enter a legal document and get a clear, accurate summary optimized for legal text.
        """,
        examples=[
            ["SECTION 1. SHORT TITLE.\nThis Act may be cited as the 'Sample Legal Document Act of 2024'.\n\nSECTION 2. PURPOSE.\nThe purpose of this Act is to demonstrate the capabilities of the LexiBrief model in summarizing legal documents effectively and accurately."]
        ],
        theme="default",
        css="""
        #input-box { min-height: 200px; }
        #output-box { min-height: 100px; }
        """
    )
    return demo

# Launch the interface
demo = create_demo()
demo.launch(share=True, enable_queue=True)
