# Preprocessing and Dataset Testing
## Reasoning Distillation Project

This notebook tests:
1. TaskFormatter for prompt creation
2. ReasoningPreprocessor for tokenization
3. PyTorch Dataset classes (e-SNLI, Alpaca, MultiTask)
4. DataLoader creation and batching
5. End-to-end pipeline validation

In [None]:
# Setup
import sys
from pathlib import Path

# Add src to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root / 'src'))

# Imports
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pprint import pprint
from tqdm import tqdm

from src.data.data_loader import TeacherDataLoader, DatasetConfig
from src.data.preprocessor import (
    ReasoningPreprocessor,
    PreprocessConfig,
    TaskFormatter,
    quick_preprocess_sample
)
from src.data.dataset import (
    ESNLIDataset,
    AlpacaDataset,
    MultiTaskDataset,
    create_dataloaders,
    load_datasets_from_config
)

# Styling
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

%load_ext autoreload
%autoreload 2

print("âœ“ Imports successful!")

## 1. Test TaskFormatter
Verify that prompts are correctly formatted for both NLI and instruction tasks.

In [None]:
# Initialize formatter
formatter = TaskFormatter()

# Test NLI formatting
print("=" * 70)
print("NLI TASK FORMATTING")
print("=" * 70)

premise = "A person on a horse jumps over a broken down airplane."
hypothesis = "A person is training his horse for a competition."
label = 1  # neutral
explanation = "The person is not necessarily training his horse."

source, target = formatter.format_nli(premise, hypothesis, label, explanation)

print(f"\nSource (Input):\n{source}")
print(f"\nTarget (Output):\n{target}")
print(f"\nSource length: {len(source.split())} words")
if target:
    print(f"Target length: {len(target.split())} words")
else:
    print("Target length: 0 words (None)")

In [None]:
# Test Instruction formatting
print("\n" + "=" * 70)
print("INSTRUCTION TASK FORMATTING")
print("=" * 70)

instruction = "Write a poem about spring."
input_text = ""  # No input context
output = "Blossoms dance in gentle breeze,\nNature wakes from winter's freeze."

source, target = formatter.format_instruction(instruction, input_text, output)

print(f"\nSource (Input):\n{source}")
print(f"\nTarget (Output):\n{target}")

# With input context
instruction2 = "Summarize the following text."
input_text2 = "Artificial intelligence has revolutionized many industries."
output2 = "AI has transformed various sectors."

source2, target2 = formatter.format_instruction(instruction2, input_text2, output2)

print(f"\n--- With Input Context ---")
print(f"\nSource (Input):\n{source2}")
print(f"\nTarget (Output):\n{target2}")

## 2. Test ReasoningPreprocessor
Test tokenization and encoding for FLAN-T5.

In [None]:
# Initialize preprocessor
config = PreprocessConfig(
    model_name="google/flan-t5-base",
    max_source_length=256,
    max_target_length=128,
    padding="max_length"
)

preprocessor = ReasoningPreprocessor(config)

# Display tokenizer info
print("=" * 70)
print("TOKENIZER INFORMATION")
print("=" * 70)
pprint(preprocessor.get_tokenizer_info())

In [None]:
# Test tokenization on e-SNLI sample
print("\n" + "=" * 70)
print("TOKENIZING e-SNLI SAMPLE")
print("=" * 70)

esnli_sample = {
    'premise': "A person on a horse jumps over a broken down airplane.",
    'hypothesis': "A person is training his horse for a competition.",
    'label': 1,
    'explanation_1': "The person is not necessarily training his horse."
}

tokenized = preprocessor.preprocess_esnli_sample(esnli_sample)

print(f"\nInput IDs shape: {tokenized['input_ids'].shape}")
print(f"Attention mask shape: {tokenized['attention_mask'].shape}")
print(f"Labels shape: {tokenized['labels'].shape}")

# Show actual tokens
print(f"\n--- First 20 Input Tokens ---")
print(tokenized['input_ids'][:20].tolist())

print(f"\n--- Decoded Input ---")
decoded_input = preprocessor.decode_prediction(tokenized['input_ids'])
print(decoded_input)

print(f"\n--- First 20 Label Tokens ---")
print(tokenized['labels'][:20].tolist())

print(f"\n--- Decoded Labels ---")
# Replace -100 with pad token for decoding
labels_for_decode = tokenized['labels'].clone()
labels_for_decode[labels_for_decode == -100] = preprocessor.tokenizer.pad_token_id
decoded_labels = preprocessor.decode_prediction(labels_for_decode)
print(decoded_labels)

In [None]:
# Visualize attention mask
print("\n" + "=" * 70)
print("ATTENTION MASK VISUALIZATION")
print("=" * 70)

