# Data Exploration and Validation
## Reasoning Distillation Project

This notebook:
1. Downloads e-SNLI and Alpaca datasets
2. Validates data quality and structure
3. Computes statistics
4. Visualizes sample examples

In [None]:
# Setup
import sys
import os
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

# Imports
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pprint import pprint

from src.data.data_loader import (
    TeacherDataLoader,
    DatasetConfig,
    quick_load_esnli,
    quick_load_alpaca
)

# Styling
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

%load_ext autoreload
%autoreload 2

## 1. Initialize Data Loader

In [None]:
# Create configuration
config = DatasetConfig(
    raw_data_dir="../data/raw",
    processed_data_dir="../data/processed",
    cache_dir="../data/cache"
)

# Initialize loader
loader = TeacherDataLoader(config)
print("Data loader initialized successfully!")
print(f"Raw data directory: {config.raw_data_dir}")
print(f"Cache directory: {config.cache_dir}")

## 2. Load and Validate e-SNLI Dataset

In [None]:
# Load e-SNLI (all splits)
print("Loading e-SNLI dataset...")
esnli_dataset = loader.load_esnli()

# Display basic info
print(f"\nAvailable splits: {list(esnli_dataset.keys())}")
for split_name, split_data in esnli_dataset.items():
    print(f"  {split_name}: {len(split_data):,} samples")

# Show features
print(f"\nFeatures: {esnli_dataset['train'].features}")

In [None]:
# Validate e-SNLI
print("Validating e-SNLI dataset...")
esnli_stats = loader.validate_esnli(esnli_dataset)

print("\n=== e-SNLI Validation Statistics ===")
pprint(esnli_stats)

In [None]:
# Visualize label distribution
label_map = {0: 'Entailment', 1: 'Neutral', 2: 'Contradiction'}

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for idx, (split_name, label_dist) in enumerate(esnli_stats['label_distribution'].items()):
    labels = [label_map[k] for k in sorted(label_dist.keys())]
    counts = [label_dist[k] for k in sorted(label_dist.keys())]
    
    axes[idx].bar(labels, counts, color=['#2ecc71', '#3498db', '#e74c3c'])
    axes[idx].set_title(f'{split_name.capitalize()} Split')
    axes[idx].set_ylabel('Count')
    axes[idx].tick_params(axis='x', rotation=45)

