# Legal Document Classification with BERT - V2 (Full Dataset)

## Part 5: Model Evaluation

In this notebook, we'll perform a comprehensive evaluation of our trained model, analyzing its performance across different metrics and exploring any misclassifications.

In [None]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import BertForSequenceClassification, BertTokenizer
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import pickle
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

In [None]:
# Load the model and tokenizer from saved checkpoints
model_dir = '/content/drive/MyDrive/legal_bert_classification_v2/final_model'

# Load tokenizer
tokenizer = BertTokenizer.from_pretrained(model_dir)

# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BertForSequenceClassification.from_pretrained(model_dir)
model.to(device)
model.eval()

print(f"Model loaded from {model_dir}")
print(f"Using device: {device}")

# Load label encoder
with open('/content/drive/MyDrive/legal_bert_classification_v2/label_encoder.pkl', 'rb') as f:
    label_encoder = pickle.load(f)

print(f"Label mapping:")
for i, label in enumerate(label_encoder.classes_):
    print(f"  {i} -> {label}")

In [None]:
# Redefine the Dataset class to be consistent with the training
class LegalDocumentDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        # Handle very long texts
        if len(text) > self.max_length * 10:
            text = text[:self.max_length * 10]
        
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [None]:
# Load test data
test_df_path = '/content/drive/MyDrive/legal_bert_classification_v2/full_bert_dataset.csv'
df = pd.read_csv(test_df_path)

# Load test data - if you saved the test split from Part 3, use that instead
# If not, we'll create a test split here
from sklearn.model_selection import train_test_split

_, test_df = train_test_split(
    df, test_size=0.1, random_state=42  # Use the same random seed as in Part 3
)

print(f"Test dataset size: {len(test_df)}")

# Encode labels
test_df['encoded_label'] = label_encoder.transform(test_df['label'])

# Create test dataset and dataloader
test_dataset = LegalDocumentDataset(
    test_df['text'].values,
    test_df['encoded_label'].values,
    tokenizer,
    max_length=512
)

test_loader = DataLoader(
    test_dataset,
    batch_size=8,  # Adjust based on GPU memory
    shuffle=False
)

In [None]:
# Evaluate on test set
all_preds = []
all_labels = []
all_probs = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Evaluating"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        logits = outputs.logits
        probs = torch.nn.functional.softmax(logits, dim=1)
        _, preds = torch.max(logits, dim=1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())

# Convert predictions to original labels
pred_labels = label_encoder.inverse_transform(all_preds)
true_labels = label_encoder.inverse_transform(all_labels)

In [None]:
# Calculate overall accuracy
accuracy = accuracy_score(true_labels, pred_labels)
print(f"Overall Accuracy: {accuracy:.4f}")

# Generate detailed classification report
class_report = classification_report(true_labels, pred_labels, output_dict=True)
print("\nClassification Report:")
print(classification_report(true_labels, pred_labels))

# Convert to dataframe for better visualization
report_df = pd.DataFrame(class_report).transpose()
report_df = report_df.sort_values('support', ascending=False).drop('accuracy')

# Save report
report_df.to_csv('/content/drive/MyDrive/legal_bert_classification_v2/classification_report.csv')

In [None]:
# Create confusion matrix
cm = confusion_matrix(true_labels, pred_labels, labels=label_encoder.classes_)

# Normalize to show percentages
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues', 
            xticklabels=label_encoder.classes_, 
            yticklabels=label_encoder.classes_)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix (Normalized)')
plt.tight_layout()
plt.savefig('/content/drive/MyDrive/legal_bert_classification_v2/confusion_matrix.png', dpi=300)
plt.show()

In [None]:
# Analyze misclassifications
# Add predictions to the test dataframe
test_df['predicted_label'] = pred_labels
test_df['correct'] = test_df['label'] == test_df['predicted_label']

# Find incorrect predictions
incorrect_df = test_df[~test_df['correct']].copy()

print(f"Number of misclassifications: {len(incorrect_df)} out of {len(test_df)} ({len(incorrect_df)/len(test_df)*100:.2f}%)")

# Save misclassifications for further analysis
if len(incorrect_df) > 0:
    incorrect_df.to_csv('/content/drive/MyDrive/legal_bert_classification_v2/misclassifications.csv', index=False)
    
    # Show sample of misclassifications
    print("\nSample misclassifications:")
    misclass_sample = incorrect_df.sample(min(5, len(incorrect_df)))
    for _, row in misclass_sample.iterrows():
        print(f"\nTrue label: {row['label']}")
        print(f"Predicted label: {row['predicted_label']}")
        print(f"Text excerpt: {row['text'][:200]}...")

In [None]:
# Confidence analysis
# Convert probabilities to a dataframe
prob_df = pd.DataFrame(all_probs, columns=[f"prob_{label}" for label in label_encoder.classes_])
test_df = pd.concat([test_df.reset_index(drop=True), prob_df], axis=1)

# Calculate confidence (probability of predicted class)
test_df['confidence'] = test_df.apply(
    lambda row: row[f"prob_{row['predicted_label']}"], 
    axis=1
)

# Analyze confidence by correctness
print("Confidence analysis:")
print(f"Mean confidence for correct predictions: {test_df[test_df['correct']]['confidence'].mean():.4f}")
if len(incorrect_df) > 0:
    print(f"Mean confidence for incorrect predictions: {test_df[~test_df['correct']]['confidence'].mean():.4f}")

# Plot confidence distribution
plt.figure(figsize=(12, 6))
sns.histplot(data=test_df, x='confidence', hue='correct', element='step', bins=30)
plt.title('Prediction Confidence Distribution')
plt.xlabel('Confidence Score')
plt.ylabel('Count')
plt.legend(title='Correct Prediction')
plt.savefig('/content/drive/MyDrive/legal_bert_classification_v2/confidence_distribution.png', dpi=300)
plt.show()

## Performance Comparison with Previous Model

Compare performance with the previous model that achieved 100% accuracy on a smaller dataset.

In [None]:
# Create a comparison summary
print("Performance Comparison:")
print("-" * 60)
print("                         Previous Model       Current Model")
print("-" * 60)
print(f"Training Data Size:         ~20,000               ~45,000")
print(f"Text Content:            Header+Recitals    Header+Recitals+Main Body")
print(f"Text Preprocessing:          Basic                 Advanced")
print(f"Overall Accuracy:           100.00%                {accuracy*100:.2f}%")
print(f"Average F1-Score:           100.00%                {report_df.iloc[:-3]['f1-score'].mean()*100:.2f}%")
print(f"Overfitting Risk:            High                  {'Low' if accuracy < 0.98 else 'Moderate'}")
print("-" * 60)

## Final Model Export

Export the model for local inference after completing evaluation.

In [None]:
# Save evaluation results
eval_results = {
    'accuracy': accuracy,
    'classification_report': class_report,
    'confusion_matrix': cm.tolist(),
    'label_mapping': {i: label for i, label in enumerate(label_encoder.classes_)}
}

import json
with open('/content/drive/MyDrive/legal_bert_classification_v2/evaluation_results.json', 'w') as f:
    json.dump(eval_results, f, indent=2)

print("Evaluation results saved to Drive.")
print("\nModel evaluation complete! The model can now be downloaded for local inference.")