In [8]:
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, TrainingArguments, Trainer
from torch.utils.data import Dataset, DataLoader

# Define a custom dataset class to load the data
class LegalDataset(Dataset):
    def __init__(self, texts, summaries, tokenizer, max_input_length=512, max_target_length=128):
        self.texts = texts
        self.summaries = summaries
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        input_text = self.texts[idx]
        target_text = self.summaries[idx]

        # Tokenize input and target text
        input_tokens = self.tokenizer.encode_plus(
            input_text,
            max_length=self.max_input_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
            add_special_tokens=True
        )
        target_tokens = self.tokenizer.encode_plus(
            target_text,
            max_length=self.max_target_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
            add_special_tokens=True
        )

        return {
            'input_ids': input_tokens['input_ids'].flatten(),
            'attention_mask': input_tokens['attention_mask'].flatten(),
            'labels': target_tokens['input_ids'].flatten(),
            'labels_attention_mask': target_tokens['attention_mask'].flatten()
        }
# Load the pre-trained DistilBERT tokenizer and model
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')

# Load sample input, preprocessed input, and summary from files
sample_input_path = "/Users/abhignanarcot/Downloads/11.txt"
sample_preproc_input_path = "/Users/abhignanarcot/Downloads/11.txt"
sample_summary_path =  "/Users/abhignanarcot/Downloads/11.txt"

with open(sample_input_path, 'r') as file:
    sample_input = file.read()

with open(sample_preproc_input_path, 'r') as file:
    sample_preproc_input = file.read()

with open(sample_summary_path, 'r') as file:
    actual_summary = file.read()

# Create a list of texts and summaries
texts = [sample_preproc_input]
summaries = [actual_summary]

# Create an instance of the custom dataset class
dataset = LegalDataset(texts, summaries, tokenizer)

# Create a data loader with a batch size of 1
batch_size = 1
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define training arguments with reduced batch size and sequence length
training_args = TrainingArguments(
    output_dir='./output',  # Specify the directory where model checkpoints will be saved
    per_device_train_batch_size=1,  # Reduced batch size
    per_device_eval_batch_size=1,   # Reduced batch size
    evaluation_strategy='epoch',
    num_train_epochs=3,
    logging_dir='./logs',
    save_steps=100,
    save_total_limit=2,
)

# Define the trainer with the modified training arguments
trainer = Trainer(
    model=model,
   args=training_args,
    train_dataset=data_loader,
)

# Fine-tune the DistilBERT model on the legal text summarization task
try:
    trainer.train()

    example_legal_document = "/Users/abhignanarcot/Downloads/11.txt"
    encoded_input = tokenizer(example_legal_document, padding=True, truncation=True, return_tensors='pt')
    output = model.generate(**encoded_input)
    summary = tokenizer.decode(output[0], skip_special_tokens=True)
    print("Generated Summary:", summary)
except Exception as e:
    print("An error occurred:", str(e))

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ImportError: Using the `Trainer` with `PyTorch` requires `accelerate>=0.21.0`: Please run `pip install transformers[torch]` or `pip install accelerate -U`