In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np

from SpectraDataset import SpectraDataset
from PipelineRunner import PipelineRunner

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RepeatedStratifiedKFold, ShuffleSplit
from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler
from sklearn.svm import SVC
from sklearn.cluster import KMeans

from nirs4all.presets.ref_models import decon
from nirs4all.transformations import (
    Gaussian as GS,
    Rotate_Translate as RT,
    SavitzkyGolay as SG,
    StandardNormalVariate as SNV,
)

print("="*80)
print("TESTING ENHANCED DATASET OPERATIONS")
print("="*80)

# Create initial dataset with two sources
np.random.seed(42)
source_1 = np.random.rand(100, 1000) * 4 + 4
source_2 = np.random.rand(100, 500) * 40 + 6
targets = np.random.randint(0, 8, size=100)

dataset = SpectraDataset(task_type="classification")
dataset.add_data(features=[source_1, source_2], targets=targets)
print("Initial Dataset:", dataset)

# Test 1: Sample Augmentation
print("\n1. SAMPLE AUGMENTATION")
print("-" * 40)
new_samples = dataset.sample_augmentation(n_copies=1, processing_tag="sample_aug")
print(f"Created {len(new_samples)} new samples")
print("After Sample Augmentation:", dataset)

# Test 2: Feature Augmentation
print("\n2. FEATURE AUGMENTATION")
print("-" * 40)
rows_before = len(dataset.indices)
dataset.feature_augmentation(processing_tag="feature_aug")
rows_after = len(dataset.indices)
print(f"Rows: {rows_before} -> {rows_after} (ratio: {rows_after/rows_before:.1f})")
print("After Feature Augmentation:", dataset)

# Test 3: Branching
print("\n3. BRANCHING")
print("-" * 40)
train_rows_before = len(dataset.indices.filter(dataset.indices['partition'] == 'train'))
dataset.branch_dataset(n_branches=3)
train_rows_after = len(dataset.indices.filter(dataset.indices['partition'] == 'train'))
print(f"Train rows: {train_rows_before} -> {train_rows_after} (ratio: {train_rows_after/train_rows_before:.1f})")
print("After Branching:", dataset)

# Test 4: Feature extraction in different formats
print("\n4. FEATURE EXTRACTION")
print("-" * 40)

# 2D features for specific branch and processing
features_branch0 = dataset.get_features_2d(filters={'branch': 0, 'processing': 'raw'})
print(f"Branch 0, raw processing: {features_branch0.shape}")

features_all_branches = dataset.get_features_2d(filters={'partition': 'train'})
print(f"All branches, all processing: {features_all_branches.shape}")

# 3D features
features_3d = dataset.get_features_3d(filters={'branch': 1, 'processing': 'feature_aug'})
print(f"3D features (branch 1, feature_aug): {features_3d.shape}")

print("\n" + "="*80)
print("ENHANCED OPERATIONS TEST COMPLETED!")
print("="*80)

# Now test with a simple pipeline
pipeline_config = {
    "pipeline": [
        MinMaxScaler,
        {"feature_augmentation": [None, SG]},
        {"sample_augmentation": [RT]},
        ShuffleSplit,
        {"cluster": KMeans(n_clusters=5, random_state=42)},
        RepeatedStratifiedKFold(n_splits=5, n_repeats=2, random_state=42),
        {
            "branch": [
                [
                    RobustScaler(),
                    {
                        "model": RandomForestClassifier(random_state=42, max_depth=10),
                        "y_pipeline": StandardScaler,
                    },
                ],
                {
                    "model": decon,
                    "y_pipeline": StandardScaler(),
                },
            ]
        },
    ]
}

print("\nRunning simple pipeline with enhanced dataset operations...")
runner = PipelineRunner()
result_dataset, fitted_pipeline, history, fitted_tree = runner.run(pipeline_config, dataset)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
TESTING ENHANCED DATASET OPERATIONS
Initial Dataset: 
Source 0: 100x1000 Mean: 6.00, Std: 0.12
Source 1: 100x500 Mean: 26.07, Std: 1.14

