# Install required libraries (if not already installed)

In [2]:
!pip install transformers datasets torch nltk torchvision bitsandbytes torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install --upgrade accelerate>=0.26.0
!pip install transformers[torch]

Looking in indexes: https://download.pytorch.org/whl/cu121


# Import necessary libraries

In [3]:
import torch
from transformers import BartForConditionalGeneration, BartTokenizer
from datasets import load_dataset
import nltk

# Download sentence tokenizer

In [4]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Shouv\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

# Load the tokenizer and model

In [5]:
model_name = 'facebook/bart-base'
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name, device_map="auto", offload_folder="offload")

print("Model and Tokenizer Loaded Successfully!")

Model and Tokenizer Loaded Successfully!


# Define a Function for Abstractive Summarization

In [6]:
def generate_summary(text, max_input=1024, max_output=200):
    """
    Generate abstractive summary using BART model.
    
    Args:
        text (str): The legal text to summarize.
        max_input (int): Max token length for input text.
        max_output (int): Max token length for summary output.

    Returns:
        str: The generated summary.
    """
    # Tokenize input text
    inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=max_input, truncation=True).to(model.device)
    
    # Generate summary
    summary_ids = model.generate(inputs, max_length=max_output, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)
    
    # Decode and return summary
    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

# Test the function

In [7]:
sample_text = "In a landmark case, the Supreme Court ruled that freedom of speech does not include the right to incite violence. \
This decision overturned previous rulings and set a new precedent in constitutional law."

print("Generated Summary:", generate_summary(sample_text))

Generated Summary: summarize: In a landmark case, the Supreme Court ruled that freedom of speech does not include the right to incite violence. This decision overturned previous rulings and set a new precedent in constitutional law. The Supreme Court affirmed the First Amendment's First Amendment rights.


# Load a legal dataset (example: 'legal_trec' from Hugging Face datasets)

In [8]:
from datasets import load_dataset

dataset = load_dataset("cnn_dailymail", "3.0.0") 
# Rename columns to match "text" and "labels"
dataset = dataset.rename_columns({"article": "text", "highlights": "labels"})
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['text', 'labels', 'id'],
        num_rows: 287113
    })
    validation: Dataset({
        features: ['text', 'labels', 'id'],
        num_rows: 13368
    })
    test: Dataset({
        features: ['text', 'labels', 'id'],
        num_rows: 11490
    })
})


# Extract legal texts and their summaries

In [9]:
train_texts = dataset["train"]["text"][:1000]
train_summaries = dataset["test"]["text"][:1000]

print(f"Loaded {len(train_texts)} legal documents for training.")

Loaded 1000 legal documents for training.


In [10]:
print(f"CUDA version: {torch.version.cuda}")

CUDA version: 12.1


# Fine-Tune BART on Legal Documents

In [16]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq, AutoModel
import torch
import time
from torch.utils.data import DataLoader
torch._dynamo.config.suppress_errors = True

def preprocess_data(examples):
    # Tokenize inputs
    model_inputs = tokenizer(
        examples["text"],
        max_length=1024,
        truncation=True,
        padding="max_length"  # Changed from False to max_length
    )
    
    # Tokenize targets with the tokenizer
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples["labels"],
            max_length=200,
            truncation=True,
            padding="max_length"  # Changed from False to max_length
        )
    
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs
    
    
# Tokenize dataset
tokenized_dataset = dataset.map(
    preprocess_data,
    batched=True,
    remove_columns=["text", "labels"],
    load_from_cache_file=True
)

# Training Arguments Optimized for Your Setup
training_args = Seq2SeqTrainingArguments(
    output_dir="./bart_legal_summarizer",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=600,
    per_device_eval_batch_size=100,
    gradient_accumulation_steps=8,
    generation_max_length=128,
    generation_num_beams=1,
    num_train_epochs=50,
    learning_rate=3e-5,
    weight_decay=0.01,
    fp16=True,
    optim="adamw_bnb_8bit",
    gradient_checkpointing=True,
    dataloader_pin_memory=True,
    dataloader_num_workers=12,
    logging_steps=50,
    save_total_limit=2,
    predict_with_generate=True,
    remove_unused_columns=False  # Add this line
)

# Dynamic Padding Collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True  # Enable dynamic padding
)

# Enable Flash Attention & Memory-Efficient Training
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.cuda.empty_cache()

# Optionally, create a custom DataLoader for training if you want to experiment with prefetch_factor
train_dataloader = DataLoader(
    tokenized_dataset["train"],
    batch_size=training_args.per_device_train_batch_size,
    num_workers=12,
    prefetch_factor=4,
    pin_memory=True,
    collate_fn=data_collator  # Use the data collator for dynamic padding
)
# Create Trainer using the custom DataLoader for training (if needed)
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer
)

# Start Training
trainer.train()

AttributeError: 'function' object has no attribute 'config'