# 02 - Prepare Dataset

This notebook provides an interactive interface for dataset preparation:
- Synthetic data generation
- Data augmentation with visualization
- Adjust augmentation parameters interactively
- Class balancing
- Train/val/test splitting


In [None]:
# Import libraries
import os
import sys
import json
from pathlib import Path
from collections import Counter
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Add src to path
sys.path.insert(0, str(Path().absolute().parent / 'src'))

from data_processor import DFGDatasetProcessor, load_config, load_dfg_mapping
from synthetic_data_generator import SyntheticPaperGenerator, create_training_ready_dataset
from data_augmentation import TextAugmenter, DatasetBalancer, create_augmented_dataset

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)

print("‚úì Libraries imported")


In [None]:
# Configuration
CONFIG_PATH = '../config.yaml'
DFG_MAPPING_PATH = '../data/dfg_mapping.json'
OUTPUT_DIR = '../dfg-classifier/data/processed'

# Load configuration and DFG mapping
config = load_config(CONFIG_PATH)
dfg_mapping = load_dfg_mapping(DFG_MAPPING_PATH)

print("‚úì Configuration loaded")
print(f"  Model: {config.get('model', {}).get('name', 'N/A')}")
print(f"  Output directory: {OUTPUT_DIR}")

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)


## Step 1: Generate Synthetic Data

Adjust `SAMPLES_PER_CATEGORY` to control the size of the generated dataset.


In [None]:
# Parameters for synthetic data generation
SAMPLES_PER_CATEGORY = 50  # Adjust this value (try 50, 100, 200)

# Generate synthetic papers
print("üîÑ Generating synthetic papers...")
generator = SyntheticPaperGenerator(dfg_mapping)
synthetic_papers = generator.generate_dataset(
    samples_per_category=SAMPLES_PER_CATEGORY,
    output_dir=None  # Don't save intermediate results
)

print(f"‚úì Generated {len(synthetic_papers)} synthetic papers")

# Show sample
if synthetic_papers:
    print("\nüìù Sample synthetic paper:")
    sample = synthetic_papers[0]
    print(f"  Category: {sample['category']}")
    print(f"  Title: {sample['title']}")
    print(f"  Abstract: {sample['abstract'][:200]}...")


In [None]:
# Visualize synthetic data distribution
if synthetic_papers:
    categories = [paper['category'] for paper in synthetic_papers]
    category_counts = Counter(categories)
    
    # Plot distribution
    fig, ax = plt.subplots(figsize=(14, 6))
    labels, counts = zip(*sorted(category_counts.items(), key=lambda x: x[1], reverse=True))
    
    ax.bar(range(len(labels)), counts, color='steelblue')
    ax.set_xticks(range(len(labels)))
    ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=9)
    ax.set_xlabel('Category', fontsize=12)
    ax.set_ylabel('Number of Samples', fontsize=12)
    ax.set_title(f'Synthetic Data Distribution ({len(synthetic_papers)} total samples)', 
                 fontsize=14, fontweight='bold')
    ax.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"  Categories: {len(category_counts)}")
    print(f"  Min samples per category: {min(counts)}")
    print(f"  Max samples per category: {max(counts)}")
    print(f"  Mean samples per category: {np.mean(counts):.1f}")


## Step 2: Process and Tokenize Data

Convert synthetic papers to tokenized format for training.


In [None]:
# Initialize data processor
processor = DFGDatasetProcessor(config, dfg_mapping)

# Convert to dataset format
print("üîÑ Processing and tokenizing data...")
dataset = []
for paper in synthetic_papers:
    # Combine title and abstract
    combined_text = f"{paper['title']} [SEP] {paper['abstract']}"
    
    # Tokenize
    tokenized = processor.tokenize_text(combined_text)
    
    # Get label ID
    label_id = processor.label_to_id.get(paper['category'], -1)
    
    if label_id == -1:
        print(f"‚ö†Ô∏è  Warning: Unknown category: {paper['category']}")
        continue
    
    dataset.append({
        'input_ids': tokenized['input_ids'],
        'attention_mask': tokenized['attention_mask'],
        'labels': label_id,
        'filename': paper['id'],
        'title': paper['title'],
        'abstract': paper['abstract'],
        'label': paper['category'],
        'combined_text': combined_text
    })

print(f"‚úì Processed {len(dataset)} samples")
print(f"  Tokenizer: {config.get('model', {}).get('name', 'N/A')}")
print(f"  Max sequence length: {len(dataset[0]['input_ids']) if dataset else 0} tokens")


