# 01 - Explore Dataset

This notebook explores the DFG classification dataset:
- Dataset statistics
- Class distribution visualization
- Data quality checks
- Sample data inspection


In [2]:
# Import libraries
import os
import sys
import json
from pathlib import Path
from collections import Counter
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Add src to path
sys.path.insert(0, str(Path().absolute().parent / 'src'))

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)

print("‚úì Libraries imported")


‚úì Libraries imported


In [3]:
# Configuration
DATA_PATH = '../dfg-classifier/data/processed'  # Change to '../dfg-classifier/data/small' for small dataset
DFG_MAPPING_PATH = '../data/dfg_mapping.json'

# Check if data exists
if not os.path.exists(DATA_PATH):
    print(f"‚ö†Ô∏è  Data path not found: {DATA_PATH}")
    print("üí° Try changing DATA_PATH to '../dfg-classifier/data/small' for small dataset")
else:
    print(f"‚úì Data path found: {DATA_PATH}")
    
    # List available files
    if os.path.isdir(DATA_PATH):
        files = os.listdir(DATA_PATH)
        print(f"  Available files: {', '.join(files)}")


‚úì Data path found: ../dfg-classifier/data/processed
  Available files: dataset_stats.json, test.json, train.json, val.json


In [4]:
# Load DFG mapping
with open(DFG_MAPPING_PATH, 'r', encoding='utf-8') as f:
    dfg_mapping = json.load(f)

# Display DFG structure
print("üìö DFG Classification Structure:")
print(f"  Total Review Boards: {dfg_mapping['metadata']['total_review_boards']}")
print(f"  Classification Levels: {dfg_mapping['metadata']['total_levels']}")
print("\nMain Disciplines (Level 1):")
for code, name in dfg_mapping['level_1']['classes'].items():
    print(f"  {code}: {name}")

# Get all level 2 classes
level_2_classes = dfg_mapping['level_2']['classes']
print(f"\nSubject Areas (Level 2): {len(level_2_classes)} classes")


üìö DFG Classification Structure:
  Total Review Boards: 218
  Classification Levels: 4

Main Disciplines (Level 1):
  1: Humanities & Social Sciences
  2: Life Sciences
  3: Natural Sciences
  4: Engineering Sciences

Subject Areas (Level 2): 30 classes


In [5]:
def load_dataset_split(split='train'):
    """Load a dataset split"""
    file_path = os.path.join(DATA_PATH, f'{split}.json')
    
    if not os.path.exists(file_path):
        print(f"‚ö†Ô∏è  File not found: {file_path}")
        return None
    
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    print(f"‚úì Loaded {split} set: {len(data)} samples")
    return data

# Load all splits
train_data = load_dataset_split('train')
val_data = load_dataset_split('val')
test_data = load_dataset_split('test')


‚úì Loaded train set: 4800 samples
‚úì Loaded val set: 600 samples
‚úì Loaded test set: 600 samples


In [None]:
# Dataset Statistics
if train_data:
    total_samples = len(train_data) + (len(val_data) if val_data else 0) + (len(test_data) if test_data else 0)
    
    stats = {
        'Total Samples': total_samples,
        'Train Samples': len(train_data) if train_data else 0,
        'Val Samples': len(val_data) if val_data else 0,
        'Test Samples': len(test_data) if test_data else 0,
    }
    
    if train_data:
        stats['Train %'] = f"{len(train_data)/total_samples*100:.1f}%"
    if val_data:
        stats['Val %'] = f"{len(val_data)/total_samples*100:.1f}%"
    if test_data:
        stats['Test %'] = f"{len(test_data)/total_samples*100:.1f}%"
    
    print("üìä Dataset Statistics:")
    for key, value in stats.items():
        print(f"  {key}: {value}")


In [None]:
# Class Distribution Analysis
if train_data:
    # Count labels in each split
    train_labels = [item['label'] for item in train_data]
    val_labels = [item['label'] for item in val_data] if val_data else []
    test_labels = [item['label'] for item in test_data] if test_data else []
    
    train_counter = Counter(train_labels)
    val_counter = Counter(val_labels) if val_labels else Counter()
    test_counter = Counter(test_labels) if test_labels else Counter()
    
    # Get all unique labels
    all_labels = sorted(set(train_labels + val_labels + test_labels))
    
    print(f"üìà Class Distribution:")
    print(f"  Total unique classes: {len(all_labels)}")
    print(f"  Classes in train set: {len(train_counter)}")
    print(f"  Classes in val set: {len(val_counter)}")
    print(f"  Classes in test set: {len(test_counter)}")
    
    # Show top 10 most common classes in train set
    print("\nüèÜ Top 10 Most Common Classes (Train Set):")
    for label, count in train_counter.most_common(10):
        print(f"  {label}: {count} samples")
    
    # Show least common classes
    print("\nüìâ Least Common Classes (Train Set):")
    for label, count in train_counter.most_common()[-10:]:
        print(f"  {label}: {count} samples")