Samples: 100, Rows: 100, Features: 2
Partitions: ['train']
  train: 100 samples
Groups: [0] - Branches: [0] - Processing: ['raw']
Targets: {'classes': [0, 1, 2, 3, 4, 5, 6, 7], 'n_samples': 100}
Results: {'n_predictions': 0, 'models': [], 'partitions': [], 'folds': []}


1. SAMPLE AUGMENTATION
----------------------------------------
Created 100 new samples
After Sample Augmentation: 
Source 0: 200x1000 Mean: 6.00, Std: 0.12
Source 1: 200x500 Mean: 26.07, Std: 1.14

Samples: 200, Rows: 200, Features: 2
Partitions: ['train']
  train: 200 samples
Groups: [0] - Branches: [0] - Processing: ['raw', 'sample_aug']
Targets: {'classes': [0, 1, 2, 3, 4, 5, 6, 7], 'n_samples': 200}
Results: {'n_predictions': 0, 'models': [], 'partitions': [], 'folds': []}


2. FEATURE AUGMENT

In [None]:
class DatasetMonitor:
    def __init__(self):
        self.snapshots = []

    def capture(self, dataset, stage_name):
        # Updated to work with new SpectraDataset structure
        snapshot = {
            'stage': stage_name,
            'features_shape': f"{len(dataset.features.sources)} sources" if dataset.features else "No features",
            'total_rows': len(dataset.indices),
            'samples': dataset._next_sample,
            'partitions': dataset.indices['partition'].unique().to_list(),
            'branches': dataset.indices['branch'].unique().to_list(),
            'processing': dataset.indices['processing'].unique().to_list(),
            'target_info': dataset.target_manager.get_info()
        }
        self.snapshots.append(snapshot)

        print(f"{stage_name:25} | Rows: {snapshot['total_rows']:4d} | Samples: {snapshot['samples']:3d} | " +
              f"Features: {snapshot['features_shape']:15} | Partitions: {snapshot['partitions']} | " +
              f"Branches: {snapshot['branches']} | Processing: {snapshot['processing']}")

    def summary(self):
        print("\n" + "="*120)
        print("DATASET EVOLUTION SUMMARY")
        print("="*120)
        for snap in self.snapshots:
            print(f"{snap['stage']:25} | Rows: {snap['total_rows']:4d} | Samples: {snap['samples']:3d} | " +
                  f"Features: {snap['features_shape']:15} | Partitions: {snap['partitions']} | " +
                  f"Branches: {snap['branches']} | Processing: {snap['processing']}")

monitor = DatasetMonitor()

# Test our enhanced operations step by step
print("\n" + "="*120)
print("MONITORING DATASET OPERATIONS")
print("="*120)

# Start with a fresh dataset
test_dataset = SpectraDataset(task_type="classification")
test_dataset.add_data(
    features=[np.random.rand(50, 200), np.random.rand(50, 100)],
    targets=np.random.randint(0, 3, size=50),
    partition="train"
)
monitor.capture(test_dataset, "Initial Dataset")

# Add some test data
test_dataset.add_data(
    features=[np.random.rand(20, 200), np.random.rand(20, 100)],
    targets=np.random.randint(0, 3, size=20),
    partition="test"
)
monitor.capture(test_dataset, "Added Test Set")

# Sample augmentation
test_dataset.sample_augmentation(n_copies=1, processing_tag="augmented")
monitor.capture(test_dataset, "Sample Augmentation")

# Feature augmentation
test_dataset.feature_augmentation(processing_tag="feat_aug")
monitor.capture(test_dataset, "Feature Augmentation")

# Branching
test_dataset.branch_dataset(n_branches=2)
monitor.capture(test_dataset, "Branching (2 branches)")

monitor.summary()


