# Legal Document Classification with BERT

## Part 3: Data Preparation

If you need to prepare data from JSON files in Colab, this section provides the necessary code.

In [None]:
import os
import json
import pandas as pd
from tqdm import tqdm

def prepare_bert_dataset(json_folder_path, output_file_path, min_text_length=10):
    """Extracts header + recitals and document type from JSON files and saves to CSV."""
    
    print(f"Extracting header + recitals and labels from JSON files...")
    
    # Get list of JSON files
    json_files = [f for f in os.listdir(json_folder_path) if f.lower().endswith('.json')]
    
    if not json_files:
        print(f"No JSON files found.")
        return
    
    print(f"Found {len(json_files)} JSON files. Processing...")
    
    # Prepare data storage
    data = []
    skipped_count = 0
    empty_text_count = 0
    
    # Process files with progress bar
    for filename in tqdm(json_files):
        file_path = os.path.join(json_folder_path, filename)
        
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                article = json.load(f)
            
            # Extract document type (label)
            doc_type = article.get('type', None)
            
            # Extract header and recitals
            header = article.get('header', '')
            recitals = article.get('recitals', '')
            
            # Skip if missing essential data
            if not doc_type or (not header and not recitals):
                skipped_count += 1
                continue
            
            # Combine header and recitals
            if header and recitals:
                text = f"{header}\n{recitals}"
            elif header:
                text = header
            else:
                text = recitals
            
            # Skip documents with very short text
            if len(text) < min_text_length:
                empty_text_count += 1
                continue
            
            # Add to dataset
            data.append({
                'text': text.strip(),
                'label': doc_type,
                'celex_id': article.get('celex_id', '')  # Keep ID for reference
            })
                
        except Exception as e:
            print(f"\nWarning: Error processing file {filename}: {e}. Skipping.")
    
    # Create DataFrame
    if not data:
        print("No valid data extracted. Exiting.")
        return
    
    df = pd.DataFrame(data)
    
    # Print statistics
    print(f"\nExtracted {len(df)} documents with valid text and labels.")
    print(f"Skipped {skipped_count} documents missing type or text.")
    print(f"Skipped {empty_text_count} documents with text shorter than {min_text_length} characters.")
    
    # Check label distribution
    label_counts = df['label'].value_counts()
    print("\nLabel distribution:")
    for label, count in label_counts.items():
        print(f"  {label}: {count} ({count/len(df)*100:.2f}%)")
    
    # Save to CSV
    print(f"\nSaving to {output_file_path}...")
    df.to_csv(output_file_path, index=False)
    print(f"Data successfully saved to {output_file_path}")
    
    return df

In [None]:
# Uncomment and run this cell if you have JSON files uploaded to a folder in Drive
# json_folder_path = '/content/drive/MyDrive/legal_bert_classification/dataset_folder'
# output_file_path = '/content/drive/MyDrive/legal_bert_classification/bert_classification_dataset.csv'
# df = prepare_bert_dataset(json_folder_path, output_file_path)

## Part 4: Data Exploration

Explore and visualize the dataset to better understand its characteristics.

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Load the dataset (adjust path as needed)
# If you uploaded the file in Part 2, use:
df = pd.read_csv('bert_classification_dataset.csv')

# Or if you saved it to Drive, use:
# df = pd.read_csv('/content/drive/MyDrive/legal_bert_classification/bert_classification_dataset.csv')

print(f"Dataset shape: {df.shape}")
print(f"Number of unique labels: {df['label'].nunique()}")

# Display label distribution
print("Label distribution:")
label_counts = df['label'].value_counts()
print(label_counts)

In [None]:
# Plot label distribution
plt.figure(figsize=(10, 6))
sns.barplot(x=label_counts.index, y=label_counts.values)
plt.title('Document Type Distribution')
plt.xlabel('Document Type')
plt.ylabel('Count')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

In [None]:
# Text length analysis
df['text_length'] = df['text'].apply(len)
print("\nText length statistics:")
print(df['text_length'].describe())

# Plot text length distribution
plt.figure(figsize=(12, 6))
sns.histplot(data=df, x='text_length', hue='label', bins=50, element='step')
plt.title('Text Length Distribution by Document Type')
plt.xlabel('Text Length (characters)')
plt.xlim(0, df['text_length'].quantile(0.99))  # Limit x-axis to 99th percentile
plt.legend(title='Document Type')
plt.tight_layout()
plt.show()

In [None]:
# Sample document display
print("\nSample document from each class:")
for label in df['label'].unique():
    sample = df[df['label'] == label].iloc[0]
    print(f"\n--- {label} Example ---")
    print(f"Text (first 300 chars): {sample['text'][:300]}...")