## Step 3: Data Augmentation

Adjust augmentation parameters:
- `USE_AUGMENTATION`: Enable/disable augmentation
- `AUGMENTATION_FACTOR`: Number of augmented samples per original (2 = double the dataset)
- `AUGMENTATION_PROB`: Probability of applying augmentation to each sample


In [None]:
# Augmentation parameters
USE_AUGMENTATION = True  # Set to False to skip augmentation
AUGMENTATION_FACTOR = 2  # Try 1 (no augmentation), 2 (double), 3 (triple)
AUGMENTATION_PROB = 0.3  # Probability of augmenting each sample

original_dataset_size = len(dataset)

if USE_AUGMENTATION:
    print("üîÑ Applying data augmentation...")
    
    # Initialize augmenter
    augmenter = TextAugmenter(
        model_name=config.get('model', {}).get('name', 'allenai/scibert_scivocab_uncased'),
        augmentation_prob=AUGMENTATION_PROB
    )
    
    # Create augmented dataset
    augmented_dataset = create_augmented_dataset(
        dataset,
        augmenter,
        augmentation_factor=AUGMENTATION_FACTOR,
        balance_classes=False  # We'll balance separately
    )
    
    dataset = augmented_dataset
    
    print(f"‚úì Augmented dataset: {original_dataset_size} ‚Üí {len(dataset)} samples")
    print(f"  Augmentation factor: {AUGMENTATION_FACTOR}x")
else:
    print("‚è≠Ô∏è  Skipping augmentation")


In [None]:
# Visualize augmentation results
if USE_AUGMENTATION:
    # Compare original vs augmented
    original_labels = [item['label'] for item in dataset[:original_dataset_size]]
    augmented_labels = [item['label'] for item in dataset]
    
    original_counts = Counter(original_labels)
    augmented_counts = Counter(augmented_labels)
    
    # Get common labels for comparison
    common_labels = sorted(set(original_labels))
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Original distribution
    orig_labels_sorted, orig_counts_sorted = zip(*[(l, original_counts[l]) for l in common_labels])
    axes[0].bar(range(len(orig_labels_sorted)), orig_counts_sorted, color='steelblue')
    axes[0].set_xticks(range(len(orig_labels_sorted)))
    axes[0].set_xticklabels(orig_labels_sorted, rotation=45, ha='right', fontsize=8)
    axes[0].set_title(f'Original Dataset ({len(original_labels)} samples)', fontsize=13, fontweight='bold')
    axes[0].set_ylabel('Count', fontsize=11)
    axes[0].grid(axis='y', alpha=0.3)
    
    # Augmented distribution
    aug_labels_sorted, aug_counts_sorted = zip(*[(l, augmented_counts[l]) for l in common_labels])
    axes[1].bar(range(len(aug_labels_sorted)), aug_counts_sorted, color='coral')
    axes[1].set_xticks(range(len(aug_labels_sorted)))
    axes[1].set_xticklabels(aug_labels_sorted, rotation=45, ha='right', fontsize=8)
    axes[1].set_title(f'Augmented Dataset ({len(augmented_labels)} samples)', fontsize=13, fontweight='bold')
    axes[1].set_ylabel('Count', fontsize=11)
    axes[1].grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"Original: {len(original_labels)} samples, {len(original_counts)} classes")
    print(f"Augmented: {len(augmented_labels)} samples, {len(augmented_counts)} classes")


## Step 4: Class Balancing

Balance the dataset to ensure equal representation across classes.


In [None]:
# Balancing parameters
BALANCE_CLASSES = True  # Set to False to skip balancing
BALANCING_STRATEGY = 'oversample'  # Options: 'oversample', 'undersample'