MONITORING DATASET OPERATIONS
Initial Dataset           | Rows:   50 | Samples:  50 | Features: 2 sources       | Partitions: ['train'] | Branches: [0] | Processing: ['raw']
Added Test Set            | Rows:   70 | Samples:  70 | Features: 2 sources       | Partitions: ['test', 'train'] | Branches: [0] | Processing: ['raw']
Sample Augmentation       | Rows:  120 | Samples: 120 | Features: 2 sources       | Partitions: ['test', 'train'] | Branches: [0] | Processing: ['augmented', 'raw']
Feature Augmentation      | Rows:  240 | Samples: 120 | Features: 2 sources       | Partitions: ['train', 'test'] | Branches: [0] | Processing: ['raw', 'feat_aug', 'augmented']
Branching (2 branches)    | Rows:  440 | Samples: 120 | Features: 2 sources       | Partitions: ['test', 'train'] | Branches: [0, 1] | Processing: ['augmented', 'feat_aug', 'raw']

DATASET EVOLUTION SUMMARY
Initial Dataset           | Rows:   50 | Samples:  50 | Features: 2 sources       | Partitions: ['train'] | Branches: [0] | 

In [None]:
# Test 3: Real-world Scenario Simulation with Fresh Dataset
print("="*80)
print("REAL-WORLD SCENARIO SIMULATION")
print("="*80)

def create_fresh_test_dataset():
    """Create a fresh dataset for testing"""
    fresh_dataset = SpectraDataset()

    # Add some fresh training data
    features = [
        np.random.randn(50, 100),  # 50 samples, 100 features (source 1)
        np.random.randn(50, 80)    # 50 samples, 80 features (source 2)
    ]
    targets = np.random.randint(0, 3, 50)  # 3-class classification

    fresh_dataset.add_data(
        features=features,
        targets=targets,
        partition="train",
        group=0,
        branch=0,
        processing="raw"
    )

    return fresh_dataset

def get_dataset_info(dataset):
    """Get basic info about a dataset"""
    unique_samples = len(dataset.indices['sample'].unique())
    total_rows = len(dataset.indices)
    n_features = len(dataset.features.sources) if dataset.features else 0
    return unique_samples, total_rows, n_features

def test_augmentation_operations():
    """Test all three augmentation operations on fresh data"""
    # Start with fresh dataset
    test_dataset = create_fresh_test_dataset()
    samples, rows, features = get_dataset_info(test_dataset)
    print(f"Initial dataset: {samples} samples, {rows} rows, {features} feature sources")
    monitor.capture(test_dataset, "Fresh Dataset")

    # 1. Sample augmentation
    print("\n1. Testing Sample Augmentation...")
    sample_ids = test_dataset.sample_augmentation(
        partition='train',
        n_copies=2,
        processing_tag='augmented'
    )
    print(f"Created {len(sample_ids)} new samples")
    samples, rows, features = get_dataset_info(test_dataset)
    print(f"After augmentation: {samples} samples, {rows} rows")
    monitor.capture(test_dataset, "After Sample Augmentation")

    # 2. Feature augmentation (single processing tag)
    print("\n2. Testing Feature Augmentation...")
    test_dataset.feature_augmentation(
        processing_tag='variant1'
    )
    samples, rows, features = get_dataset_info(test_dataset)
    print(f"After feature augmentation: {samples} samples, {rows} rows")
    monitor.capture(test_dataset, "After Feature Augmentation")

    # 3. Feature extraction
    print("\n3. Testing Feature Extraction...")
    try:
        features_2d = test_dataset.get_features_2d(
            groups=[0],
            branches=[0],
            processing_tags=['raw']
        )
        print(f"2D Features shape: {features_2d.shape}")
    except Exception as e:
        print(f"2D extraction error: {e}")

    try:
        features_3d = test_dataset.get_features_3d(
            groups=[0],
            branches=[0],
            processing_tags=['raw', 'augmented']
        )
        print(f"3D Features shape: {features_3d.shape}")
    except Exception as e:
        print(f"3D extraction error: {e}")

    return test_dataset

fresh_dataset = test_augmentation_operations()
samples, rows, features = get_dataset_info(fresh_dataset)
print(f"\nFinal fresh dataset: {samples} samples, {rows} rows, {features} feature sources")

