# Legal Document Classification with BERT

## Part 5: Data Preprocessing

Prepare the text data for BERT by tokenizing and creating PyTorch datasets.

In [None]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from transformers import BertTokenizer, AutoTokenizer
import pickle

# Set random seed for reproducibility
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

In [None]:
# Define dataset class
class LegalDocumentDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [None]:
# Subsample the dataset to reduce training time (optional)
# Uncomment if you want to use a smaller dataset
sample_size = 20000
if len(df) > sample_size:
    df = df.groupby('label', group_keys=False).apply(
        lambda x: x.sample(min(len(x), int(sample_size * len(x) / len(df))))
    )
    print(f"Subsampled dataset to {len(df)} examples")

In [None]:
# Encode labels
label_encoder = LabelEncoder()
df['encoded_label'] = label_encoder.fit_transform(df['label'])

# Display label mapping
print("Label mapping:")
for i, label in enumerate(label_encoder.classes_):
    print(f"  {label} -> {i}")

# Save label encoder
with open('/content/drive/MyDrive/legal_bert_classification/label_encoder.pkl', 'wb') as f:
    pickle.dump(label_encoder, f)

In [None]:
# Split dataset
train_df, val_df = train_test_split(
    df, test_size=0.1, random_state=42, stratify=df['encoded_label']
)

print(f"Training set size: {len(train_df)}")
print(f"Validation set size: {len(val_df)}")

In [None]:
# Load tokenizer
print("Loading BERT tokenizer...")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Alternative option: use a legal domain BERT model
# tokenizer = AutoTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased')

In [None]:
# Create datasets
train_dataset = LegalDocumentDataset(
    train_df['text'].values,
    train_df['encoded_label'].values,
    tokenizer,
    max_length=512
)

val_dataset = LegalDocumentDataset(
    val_df['text'].values,
    val_df['encoded_label'].values,
    tokenizer,
    max_length=512
)

In [None]:
# Create data loaders
batch_size = 16  # Adjust based on your GPU memory

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    sampler=RandomSampler(train_dataset),
    num_workers=2
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    sampler=SequentialSampler(val_dataset),
    num_workers=2
)

In [None]:
# Inspect a batch
batch = next(iter(train_loader))
print("Sample batch inspection:")
print(f"Input IDs shape: {batch['input_ids'].shape}")
print(f"Attention mask shape: {batch['attention_mask'].shape}")
print(f"Labels shape: {batch['labels'].shape}")

# Sample text decoding
sample_idx = 0
print("\nDecoded sample text (truncated):")
print(tokenizer.decode(batch['input_ids'][sample_idx])[:100] + "...")

## Part 6: Model Training

Train the BERT model on our legal document dataset.

In [None]:
import torch
from torch.optim import AdamW
from transformers import BertForSequenceClassification, get_linear_schedule_with_warmup
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import time
import matplotlib.pyplot as plt
import os

In [None]:
# Define training function
def train_model(model, train_loader, val_loader, device, epochs=4, 
                learning_rate=2e-5, warmup_steps=0, weight_decay=0.01,
                save_dir='/content/drive/MyDrive/legal_bert_classification'):
    
    # Prepare optimizer and scheduler
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {
            'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            'weight_decay': weight_decay
        },
        {
            'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            'weight_decay': 0.0
        }
    ]
    
    optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
    
    # Calculate total training steps
    total_steps = len(train_loader) * epochs
    
    # Create scheduler
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )
    
    # Mixed precision training
    scaler = GradScaler()
    
    # Track training metrics
    train_losses = []
    val_losses = []
    val_accuracies = []
    best_val_accuracy = 0
    best_model_path = os.path.join(save_dir, "best_model.pt")
    
    # Training loop
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        epoch_start_time = time.time()
        
        # Training phase
        model.train()
        total_train_loss = 0
        
        # Progress bar for training
        progress_bar = tqdm(train_loader, desc=f"Training Epoch {epoch+1}")
        
        for batch in progress_bar:
            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            # Clear gradients
            optimizer.zero_grad()
            
            # Forward pass with mixed precision
            with autocast():
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                
                loss = outputs.loss
            
            # Backward pass with gradient scaling
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            
            # Update metrics
            total_train_loss += loss.item()
            
            # Update progress bar
            progress_bar.set_postfix({'loss': loss.item()})
        
        # Calculate average loss
        avg_train_loss = total_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        
        # Validation phase
        model.eval()
        total_val_loss = 0
        all_preds = []
        all_labels = []
        
        # Progress bar for validation
        progress_bar = tqdm(val_loader, desc=f"Validating Epoch {epoch+1}")
        
        with torch.no_grad():
            for batch in progress_bar:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                
                # Forward pass
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                
                loss = outputs.loss
                total_val_loss += loss.item()
                
                # Get predictions
                logits = outputs.logits
                _, preds = torch.max(logits, dim=1)
                
                all_preds.extend(preds.cpu().tolist())
                all_labels.extend(labels.cpu().tolist())
        
        # Calculate metrics
        avg_val_loss = total_val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        
        # Calculate accuracy
        accuracy = sum(1 for p, l in zip(all_preds, all_labels) if p == l) / len(all_preds)
        val_accuracies.append(accuracy)
        
        # Print epoch summary
        epoch_time = time.time() - epoch_start_time
        print(f"Epoch {epoch+1} completed in {epoch_time:.2f}s")
        print(f"Train Loss: {avg_train_loss:.4f}")
        print(f"Validation Loss: {avg_val_loss:.4f}")
        print(f"Validation Accuracy: {accuracy:.4f}")
        
        # Save checkpoint to Drive
        checkpoint_path = os.path.join(save_dir, f"checkpoint_epoch_{epoch+1}.pt")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss,
            'val_accuracy': accuracy
        }, checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")
        
        # Save best model
        if accuracy > best_val_accuracy:
            best_val_accuracy = accuracy
            torch.save(model.state_dict(), best_model_path)
            print(f"New best model saved with accuracy: {accuracy:.4f}")
    
    # Save training history
    history = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_accuracies': val_accuracies
    }
    
    return model, history