In [1]:
import os
import re
import torch
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import (
    EncoderDecoderModel,
    BertTokenizer,
    TrainingArguments,
    Trainer,
)
from rouge_score import rouge_scorer
from bert_score import score as bert_score

In [2]:
# Step 1: Load Data
# Paths to the judgment and summary folders
train_judgment_path = '/Users/praneethvarma/Documents/dataset/IN-Abs/train-data/judgement'
train_summary_path = '/Users/praneethvarma/Documents/dataset/IN-Abs/train-data/summary'
test_judgment_path = '/Users/praneethvarma/Documents/dataset/IN-Abs/test-data/judgement'
test_summary_path = '/Users/praneethvarma/Documents/dataset/IN-Abs/test-data/summary'

# Helper function to sort files numerically
def sort_numerically(file_list):
    return sorted(file_list, key=lambda x: int(re.match(r'(\d+)', x).group()))

# Load and process data
def load_data(judgment_path, summary_path):
    judgment_files = sort_numerically([f for f in os.listdir(judgment_path) if f.endswith('.txt')])
    summary_files = sort_numerically([f for f in os.listdir(summary_path) if f.endswith('.txt')])

    judgment_ids = {re.match(r'(\d+)', f).group() for f in judgment_files}
    summary_ids = {re.match(r'(\d+)', f).group() for f in summary_files}
    common_ids = sorted(judgment_ids.intersection(summary_ids), key=int)

    case_ids, judgment_texts, summary_texts = [], [], []
    for case_id in common_ids:
        j_file = f"{case_id}.txt"
        s_file = f"{case_id}.txt"
        with open(os.path.join(judgment_path, j_file), 'r', encoding='utf-8') as j_f, \
             open(os.path.join(summary_path, s_file), 'r', encoding='utf-8') as s_f:
            case_ids.append(case_id)
            judgment_texts.append(j_f.read())
            summary_texts.append(s_f.read())

    return pd.DataFrame({'case_id': case_ids, 'judgment_text': judgment_texts, 'summary_text': summary_texts})

train_data = load_data(train_judgment_path, train_summary_path)
test_data = load_data(test_judgment_path, test_summary_path)

In [3]:
# Step 2: Preprocess Data
def preprocess_data(tokenizer, text, max_length):
    return tokenizer(
        text,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    )

tokenizer = BertTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased')

def prepare_data(data, tokenizer, input_max_length, output_max_length):
    input_ids = torch.cat([preprocess_data(tokenizer, txt, input_max_length)['input_ids'] for txt in data['judgment_text']], dim=0)
    attention_mask = torch.cat([preprocess_data(tokenizer, txt, input_max_length)['attention_mask'] for txt in data['judgment_text']], dim=0)
    labels = torch.cat([preprocess_data(tokenizer, txt, output_max_length)['input_ids'] for txt in data['summary_text']], dim=0)
    return input_ids, attention_mask, labels

input_ids, attention_mask, labels = prepare_data(train_data, tokenizer, 512, 128)

# Split training data
train_inputs, eval_inputs, train_masks, eval_masks, train_labels, eval_labels = train_test_split(
    input_ids, attention_mask, labels, test_size=0.2, random_state=42
)

# Define dataset class
class LegalDataset(torch.utils.data.Dataset):
    def __init__(self, inputs, masks, labels):
        self.inputs = inputs
        self.masks = masks
        self.labels = labels

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return {
            "input_ids": self.inputs[idx],
            "attention_mask": self.masks[idx],
            "labels": self.labels[idx],
        }

train_dataset = LegalDataset(train_inputs, train_masks, train_labels)
eval_dataset = LegalDataset(eval_inputs, eval_masks, eval_labels)

In [4]:
from transformers import EncoderDecoderModel, BertTokenizer, TrainingArguments, Trainer
import torch

# Load Legal BERT base model and tokenizer
tokenizer = BertTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased')
model = EncoderDecoderModel.from_encoder_decoder_pretrained(
    'nlpaueb/legal-bert-base-uncased', 
    'nlpaueb/legal-bert-base-uncased'
)