REAL-WORLD SCENARIO SIMULATION
Initial dataset: 50 samples, 50 rows, 2 feature sources
Fresh Dataset             | Rows:   50 | Samples:  50 | Features: 2 sources       | Partitions: ['train'] | Branches: [0] | Processing: ['raw']

1. Testing Sample Augmentation...
Created 100 new samples
After augmentation: 150 samples, 150 rows
After Sample Augmentation | Rows:  150 | Samples: 150 | Features: 2 sources       | Partitions: ['train'] | Branches: [0] | Processing: ['raw', 'augmented']

2. Testing Feature Augmentation...
After feature augmentation: 150 samples, 300 rows
After Feature Augmentation | Rows:  300 | Samples: 150 | Features: 2 sources       | Partitions: ['train'] | Branches: [0] | Processing: ['raw', 'variant1', 'augmented']

3. Testing Feature Extraction...
2D extraction error: SpectraDataset.get_features_2d() got an unexpected keyword argument 'groups'
3D extraction error: SpectraDataset.get_features_3d() got an unexpected keyword argument 'groups'

Final fresh dataset: 150

In [None]:
# Test 4: Performance Benchmarking
print("="*80)
print("PERFORMANCE BENCHMARKING")
print("="*80)

def get_memory_usage_mb():
    """Get current memory usage in MB"""
    import psutil
    import os
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024

def benchmark_operations():
    """Benchmark all three operations and feature extraction"""
    import time
    import gc

    # Create a larger dataset for benchmarking
    base_dataset = dataset.copy()
    samples, rows, features = get_dataset_info(base_dataset)
    memory_mb = get_memory_usage_mb()

    print(f"Benchmarking with dataset:")
    print(f"  Samples: {samples}")
    print(f"  Rows: {rows}")
    print(f"  Feature sources: {features}")
    print(f"  Memory usage: ~{memory_mb:.1f} MB")
    print()

    # Initialize variables
    features_2d = None
    features_3d = None

    # Sample Augmentation Benchmark
    start_time = time.time()
    sample_ids = base_dataset.sample_augmentation(partition='train', n_copies=2, processing_tag='bench_aug')
    sample_time = time.time() - start_time
    print(f"Sample Augmentation (2x): {sample_time:.3f}s")

    # Feature Augmentation Benchmark
    start_time = time.time()
    base_dataset.feature_augmentation(processing_tag='bench_feat_aug')
    feature_time = time.time() - start_time
    print(f"Feature Augmentation: {feature_time:.3f}s")

    # 2D Feature Extraction Benchmark
    try:
        start_time = time.time()
        features_2d = base_dataset.get_features_2d(partition='train')
        extract_2d_time = time.time() - start_time
        print(f"2D Feature Extraction: {extract_2d_time:.3f}s, Shape: {features_2d.shape}")
    except Exception as e:
        print(f"2D extraction error: {e}")
        extract_2d_time = 0

    # 3D Feature Extraction Benchmark
    try:
        start_time = time.time()
        features_3d = base_dataset.get_features_3d(partition='train')
        extract_3d_time = time.time() - start_time
        print(f"3D Feature Extraction: {extract_3d_time:.3f}s, Shape: {features_3d.shape}")
    except Exception as e:
        print(f"3D extraction error: {e}")
        extract_3d_time = 0

    # Memory cleanup - safe deletion
    if 'features_2d' in locals() and features_2d is not None:
        del features_2d
    if 'features_3d' in locals() and features_3d is not None:
        del features_3d
    gc.collect()

    return {
        'sample_aug': sample_time,
        'feature_aug': feature_time,
        'extract_2d': extract_2d_time,
        'extract_3d': extract_3d_time
    }

benchmark_results = benchmark_operations()
print(f"\nTotal benchmark time: {sum(benchmark_results.values()):.3f}s")

# Show final memory usage
final_memory = get_memory_usage_mb()
print(f"Final memory usage: ~{final_memory:.1f} MB")