attention_mask = tokenized['attention_mask'].numpy()
input_ids = tokenized['input_ids'].numpy()

# Find where actual content ends
content_length = attention_mask.sum()
print(f"\nActual content tokens: {content_length} / {len(attention_mask)}")
print(f"Padding tokens: {len(attention_mask) - content_length}")

# Plot attention mask
plt.figure(figsize=(14, 3))
plt.imshow(attention_mask.reshape(1, -1), cmap='RdYlGn', aspect='auto')
plt.colorbar(label='Attention (1=attend, 0=ignore)')
plt.xlabel('Token Position')
plt.yticks([])
plt.title('Attention Mask Pattern')
plt.tight_layout()
plt.show()

## 3. Load Real Data and Create Datasets

In [None]:
# Load small subsets for testing
print("Loading data...")
loader = TeacherDataLoader()

# Load e-SNLI (small subset)
esnli_full = loader.load_esnli()
esnli_train_small = esnli_full['train'].select(range(100))  # First 100 samples
esnli_val_small = esnli_full['validation'].select(range(50))  # First 50 samples

# Load Alpaca (small subset)
alpaca_small = loader.load_alpaca(max_samples=100)

print(f"âœ“ Loaded {len(esnli_train_small)} e-SNLI train samples")
print(f"âœ“ Loaded {len(esnli_val_small)} e-SNLI val samples")
print(f"âœ“ Loaded {len(alpaca_small)} Alpaca samples")

In [None]:
# Create PyTorch datasets
print("\n" + "=" * 70)
print("CREATING PYTORCH DATASETS")
print("=" * 70)

# Initialize preprocessor
preprocess_config = PreprocessConfig(
    model_name="google/flan-t5-base",
    max_source_length=256,
    max_target_length=128
)
preprocessor = ReasoningPreprocessor(preprocess_config)

# Create datasets
esnli_train_dataset = ESNLIDataset(
    esnli_train_small,
    preprocessor,
    cache_dir="../data/cache/esnli_train",
    use_cache=True
)

esnli_val_dataset = ESNLIDataset(
    esnli_val_small,
    preprocessor,
    cache_dir="../data/cache/esnli_val",
    use_cache=True
)

alpaca_dataset = AlpacaDataset(
    alpaca_small,
    preprocessor,
    cache_dir="../data/cache/alpaca",
    use_cache=True
)

print(f"\nâœ“ e-SNLI train dataset: {len(esnli_train_dataset)} samples")
print(f"âœ“ e-SNLI val dataset: {len(esnli_val_dataset)} samples")
print(f"âœ“ Alpaca dataset: {len(alpaca_dataset)} samples")

In [None]:
# Test dataset indexing
print("\n" + "=" * 70)
print("TESTING DATASET INDEXING")
print("=" * 70)

# Get a sample from e-SNLI dataset
sample_idx = 0
esnli_sample = esnli_train_dataset[sample_idx]

print(f"\nSample keys: {esnli_sample.keys()}")
print(f"\nShapes:")
for key, value in esnli_sample.items():
    print(f"  {key}: {value.shape}")

# Decode and display
print(f"\n--- Decoded Sample ---")
decoded_input = preprocessor.decode_prediction(esnli_sample['input_ids'])
print(f"Input: {decoded_input}")

labels_for_decode = esnli_sample['labels'].clone()
labels_for_decode[labels_for_decode == -100] = preprocessor.tokenizer.pad_token_id
decoded_target = preprocessor.decode_prediction(labels_for_decode)
print(f"Target: {decoded_target}")

# Show raw sample
raw = esnli_train_dataset.get_raw_sample(sample_idx)
print(f"\n--- Raw Sample ---")
print(f"Premise: {raw['premise']}")
print(f"Hypothesis: {raw['hypothesis']}")
print(f"Label: {raw['label']}")

## 4. Test Multi-Task Dataset

In [None]:
# Create multi-task dataset with different sampling strategies
print("=" * 70)
print("MULTI-TASK DATASET TESTING")
print("=" * 70)

strategies = ['balanced', 'proportional', 'esnli_only', 'alpaca_only']

for strategy in strategies:
    print(f"\n--- Strategy: {strategy} ---")
    
    multitask_dataset = MultiTaskDataset(
        esnli_dataset=esnli_train_dataset,
        alpaca_dataset=alpaca_dataset,
        sampling_strategy=strategy
    )
    
    # Sample 20 times and count sources
    sample_counts = {'esnli': 0, 'alpaca': 0}
    
    for i in range(20):
        sample = multitask_dataset[i]
        decoded = preprocessor.decode_prediction(sample['input_ids'])
        
        # Detect source by prompt format
        if 'nli premise:' in decoded:
            sample_counts['esnli'] += 1
        else:
            sample_counts['alpaca'] += 1
    
    print(f"  e-SNLI samples: {sample_counts['esnli']}/20 ({sample_counts['esnli']/20*100:.0f}%)")
    print(f"  Alpaca samples: {sample_counts['alpaca']}/20 ({sample_counts['alpaca']/20*100:.0f}%)")

