# Legal Document Classification with BERT - V2 (Full Dataset)

## Part 4: Model Training

Train the BERT model on the full 45K documents dataset with enhanced techniques to prevent overfitting.

In [None]:
import torch
from torch.optim import AdamW
from transformers import BertForSequenceClassification, get_linear_schedule_with_warmup
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import time
import matplotlib.pyplot as plt
import numpy as np
import os
import datetime

In [None]:
# Define training function with early stopping
def train_model(model, train_loader, val_loader, device, epochs=4, 
                learning_rate=2e-5, warmup_steps=0, weight_decay=0.01,
                early_stopping_patience=2, save_dir=None):
    """Train the model with early stopping and gradient accumulation."""
    
    if not save_dir:
        save_dir = f"/content/drive/MyDrive/legal_bert_classification_v2/model_{datetime.datetime.now().strftime('%Y%m%d_%H%M')}"
        os.makedirs(save_dir, exist_ok=True)
    
    # Prepare optimizer
    optimizer = AdamW(model.parameters(), lr=learning_rate, eps=1e-8, weight_decay=weight_decay)
    
    # Calculate total training steps
    total_steps = len(train_loader) * epochs
    
    # Prepare scheduler
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )
    
    # Mixed precision training
    scaler = GradScaler()
    
    # Track training metrics
    train_losses = []
    val_losses = []
    val_accuracies = []
    best_val_accuracy = 0
    best_model_path = os.path.join(save_dir, "best_model.pt")
    
    # Early stopping variables
    no_improvement_count = 0
    
    # Save config for reproducibility
    with open(os.path.join(save_dir, "training_config.txt"), 'w') as f:
        f.write(f"Epochs: {epochs}\n")
        f.write(f"Learning rate: {learning_rate}\n")
        f.write(f"Weight decay: {weight_decay}\n")
        f.write(f"Warmup steps: {warmup_steps}\n")
        f.write(f"Batch size: {next(iter(train_loader))['input_ids'].shape[0]}\n")
        f.write(f"Early stopping patience: {early_stopping_patience}\n")
        f.write(f"Training samples: {len(train_loader.dataset)}\n")
        f.write(f"Validation samples: {len(val_loader.dataset)}\n")
        f.write(f"Device: {device}\n")
    
    # Training loop
    for epoch in range(epochs):
        print(f"\n{'='*50}")
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"{'='*50}")
        
        # Track time
        epoch_start_time = time.time()
        
        # Training phase
        model.train()
        total_train_loss = 0
        
        # Progress bar for training
        progress_bar = tqdm(train_loader, desc="Training")
        
        # Gradient accumulation steps (helps with large models/limited memory)
        accumulation_steps = 2
        optimizer.zero_grad()
        
        for step, batch in enumerate(progress_bar):
            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            # Forward pass with mixed precision
            with autocast():
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                
                loss = outputs.loss / accumulation_steps
            
            # Backward pass with gradient scaling
            scaler.scale(loss).backward()
            
            # Track loss
            total_train_loss += loss.item() * accumulation_steps
            
            # Update weights every accumulation_steps
            if (step + 1) % accumulation_steps == 0 or (step + 1) == len(train_loader):
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()
            
            # Update progress bar
            progress_bar.set_postfix({"loss": loss.item() * accumulation_steps})
            
        # Calculate average loss for the epoch
        avg_train_loss = total_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        
        # Validation phase
        model.eval()
        total_val_loss = 0
        all_preds = []
        all_labels = []
        
        # Progress bar for validation
        progress_bar = tqdm(val_loader, desc="Validating")
        
        with torch.no_grad():
            for batch in progress_bar:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                
                # Forward pass
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                
                loss = outputs.loss
                total_val_loss += loss.item()
                
                # Get predictions
                logits = outputs.logits
                _, preds = torch.max(logits, dim=1)
                
                all_preds.extend(preds.cpu().tolist())
                all_labels.extend(labels.cpu().tolist())
        
        # Calculate metrics
        avg_val_loss = total_val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        
        # Calculate accuracy
        accuracy = sum(1 for p, l in zip(all_preds, all_labels) if p == l) / len(all_preds)
        val_accuracies.append(accuracy)
        
        # Print epoch summary
        epoch_time = time.time() - epoch_start_time
        print(f"Epoch {epoch+1} completed in {epoch_time:.2f}s")
        print(f"Train Loss: {avg_train_loss:.4f}")
        print(f"Validation Loss: {avg_val_loss:.4f}")
        print(f"Validation Accuracy: {accuracy:.4f}")
        
        # Save checkpoint to Drive
        checkpoint_path = os.path.join(save_dir, f"checkpoint_epoch_{epoch+1}.pt")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss,
            'val_accuracy': accuracy
        }, checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")
        
        # Save best model
        if accuracy > best_val_accuracy:
            best_val_accuracy = accuracy
            torch.save(model.state_dict(), best_model_path)
            print(f"New best model saved with accuracy: {accuracy:.4f}")
            no_improvement_count = 0
        else:
            no_improvement_count += 1
            print(f"No improvement for {no_improvement_count} epochs. Best accuracy: {best_val_accuracy:.4f}")
            
        # Early stopping
        if no_improvement_count >= early_stopping_patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break
    
    # Final save of the model
    model_path = os.path.join(save_dir, "final_model.pt")
    torch.save(model.state_dict(), model_path)
    print(f"Final model saved to {model_path}")
    
    # Save training history
    history = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_accuracies': val_accuracies
    }
    
    return model, history

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Define model parameters
save_dir = '/content/drive/MyDrive/legal_bert_classification_v2/model'
os.makedirs(save_dir, exist_ok=True)

# Initialize model, optimizer, and scheduler from Part 3

In [None]:
# Train the model
trained_model, history = train_model(
    model=model,  # From Part 3
    train_loader=train_loader,  # From Part 3
    val_loader=val_loader,  # From Part 3
    device=device,
    epochs=4,
    learning_rate=2e-5,
    warmup_steps=int(0.1 * len(train_loader) * 4),  # 10% of total steps
    weight_decay=0.01,
    early_stopping_patience=2,
    save_dir=save_dir
)

In [None]:
# Plot training history
plt.figure(figsize=(16, 6))

# Plot loss curves
plt.subplot(1, 2, 1)
plt.plot(history['train_losses'], label='Train Loss')
plt.plot(history['val_losses'], label='Validation Loss')
plt.title('Loss Curves', fontsize=14)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)

# Plot accuracy
plt.subplot(1, 2, 2)
plt.plot(history['val_accuracies'], marker='o', linestyle='-', color='green')
plt.title('Validation Accuracy', fontsize=14)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('/content/drive/MyDrive/legal_bert_classification_v2/training_history.png')
plt.show()

In [None]:
# Export model for local deployment

# Save model to Drive in transformers format
output_dir = "/content/drive/MyDrive/legal_bert_classification_v2/final_model"
os.makedirs(output_dir, exist_ok=True)

# Load best model
best_model_path = os.path.join(save_dir, "best_model.pt")
model.load_state_dict(torch.load(best_model_path))

# Save using Hugging Face's save_pretrained method
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

print(f"Model and tokenizer saved to {output_dir}")
print("Now you can download the model folder for local inference.")