# LED Legal Summarization with Full Argument Labelling Integration

This notebook implements a complete pipeline for training a LED model on legal summarization data, fully integrating argument labelling to filter and augment inputs for better argumentative coherence in summaries.

## 1. Import Required Libraries

In [1]:
import os
import json
import torch
from transformers import (
    LEDTokenizer, LEDForConditionalGeneration,
    Seq2SeqTrainer, Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq
)
from datasets import load_dataset, Dataset
import evaluate
import numpy as np
from functools import partial
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to C:\Users\Atharva
[nltk_data]     Badgujar\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


## 2. Load and Explore the Dataset

In [None]:
# Paths
DATA_DIR = "."
SRC_DIR = "src"
MODEL_DIR = "models"
CONFIG_FILE = os.path.join(MODEL_DIR, "config.json")

# Load config
def load_cfg():
    if os.path.exists(CONFIG_FILE):
        with open(CONFIG_FILE, 'r') as f:
            return json.load(f)
    return {
        "model": {"checkpoint": "allenai/led-base-16384"},
        "data_schema": {"input_col": "Text", "summary_col": "Summary"},
        "data": {"max_input_length": 16384, "max_output_length": 512}
    }

cfg = load_cfg()

# Load datasets
data_files = {"train": "train (1).csv", "test": "test.csv"}
raw_datasets = load_dataset("csv", data_files=data_files)
print("Dataset loaded:", raw_datasets)
print("Train sample:", raw_datasets['train'][0])

## 3. Preprocess the Data (with Argument Labelling)

In [2]:
# Load argument predictions
preds_path = os.path.join(SRC_DIR, 'argument_classification', 'artifacts', 'legal_bert_predicts.txt')
with open(preds_path, 'r') as f:
    argument_labels = [line.strip() for line in f.readlines()]

print(f"Loaded {len(argument_labels)} argument labels.")

# Function to preprocess with argument labelling
def preprocess_with_arguments(batch, labels, input_col, summary_col, start_idx=0):
    processed_batch = []
    label_idx = start_idx
    for text, summary in zip(batch[input_col], batch[summary_col]):
        sentences = sent_tokenize(text)
        filtered_sentences = []
        for sent in sentences:
            if label_idx < len(labels) and labels[label_idx] in ['Issue', 'Reason', 'Conclusion']:
                filtered_sentences.append(sent)
            label_idx += 1
        # If no argumentative sentences, keep original
        if not filtered_sentences:
            filtered_sentences = sentences[:5]  # Fallback to first 5
        augmented_text = " ".join(filtered_sentences)
        processed_batch.append({input_col: augmented_text, summary_col: summary})
    return processed_batch, label_idx

# Apply to datasets
train_data, train_end_idx = preprocess_with_arguments(raw_datasets['train'], argument_labels, cfg['data_schema']['input_col'], cfg['data_schema']['summary_col'])
test_data, _ = preprocess_with_arguments(raw_datasets['test'], argument_labels, cfg['data_schema']['input_col'], cfg['data_schema']['summary_col'], start_idx=train_end_idx)

train_dataset = Dataset.from_list(train_data)
test_dataset = Dataset.from_list(test_data)

print("Preprocessed train sample:", train_dataset[0])

NameError: name 'SRC_DIR' is not defined

## 4. Tokenize and Prepare Data for Training

In [None]:
# Load tokenizer and model
tokenizer = LEDTokenizer.from_pretrained(cfg['model']['checkpoint'])
model = LEDForConditionalGeneration.from_pretrained(cfg['model']['checkpoint'])

# Tokenize function
def tokenize_function(batch):
    inputs = tokenizer(batch[cfg['data_schema']['input_col']], truncation=True, padding='max_length', max_length=cfg['data']['max_input_length'])
    targets = tokenizer(batch[cfg['data_schema']['summary_col']], truncation=True, padding='max_length', max_length=cfg['data']['max_output_length'])
    inputs['labels'] = targets['input_ids']
    # Global attention for LED
    inputs['global_attention_mask'] = [[1 if i == 0 else 0 for i in range(len(ids))] for ids in inputs['input_ids']]
    return inputs

# Tokenize datasets
tokenized_train = train_dataset.map(tokenize_function, batched=True, remove_columns=[cfg['data_schema']['input_col'], cfg['data_schema']['summary_col']])
tokenized_test = test_dataset.map(tokenize_function, batched=True, remove_columns=[cfg['data_schema']['input_col'], cfg['data_schema']['summary_col']])

# Data collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

## 5. Build and Train the Machine Learning Model

In [None]:
# Metrics
rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = rouge.compute(predictions=decoded_preds, references=decoded_labels)
    return {k: round(v * 100, 4) for k, v in result.items()}

# Training args
training_args = Seq2SeqTrainingArguments(
    output_dir="./led_argument_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=3,
    weight_decay=0.01,
    save_total_limit=2,
    predict_with_generate=True,
    fp16=True,
    load_best_model_at_end=True,
)

# Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# Train
trainer.train()

## 6. Evaluate the Model Performance

In [None]:
# Evaluate
eval_results = trainer.evaluate()
print("Evaluation Results:", eval_results)

## 7. Save and Load the Trained Model

In [None]:
# Save model
trainer.save_model("./led_argument_model_final")
tokenizer.save_pretrained("./led_argument_model_final")

# Load for inference
loaded_model = LEDForConditionalGeneration.from_pretrained("./led_argument_model_final")
loaded_tokenizer = LEDTokenizer.from_pretrained("./led_argument_model_final")

## 8. Check if Retraining is Needed and Retrain if Necessary

In [None]:
# Check performance threshold
rouge_l_score = eval_results.get('eval_rougeL', 0)
threshold = 50.0  # Example threshold

if rouge_l_score < threshold:
    print(f"ROUGE-L score {rouge_l_score} below threshold {threshold}. Retraining...")
    # Retrain with more epochs or different params
    training_args.num_train_epochs = 5
    trainer = Seq2SeqTrainer(
        model=loaded_model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_test,
        tokenizer=loaded_tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )
    trainer.train()
    trainer.save_model("./led_argument_model_retrained")
else:
    print(f"Model performance satisfactory: ROUGE-L {rouge_l_score}")