In [None]:
# Visualize sampling distribution
print("\n" + "=" * 70)
print("SAMPLING DISTRIBUTION ANALYSIS")
print("=" * 70)

# Test balanced strategy over 1000 samples
multitask_balanced = MultiTaskDataset(
    esnli_dataset=esnli_train_dataset,
    alpaca_dataset=alpaca_dataset,
    sampling_strategy='balanced'
)

n_samples = 500
sample_sources = []

for i in tqdm(range(n_samples), desc="Sampling"):
    sample = multitask_balanced[i]
    decoded = preprocessor.decode_prediction(sample['input_ids'])
    
    if 'nli premise:' in decoded:
        sample_sources.append('e-SNLI')
    else:
        sample_sources.append('Alpaca')

# Plot distribution
from collections import Counter
counts = Counter(sample_sources)

plt.figure(figsize=(10, 6))
plt.bar(list(counts.keys()), list(counts.values()), color=['#3498db', '#e67e22'])
plt.xlabel('Dataset Source')
plt.ylabel('Number of Samples')
plt.title(f'Multi-Task Sampling Distribution (n={n_samples}, strategy=balanced)')
for i, (key, value) in enumerate(counts.items()):
    plt.text(i, value + 5, f'{value}\n({value/n_samples*100:.1f}%)', 
             ha='center', va='bottom', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

## 5. Test DataLoader Creation

In [None]:
# Create dataloaders
print("=" * 70)
print("CREATING DATALOADERS")
print("=" * 70)

batch_size = 8

train_loader, val_loader = create_dataloaders(
    train_dataset=esnli_train_dataset,
    val_dataset=esnli_val_dataset,
    batch_size=batch_size,
    num_workers=0,  # Use 0 for notebook compatibility
    pad_token_id=preprocessor.tokenizer.pad_token_id,
    shuffle_train=True
)

print(f"\nâœ“ Train DataLoader: {len(train_loader)} batches")
print(f"âœ“ Val DataLoader: {len(val_loader)} batches")
print(f"\nBatch size: {batch_size}")
print(f"Total train samples: {len(train_loader) * batch_size}")
print(f"Total val samples: {len(val_loader) * batch_size}")

In [None]:
# Test batch iteration
print("\n" + "=" * 70)
print("TESTING BATCH ITERATION")
print("=" * 70)

# Get first batch
batch = next(iter(train_loader))

print(f"\nBatch keys: {batch.keys()}")
print(f"\nBatch shapes:")
for key, value in batch.items():
    print(f"  {key}: {value.shape}")

# Verify batch dimensions
assert batch['input_ids'].shape[0] == batch_size, "Batch size mismatch!"
assert batch['input_ids'].shape[1] == config.max_source_length, "Sequence length mismatch!"
print(f"\nâœ“ Batch dimensions correct!")

# Check device and dtype
print(f"\nTensor device: {batch['input_ids'].device}")
print(f"Tensor dtype: {batch['input_ids'].dtype}")

In [None]:
# Display samples from batch
print("\n" + "=" * 70)
print("BATCH SAMPLES PREVIEW")
print("=" * 70)

n_display = 3

for i in range(min(n_display, batch_size)):
    print(f"\n--- Sample {i+1} ---")
    
    # Decode input
    input_text = preprocessor.decode_prediction(batch['input_ids'][i])
    print(f"Input: {input_text}")
    
    # Decode target
    labels = batch['labels'][i].clone()
    labels[labels == -100] = preprocessor.tokenizer.pad_token_id
    target_text = preprocessor.decode_prediction(labels)
    print(f"Target: {target_text}")

## 6. Performance Analysis

In [None]:
# Measure preprocessing speed
import time

print("=" * 70)
print("PREPROCESSING PERFORMANCE")
print("=" * 70)

# Test with caching
n_iterations = 3

print("\n--- With Caching ---")
times_cached = []

for iteration in range(n_iterations):
    start = time.time()
    
    for i in range(50):
        _ = esnli_train_dataset[i]
    
    elapsed = time.time() - start
    times_cached.append(elapsed)
    print(f"Iteration {iteration+1}: {elapsed:.3f}s ({50/elapsed:.1f} samples/sec)")

print(f"\nAverage: {np.mean(times_cached):.3f}s")
print(f"Speedup (iter 2 vs iter 1): {times_cached[0]/times_cached[1]:.2f}x")

In [None]:
# Analyze token length distribution
print("\n" + "=" * 70)
print("TOKEN LENGTH DISTRIBUTION")
print("=" * 70)

input_lengths = []
label_lengths = []

for i in tqdm(range(len(esnli_train_dataset)), desc="Analyzing lengths"):
    sample = esnli_train_dataset[i]
    
    # Count non-padding tokens
    input_len = sample['attention_mask'].sum().item()
    label_len = (sample['labels'] != -100).sum().item()
    
    input_lengths.append(input_len)
    label_lengths.append(label_len)

# Plot distributions
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].hist(input_lengths, bins=30, color='#3498db', alpha=0.7, edgecolor='black')
axes[0].axvline(np.mean(input_lengths), color='red', linestyle='--', 
                label=f'Mean: {np.mean(input_lengths):.1f}')
