# Legal Document Classification with BERT

## Part 7: Model Training Execution

Initialize and train the BERT model on our legal document dataset.

In [None]:
# Load pre-trained model
from transformers import BertForSequenceClassification
import torch

print("Loading pre-trained BERT model...")
model = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased',  # Change to 'nlpaueb/legal-bert-base-uncased' for legal BERT
    num_labels=len(label_encoder.classes_),
    output_attentions=False,
    output_hidden_states=False
)

# Move model to device
model.to(device)

In [None]:
# Set training parameters
epochs = 4  # You can try more epochs if needed
learning_rate = 2e-5
warmup_steps = 0
weight_decay = 0.01

# Create save directory if it doesn't exist
import os
save_dir = '/content/drive/MyDrive/legal_bert_classification'
os.makedirs(save_dir, exist_ok=True)

In [None]:
# Train the model
trained_model, history = train_model(
    model, 
    train_loader, 
    val_loader, 
    device, 
    epochs=epochs, 
    learning_rate=learning_rate, 
    warmup_steps=warmup_steps, 
    weight_decay=weight_decay,
    save_dir=save_dir
)

In [None]:
# Save the final model
model_save_path = os.path.join(save_dir, 'final_model')
model.save_pretrained(model_save_path)
tokenizer.save_pretrained(model_save_path)
print(f"Final model saved to {model_save_path}")

# Save training history
import pickle
with open(os.path.join(save_dir, 'training_history.pkl'), 'wb') as f:
    pickle.dump(history, f)