if BALANCE_CLASSES:
    print("üîÑ Balancing classes...")
    
    # Show distribution before balancing
    before_counts = Counter([item['label'] for item in dataset])
    print(f"  Before: {len(dataset)} samples, {len(before_counts)} classes")
    
    # Balance dataset
    balancer = DatasetBalancer(strategy=BALANCING_STRATEGY)
    balanced_dataset = balancer.balance_dataset(dataset)
    dataset = balanced_dataset
    
    # Show distribution after balancing
    after_counts = Counter([item['label'] for item in dataset])
    print(f"  After: {len(dataset)} samples, {len(after_counts)} classes")
    
    # Visualize balancing results
    common_labels = sorted(set(list(before_counts.keys()) + list(after_counts.keys())))
    
    fig, ax = plt.subplots(figsize=(14, 6))
    x = np.arange(len(common_labels))
    width = 0.35
    
    before_vals = [before_counts.get(l, 0) for l in common_labels]
    after_vals = [after_counts.get(l, 0) for l in common_labels]
    
    ax.bar(x - width/2, before_vals, width, label='Before Balancing', color='steelblue', alpha=0.7)
    ax.bar(x + width/2, after_vals, width, label='After Balancing', color='coral', alpha=0.7)
    
    ax.set_xlabel('Category', fontsize=12)
    ax.set_ylabel('Number of Samples', fontsize=12)
    ax.set_title('Class Distribution: Before vs After Balancing', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(common_labels, rotation=45, ha='right', fontsize=8)
    ax.legend()
    ax.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"  Balance improvement: Min samples per class increased from {min(before_vals)} to {min(after_vals)}")
else:
    print("‚è≠Ô∏è  Skipping class balancing")


## Step 5: Split Dataset

Split the dataset into train/validation/test sets.


In [None]:
# Split dataset
print("üîÑ Splitting dataset...")
train_data, val_data, test_data = processor.split_dataset(dataset)

print(f"‚úì Dataset split:")
print(f"  Train: {len(train_data)} samples ({len(train_data)/len(dataset)*100:.1f}%)")
print(f"  Validation: {len(val_data)} samples ({len(val_data)/len(dataset)*100:.1f}%)")
print(f"  Test: {len(test_data)} samples ({len(test_data)/len(dataset)*100:.1f}%)")

# Visualize split distribution
train_labels = Counter([item['label'] for item in train_data])
val_labels = Counter([item['label'] for item in val_data])
test_labels = Counter([item['label'] for item in test_data])

common_labels = sorted(set(list(train_labels.keys()) + list(val_labels.keys()) + list(test_labels.keys())))

fig, ax = plt.subplots(figsize=(14, 6))
x = np.arange(len(common_labels))
width = 0.25

train_vals = [train_labels.get(l, 0) for l in common_labels]
val_vals = [val_labels.get(l, 0) for l in common_labels]
test_vals = [test_labels.get(l, 0) for l in common_labels]

ax.bar(x - width, train_vals, width, label='Train', color='steelblue', alpha=0.8)
ax.bar(x, val_vals, width, label='Validation', color='orange', alpha=0.8)
ax.bar(x + width, test_vals, width, label='Test', color='green', alpha=0.8)

ax.set_xlabel('Category', fontsize=12)
ax.set_ylabel('Number of Samples', fontsize=12)
ax.set_title('Class Distribution Across Splits', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(common_labels, rotation=45, ha='right', fontsize=8)
ax.legend()
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()


## Step 6: Save Processed Datasets

Save the processed datasets to disk for training.


In [None]:
# Save datasets
print("üíæ Saving processed datasets...")

processor.save_processed_dataset(train_data, os.path.join(OUTPUT_DIR, 'train.json'))
processor.save_processed_dataset(val_data, os.path.join(OUTPUT_DIR, 'val.json'))
processor.save_processed_dataset(test_data, os.path.join(OUTPUT_DIR, 'test.json'))

print(f"‚úì Datasets saved to: {OUTPUT_DIR}")

# Save statistics
stats = {
    'total_samples': len(dataset),
    'train_samples': len(train_data),
    'val_samples': len(val_data),
    'test_samples': len(test_data),
    'num_classes': len(processor.label_to_id),
    'samples_per_category': SAMPLES_PER_CATEGORY,
    'augmentation_factor': AUGMENTATION_FACTOR if USE_AUGMENTATION else 1,
    'balanced': BALANCE_CLASSES,
    'class_distribution': dict(train_labels)
}

stats_file = os.path.join(OUTPUT_DIR, 'dataset_stats.json')
with open(stats_file, 'w', encoding='utf-8') as f:
    json.dump(stats, f, indent=2, ensure_ascii=False)

print(f"‚úì Statistics saved to: {stats_file}")
print("\n‚úÖ Dataset preparation complete!")
print(f"\nüìä Final Statistics:")
print(f"  Total samples: {len(dataset)}")
print(f"  Train: {len(train_data)}")
print(f"  Validation: {len(val_data)}")
print(f"  Test: {len(test_data)}")
print(f"  Classes: {len(processor.label_to_id)}")