print("\nBenchmark Summary:")
for operation, time_taken in benchmark_results.items():
    print(f"  {operation}: {time_taken:.3f}s")

PERFORMANCE BENCHMARKING
Benchmarking with dataset:
  Samples: 200
  Rows: 1200
  Feature sources: 2
  Memory usage: ~1173.1 MB

Sample Augmentation (2x): 0.240s
Feature Augmentation: 0.129s
2D extraction error: SpectraDataset.get_features_2d() got an unexpected keyword argument 'partition'
3D extraction error: SpectraDataset.get_features_3d() got an unexpected keyword argument 'partition'

Total benchmark time: 0.369s
Final memory usage: ~1060.4 MB

Benchmark Summary:
  sample_aug: 0.240s
  feature_aug: 0.129s
  extract_2d: 0.000s
  extract_3d: 0.000s


In [None]:
import time
print("="*80)
print("PERFORMANCE TESTING")
print("="*80)

# Test performance with larger datasets
def test_performance():
    print("Testing performance with larger datasets...")

    # Create a larger dataset
    n_samples = 1000
    large_dataset = SpectraDataset(task_type="classification")

    # Add initial data with timing
    start_time = time.time()
    large_dataset.add_data(
        features=[np.random.rand(n_samples, 2000), np.random.rand(n_samples, 1000)],
        targets=np.random.randint(0, 10, size=n_samples),
        partition="train"
    )
    init_time = time.time() - start_time
    print(f"Initial data loading: {init_time:.3f}s for {n_samples} samples")

    # Test sample augmentation performance
    start_time = time.time()
    large_dataset.sample_augmentation(n_copies=2, processing_tag="perf_aug")
    aug_time = time.time() - start_time
    print(f"Sample augmentation (2x): {aug_time:.3f}s")

    # Test feature augmentation performance
    start_time = time.time()
    large_dataset.feature_augmentation(processing_tag="perf_feat")
    feat_time = time.time() - start_time
    print(f"Feature augmentation: {feat_time:.3f}s")

    # Test branching performance
    start_time = time.time()
    large_dataset.branch_dataset(n_branches=3)
    branch_time = time.time() - start_time
    print(f"Branching (3x): {branch_time:.3f}s")

    print(f"\nFinal dataset size: {len(large_dataset.indices)} rows")
    print(f"Memory usage estimate: ~{len(large_dataset.indices) * 3000 * 8 / 1024 / 1024:.1f} MB")

    # Test extraction performance
    start_time = time.time()
    features = large_dataset.get_features_2d(filters={'branch': 0, 'processing': 'raw'})
    extract_time = time.time() - start_time
    print(f"Feature extraction: {extract_time:.3f}s for {features.shape}")

    return large_dataset

perf_dataset = test_performance()

print("\n" + "="*80)
print("PERFORMANCE TESTING COMPLETED!")
print("="*80)

PERFORMANCE TESTING
Testing performance with larger datasets...
Initial data loading: 0.020s for 1000 samples
Sample augmentation (2x): 0.337s
Feature augmentation: 0.209s
Branching (3x): 2.212s

Final dataset size: 18000 rows
Memory usage estimate: ~412.0 MB
Feature extraction: 0.019s for (1000, 3000)

PERFORMANCE TESTING COMPLETED!


In [None]:
print("="*80)
print("MEMORY EFFICIENCY TESTING")
print("="*80)

