In [1]:
import numpy as np
import torch
from loaders.augmented_loader import AugmentedLoader, Loader_aug_batch
from loaders.dataset import DataEntity

In [3]:
def test_augmented_loader_basic():
    """Test basic initialization and data loading"""
    # Create mock data
    num_features = 3
    num_timesteps = 500
    data = np.random.randn(num_features, num_timesteps)

    # Create a DataEntity
    entity = DataEntity(Y=data, name="test_entity")

    # Initialize loader
    loader = AugmentedLoader(
        dataset=entity,
        batch_size=16,
        window_size=100,
        window_step=50,
        anomaly_types=["normal", "spike", "noise"],
    )

    # Check attributes
    assert loader.Y_windows.shape[0] > 0, "No windows created"
    assert loader.Z_windows.shape[0] == loader.Y_windows.shape[0], (
        "Y and Z size mismatch"
    )
    assert loader.anomaly_mask.shape == loader.Y_windows.shape, "Mask shape mismatch"
    print("✓ Basic initialization test passed")


def test_window_dimensions():
    """Test window shape and dimensions"""
    num_features = 5
    num_timesteps = 1000
    data = np.random.randn(num_features, num_timesteps)

    entity = DataEntity(Y=data, name="test_entity")
    loader = AugmentedLoader(
        dataset=entity,
        batch_size=8,
        window_size=128,
        window_step=64,
        anomaly_types=["normal", "spike"],
    )

    # Each window should be (n_features, window_size)
    assert loader.Y_windows.shape[1] == num_features, (
        f"Expected {num_features} features"
    )
    assert loader.Y_windows.shape[2] == 128, "Expected window_size=128"
    print("✓ Window dimension test passed")


def test_iteration():
    """Test batch iteration"""
    num_features = 3
    num_timesteps = 500
    data = np.random.randn(num_features, num_timesteps)

    entity = DataEntity(Y=data, name="test_entity")
    loader = AugmentedLoader(
        dataset=entity,
        batch_size=16,
        window_size=100,
        window_step=50,
        anomaly_types=["normal", "spike"],
        shuffle=False,
    )

    batch_count = 0
    for batch in loader:
        batch_count += 1
        assert "Y" in batch, "Missing 'Y' key"
        assert "Z" in batch, "Missing 'Z' key"
        assert "anomaly_mask" in batch, "Missing 'anomaly_mask' key"
        assert "label" in batch, "Missing 'label' key"
        print(f"  Batch {batch_count}: Y shape = {batch['Y'].shape}")

    assert batch_count == loader.num_batches_per_epoch, "Batch count mismatch"
    print(f"✓ Iteration test passed ({batch_count} batches)")


def test_anomaly_types():
    """Test different anomaly type injections"""
    num_features = 3
    num_timesteps = 500
    data = np.random.randn(num_features, num_timesteps)

    entity = DataEntity(Y=data, name="test_entity")
    anomaly_types = ["normal", "spike", "noise", "flip", "scale"]

    loader = AugmentedLoader(
        dataset=entity,
        batch_size=16,
        window_size=100,
        window_step=50,
        anomaly_types=anomaly_types,
    )

    # Check one-hot labels
    unique_labels = torch.argmax(loader.label, dim=1).unique()
    assert len(unique_labels) <= len(anomaly_types), "More labels than anomaly types"
    print(f"✓ Anomaly types test passed (created {len(anomaly_types)} types)")


def test_loader_aug_batch():
    """Test batch-based loader"""
    # Create batch data: (batch_size, n_features, n_time)
    batch_data = np.random.randn(4, 3, 256)
    batch_tensor = torch.Tensor(batch_data)

    loader = Loader_aug_batch(
        data=batch_tensor,
        batch_size=8,
        anomaly_types=["normal", "spike", "noise"],
    )

    assert loader.Y_windows.shape[0] > 0, "No windows created"
    print("✓ Batch loader test passed")


# Run all tests
if __name__ == "__main__":
    test_augmented_loader_basic()
    test_window_dimensions()
    test_iteration()
    test_anomaly_types()
    test_loader_aug_batch()
    print("\n✅ All tests passed!")

✓ Basic initialization test passed
✓ Window dimension test passed
  Batch 1: Y shape = torch.Size([16, 3, 100])
  Batch 2: Y shape = torch.Size([2, 3, 100])
✓ Iteration test passed (2 batches)
✓ Anomaly types test passed (created 5 types)
✓ Batch loader test passed

✅ All tests passed!