axes[0].set_xlabel('Input Length (tokens)')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Input Token Length Distribution')
axes[0].legend()

axes[1].hist(label_lengths, bins=30, color='#e74c3c', alpha=0.7, edgecolor='black')
axes[1].axvline(np.mean(label_lengths), color='red', linestyle='--',
                label=f'Mean: {np.mean(label_lengths):.1f}')
axes[1].set_xlabel('Target Length (tokens)')
axes[1].set_ylabel('Frequency')
axes[1].set_title('Target Token Length Distribution')
axes[1].legend()

plt.tight_layout()
plt.show()

print(f"\nInput lengths - Mean: {np.mean(input_lengths):.1f}, Std: {np.std(input_lengths):.1f}")
print(f"Target lengths - Mean: {np.mean(label_lengths):.1f}, Std: {np.std(label_lengths):.1f}")
print(f"\nMax configured lengths: Input={config.max_source_length}, Target={config.max_target_length}")

## 7. Test Extraction Functions

In [None]:
# Test label and explanation extraction
print("=" * 70)
print("TESTING PREDICTION PARSING")
print("=" * 70)

# Test predictions
test_predictions = [
    "entailment explanation: The person is definitely on a horse.",
    "neutral explanation: We cannot determine if they are training.",
    "contradiction",
    "entailment This clearly follows from the premise."
]

for pred in test_predictions:
    print(f"\nPrediction: {pred}")
    
    label = preprocessor.extract_label_from_prediction(pred)
    explanation = preprocessor.extract_explanation_from_prediction(pred)
    
    print(f"  â†’ Label: {label}")
    print(f"  â†’ Explanation: {explanation}")

## 8. Summary and Validation

In [None]:
print("\n" + "=" * 70)
print("PREPROCESSING PIPELINE SUMMARY")
print("=" * 70)

print("\nâœ… TaskFormatter: PASSED")
print("  â€¢ NLI tasks formatted correctly")
print("  â€¢ Instruction tasks formatted correctly")

print("\nâœ… ReasoningPreprocessor: PASSED")
print(f"  â€¢ Tokenizer loaded: {preprocessor.config.model_name}")
print(f"  â€¢ Max source length: {preprocessor.config.max_source_length}")
print(f"  â€¢ Max target length: {preprocessor.config.max_target_length}")
print("  â€¢ Tokenization working correctly")
print("  â€¢ Padding/truncation working")

print("\nâœ… PyTorch Datasets: PASSED")
print(f"  â€¢ e-SNLI train: {len(esnli_train_dataset)} samples")
print(f"  â€¢ e-SNLI val: {len(esnli_val_dataset)} samples")
print(f"  â€¢ Alpaca: {len(alpaca_dataset)} samples")
print("  â€¢ Caching working correctly")
print("  â€¢ Multi-task dataset working")

print("\nâœ… DataLoaders: PASSED")
print(f"  â€¢ Train batches: {len(train_loader)}")
print(f"  â€¢ Val batches: {len(val_loader)}")
print(f"  â€¢ Batch size: {batch_size}")
print("  â€¢ Collation working correctly")

print("\nâœ… Performance:")
print(f"  â€¢ Average preprocessing time: {np.mean(times_cached[1:]):.3f}s per 50 samples")
print(f"  â€¢ Caching speedup: ~{times_cached[0]/np.mean(times_cached[1:]):.1f}x")
print(f"  â€¢ Average input tokens: {np.mean(input_lengths):.1f}")
print(f"  â€¢ Average target tokens: {np.mean(label_lengths):.1f}")

print("\n" + "=" * 70)
print("ðŸŽ‰ ALL TESTS PASSED - READY FOR MODEL TRAINING!")
print("=" * 70)