def test_memory_efficiency():
    """Test memory efficiency of dataset operations"""
    print("Testing memory efficiency and data sharing...")

    # Use a copy of the existing dataset
    base_data = dataset.copy()

    print(f"Base dataset: {len(base_data.indices)} rows")
    initial_samples = len(base_data.indices["sample"].unique())
    print(f"Initial samples: {initial_samples}")

    # Check if features are properly shared/copied
    print("\nTesting feature sharing/copying...")

    # Sample augmentation
    sample_ids = base_data.sample_augmentation(partition='train', n_copies=2)
    print(f"After sample augmentation: {len(base_data.indices)} rows")
    print(f"New sample IDs created: {len(sample_ids)}")
    augmented_samples = len(base_data.indices["sample"].unique())
    print(f"Total samples after augmentation: {augmented_samples}")

    # Feature augmentation
    rows_before = len(base_data.indices)
    base_data.feature_augmentation('memory_test')
    print(f"After feature augmentation: {len(base_data.indices)} rows (increased by {len(base_data.indices) - rows_before})")

    # Branch dataset (use integer, not list)
    rows_before = len(base_data.indices)
    base_data.branch_dataset(3)  # Create 3 branches total (0, 1, 2)
    print(f"After branching: {len(base_data.indices)} rows (increased by {len(base_data.indices) - rows_before})")

    # Test data extraction
    print("\nTesting data extraction...")

    # Extract 2D features (using filters)
    features_2d = base_data.get_features_2d(filters={'partition': ['train']})
    print(f"2D features shape: {features_2d.shape}")

    # Extract 3D features (using filters)
    features_3d = base_data.get_features_3d(filters={'partition': ['train']})
    print(f"3D features shape: {features_3d.shape}")

    # Branch-specific extraction
    branch_data = base_data.get_features_2d(filters={'partition': ['train'], 'branch': [0]})
    print(f"Branch 0 data shape: {branch_data.shape}")

    # Test processing-specific extraction
    raw_data = base_data.get_features_2d(filters={'partition': ['train'], 'processing': ['raw']})
    print(f"Raw processing data shape: {raw_data.shape}")

    return base_data

memory_dataset = test_memory_efficiency()

print("\n" + "="*80)
print("MEMORY EFFICIENCY TESTING COMPLETED!")
print("="*80)

MEMORY EFFICIENCY TESTING
Testing memory efficiency and data sharing...
Base dataset: 1200 rows
Initial samples: 200

Testing feature sharing/copying...
After sample augmentation: 2400 rows
New sample IDs created: 1200
Total samples after augmentation: 1400
After feature augmentation: 4800 rows (increased by 2400)
After branching: 14400 rows (increased by 9600)

Testing data extraction...
2D features shape: (14400, 1500)
3D features shape: (14400, 1500, 1)
Branch 0 data shape: (1600, 1500)
Raw processing data shape: (900, 1500)

MEMORY EFFICIENCY TESTING COMPLETED!


In [None]:
class BranchManager:
    @staticmethod
    def create_branches(dataset, branch_configs):
        branches = {}

        for i, config in enumerate(branch_configs):
            branch_name = f"branch_{i}"

            branch_indices = dataset.indices.with_columns(
                pl.lit(branch_name).alias('branch_id')
            )

            branch_dataset = SpectraDataset(
                features=dataset.features.copy(),
                targets=dataset.targets.copy(),
                indices=branch_indices,
                feature_sources=dataset.feature_sources.copy()
            )

            branches[branch_name] = branch_dataset

        return branches

    @staticmethod
    def dispatch_to_branch(dataset, condition_func, branch_name):
        mask = condition_func(dataset)

        filtered_indices = dataset.indices.filter(mask)
        valid_sample_ids = filtered_indices['sample_id'].to_list()

        mask_array = np.isin(range(len(dataset.features)), valid_sample_ids)

        branch_indices = filtered_indices.with_columns(
            pl.lit(branch_name).alias('branch_id')
        )

        return SpectraDataset(
            features=dataset.features[mask_array],
            targets=dataset.targets[mask_array],
            indices=branch_indices,
            feature_sources=dataset.feature_sources
        )

branch_configs = ['model_rf', 'model_svm', 'model_decon']
branches = BranchManager.create_branches(clustered, branch_configs)

for name, branch in branches.items():
    monitor.capture(branch, f"Branch: {name}")

condition = lambda ds: ds.indices['cluster_id'] < 2
filtered_branch = BranchManager.dispatch_to_branch(clustered, condition, 'low_cluster')
monitor.capture(filtered_branch, "Filtered Branch")