plt.suptitle('e-SNLI Label Distribution', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Visualize explanation lengths
exp_stats = esnli_stats['explanation_stats']

splits = list(exp_stats.keys())
mean_lengths = [exp_stats[s]['mean_length'] for s in splits]
min_lengths = [exp_stats[s]['min_length'] for s in splits]
max_lengths = [exp_stats[s]['max_length'] for s in splits]

x = range(len(splits))
width = 0.25

plt.figure(figsize=(10, 6))
plt.bar([i - width for i in x], mean_lengths, width, label='Mean', color='#3498db')
plt.bar(x, min_lengths, width, label='Min', color='#2ecc71')
plt.bar([i + width for i in x], max_lengths, width, label='Max', color='#e74c3c')

plt.xlabel('Split')
plt.ylabel('Explanation Length (words)')
plt.title('e-SNLI Explanation Length Statistics')
plt.xticks(x, splits)
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
# Display sample e-SNLI examples
print("\n=== Sample e-SNLI Examples ===")
samples = loader.get_sample_examples(esnli_dataset['train'], n_samples=3)

for i, sample in enumerate(samples, 1):
    parsed = loader.parse_esnli_sample(sample)
    print(f"\n--- Example {i} ---")
    print(f"Premise: {parsed['premise']}")
    print(f"Hypothesis: {parsed['hypothesis']}")
    print(f"Label: {label_map[parsed['label']]}")
    print(f"Explanation: {parsed['explanation']}")

## 3. Load and Validate Alpaca Dataset

In [None]:
# Load Alpaca (subset for quick testing)
print("Loading Alpaca dataset...")
alpaca_dataset = loader.load_alpaca(
    dataset_name="tatsu-lab/alpaca",
    max_samples=5000  # Use subset for faster testing
)

print(f"\nLoaded {len(alpaca_dataset):,} samples")
print(f"Features: {alpaca_dataset.features}")

In [None]:
# Validate Alpaca
print("Validating Alpaca dataset...")
alpaca_stats = loader.validate_alpaca(alpaca_dataset)

print("\n=== Alpaca Validation Statistics ===")
pprint(alpaca_stats)

In [None]:
# Visualize Alpaca statistics
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Instruction vs Output lengths
inst_stats = alpaca_stats['instruction_length']
out_stats = alpaca_stats['output_length']

categories = ['Mean', 'Min', 'Max']
inst_values = [inst_stats['mean'], inst_stats['min'], inst_stats['max']]
out_values = [out_stats['mean'], out_stats['min'], out_stats['max']]

x = range(len(categories))
width = 0.35

axes[0].bar([i - width/2 for i in x], inst_values, width, label='Instruction', color='#9b59b6')
axes[0].bar([i + width/2 for i in x], out_values, width, label='Output', color='#e67e22')
axes[0].set_ylabel('Length (words)')
axes[0].set_title('Alpaca Text Length Statistics')
axes[0].set_xticks(x)
axes[0].set_xticklabels(categories)
axes[0].legend()

# Samples with/without input
with_input = alpaca_stats['samples_with_input']
without_input = alpaca_stats['total_samples'] - with_input

axes[1].pie(
    [with_input, without_input],
    labels=['With Input Context', 'Without Input Context'],
    autopct='%1.1f%%',
    colors=['#3498db', '#95a5a6'],
    startangle=90
)
axes[1].set_title('Alpaca Samples: Input Context Distribution')

plt.tight_layout()
plt.show()

In [None]:
# Display sample Alpaca examples
print("\n=== Sample Alpaca Examples ===")
alpaca_samples = loader.get_sample_examples(alpaca_dataset, n_samples=3)

for i, sample in enumerate(alpaca_samples, 1):
    parsed = loader.parse_alpaca_sample(sample)
    print(f"\n--- Example {i} ---")
    print(f"Instruction: {parsed['instruction']}")
    if parsed['input']:
        print(f"Input: {parsed['input']}")
    print(f"Output: {parsed['output'][:200]}...")  # Truncate long outputs

## 4. Save Processed Data (Optional)

In [None]:
# Optional: Save processed datasets for later use
save_processed = False  # Set to True to save

if save_processed:
    print("Saving processed e-SNLI...")
    for split_name, split_data in esnli_dataset.items():
        loader.save_processed_data(split_data, 'esnli', split_name)
    
    print("Saving processed Alpaca...")
    loader.save_processed_data(alpaca_dataset, 'alpaca', 'train')
    
    print("âœ“ All data saved successfully!")
else:
    print("Skipping save step (set save_processed=True to save)")

## 5. Summary and Next Steps

In [None]:
print("\n" + "="*60)
print("DATA LOADING & VALIDATION SUMMARY")
print("="*60)

print("\nðŸ“Š e-SNLI Dataset:")
for split_name in esnli_stats['splits']:
    n = esnli_stats['splits'][split_name]
    exp = esnli_stats['explanation_stats'][split_name]['samples_with_explanation']
    print(f"  â€¢ {split_name}: {n:,} samples ({exp:,} with explanations)")

print("\nðŸ“š Alpaca Dataset:")
print(f"  â€¢ Total samples: {alpaca_stats['total_samples']:,}")
print(f"  â€¢ Mean instruction length: {alpaca_stats['instruction_length']['mean']:.1f} words")
print(f"  â€¢ Mean output length: {alpaca_stats['output_length']['mean']:.1f} words")
print("\n" + "="*60)