In [None]:
# Plot training history
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history['train_losses'], label='Train Loss')
plt.plot(history['val_losses'], label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')

plt.subplot(1, 2, 2)
plt.plot(history['val_accuracies'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Validation Accuracy')

plt.tight_layout()
plt.savefig(os.path.join(save_dir, 'training_history.png'))
plt.show()

## Part 8: Model Evaluation

Evaluate the trained model's performance on the validation set.

In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

In [None]:
# Evaluate final model
def evaluate_model(model, dataloader, device, label_names):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            
            _, preds = torch.max(outputs.logits, dim=1)
            
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())
    
    # Generate classification report
    report = classification_report(all_labels, all_preds, target_names=label_names, digits=4)
    
    # Create confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    
    return all_preds, all_labels, report, cm

In [None]:
# Load best model for evaluation
model.load_state_dict(torch.load(os.path.join(save_dir, 'best_model.pt')))
model.to(device)

# Run evaluation
predictions, true_labels, report, cm = evaluate_model(
    model, val_loader, device, label_encoder.classes_
)

In [None]:
# Print classification report
print("Classification Report:")
print(report)

In [None]:
# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(
    cm, 
    annot=True, 
    fmt='d', 
    cmap='Blues',
    xticklabels=label_encoder.classes_,
    yticklabels=label_encoder.classes_
)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.tight_layout()
plt.savefig(os.path.join(save_dir, 'confusion_matrix.png'))
plt.show()

In [None]:
# Calculate per-class accuracy
class_accuracy = {}
for i, label in enumerate(label_encoder.classes_):
    class_mask = [tl == i for tl in true_labels]
    class_true = sum(class_mask)
    class_correct = sum(1 for p, tl in zip(predictions, true_labels) if p == tl and tl == i)
    class_accuracy[label] = class_correct / class_true if class_true > 0 else 0

# Plot per-class accuracy
plt.figure(figsize=(10, 6))
sns.barplot(x=list(class_accuracy.keys()), y=list(class_accuracy.values()))
plt.title('Accuracy by Document Type')
plt.xlabel('Document Type')
plt.ylabel('Accuracy')
plt.ylim(0, 1.0)
for i, val in enumerate(class_accuracy.values()):
    plt.text(i, val + 0.01, f'{val:.4f}', ha='center')
plt.tight_layout()
plt.savefig(os.path.join(save_dir, 'class_accuracy.png'))
plt.show()

## Part 9: Inference with New Data

Create functions to use the trained model for classifying new legal documents.

In [None]:
def predict_document_type(text, model, tokenizer, label_encoder, device, max_length=512):
    """Predict document type for a new text."""
    # Prepare the text
    encoding = tokenizer(
        text,
        truncation=True,
        padding='max_length',
        max_length=max_length,
        return_tensors='pt'
    )
    
    # Move to device
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    # Make prediction
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    
    # Get predicted class
    probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
    confidence, predicted_class = torch.max(probabilities, dim=1)
    
    # Convert to label
    predicted_label = label_encoder.inverse_transform([predicted_class.item()])[0]
    
    return {
        'predicted_label': predicted_label,
        'confidence': confidence.item(),
        'probabilities': {
            label: prob.item() 
            for label, prob in zip(label_encoder.classes_, probabilities[0].cpu().numpy())
        }
    }

In [None]:
# Example usage
sample_text = """REGULATION (EU) 2016/679 OF THE EUROPEAN PARLIAMENT AND OF THE COUNCIL
of 27 April 2016
on the protection of natural persons with regard to the processing of personal data and on the free movement of such data, and repealing Directive 95/46/EC (General Data Protection Regulation)
(Text with EEA relevance)
THE EUROPEAN PARLIAMENT AND THE COUNCIL OF THE EUROPEAN UNION,
Having regard to the Treaty on the Functioning of the European Union, and in particular Article 16 thereof,
Having regard to the proposal from the European Commission,
After transmission of the draft legislative act to the national parliaments,
Having regard to the opinion of the European Economic and Social Committee,
Having regard to the opinion of the Committee of the Regions,
Acting in accordance with the ordinary legislative procedure,"""

# Make prediction
prediction = predict_document_type(sample_text, model, tokenizer, label_encoder, device)

# Print results
print(f"Predicted document type: {prediction['predicted_label']}")
print(f"Confidence: {prediction['confidence']:.4f}")
print("\nProbability for each class:")
for label, prob in prediction['probabilities'].items():
    print(f"  {label}: {prob:.4f}")

## Part 10: Save and Load the Model

Instructions for saving and loading the model for future use.

In [None]:
# Save the model and tokenizer
# This was already done in the training section, but here's the code again:
model_save_path = os.path.join(save_dir, 'final_model')
model.save_pretrained(model_save_path)
tokenizer.save_pretrained(model_save_path)
print(f"Final model saved to {model_save_path}")

In [None]:
# Example of how to load the model later
from transformers import BertForSequenceClassification, BertTokenizer
import pickle

def load_model_for_inference():
    # Load label encoder
    with open('/content/drive/MyDrive/legal_bert_classification/label_encoder.pkl', 'rb') as f:
        label_encoder = pickle.load(f)
    
    # Load model and tokenizer
    model_path = '/content/drive/MyDrive/legal_bert_classification/final_model'
    model = BertForSequenceClassification.from_pretrained(model_path)
    tokenizer = BertTokenizer.from_pretrained(model_path)
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    return model, tokenizer, label_encoder, device

# Uncommment to test loading and using the model
# loaded_model, loaded_tokenizer, loaded_encoder, loaded_device = load_model_for_inference()
# prediction = predict_document_type("Your text here", loaded_model, loaded_tokenizer, loaded_encoder, loaded_device)

## Comparison with Previous Approach

Finally, let's compare the BERT-based approach with the previous concept-code-based approach:

1. **Data Source**:
   - Previous: Used concept codes (e.g., "1086, 1196, 2002")
   - BERT: Uses header + recitals text

2. **Model Complexity**:
   - Previous: Simple bag-of-words with LogisticRegression
   - BERT: Deep transformer model with millions of parameters

3. **Resource Requirements**:
   - Previous: Low (can run on CPU)
   - BERT: High (requires GPU for efficient training)

4. **Performance**:
   - Previous: 91.18% accuracy
   - BERT: Potentially higher accuracy, especially for documents with similar concepts but different legal purposes

5. **Language Understanding**:
   - Previous: No understanding of the legal text
   - BERT: Can understand legal language patterns and context

Both approaches have their strengths - the concept-based model is simple and efficient, while the BERT model has the potential for deeper understanding of legal texts.