# Legal Document Classification with BERT - V2 (Full Dataset)

## Part 3: Data Preprocessing and Model Setup

Prepare the text data for BERT by tokenizing and creating PyTorch datasets.

In [None]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from transformers import BertTokenizer, AutoTokenizer
import pickle

# Set random seed for reproducibility
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

In [None]:
# Load the dataset (use the path from Part 1)
dataset_path = '/content/drive/MyDrive/legal_bert_classification_v2/full_bert_dataset.csv'
df = pd.read_csv(dataset_path)

print(f"Dataset shape: {df.shape}")
print(f"Number of unique labels: {df['label'].nunique()}")

In [None]:
# Define dataset class
class LegalDocumentDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        # Handle very long texts by truncating to max_length
        if len(text) > self.max_length * 10:  # Rough character estimate
            text = text[:self.max_length * 10]  # Truncate very long texts early
        
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [None]:
# Encode labels
label_encoder = LabelEncoder()
df['encoded_label'] = label_encoder.fit_transform(df['label'])

# Display label mapping
print("Label mapping:")
for i, label in enumerate(label_encoder.classes_):
    print(f"  {label} -> {i}")

# Save label encoder
with open('/content/drive/MyDrive/legal_bert_classification_v2/label_encoder.pkl', 'wb') as f:
    pickle.dump(label_encoder, f)

In [None]:
# Split dataset with stratification to maintain class balance
train_df, temp_df = train_test_split(
    df, test_size=0.2, random_state=42, stratify=df['encoded_label']
)

# Further split into validation and test sets
val_df, test_df = train_test_split(
    temp_df, test_size=0.5, random_state=42, stratify=temp_df['encoded_label']
)

print(f"Training set size: {len(train_df)}")
print(f"Validation set size: {len(val_df)}")
print(f"Test set size: {len(test_df)}")

# Verify distribution
print("\nLabel distribution in splits:")
print("Training:")
print(train_df['label'].value_counts(normalize=True).sort_index() * 100)
print("\nValidation:")
print(val_df['label'].value_counts(normalize=True).sort_index() * 100)
print("\nTest:")
print(test_df['label'].value_counts(normalize=True).sort_index() * 100)

In [None]:
# Load tokenizer
print("Loading BERT tokenizer...")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Create datasets
train_dataset = LegalDocumentDataset(
    train_df['text'].values,
    train_df['encoded_label'].values,
    tokenizer,
    max_length=512
)

val_dataset = LegalDocumentDataset(
    val_df['text'].values,
    val_df['encoded_label'].values,
    tokenizer,
    max_length=512
)

test_dataset = LegalDocumentDataset(
    test_df['text'].values,
    test_df['encoded_label'].values,
    tokenizer,
    max_length=512
)

In [None]:
# Create data loaders
batch_size = 8  # Smaller batch size for the full dataset

train_loader = DataLoader(
    train_dataset,
    sampler=RandomSampler(train_dataset),
    batch_size=batch_size
)

val_loader = DataLoader(
    val_dataset,
    sampler=SequentialSampler(val_dataset),
    batch_size=batch_size
)

test_loader = DataLoader(
    test_dataset,
    sampler=SequentialSampler(test_dataset),
    batch_size=batch_size
)

print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")
print(f"Number of test batches: {len(test_loader)}")

## Model Setup

Initialize the BERT model for sequence classification.

In [None]:
import torch
from torch.optim import AdamW
from transformers import BertForSequenceClassification, get_linear_schedule_with_warmup
import os

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create output directory in Drive
save_dir = '/content/drive/MyDrive/legal_bert_classification_v2/model'
os.makedirs(save_dir, exist_ok=True)

# Load BERT model with dropout to prevent overfitting
num_labels = len(label_encoder.classes_)  # Number of unique classes
model = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased', 
    num_labels=num_labels,
    hidden_dropout_prob=0.3,  # Increased dropout for regularization
    attention_probs_dropout_prob=0.3  # Increased dropout for regularization
)

# Move model to GPU if available
model.to(device)

print(f"Model loaded with {num_labels} output classes")
print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

In [None]:
# Set up optimizer with weight decay for regularization
optimizer = AdamW(
    model.parameters(),
    lr=2e-5,  # Learning rate
    eps=1e-8,  # Epsilon for numerical stability
    weight_decay=0.01  # Weight decay for regularization
)

In [None]:
# Set up learning rate scheduler with warmup
# Calculate total training steps
epochs = 4
total_steps = len(train_loader) * epochs
warmup_steps = int(0.1 * total_steps)  # 10% of total steps for warmup

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

print(f"Training for {epochs} epochs with {total_steps} total steps")
print(f"Using {warmup_steps} warmup steps")