In [None]:
# Visualize Class Distribution
if train_data:
    # Prepare data for visualization
    label_counts = pd.DataFrame([
        {'Split': 'Train', 'Label': label, 'Count': count}
        for label, count in train_counter.items()
    ] + ([
        {'Split': 'Val', 'Label': label, 'Count': count}
        for label, count in val_counter.items()
    ] if val_data else []) + ([
        {'Split': 'Test', 'Label': label, 'Count': count}
        for label, count in test_counter.items()
    ] if test_data else []))
    
    # Plot class distribution
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Bar plot of top classes
    top_labels = train_counter.most_common(20)
    labels, counts = zip(*top_labels)
    
    axes[0].barh(range(len(labels)), counts, color='steelblue')
    axes[0].set_yticks(range(len(labels)))
    axes[0].set_yticklabels(labels, fontsize=8)
    axes[0].set_xlabel('Number of Samples', fontsize=12)
    axes[0].set_title('Top 20 Classes by Sample Count (Train Set)', fontsize=14, fontweight='bold')
    axes[0].invert_yaxis()
    axes[0].grid(axis='x', alpha=0.3)
    
    # Distribution histogram
    counts_list = list(train_counter.values())
    axes[1].hist(counts_list, bins=20, color='coral', edgecolor='black', alpha=0.7)
    axes[1].set_xlabel('Samples per Class', fontsize=12)
    axes[1].set_ylabel('Number of Classes', fontsize=12)
    axes[1].set_title('Distribution of Samples per Class', fontsize=14, fontweight='bold')
    axes[1].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Statistics
    print(f"\nüìä Class Distribution Statistics:")
    print(f"  Mean samples per class: {np.mean(counts_list):.1f}")
    print(f"  Median samples per class: {np.median(counts_list):.1f}")
    print(f"  Std samples per class: {np.std(counts_list):.1f}")
    print(f"  Min samples per class: {np.min(counts_list)}")
    print(f"  Max samples per class: {np.max(counts_list)}")


In [None]:
# Data Quality Checks
if train_data:
    # Check for missing data
    print("üîç Data Quality Checks:")
    
    # Check text lengths
    title_lengths = [len(item.get('title', '')) for item in train_data]
    abstract_lengths = [len(item.get('abstract', '')) for item in train_data]
    combined_lengths = [title_len + abstract_len for title_len, abstract_len in zip(title_lengths, abstract_lengths)]
    
    # Check for empty fields
    empty_titles = sum(1 for item in train_data if not item.get('title', '').strip())
    empty_abstracts = sum(1 for item in train_data if not item.get('abstract', '').strip())
    missing_labels = sum(1 for item in train_data if not item.get('label', ''))
    
    print(f"  Empty titles: {empty_titles}")
    print(f"  Empty abstracts: {empty_abstracts}")
    print(f"  Missing labels: {missing_labels}")
    
    # Text length statistics
    print(f"\nüìù Text Length Statistics:")
    print(f"  Title length - Mean: {np.mean(title_lengths):.1f}, Median: {np.median(title_lengths):.1f}")
    print(f"  Abstract length - Mean: {np.mean(abstract_lengths):.1f}, Median: {np.median(abstract_lengths):.1f}")
    print(f"  Combined length - Mean: {np.mean(combined_lengths):.1f}, Median: {np.median(combined_lengths):.1f}")
    
    # Visualize text lengths
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    axes[0].hist(title_lengths, bins=50, color='skyblue', edgecolor='black', alpha=0.7)
    axes[0].set_xlabel('Title Length (characters)', fontsize=11)
    axes[0].set_ylabel('Frequency', fontsize=11)
    axes[0].set_title('Title Length Distribution', fontsize=13, fontweight='bold')
    axes[0].grid(alpha=0.3)
    
    axes[1].hist(abstract_lengths, bins=50, color='lightgreen', edgecolor='black', alpha=0.7)
    axes[1].set_xlabel('Abstract Length (characters)', fontsize=11)
    axes[1].set_ylabel('Frequency', fontsize=11)
    axes[1].set_title('Abstract Length Distribution', fontsize=13, fontweight='bold')
    axes[1].grid(alpha=0.3)
    
    axes[2].hist(combined_lengths, bins=50, color='coral', edgecolor='black', alpha=0.7)
    axes[2].set_xlabel('Combined Length (characters)', fontsize=11)
    axes[2].set_ylabel('Frequency', fontsize=11)
    axes[2].set_title('Combined Text Length Distribution', fontsize=13, fontweight='bold')
    axes[2].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()


In [None]:
# Sample Data Inspection
if train_data:
    print("üìñ Sample Data Examples:")
    print("=" * 80)
    
    # Show a few random samples
    import random
    sample_indices = random.sample(range(len(train_data)), min(3, len(train_data)))
    
    for idx, sample_idx in enumerate(sample_indices, 1):
        sample = train_data[sample_idx]
        print(f"\n[Sample {idx}]")
        print(f"  Filename: {sample.get('filename', 'N/A')}")
        print(f"  Label: {sample.get('label', 'N/A')}")
        print(f"  Title: {sample.get('title', 'N/A')[:100]}..." if len(sample.get('title', '')) > 100 else f"  Title: {sample.get('title', 'N/A')}")
        print(f"  Abstract: {sample.get('abstract', 'N/A')[:200]}..." if len(sample.get('abstract', '')) > 200 else f"  Abstract: {sample.get('abstract', 'N/A')}")
        print(f"  Input IDs shape: {len(sample.get('input_ids', []))} tokens")
        print("-" * 80)


In [None]:
# Check dataset statistics file if available
stats_file = os.path.join(DATA_PATH, 'dataset_stats.json')
if os.path.exists(stats_file):
    with open(stats_file, 'r', encoding='utf-8') as f:
        dataset_stats = json.load(f)
    
    print("üìã Saved Dataset Statistics:")
    print(json.dumps(dataset_stats, indent=2))
else:
    print("‚ÑπÔ∏è  No dataset_stats.json file found")