print("="*80)
print("EDGE CASE TESTING")
print("="*80)

def test_edge_cases():
    """Test edge cases and error handling"""
    print("Testing edge cases and error handling...")

    # Test with empty dataset
    empty_dataset = dataset.copy()
    empty_dataset.indices = empty_dataset.indices.filter(pl.col("sample") == -1)  # No matches

    print("\nTesting with empty dataset:")
    try:
        result = empty_dataset.sample_augmentation(partition='train', n_copies=2)
        print(f"Empty dataset augmentation returned: {len(result)} samples")
    except Exception as e:
        print(f"Empty dataset error: {e}")

    # Test with missing partition
    test_dataset = dataset.copy()
    print("\nTesting with missing partition:")
    try:
        result = test_dataset.sample_augmentation(partition='test', n_copies=2)  # 'test' may not exist
        print(f"Missing partition augmentation returned: {len(result)} samples")
    except Exception as e:
        print(f"Missing partition error: {e}")

    # Test feature augmentation on empty dataset
    print("\nTesting feature augmentation on empty dataset:")
    try:
        empty_dataset.feature_augmentation('test_empty')
        print("Feature augmentation on empty dataset completed")
    except Exception as e:
        print(f"Feature augmentation error: {e}")

    # Test large copy numbers
    print("\nTesting with large copy numbers:")
    try:
        small_dataset = dataset.copy()
        # Get just first 10 samples
        first_samples = small_dataset.indices.head(10)
        small_dataset.indices = first_samples
        result = small_dataset.sample_augmentation(partition='train', n_copies=5)
        print(f"Large copy number returned: {len(result)} samples")
    except Exception as e:
        print(f"Large copy number error: {e}")

    # Test branch dataset with no train data
    print("\nTesting branch dataset edge cases:")
    try:
        no_train_dataset = dataset.copy()
        no_train_dataset.indices = no_train_dataset.indices.with_columns(
            pl.col("partition").str.replace("train", "test")
        )
        no_train_dataset.branch_dataset(2)
        print("Branch dataset with no train data completed")
    except Exception as e:
        print(f"Branch dataset error: {e}")

    print("\nEdge case testing completed!")

test_edge_cases()

print("="*80)
print("EDGE CASE TESTING COMPLETED!")
print("="*80)

NameError: name 'clustered' is not defined

In [None]:
print("\n" + "="*80)
print("COMPREHENSIVE INTEGRATION TEST")
print("="*80)

