# Legal Text Classification with BERT

In this notebook, we will build a legal text classification model using BERT. We will preprocess the data, create a custom dataset class with a sliding window approach, and train the model.

In [1]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.utils.class_weight import compute_class_weight
import torch.optim as optim
from tqdm import tqdm

# Load and Preprocess the Data

In [2]:
dataset_path = '/kaggle/input/legal-text-classification-dataset/legal_text_classification.csv'
df = pd.read_csv(dataset_path)

df.dropna(inplace=True)

label_encoder = LabelEncoder()
df['case_outcome_encoded'] = label_encoder.fit_transform(df['case_outcome'])

class_counts = df['case_outcome_encoded'].value_counts()
print("Class distribution in training data:\n", class_counts)

Class distribution in training data:
 case_outcome_encoded
3    12110
8     4363
1     2438
7     2252
4     1699
5     1018
6      603
9      112
2      108
0      106
Name: count, dtype: int64


# Train-Validation Split

In [3]:
train_df, val_df = train_test_split(df, test_size=0.2, stratify=df['case_outcome_encoded'], random_state=42)

majority_class_size = train_df['case_outcome_encoded'].value_counts().max()
train_df_balanced = train_df.groupby('case_outcome_encoded', group_keys=False) \
                            .apply(lambda x: x.sample(majority_class_size, replace=True)).reset_index(drop=True)

balanced_class_counts = train_df_balanced['case_outcome_encoded'].value_counts()
print("Class distribution after oversampling:\n", balanced_class_counts)

Class distribution after oversampling:
 case_outcome_encoded
0    9688
1    9688
2    9688
3    9688
4    9688
5    9688
6    9688
7    9688
8    9688
9    9688
Name: count, dtype: int64


  .apply(lambda x: x.sample(majority_class_size, replace=True)).reset_index(drop=True)


# Custom Dataset Class with Sliding Window

In [4]:
class LegalDataset(Dataset):
    def __init__(self, data, tokenizer, max_length, stride):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.stride = stride

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        case_text = str(self.data.iloc[index]['case_text'])
        case_outcome = self.data.iloc[index]['case_outcome_encoded']

        encoding = self.tokenizer.encode_plus(
            case_text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(case_outcome, dtype=torch.long)
        }

# Model and Tokenizer Initialization

In [5]:
legalbert_model_name = 'nlpaueb/legal-bert-base-uncased'

tokenizer = BertTokenizer.from_pretrained(legalbert_model_name)
model = BertForSequenceClassification.from_pretrained(legalbert_model_name, num_labels=len(label_encoder.classes_))

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/222k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Training Setup

In [6]:
max_length = 256
stride = 64
batch_size = 16
epochs = 6
learning_rate = 2e-5
weight_decay = 0.01

# Create Datasets and DataLoaders

In [7]:
train_dataset = LegalDataset(train_df_balanced, tokenizer, max_length, stride)
val_dataset = LegalDataset(val_df, tokenizer, max_length, stride)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

# Compute Class Weights

In [8]:
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_df_balanced['case_outcome_encoded']), y=train_df_balanced['case_outcome_encoded'])
class_weights = torch.tensor(class_weights, dtype=torch.float)

# Define Loss Function and Optimizer

In [9]:
loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Training Scheduler

In [10]:
total_steps = len(train_loader) * epochs
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, total_iters=total_steps)

# Training Function

In [11]:
def train_epoch(model, data_loader, optimizer, device, scheduler):
    model = model.train()
    losses = []
    correct_predictions = 0

    for d in tqdm(data_loader):
        input_ids = d['input_ids'].to(device)
        attention_mask = d['attention_mask'].to(device)
        labels = d['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        logits = outputs.logits

        loss.backward()
        optimizer.step()
        scheduler.step()

        losses.append(loss.item())
        preds = torch.argmax(logits, dim=1)
        correct_predictions += torch.sum(preds == labels).item()

    return correct_predictions / len(data_loader.dataset), np.mean(losses)

# Evaluation Function

In [12]:
def eval_model(model, data_loader, device):
    model = model.eval()
    losses = []
    correct_predictions = 0
    y_preds = []
    y_true = []

    with torch.no_grad():
        for d in tqdm(data_loader):
            input_ids = d['input_ids'].to(device)
            attention_mask = d['attention_mask'].to(device)
            labels = d['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            logits = outputs.logits

            losses.append(loss.item())
            preds = torch.argmax(logits, dim=1)
            correct_predictions += torch.sum(preds == labels).item()
            y_preds.extend(preds.cpu().numpy())
            y_true.extend(labels.cpu().numpy())

    return correct_predictions / len(data_loader.dataset), np.mean(losses), y_true, y_preds

# Training Loop

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

for epoch in range(epochs):
    print(f'\nEpoch {epoch + 1}/{epochs}')
    train_acc, train_loss = train_epoch(model, train_loader, optimizer, device, scheduler)
    print(f'Train loss: {train_loss:.3f}, accuracy: {train_acc:.3f}')
    val_acc, val_loss, y_true, y_preds = eval_model(model, val_loader, device)
    print(f'Validation loss: {val_loss:.3f}, accuracy: {val_acc:.3f}')


Epoch 1/6


100%|██████████| 6055/6055 [1:14:27<00:00,  1.36it/s]


Train loss: 1.037, accuracy: 0.623


100%|██████████| 311/311 [02:00<00:00,  2.59it/s]


Validation loss: 1.754, accuracy: 0.356

Epoch 2/6


100%|██████████| 6055/6055 [1:14:24<00:00,  1.36it/s]


Train loss: 0.430, accuracy: 0.855


100%|██████████| 311/311 [01:59<00:00,  2.60it/s]


Validation loss: 1.853, accuracy: 0.454

Epoch 3/6


100%|██████████| 6055/6055 [1:14:15<00:00,  1.36it/s]


Train loss: 0.265, accuracy: 0.916


100%|██████████| 311/311 [01:59<00:00,  2.60it/s]


Validation loss: 1.959, accuracy: 0.504

Epoch 4/6


100%|██████████| 6055/6055 [1:14:16<00:00,  1.36it/s]


Train loss: 0.207, accuracy: 0.933


100%|██████████| 311/311 [01:59<00:00,  2.60it/s]


Validation loss: 1.978, accuracy: 0.527

Epoch 5/6


100%|██████████| 6055/6055 [1:14:03<00:00,  1.36it/s]


Train loss: 0.178, accuracy: 0.941


100%|██████████| 311/311 [01:58<00:00,  2.62it/s]


Validation loss: 2.123, accuracy: 0.542

Epoch 6/6


100%|██████████| 6055/6055 [1:14:37<00:00,  1.35it/s]


Train loss: 0.161, accuracy: 0.945


100%|██████████| 311/311 [01:59<00:00,  2.59it/s]

Validation loss: 2.340, accuracy: 0.516





# Classification Report

In [14]:
print(classification_report(y_true, y_preds, target_names=label_encoder.classes_))

               precision    recall  f1-score   support

     affirmed       0.40      0.76      0.52        21
      applied       0.34      0.33      0.34       488
     approved       0.18      0.19      0.19        21
        cited       0.76      0.56      0.65      2422
   considered       0.26      0.42      0.32       340
    discussed       0.36      0.28      0.31       204
distinguished       0.45      0.39      0.42       121
     followed       0.44      0.40      0.42       450
  referred to       0.42      0.66      0.51       873
      related       0.53      0.45      0.49        22

     accuracy                           0.52      4962
    macro avg       0.41      0.44      0.42      4962
 weighted avg       0.56      0.52      0.53      4962