# Set model configurations
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.eos_token_id = tokenizer.sep_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.no_repeat_ngram_size = 3
model.config.vocab_size = model.config.encoder.vocab_size

# Ensure MPS device compatibility
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
model.to(device)

# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=5,
    learning_rate=5e-5,
    weight_decay=0.01,
    save_steps=500,
    save_total_limit=2,
    logging_dir="./logs",
    report_to="none",  # Disable logging to external services
    no_cuda=True,  # Ensure no CUDA (GPU) is used
    bf16=False,    # Disable mixed precision explicitly
    fp16=False,    # Disable fp16 explicitly
)

# Trainer setup
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,  # Replace with actual training dataset
    eval_dataset=eval_dataset,   # Replace with actual evaluation dataset
    tokenizer=tokenizer,
)

# Train the model
trainer.train()

# Save the fine-tuned model
model.save_pretrained("./fine_tuned_legal_bert")
tokenizer.save_pretrained("./fine_tuned_legal_bert")


Some weights of BertLMHeadModel were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.encoder.layer.1.crossattention.self.key.

Epoch,Training Loss,Validation Loss
1,3.2182,3.003578
2,2.7543,2.825257
3,2.3871,2.786477
4,2.0721,2.808159
5,1.8632,2.879472


  decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id)
  decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id)
  decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id)
  decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id)
  decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id)
  decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id)
  decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id)
  decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id)
  decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id)
  decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids

('./fine_tuned_legal_bert/tokenizer_config.json',
 './fine_tuned_legal_bert/special_tokens_map.json',
 './fine_tuned_legal_bert/vocab.txt',
 './fine_tuned_legal_bert/added_tokens.json')

In [6]:
# Step 4: Generate and Evaluate Summaries
def generate_summaries(model, tokenizer, data, max_input_length, max_output_length):
    summaries = []
    for text in data['judgment_text']:
        inputs = tokenizer(
            text,
            return_tensors="pt",
            max_length=max_input_length,
            truncation=True,
            padding="max_length",
        )
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        output = model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_length=max_output_length,  # Fix applied here
            num_beams=10,
            no_repeat_ngram_size=3,
            early_stopping=True,
        )
        summary = tokenizer.decode(output[0], skip_special_tokens=True)
        summaries.append(summary)
    return summaries

predicted_summaries = generate_summaries(model, tokenizer, test_data, 512, 128)

# Post-process summaries
def post_process(summary):
    sentences = summary.split(". ")
    unique_sentences = []
    seen = set()
    for sent in sentences:
        if sent not in seen:
            seen.add(sent)
            unique_sentences.append(sent)
    return ". ".join(unique_sentences)

cleaned_summaries = [post_process(summary) for summary in predicted_summaries]

# Evaluate using ROUGE and BERTScore
true_summaries = [
    summary for summary in test_data['summary_text'].tolist() if summary.strip()
]
filtered_predictions = [
    pred for pred, ref in zip(cleaned_summaries, test_data['summary_text'].tolist())
    if ref.strip()
]

scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

def evaluate_rouge(predictions, references):
    scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}
    for pred, ref in zip(predictions, references):
        score = scorer.score(ref, pred)
        for key in scores:
            scores[key].append(score[key].fmeasure)
    return {key: sum(vals) / len(vals) for key, vals in scores.items()}

rouge_results = evaluate_rouge(filtered_predictions, true_summaries)
print("ROUGE Evaluation Results:", rouge_results)

P, R, F1 = bert_score(filtered_predictions, true_summaries, lang='en', verbose=True)
print("BERTScore Results:", F1.mean().item())

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


ROUGE Evaluation Results: {'rouge1': 0.20743771525383972, 'rouge2': 0.0778424054499269, 'rougeL': 0.13212417771769677}


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


calculating scores...
computing bert embedding.


  0%|          | 0/4 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/2 [00:00<?, ?it/s]

done in 60.85 seconds, 1.64 sentences/sec
BERTScore Results: 0.8197181820869446