def comprehensive_test():
    """
    Test all three operations in sequence to ensure they work together properly.
    """
    print("Running comprehensive integration test...")

    # Create a realistic dataset
    integration_dataset = SpectraDataset(task_type="classification")

    # Add multi-source spectral data
    nir_data = np.random.rand(200, 1024) * 0.5 + 1.0
    raman_data = np.random.rand(200, 512) * 0.3 + 0.8
    labels = np.random.randint(0, 4, size=200)

    # Split data
    train_nir, train_raman, train_labels = nir_data[:140], raman_data[:140], labels[:140]
    val_nir, val_raman, val_labels = nir_data[140:170], raman_data[140:170], labels[140:170]
    test_nir, test_raman, test_labels = nir_data[170:], raman_data[170:], labels[170:]

    # Add to dataset
    integration_dataset.add_data([train_nir, train_raman], train_labels, partition="train")
    integration_dataset.add_data([val_nir, val_raman], val_labels, partition="val")
    integration_dataset.add_data([test_nir, test_raman], test_labels, partition="test")

    monitor = DatasetMonitor()
    monitor.capture(integration_dataset, "Initial Multi-Source")

    # Step 1: Sample augmentation for data balancing
    integration_dataset.sample_augmentation(partition="train", n_copies=1, processing_tag="balanced")
    monitor.capture(integration_dataset, "1. Sample Augmentation")

    # Step 2: Feature augmentation for preprocessing variants
    integration_dataset.feature_augmentation(processing_tag="preprocessed_v1")
    monitor.capture(integration_dataset, "2. Feature Augmentation #1")

    integration_dataset.feature_augmentation(processing_tag="preprocessed_v2")
    monitor.capture(integration_dataset, "3. Feature Augmentation #2")

    # Step 3: Branching for ensemble methods
    integration_dataset.branch_dataset(n_branches=4)
    monitor.capture(integration_dataset, "4. Branching (4 models)")

    # Validate data integrity
    print("\n" + "-"*60)
    print("DATA INTEGRITY VALIDATION")
    print("-"*60)

    # Check that different extraction patterns work
    test_cases = [
        ("Original train, branch 0", {'partition': 'train', 'branch': 0, 'processing': 'raw'}),
        ("Balanced train, branch 1", {'partition': 'train', 'branch': 1, 'processing': 'balanced'}),
        ("Preprocessed v1, branch 2", {'partition': 'train', 'branch': 2, 'processing': 'preprocessed_v1'}),
        ("Preprocessed v2, branch 3", {'partition': 'train', 'branch': 3, 'processing': 'preprocessed_v2'}),
        ("Validation data", {'partition': 'val', 'processing': 'raw'}),
        ("Test data", {'partition': 'test', 'processing': 'raw'}),
    ]

    for description, filters in test_cases:
        try:
            data = integration_dataset.get_features_2d(filters=filters)
            print(f"✓ {description:25}: {data.shape}")
        except Exception as e:
            print(f"✗ {description:25}: Error - {e}")

    # Test 3D extraction
    try:
        data_3d = integration_dataset.get_features_3d(filters={'partition': 'train', 'branch': 0})
        print(f"✓ 3D extraction:             {data_3d.shape}")
    except Exception as e:
        print(f"✗ 3D extraction:             Error - {e}")

    print("\n" + "-"*60)
    print("FINAL STATISTICS")
    print("-"*60)
    print(f"Total rows: {len(integration_dataset.indices)}")
    print(f"Unique samples: {integration_dataset._next_sample}")
    print(f"Partitions: {integration_dataset.indices['partition'].unique().to_list()}")
    print(f"Branches: {integration_dataset.indices['branch'].unique().to_list()}")
    print(f"Processing types: {integration_dataset.indices['processing'].unique().to_list()}")

    monitor.summary()

    return integration_dataset

final_dataset = comprehensive_test()

print("\n" + "="*80)
print("🎉 ALL TESTS COMPLETED SUCCESSFULLY!")
print("✅ Sample Augmentation: Creates new samples with new IDs, preserves origins")
print("✅ Feature Augmentation: Creates new rows with same sample IDs, different processing")
print("✅ Branching: Copies train data across branches for ensemble methods")
print("✅ Feature Extraction: 2D/3D extraction with flexible filtering")
print("="*80)


COMPREHENSIVE INTEGRATION TEST
Running comprehensive integration test...
Initial Multi-Source      | Rows:  200 | Samples: 200 | Features: 2 sources       | Partitions: ['train', 'test', 'val'] | Branches: [0] | Processing: ['raw']
1. Sample Augmentation    | Rows:  340 | Samples: 340 | Features: 2 sources       | Partitions: ['val', 'test', 'train'] | Branches: [0] | Processing: ['balanced', 'raw']
2. Feature Augmentation #1 | Rows:  680 | Samples: 340 | Features: 2 sources       | Partitions: ['test', 'val', 'train'] | Branches: [0] | Processing: ['balanced', 'preprocessed_v1', 'raw']
3. Feature Augmentation #2 | Rows: 1360 | Samples: 340 | Features: 2 sources       | Partitions: ['train', 'test', 'val'] | Branches: [0] | Processing: ['preprocessed_v1', 'raw', 'balanced', 'preprocessed_v2']
4. Branching (4 models)   | Rows: 4720 | Samples: 340 | Features: 2 sources       | Partitions: ['test', 'train', 'val'] | Branches: [0, 1, 2, 3] | Processing: ['preprocessed_v2', 'preprocessed_v