# Data Augmentation Utilities for Sleep Signal Classification
This notebook contains various data augmentation techniques for time-series sleep signal data.

## Augmentation Techniques Implemented:
1. **Time Warping** - Non-linear time stretching/compression
2. **Magnitude Warping** - Smooth magnitude variations
3. **Window Slicing** - Random cropping and padding
4. **Jittering** - Gaussian noise injection
5. **Scaling** - Random amplitude scaling
6. **Time Shifting** - Circular temporal shifts
7. **Signal Rotation** - Inversion and flipping
8. **Mixup** - Linear interpolation between samples

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import CubicSpline
import random

%matplotlib inline

## 1. Time Warping
Applies smooth random time warping to the signal

In [None]:
def time_warp(x, sigma=0.2, knot=4):
    """
    Apply time warping to the signal
    
    Parameters:
    -----------
    x : array, shape (n_samples,) or (n_samples, n_features)
        Input signal
    sigma : float
        Standard deviation of warping magnitude
    knot : int
        Number of knots for cubic spline
    
    Returns:
    --------
    warped : array
        Time-warped signal
    """
    orig_steps = np.arange(x.shape[0])
    
    random_warps = np.random.normal(loc=1.0, scale=sigma, size=(knot+2,))
    warp_steps = (np.linspace(0, x.shape[0]-1, num=knot+2))
    
    ret = np.interp(orig_steps, warp_steps, random_warps)
    ret = ret / ret.sum() * x.shape[0]
    ret = np.cumsum(ret)
    
    if len(x.shape) == 1:
        return np.interp(orig_steps, ret, x)
    else:
        return np.array([np.interp(orig_steps, ret, x[:, i]) for i in range(x.shape[1])]).T

## 2. Magnitude Warping
Applies smooth random magnitude scaling

In [None]:
def magnitude_warp(x, sigma=0.2, knot=4):
    """
    Apply magnitude warping to the signal
    
    Parameters:
    -----------
    x : array
        Input signal
    sigma : float
        Standard deviation of magnitude variation
    knot : int
        Number of knots for cubic spline
    
    Returns:
    --------
    warped : array
        Magnitude-warped signal
    """
    orig_steps = np.arange(x.shape[0])
    
    random_warps = np.random.normal(loc=1.0, scale=sigma, size=(knot+2,))
    warp_steps = (np.linspace(0, x.shape[0]-1, num=knot+2))
    
    warper = np.interp(orig_steps, warp_steps, random_warps)
    
    if len(x.shape) == 1:
        return x * warper
    else:
        return x * warper[:, np.newaxis]

## 3. Jittering (Noise Injection)

In [None]:
def jitter(x, sigma=0.03):
    """
    Add random Gaussian noise to the signal
    
    Parameters:
    -----------
    x : array
        Input signal
    sigma : float
        Standard deviation of noise (relative to signal std)
    
    Returns:
    --------
    noisy : array
        Signal with added noise
    """
    noise = np.random.normal(loc=0., scale=sigma * np.std(x), size=x.shape)
    return x + noise

## 4. Scaling

In [None]:
def scaling(x, sigma=0.1):
    """
    Randomly scale the signal amplitude
    
    Parameters:
    -----------
    x : array
        Input signal
    sigma : float
        Standard deviation of scaling factor
    
    Returns:
    --------
    scaled : array
        Scaled signal
    """
    factor = np.random.normal(loc=1., scale=sigma)
    return x * factor

## 5. Time Shifting

In [None]:
def time_shift(x, shift_range=0.1):
    """
    Randomly shift the signal in time (circular)
    
    Parameters:
    -----------
    x : array
        Input signal
    shift_range : float
        Maximum shift as fraction of signal length
    
    Returns:
    --------
    shifted : array
        Time-shifted signal
    """
    shift = int(np.random.uniform(-shift_range, shift_range) * x.shape[0])
    return np.roll(x, shift, axis=0)

## 6. Window Slicing

In [None]:
def window_slice(x, reduce_ratio=0.9):
    """
    Randomly crop and resize the signal
    
    Parameters:
    -----------
    x : array
        Input signal
    reduce_ratio : float
        Target ratio of original length to keep
    
    Returns:
    --------
    sliced : array
        Cropped and resized signal
    """
    target_len = int(reduce_ratio * x.shape[0])
    if target_len >= x.shape[0]:
        return x
    
    start = np.random.randint(0, x.shape[0] - target_len)
    end = start + target_len
    
    sliced = x[start:end]
    
    # Resize back to original length
    indices = np.linspace(0, len(sliced)-1, x.shape[0])
    if len(x.shape) == 1:
        return np.interp(np.arange(x.shape[0]), np.arange(len(sliced)), sliced)
    else:
        return np.array([np.interp(np.arange(x.shape[0]), np.arange(len(sliced)), sliced[:, i]) 
                        for i in range(x.shape[1])]).T

## 7. Signal Rotation

In [None]:
def rotation(x):
    """
    Randomly flip/invert the signal
    
    Parameters:
    -----------
    x : array
        Input signal
    
    Returns:
    --------
    rotated : array
        Flipped signal (50% chance)
    """
    flip = np.random.choice([-1, 1])
    return flip * x

## 8. Mixup

In [None]:
def mixup(x1, x2, alpha=0.2):
    """
    Apply mixup augmentation between two samples
    
    Parameters:
    -----------
    x1, x2 : array
        Two input signals
    alpha : float
        Beta distribution parameter
    
    Returns:
    --------
    mixed : array
        Mixed signal
    lam : float
        Mixing coefficient (for label mixing)
    """
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    
    mixed = lam * x1 + (1 - lam) * x2
    return mixed, lam

## 9. Combined Augmentation Pipeline

In [None]:
def augment_signal(x, augmentation_list=['jitter', 'scaling', 'time_warp', 'magnitude_warp'], 
                   n_augmentations=2):
    """
    Apply random augmentations to a signal
    
    Parameters:
    -----------
    x : array
        Input signal
    augmentation_list : list of str
        List of augmentation techniques to choose from
    n_augmentations : int
        Number of augmentations to apply
    
    Returns:
    --------
    augmented : array
        Augmented signal
    """
    augmented = x.copy()
    
    # Randomly select augmentations
    selected = np.random.choice(augmentation_list, size=n_augmentations, replace=False)
    
    for aug in selected:
        if aug == 'jitter':
            augmented = jitter(augmented)
        elif aug == 'scaling':
            augmented = scaling(augmented)
        elif aug == 'time_warp':
            augmented = time_warp(augmented)
        elif aug == 'magnitude_warp':
            augmented = magnitude_warp(augmented)
        elif aug == 'time_shift':
            augmented = time_shift(augmented)
        elif aug == 'window_slice':
            augmented = window_slice(augmented)
        elif aug == 'rotation':
            augmented = rotation(augmented)
    
    return augmented

## 10. Batch Augmentation Generator

In [None]:
def augment_dataset(X, y, augmentation_factor=2, augmentation_methods=['jitter', 'scaling', 'time_warp']):
    """
    Augment entire dataset
    
    Parameters:
    -----------
    X : array, shape (n_samples, n_timesteps) or (n_samples, n_timesteps, n_features)
        Input signals
    y : array, shape (n_samples,)
        Labels
    augmentation_factor : int
        How many augmented copies to create per sample
    augmentation_methods : list
        Which augmentation methods to use
    
    Returns:
    --------
    X_aug : array
        Original + augmented data
    y_aug : array
        Corresponding labels
    """
    X_aug_list = [X]
    y_aug_list = [y]
    
    for i in range(augmentation_factor):
        X_new = np.array([augment_signal(x, augmentation_methods, n_augmentations=2) for x in X])
        X_aug_list.append(X_new)
        y_aug_list.append(y)
    
    X_aug = np.concatenate(X_aug_list, axis=0)
    y_aug = np.concatenate(y_aug_list, axis=0)
    
    # Shuffle
    indices = np.random.permutation(len(X_aug))
    X_aug = X_aug[indices]
    y_aug = y_aug[indices]
    
    print(f"Original dataset size: {len(X)}")
    print(f"Augmented dataset size: {len(X_aug)}")
    print(f"Augmentation factor: {len(X_aug) / len(X):.2f}x")
    
    return X_aug, y_aug

## 11. Visualization Functions

In [None]:
def visualize_augmentations(x, num_augmentations=5):
    """
    Visualize different augmentations applied to a single signal
    
    Parameters:
    -----------
    x : array
        Input signal to visualize
    num_augmentations : int
        Number of augmented versions to show
    """
    augmentation_types = ['jitter', 'scaling', 'time_warp', 'magnitude_warp', 
                          'time_shift', 'window_slice', 'rotation']
    
    fig, axes = plt.subplots(len(augmentation_types) + 1, 1, figsize=(15, 12))
    
    # Original signal
    axes[0].plot(x, 'b-', linewidth=1)
    axes[0].set_title('Original Signal', fontsize=12, fontweight='bold')
    axes[0].set_ylabel('Amplitude')
    axes[0].grid(True, alpha=0.3)
    
    # Augmented signals
    for idx, aug_type in enumerate(augmentation_types):
        x_aug = augment_signal(x, augmentation_list=[aug_type], n_augmentations=1)
        axes[idx + 1].plot(x_aug, 'r-', linewidth=1, alpha=0.7)
        axes[idx + 1].plot(x, 'b-', linewidth=0.5, alpha=0.3, label='Original')
        axes[idx + 1].set_title(f'{aug_type.replace("_", " ").title()} Augmentation', 
                                fontsize=12, fontweight='bold')
        axes[idx + 1].set_ylabel('Amplitude')
        axes[idx + 1].grid(True, alpha=0.3)
        axes[idx + 1].legend(loc='upper right')
    
    axes[-1].set_xlabel('Time Steps')
    plt.tight_layout()
    plt.show()

print("âœ… Data Augmentation utilities loaded successfully!")
print("\nAvailable functions:")
print("  - time_warp(x, sigma, knot)")
print("  - magnitude_warp(x, sigma, knot)")
print("  - jitter(x, sigma)")
print("  - scaling(x, sigma)")
print("  - time_shift(x, shift_range)")
print("  - window_slice(x, reduce_ratio)")
print("  - rotation(x)")
print("  - mixup(x1, x2, alpha)")
print("  - augment_signal(x, augmentation_list, n_augmentations)")
print("  - augment_dataset(X, y, augmentation_factor, methods)")
print("  - visualize_augmentations(x, num_augmentations)")

## Test the Augmentation Functions
Uncomment below to test with sample data

In [None]:
# # Create a sample signal
# t = np.linspace(0, 2, 1024)
# sample_signal = 50 * np.sin(2 * np.pi * 5 * t) + 30 * np.sin(2 * np.pi * 10 * t) + np.random.normal(0, 5, 1024)
# 
# # Visualize all augmentations
# visualize_augmentations(sample_signal)