# Quick introduction to the dataloading module

This section is intended to provide a brief introduction to the dataloading module and its main functionalities.

In short, all functions and custom classes are designed to help you creating an efficient Pytorch Dataloader to use during training. The main objective is to avoid loading the entire dataset locally  A typical pipeline is based on the following steps:

1) Call of the **GetEEGPartitionNumber** function to extract the dataset length, i.e. the number 

In [1]:
import sys
sys.path.append(os.getcwd().split('/Notebooks')[0])
from selfeeg import dataloading as dl


# Operative augmentations

## Shifts and Flips

In [None]:
def shift_vertical(x: "N-D Tensor of numpy Array", 
                   value: float):
    """
    shift_vertical add a scalar value to the input array x.
    
    Parameters
    ----------
    x: N-D Tensor or numpy array
    value: scalar
        The value to add
    """
    x_shift = x + value
    return x_shift
    

In [None]:
def flip_vertical(x: "N-D Tensor of numpy Array"):
    """
    flip_vertical change the sign of all the elements of the input array x.
    
    Paramters
    ---------
    x: N-D Tensor or numpy array
    """
    x_flip= x*(-1)
    return x_flip

In [None]:
def flip_horizontal(x: "N-D Tensor of numpy Array"):
    """
    flip_horizontal flip the elements of the last dimension of x.
    
    Paramters
    ---------
    x: N-D Tensor or numpy array
        Array to flip. Last dimension must have the EEG recordings
    """
    if isinstance(x, np.ndarray):
        x_flip = np.flip(x, len(x.shape)-1)
    else:
        x_flip = torch.flip(x, [len(x.shape)-1])
    
    return x_flip

## Noise adder

In [None]:
def add_noise(x: "N-D Tensor of numpy Array", 
              mean: float=0., 
              std: float=1.,
              get_noise: bool=False
             ):
    """
    add_noise add gaussian noise with the desired mean and standard deviation.
    
    Paramters
    ---------
    x: N-D Tensor or numpy array
        array to add noise
    mean: scalar, optional
        the mean of the gaussian distribution
        Default: 0
    std: scalar, optional
        the std of the gaussian distribution
        Default: 1
    get_noise: bool, optional
        whether to return the generated noise or not
        Default: False
    """
    
    if isinstance(x,np.ndarray):
        noise = mean + std * np.random.randn(*x.shape)
    else:
        noise   = mean +  std * torch.randn(*x.shape, device=x.device)
    x_noise = x + noise 
    if get_noise:
        return x_noise, noise
    else:
        return x_noise
    

In [None]:
def add_noise_SNR(x: "N-D Tensor of numpy Array", 
                  target_snr: float=5, 
                  get_noise: bool=False
                 ):
    """
    add_noise_SNR add noise such that the SNR (Signal to Noise Ratio) will be the one desired.
    
    Since the signal is supposed to be already noisy, it makes more sense to say that this function reduce 
    the SNR by a factor equal to 1/P_noise_new, where P_noise_new is the power of the new added noise.
    
    Parameters
    ----------
    x: N-D Tensor or numpy array
        array to add noise
    target_SNR: scalar, optional
        the target SNR
        Default: 5
    get_noise: bool, optional
        whether to return the generated noise or not
        Default: False
        
    
    created using the following reference: 
        https://stackoverflow.com/questions/14058340/adding-noise-to-a-signal-in-python
    """
    
    # get signal power. Not exactly true since we have an already noised signal
    x_pow = (x ** 2)
    
    if isinstance(x,np.ndarray):
        x_db = 10 * np.log10(x_pow)
        x_pow_avg = np.mean(x_pow)
        x_db_avg = 10 * np.log10(x_pow_avg) 
        noise_db_avg = x_db_avg - target_snr
        noise_pow_avg = 10 ** (noise_db_avg / 10)
        noise = np.random.normal(0, noise_pow_avg**0.5 , size=x.shape) 
        x_noise = x + noise
    
    else:
        x_db = 10 * torch.log10(x_pow)
        x_pow_avg = torch.mean(x_pow)
        x_db_avg = 10 * torch.log10(x_pow_avg)
        noise_db_avg = x_db_avg - target_snr
        noise_pow_avg = 10 ** (noise_db_avg / 10)
        noise = ((noise_pow_avg**0.5)*torch.randn(*x.shape))#.to(device=x.device) 
        x_noise = x + noise        
    
    if get_noise:
        return x_noise, noise
    else:
        return x_noise

In [None]:
def add_band_noise(x: "numpy array or tensor",
                   bandwidth: list[tuple[float,float], str, float], 
                   samplerate: float=256, 
                   std: float=0.5,
                   get_noise: bool=False
                  ):
    
    """
    add_band noise add random noise filtered at specific bandwidths.
    
    Given a set of bandwidths or specific a set of specific frequency, add_band_noise create a noise whose 
    spectrum is bigger than zero only on those bands. It can be used to alter only specific frequency components
    of the original signal
    
    Parameters
    ----------
    x: N-D Tensor or numpy array
        array to add noise
    bandwidth: list
        The frequency components which the noise must have. Must be a LIST with the following values:
        strings: add noise to specific EEG components. Can be any of "delta", "theta", "alpha", "beta", 
                 "gamma", "gamma_low", "gamma_high"
        scalar: add noise to a specifi component
        tuple with 2 scalar: add noise to a specific band set with the tuple (start_component, end_component)
    samplerate : float, optional
        The sampling rate, given in Hz. Remember to change this value according to the signal sampling rate
        Default: 256
    std: float, optional
        The desired standard deviation of the noise. Use to change the magnitude of the noise
        Default: 0.5
    get_noise: bool, optional
        whether to return the generated noise or not
        Default: False
        
    """
    
    # converting to list if single string or integer is given
    if not(isinstance(bandwidth, list)):
        bandwidth=[bandwidth]
    
    # transform all elements in 2 floats tuple
    for i in range(len(bandwidth)):
        # change string bandwidth call to frequency slice
        if isinstance(bandwidth[i],str):
            if bandwidth[i].lower() == 'delta':
                bandwidth[i]=(0.5,4)
            elif bandwidth[i].lower() == 'theta':
                bandwidth[i]=(4,8)
            elif bandwidth[i].lower() == 'alpha':
                bandwidth[i]=(8,13)
            elif bandwidth[i].lower() == 'beta':
                bandwidth[i]=(13,30)
            elif bandwidth[i].lower() == 'gamma_low':
                bandwidth[i]=(30,70)
            elif bandwidth[i].lower() == 'gamma_high':
                bandwidth[i]=(70,150)
            elif bandwidth[i].lower() == 'gamma':
                bandwidth[i]=(30,150)
            else:
                message  = 'Brainwave \"',bandwidth[i], '\" not exist. \n'
                message += 'Choose between delta, theta, alpha, beta, gamma, gamma_low, gamma_high'
                raise ValueError(message)
        
        # change single frequency call
        elif np.isscalar(bandwidth[i]):                
            bandwidth[i]=(bandwidth[i],bandwidth[i])

    N=len(bandwidth)          
    samples = x.shape[-1]
    if isinstance(x,np.ndarray):
        
        f = np.zeros(samples, dtype='complex')
        for i in range(N):
            start = int(bandwidth[i][0]*samples/samplerate)
            end = int(bandwidth[i][1]*samples/samplerate +1)
            f[start:end] = 1
        Np = (len(f) - 1) // 2
        phases = np.random.rand(Np) * 2 * np.pi
        phases = np.cos(phases) + 1j * np.sin(phases)
        f[1:Np+1] *= phases
        f[-1:-1-Np:-1] = np.conj(f[1:Np+1])
        noise = np.fft.ifft(f).real   
        G = std/np.std(noise)
        noise *= G
        x_noise= x+noise
               
    else:
        f = torch.zeros(samples, dtype=torch.complex64, device=x.device)
        for i in range(N):
            start = int(bandwidth[i][0]*samples/samplerate)
            end = int(bandwidth[i][1]*samples/samplerate +1)
            f[start:end] = 1
        Np = (samples - 1) // 2
        phases = torch.rand(Np, device=x.device)
        phases =  phases * 2 * math.pi
        phases = torch.cos(phases) + 1j * torch.sin(phases)
        f[1:Np+1] *= phases
        f[-Np:] = torch.flip(torch.conj(f[1:Np+1]), [0])
        noise = torch.fft.ifft(f).real   
        G = std/torch.std(noise)
        noise *= G
        x_noise= x+noise

    
    if get_noise:
        return x_noise , noise
    else:
        return x_noise

## Filtering

In [None]:
def moving_avg(x, order: int=5, pad_mode: str='same'):
    """
    moving_avg apply a moving average filter to the signal x.
    
    moving_avg apply a moving average filter to the last dimension of the array or Tensor x. The filter order
    and padding strategy can be given as function argument.
    
    Parameters
    ----------
    x: N-D Tensor or numpy array
        The element to filter. Signals must be on the last dimension.
    order: int, optional
        The order of the filter.
        Default: 5
    pad_mode: str or int or tuple of int
        The padding strategy. 
    """
    
    
    
    if isinstance(x, np.ndarray):
        x_avg = np.empty_like(x)
        filt = np.ones(order)/order
        Ndim= len(x.shape)
        
        # call recursively to handle different dimensions (made to handle problem with torch conv2d)
        if Ndim>1:
            for i in range(x.shape[0]):
                x_avg[i] = moving_avg(x[i], order=order, pad_mode=pad_mode)
        else:
            x_avg = np.convolve( x, filt, pad_mode)
            
    else:
        Ndim = len(x.shape)
        # adapt to x to conv2d functions
        if Ndim==1:
            x = x.view(1,1,1,*x.shape)
        elif Ndim==2:
            x = x.view(1,1, *x.shape)
        elif Ndim==3:
            x = x.unsqueeze(1)
        x_avg = torch.empty_like(x)
        filt = torch.ones((1,1,1,order), device=x.device)/order
        
        # call recursively if the dimension is larger than 4
        if Ndim > 4:
            for i in range(x.shape[0]):
                x_avg[i] = moving_avg(x[i], order=order, pad_mode=pad_mode)
        else:
            x_avg = F.conv2d(x, filt, padding= pad_mode)
            x_avg = torch.reshape(x_avg, x.shape)

    
    return x_avg

In [None]:
def get_filter_coeff(Wp: float, 
                     Ws: float,
                     rp: float=-20*np.log10(.95), 
                     rs: float=-20*np.log10(.15), 
                     btype: str='low', 
                     filter_type: str='butter', 
                     order: int=None, 
                     Wn: Union[float,List[float]]=None, 
                     eeg_band: str=None, 
                     Fs: float=None
                    ):
    """
    get_filter_coeff returns the filter coefficients a and b needed to call the scipy's or torchaudio's 
    filtfilt function.
    
    get_filter_coeff is internally called by other filtering function when a and b coefficients are not given
    as input argument. It works following this priority pipeline:
    1) if specific EEG bands are given, set Wp, Ws, rp, rs for filter design according to the given band
    2) if order and Wn are not given, use previous parameter to design the filter
    3) Use Wn and order to get a and b coefficient to return
    
    In other words (Wp,Ws,rp,rs) ----> (Wn, order) -----> (a,b) 
    
    Parameters
    ----------
    
    Wp: float
        bandpass normalized from 0 to 1
    Ws: float
        stopband normalized from 0 to 1
    rp: float, optional
        ripple at bandpass in decibel. 
        Default: -20*log10(0.95)
    rs: float, optional
        ripple at stopband in decibel. 
        Default: -20*log10(0.15)
    btype: str, optional
        filter type. Can be any of the scipy's btype argument (e.g. 'lowpass', 'highpass', 'bandpass')
        Default: 'low'
    filter_type: str, optional
        which filter design. Accepted values are 'butter', 'ellip', 'cheby1', 'cheby2'
        Default: 'butter'
    order: int, optional
        the order of the filter
        Default: None
    Wn: array_like, optional
        the critical frequency or frequencies.
        Default: None
    eeg_band: str, optional
        any of the possible EEG bands. Accepted values are "delta", "theta", "alpha", "beta", 
        "gamma", "gamma_low", "gamma_high".
        Default: None
    Fs: float, optional
        the sampling frequency. Must be given if eeg_band is also given
        Default: None
    """
    
    if btype.lower() == 'bandpass':
        if eeg_band is not None:
            if eeg_band.lower() == 'delta':
                Wp, Ws, rp, rs, btype = 4, 8, -20*np.log10(.95), -20*np.log10(.1), 'lowpass'
            elif eeg_band.lower() == 'theta':
                Wp, Ws, rp, rs = [4, 8], [0, 15], -20*np.log10(.95), -20*np.log10(.1)
            elif eeg_band.lower() == 'alpha':
                Wp, Ws, rp, rs = [8, 13], [4, 22], -20*np.log10(.95), -20*np.log10(.1)
            elif eeg_band.lower() == 'beta':
                Wp, Ws, rp, rs = [13, 30], [8, 40], -20*np.log10(.95), -20*np.log10(.15)
            elif eeg_band.lower() == 'gamma_low':
                Wp, Ws, rp, rs = [30, 70], [22, 78], -20*np.log10(.95), -20*np.log10(.1)
            elif eeg_band.lower() == 'gamma_high':
                if Fs>=158*2:
                    Wp, Ws, rp, rs = [70, 150], [62, 158], -20*np.log10(.95), -20*np.log10(.1)
                else:
                    Wp, Ws, rp, rs, btype = 70, 62, -20*np.log10(.95), -20*np.log10(.1), 'highpass'
            elif eeg_band.lower() == 'gamma':
                if Fs>=158*2:
                    Wp, Ws, rp, rs = [30, 150], [22, 158], -20*np.log10(.95), -20*np.log10(.1)
                else:
                    Wp, Ws, rp, rs, btype = 30, 22, -20*np.log10(.95), -20*np.log10(.1), 'highpass'
            else:
                message  = 'Brainwave \"',bandwidth[i], '\" not exist. \n'
                message += 'Choose between delta, theta, alpha, beta, gamma, gamma_low, gamma_high'
                raise ValueError(message)
            Wp, Ws = np.array(Wp)/(Fs/2), np.array(Ws)/(Fs/2)
    
    if (order is None) or (Wn is None):
        if filter_type.lower()=='butter':
            order, Wn = signal.buttord(Wp, Ws, rp, rs)
        elif filter_type.lower()=='ellip':
            order, Wn = signal.ellipord(Wp, Ws, rp, rs)
        elif filter_type.lower()=='cheby1':
            order, Wn = signal.cheb1ord(Wp, Ws, rp, rs)
        elif filter_type.lower()=='cheby2':
            order, Wn = signal.cheb2ord(Wp, Ws, rp, rs)
    
    if filter_type.lower()=='butter':
        b, a = signal.butter(order, Wn, btype)
    elif filter_type.lower()=='ellip':
        b, a = signal.ellip(order, rp, rs, Wn, btype)
    elif filter_type.lower()=='cheby1':
        b, a = signal.cheby1(order,rp, Wn, btype)
    elif filter_type.lower()=='cheby2':
        b, a = signal.cheby2(order, rs, Wn, btype)
    
    return b, a

In [None]:
def filter_lowpass(x: "array or tensor", 
                   Wp: float=50,
                   Ws: float=70,
                   rp: float=-20*np.log10(.95), 
                   rs: float=-20*np.log10(.15),
                   filter_type: str='butter',
                   order: int=None, 
                   Wn: float=None,
                   a: Union[np.ndarray,float]=None,
                   b: Union[np.ndarray,float]=None,
                   return_filter_coeff: bool=False
                  ):
    """
    filter_lowpass apply a lowpass filter on the last dimension of the given input x.
    
    filter_lowpass apply a designed lowpass filter on the last dimension of x. If a and b coefficient are not 
    given, calls get_filter_coeff with the other arguments to get them. The filter dedign follow this order:
                            (Wp,Ws,rp,rs) ----> (Wn, order) -----> (a,b). 
    Therefore the arguments closer to a and b in the scheme are used to get the filter coefficient.
    
    Parameters
    ----------
    x: N-D array or Tensor
        The element to filter
    Wp: float, optional
        bandpass in Hz
        Default: 50
    Ws: float, optional
        stopband in Hz
        Default: 70
    rp: float, optional
        ripple at bandpass in decibel. 
        Default: -20*log10(0.95)
    rs: float, optional
        ripple at stopband in decibel. 
        Default: -20*log10(0.15)
    filter_type: str, optional
        which filter design. Accepted values are 'butter', 'ellip', 'cheby1', 'cheby2'
        Default: 'butter'
    order: int, optional
        the order of the filter
        Default: None
    Wn: array_like, optional
        the critical frequency or frequencies.
        Default: None
    a: array_like, optional
        the denominator coefficient of the filter
        Default: None
    b: array_like, optional
        the numerator coefficient of the filer
        Default: None
    return_filter_coeff: bool, optional
        whether to return the filter coefficient or not
        Default: False
        
    NOTE: pytorch filtfilt works differently on edges and is pretty unstable with high order filters, so avoid 
    restrictive condition which can increase the order of the filter.
    """
    
    
    if filter_type not in ['butter', 'ellip', 'cheby1', 'cheby2']:
        raise ValueError('filter type not supported. Choose between butter, elliptic, cheby1, cheby2')
    
    if (a is None) or (b is None):
        b, a = get_filter_coeff(Wp = Wp, Ws = Ws, rp = rp, rs = rs, btype = 'lowpass', 
                                filter_type = filter_type, order = order, Wn = Wn,eeg_band = None, Fs = None 
                               )
         
    if isinstance(x, np.ndarray):
        x_filt = signal.filtfilt(b, a, x, padtype='constant' )  
    else:
        a= torch.from_numpy(a).to(dtype=x.dtype, device=x.device)
        b= torch.from_numpy(b).to(dtype=x.dtype, device=x.device)
        x_filt = filtfilt(x, a, b, clamp=False)   
    
    if return_filter_coeff:
        return x_filt, b, a
    else:
        return x_filt

In [None]:
def filter_highpass(x: "array or tensor", 
                    Wp: float=30,
                    Ws: float=13,
                    rp: float=-20*np.log10(.95), 
                    rs: float=-20*np.log10(.15),
                    filter_type: str='butter',
                    order: int=None, 
                    Wn: float=None,
                    a: Union[np.ndarray,float]=None,
                    b: Union[np.ndarray,float]=None,
                    return_filter_coeff: bool=False
                   ):
    
    """
    filter_highpass apply a highpass filter on the last dimension of the given input x.
    
    filter_highpass apply a designed highpass filter on the last dimension of x. If a and b coefficient are not 
    given, calls get_filter_coeff with the other arguments to get them. The filter dedign follow this order:
                            (Wp,Ws,rp,rs) ----> (Wn, order) -----> (a,b). 
    Therefore the arguments closer to a and b in the scheme are used to get the filter coefficient.
    
    Parameters
    ----------
    x: N-D array or Tensor
        The element to filter
    Wp: float, optional
        bandpass in Hz
        Default: 30
    Ws: float, optional
        stopband in Hz
        Default: 13
    rp: float, optional
        ripple at bandpass in decibel. 
        Default: -20*log10(0.95)
    rs: float, optional
        ripple at stopband in decibel. 
        Default: -20*log10(0.15)
    filter_type: str, optional
        which filter design. Accepted values are 'butter', 'ellip', 'cheby1', 'cheby2'
        Default: 'butter'
    order: int, optional
        the order of the filter
        Default: None
    Wn: array_like, optional
        the critical frequency or frequencies.
        Default: None
    a: array_like, optional
        the denominator coefficient of the filter
        Default: None
    b: array_like, optional
        the numerator coefficient of the filer
        Default: None
    return_filter_coeff: bool, optional
        whether to return the filter coefficient or not
        Default: False
        
    NOTE: pytorch filtfilt works differently on edges and is pretty unstable with high order filters, so avoid 
    restrictive condition which can increase the order of the filter.
    """
    
    if filter_type not in ['butter', 'ellip', 'cheby1', 'cheby2']:
        raise ValueError('filter type not supported. Choose between butter, elliptic, cheby1, cheby2')
    
    if (a is None) or (b is None):
        b, a = get_filter_coeff(Wp = Wp, Ws = Ws, rp = rp, rs = rs, btype = 'highpass', 
                                filter_type = filter_type, order = order, Wn = Wn,eeg_band = None, Fs = None 
                               )
         
    if isinstance(x, np.ndarray):
        x_filt = signal.filtfilt(b, a, x, padtype='constant' )  
    else:
        a= torch.from_numpy(a).to(dtype=x.dtype, device=x.device)
        b= torch.from_numpy(b).to(dtype=x.dtype, device=x.device)
        x_filt = filtfilt(x, a, b, clamp=False)   
    
    if return_filter_coeff:
        return x_filt, b, a
    else:
        return x_filt

In [None]:
def filter_bandpass(x: "array or tensor", 
                    Wp: list[float]=None,
                    Ws: list[float]=None,
                    rp: float=-20*np.log10(.95), 
                    rs: float=-20*np.log10(.05),
                    filter_type: str='butter',
                    order: int=None, 
                    Wn: float=None,
                    a: Union[np.ndarray,float]=None,
                    b: Union[np.ndarray,float]=None,
                    eeg_band: str=None,
                    Fs: float=None,
                    return_filter_coeff: bool=False
                   ):
    """
    filter_bandpass apply a bandpass filter on the last dimension of the given input x.
    
    filter_bandpass apply a designed bandpass filter on the last dimension of x. If a and b coefficient are not 
    given, calls get_filter_coeff with the other arguments to get them. The filter dedign follow this order:
                            (Wp,Ws,rp,rs) ----> (Wn, order) -----> (a,b). 
    Therefore the arguments closer to a and b in the scheme are used to get the filter coefficient.
    
    Parameters
    ----------
    x: N-D array or Tensor
        The element to filter
    Wp: float, optional
        bandpass in Hz.
        Default: 30
    Ws: float, optional
        stopband in Hz
        Default: 13
    rp: float, optional
        ripple at bandpass in decibel. 
        Default: -20*log10(0.95)
    rs: float, optional
        ripple at stopband in decibel. 
        Default: -20*log10(0.15)
    filter_type: str, optional
        which filter design. Accepted values are 'butter', 'ellip', 'cheby1', 'cheby2'
        Default: 'butter'
    order: int, optional
        the order of the filter
        Default: None
    Wn: array_like, optional
        the critical frequency or frequencies.
        Default: None
    a: array_like, optional
        the denominator coefficient of the filter
        Default: None
    b: array_like, optional
        the numerator coefficient of the filer
        Default: None
    eeg_band: str, optional
        any of the possible EEG bands. Accepted values are "delta", "theta", "alpha", "beta", 
        "gamma", "gamma_low", "gamma_high".
        Default: None
    Fs: float, optional
        the sampling frequency. Must be given if eeg_band is also given
        Default: None
    return_filter_coeff: bool, optional
        whether to return the filter coefficient or not
        Default: False
        
    NOTE: pytorch filtfilt works differently on edges and is pretty unstable with high order filters, so avoid 
    restrictive condition which can increase the order of the filter.
    """
    
    if filter_type not in ['butter', 'ellip', 'cheby1', 'cheby2']:
        raise ValueError('filter type not supported. Choose between butter, elliptic, cheby1, cheby2')
    
    if (a is None) or (b is None):
        b, a = get_filter_coeff(Wp = Wp, Ws = Ws, rp = rp, rs = rs, btype = 'bandpass', 
                                filter_type = filter_type, order = order, Wn = Wn,eeg_band = eeg_band, Fs = Fs 
                               )
         
    if isinstance(x, np.ndarray):
        x_filt = signal.filtfilt(b, a, x, padtype='constant' )  
    else:
        a= torch.from_numpy(a).to(dtype=x.dtype, device=x.device)
        b= torch.from_numpy(b).to(dtype=x.dtype, device=x.device)
        x_filt = filtfilt(x, a, b, clamp=False) 
    
    if return_filter_coeff:
        return x_filt, b, a
    else:
        return x_filt

## Permutation

In [None]:
def get_eeg_channel_network_names():
    
    DMN= np.array(['AF4', 'AF7', 'AF8', 'AFZ', 'CP3', 'CP4', 'CP5', 'F1', 'F2', 'F3', 'F4', 'F5', 'F6',
                   'F7', 'F8', 'FC1', 'FC3', 'FC4', 'FC5', 'FP1', 'FP2', 'FPZ', 'FT10', 'FT8', 'FT9',
                   'FZ', 'P3', 'P4', 'P5', 'T7', 'T8', 'TP7', 'TP8'], dtype='<U4')
    DAN= np.array(['C5', 'C6', 'CP1', 'CP2', 'CPZ', 'FC1', 'FC2', 'FC5', 'P1', 'P2',
                   'P7', 'P8', 'PO3', 'PO4', 'PO7', 'PO8', 'POZ', 'PZ', 'T7', 'TP8'], dtype='<U4')
    VAN= np.array(['AF3', 'AF4', 'AF8', 'C5', 'C6', 'CP1', 'CP2', 'CP4', 'CP5', 'CP6',
                   'CPZ', 'F7', 'F8', 'FC1', 'FC2', 'FC5', 'FC6', 'FT7', 'P1', 'P2',
                   'P7', 'P8', 'PO3', 'PO4', 'PO7', 'PO8', 'POZ', 'PZ', 'T7', 'TP8'], dtype='<U4')
    SMN= np.array(['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'CP1', 'CP2', 'CP5', 'CPZ',
                   'CZ', 'FC5', 'FC6', 'FCZ', 'FT8', 'FTZ', 'P3', 'P5','P6', 'P7', 'P8',
                   'PO4', 'PO7', 'PO8', 'T7', 'T8', 'TP7'], dtype='<U4')
    VFN= np.array(['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'CP1', 'CP2', 'CPZ', 'FC1',
                   'FC5', 'FC6', 'FT8', 'FTZ', 'O1', 'O2', 'OZ', 'P7', 'P8', 'PO3',
                   'PO4', 'PO7', 'PO8', 'POZ', 'PZ', 'T7', 'TP8'], dtype='<U4') 
    FPN= np.array(['AF3', 'AF4', 'AF7', 'AF8', 'AFZ', 'C6', 'CP3', 'CP4', 'CP5',
                   'CP6', 'F1', 'F2', 'F3', 'F4', 'F5', 'F6', 'F7', 'F8', 'FC1',
                   'FC2', 'FC3', 'FC4', 'FC5', 'FC6', 'FP1', 'FP2', 'FPZ', 'FT10',
                   'FT7', 'FT9', 'FZ', 'P3', 'P4', 'P5', 'T7', 'T8', 'TP7', 'TP8'], dtype='<U4')
    print('Default Mode Network - DMN')
    print(DMN)
    print('')
    print('Dorsal Attention Network - DAN')
    print(DAN)
    print('')
    print('Ventral Attention Network - VAN')
    print(VAN)
    print('')
    print('SomatoMotor functional Network - SMN')
    print(SMN)
    print('')
    print('Visual Functional Network - VFN')
    print(VFN)
    print('')
    print('FrontoParietal Network - FPN')
    print(FPN)
    print('')

In [None]:
def get_channel_map_and_networks(channel_map: list=None,
                                 chan_net: list[str]='all',
                                ):
    """
    get_channel_map_and_networks simply return the channel_map and chan_net argument for permute_channels.
    Run help(permute_channels) to get more informations.
    """
    
    
    if channel_map is None:
        channel_map = np.array(['FP1', 'AF7', 'AF3', 'F1', 'F3', 'F5', 'F7', 'FT7', 'FC5',
                                'FC3', 'FC1', 'C1', 'C3', 'C5', 'T7', 'TP7', 'CP5', 'CP3',
                                'CP1', 'P1', 'P3', 'P5', 'P7', 'PO7', 'PO3', 'O1', 'OZ',
                                'POZ', 'PZ', 'CPZ', 'FPZ', 'FP2', 'AF8', 'AF4', 'AFZ', 'FZ',
                                'F2', 'F4', 'F6', 'F8', 'FT8', 'FC6', 'FC4', 'FC2', 'FCZ',
                                'CZ', 'C2', 'C4', 'C6', 'T8', 'TP8', 'CP6', 'CP4', 'CP2',
                                'P2', 'P4', 'P6', 'P8', 'PO8', 'PO4', 'O2'], dtype='<U4')
    elif isinstance(channel_map, list):
        channel_map = np.array(channel_map, dtype='<U4')
    
    # define networks (according to rojas et al. 2018)
    DMN= np.array(['AF4', 'AF7', 'AF8', 'AFZ', 'CP3', 'CP4', 'CP5', 'F1', 'F2', 'F3', 'F4', 'F5', 'F6',
                   'F7', 'F8', 'FC1', 'FC3', 'FC4', 'FC5', 'FP1', 'FP2', 'FPZ', 'FT10', 'FT8', 'FT9',
                   'FZ', 'P3', 'P4', 'P5', 'T7', 'T8', 'TP7', 'TP8'], dtype='<U4')
    DAN= np.array(['C5', 'C6', 'CP1', 'CP2', 'CPZ', 'FC1', 'FC2', 'FC5', 'P1', 'P2',
                   'P7', 'P8', 'PO3', 'PO4', 'PO7', 'PO8', 'POZ', 'PZ', 'T7', 'TP8'], dtype='<U4')
    VAN= np.array(['AF3', 'AF4', 'AF8', 'C5', 'C6', 'CP1', 'CP2', 'CP4', 'CP5', 'CP6',
                   'CPZ', 'F7', 'F8', 'FC1', 'FC2', 'FC5', 'FC6', 'FT7', 'P1', 'P2',
                   'P7', 'P8', 'PO3', 'PO4', 'PO7', 'PO8', 'POZ', 'PZ', 'T7', 'TP8'], dtype='<U4')
    SMN= np.array(['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'CP1', 'CP2', 'CP5', 'CPZ',
                   'CZ', 'FC5', 'FC6', 'FCZ', 'FT8', 'FTZ', 'P3', 'P5','P6', 'P7', 'P8',
                   'PO4', 'PO7', 'PO8', 'T7', 'T8', 'TP7'], dtype='<U4')
    VFN= np.array(['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'CP1', 'CP2', 'CPZ', 'FC1',
                   'FC5', 'FC6', 'FT8', 'FTZ', 'O1', 'O2', 'OZ', 'P7', 'P8', 'PO3',
                   'PO4', 'PO7', 'PO8', 'POZ', 'PZ', 'T7', 'TP8'], dtype='<U4') 
    FPN= np.array(['AF3', 'AF4', 'AF7', 'AF8', 'AFZ', 'C6', 'CP3', 'CP4', 'CP5',
                   'CP6', 'F1', 'F2', 'F3', 'F4', 'F5', 'F6', 'F7', 'F8', 'FC1',
                   'FC2', 'FC3', 'FC4', 'FC5', 'FC6', 'FP1', 'FP2', 'FPZ', 'FT10',
                   'FT7', 'FT9', 'FZ', 'P3', 'P4', 'P5', 'T7', 'T8', 'TP7', 'TP8'], dtype='<U4')
    networks =[DMN, DAN, VAN, SMN, VFN, FPN]

    if isinstance(chan_net, str):
        chan_net = [chan_net]

    net_idx=[]
    for i in range(len(chan_net)):
        if chan_net[i].lower() == 'all':
            net_idx = [0,1,2,3,4,5]
            break
        elif chan_net[i].lower() == 'dmn':
            net_idx.append(0)
        elif chan_net[i].lower() == 'dan':
            net_idx.append(1)
        elif chan_net[i].lower() == 'van':
            net_idx.append(2)
        elif chan_net[i].lower() == 'smn':
            net_idx.append(3) 
        elif chan_net[i].lower() == 'vfn':
            net_idx.append(4)
        elif chan_net[i].lower() == 'fpn':
            net_idx.append(5)
        else:
            raise ValueError('brain network not supported. Can be any of DMN, DAN, VAN, SMN, VFN, FPN')

    for index in sorted( set([0,1,2,3,4,5])-set(net_idx) , reverse=True):
        networks.pop(index)
    random.shuffle(networks)
    
    return channel_map, networks
    

In [None]:
def permute_channels(x, 
                     chan2shuf: int=-1,
                     mode: str="random",
                     channel_map: list=None,
                     chan_net: list[str]='all',
                     batch_equal: bool=False
                    ):
    
    """
    permutation_channels permute the input tensor EEG signals x along the channel dimension (second to last).
    
    Given an input x where the last two dimension must be (EEG_channels x EEG_samples), permutation_channels 
    shuffles all or a subset of the eeg along its channels. Shuffles can be done randomly or using specific
    networks (based on resting state functional connectivity networks).
    If batch_equal is set to False, call the function recursively along each of the (N_{1}*N_{2}...*N{-3}) 
    dimensions of the tensor.
    
    Parameter
    ---------
    
    x: N-D Tensor or numpy array
        The element to shuffle. The last two dimensions must be (EEG_channel x EEG_samples), which means that the 
        permutation is applied on the second to last dimension. 
    chan2shuf: int, optional
        The number of channels to shuffle. Must be greater than 1. -1 is the only accepted negative number and 
        means permute all the segments.
        Default: -1
    mode: str, optional
        How to permute the channels. Can be any of:
            'random': shuffle channels at random
            'network': shuffle channels which belongs to the same network. A network is a subset of channels whose
                       activity is (with a minumum degree) between each other. This mode support only a subset of
                       61 channels of the 10-10 system
        Default: 'random'
    channel_map: list of str, optional
        The channel map of EEG acquisitions. Must be a list of string or a numpy array of dtype='<U4' with channel 
        names as elements. Channel name must be defined with capital letters (e.g. 'P04', 'FC5').
        Default: np.array(['FP1', 'AF7', 'AF3', 'F1', 'F3', 'F5', 'F7', 'FT7', 'FC5', 'FC3', 'FC1', 'C1', 'C3', 
                            'C5', 'T7', 'TP7', 'CP5', 'CP3','CP1', 'P1', 'P3', 'P5', 'P7', 'PO7', 'PO3', 'O1', 
                            'OZ', 'POZ', 'PZ', 'CPZ', 'FPZ', 'FP2', 'AF8', 'AF4', 'AFZ', 'FZ', 'F2', 'F4', 'F6', 
                            'F8', 'FT8', 'FC6', 'FC4', 'FC2', 'FCZ', 'CZ', 'C2', 'C4', 'C6', 'T8', 'TP8', 'CP6', 
                            'CP4', 'CP2', 'P2', 'P4', 'P6', 'P8', 'PO8', 'PO4'],dtype='<U4')
    chan_net: str or list of str, optional
        The list of networks to use if network mode is selected. Must be a list of string or a single string.
        Supported networks are DMN, DAN, VAN, SMN, VFN, FPN. Use 'all' to select all networks. To get a list of
        the channel names per network use get_eeg_network_channel_names()
        Default: 'all'
    batch_equal: bool, optional
        whether to apply the same permutation to all EEG record or not. If True, permute_signal is called 
        recursively for each dimension of the batch until the last two are reached (e.g. given a tensor x
        of dimension (16,8,64,512) the function permute each of the 16*8 EEG signals (64 channels of 512 samples)
        individually).
        Default: False
    
    """
    
    
    Nchan=x.shape[-2]
    Ndim= len(x.shape)
    
    # Check if given input is ok 
    if (chan2shuf==1) or (chan2shuf==0) or (chan2shuf>Nchan):
        msgLog='chan2shuf must be bigger than 1 and smaller than the number of channels in the recorded EEG. \n '
        msgLog += 'Default value is -1, which means all EEG channels are shuffled'
        raise ValueError(msgLog)
    if Ndim==1:
        raise ValueError('x must be an array or tensor with the last two dimensions [channel]*[time window]')
    
    chan2shuf = x.shape[-2] if chan2shuf==-1 else chan2shuf 
    x2= np.empty_like(x) if isinstance(x, np.ndarray) else torch.empty_like(x)
    if (Ndim<3) or (batch_equal):
        
        if mode.lower()=='network':
            # Define or check channel map and channel networks
            channel_map , networks = get_channel_map_and_networks(channel_map, chan_net)
            if channel_map.shape[0] != x.shape[-2]:
                raise ValueError('channel map does not match the number of channels in eeg recording')

            # randomly select a number of channels equals to chan2shuf
            idxor_full = np.random.permutation(np.arange(Nchan, dtype=int))[:chan2shuf]
            idxor_full = np.sort(idxor_full)
            idx_full=np.full(chan2shuf, -1, dtype=int)

            # shuffle according to the selected networks
            for k in range(len(networks)):
                idxor = np.where(np.in1d(channel_map[idxor_full], networks[k]))[0] #identify chans idx
                idxor = idxor[np.where(idx_full[idxor]==-1)[0]] # keep only non shuffled channels
                idx = (idxor_full[idxor]) #get chans idx
                if idx.shape[0]>1:
                    while len(np.where(idx==idxor_full[idxor])[0])>0:
                        np.random.shuffle(idx)
                idx_full[idxor]=idx

            # final results
            idxor = idxor_full
            idx = idx_full   
            if not(isinstance(x,np.ndarray)):
                idxor = torch.from_numpy(idxor_full).to(device=x.device)
                idx = torch.from_numpy(idx_full).to(device=x.device)

        
        # random mode shuffle channels at random
        elif mode.lower()=='random':
            if isinstance(x, np.ndarray):
                idx = np.arange(Nchan, dtype=int)
                np.random.shuffle(idx)
                idx = idx[:chan2shuf]
                idxor = np.sort(idx)
                if len(idx)>1:
                    while len(np.where(idx==idxor)[0])>0:
                        np.random.shuffle(idx)
            else:
                idx = torch.randperm(Nchan, device=x.device)
                idx = idx[:chan2shuf]
                idxor, _ = torch.sort(idx)
                if len(idx)>1:
                    while torch.sum(torch.eq(idx,idxor))!=0:
                        idx = idx[torch.randperm(idx.shape[0], device=x.device)]

        # apply defined shuffle
        xtemp = x[..., idx, :]
        x2[...,idxor,:] = xtemp
    
    else:
        # call recursively for each dimension until last 2 are reached
        for i in range(x.shape[0]):
            x2[i] = permute_channels(x[i], chan2shuf= chan2shuf, mode=mode, 
                                     channel_map=channel_map, chan_net=chan_net)
               
    return x2

In [None]:
def permutation_signal(x, 
                       segments: int=10, 
                       seg_to_per: int=-1,
                       batch_equal: bool=False
                      ):
    """
    permutation_signal permute some portion of the last dimension of the input N-D array_like x
    
    Given an input x where the last two dimension must be (EEG_channels x EEG_samples), permutation_signal 
    divides the elements of the last dimension of x into N segments, then chooses M<=N segments and shuffle it. 
    Permutations are equally performed along each Channel of the same EEG. 
    If batch_equal is set to False, call the function recursively along each of the (N_{1}*N_{2}...*N{-3}) 
    dimensions of the tensor.
    
    Parameter
    ---------
    
    x: N-D Tensor or numpy array
        The element to shuffle. The last two dimensions must be (EEG_channel x EEG_samples), which means that the 
        same permutation is applied to all the channels of the EEG signal.
    segments: int, optional
        The number of segments in which the last dimension of x must be divided. Must be greater than 1
        Default: 1
    seg_to_per: int, optional
        The number of segments to permute. Must be greater than 1 and lower than segments. -1 is the only
        accepted negative number and means permute all the segments.
        Default: -1
    batch_equal: bool, optional
        whether to apply the same permutation to all EEG record or not. If True, permute_signal is called 
        recursively for each dimension of the batch until the last two are reached (e.g. given a tensor x
        of dimension (16,8,64,512) the function permute each of the 16*8 EEG signals (64 channels of 512 samples)
        individually).
        Default: False
    
    """
    
    if segments<1:
        raise ValueError('segments cannot be less than 2')     
    if seg_to_per<1:
        if seg_to_per>=0:
            raise ValueError('seg_to_per must be bigger than 1 (put -1 to permute all segments)')
        elif seg_to_per==-1:
            seg_to_per=segments
        elif (seg_to_per<(-1)):
            msgError='got a negative number of segments to permute. Only -1 to permute all segments is allowed'
            raise ValueError(msgError)
    elif seg_to_per>segments:
        raise ValueError('number of segment to permute is bigger than the number of segment')
    
    Ndim=len(x.shape)
    if (Ndim<=2) or (batch_equal):
        
        segment_len= x.shape[-1] // segments
        idx1=np.arange(segments)
        np.random.shuffle(idx1)
        idx2 = np.sort(idx1[:seg_to_per])
        idx1=np.sort(idx1)
        idx3 = np.copy(idx2)
        while len(np.where(idx2==idx3)[0])>0:
            np.random.shuffle(idx3)
        idx1[idx2]=idx3
        full_idx = np.arange(x.shape[-1])
        for k in range(len(idx1)):
            if idx1[k]!= k:
                start=segment_len*idx1[k]
                start2 = segment_len*k
                newidx = np.arange(start, start+segment_len)
                full_idx[start2: start2+segment_len]= newidx
                
        if not(isinstance(x, np.ndarray)):
               full_idx=torch.from_numpy(full_idx).to(device=x.device)
        x2 = x[..., full_idx]
    
    else:
        x2 = np.empty_like(x) if isinstance(x, np.ndarray) else torch.empty_like(x)
        for i in range(x.shape[0]):
            x2[i] = permutation_signal(x[i], segments=segments, seg_to_per=seg_to_per, batch_equal=batch_equal)
            
    return x2
    
    

## Crop and Resize

In [None]:
def torch_pchip(x: "1D Tensor", 
                y: "ND Tensor", 
                xv: "1D Tensor",
                save_memory: bool=True,
                new_y_max_numel: int=4194304
               ):
    """
    torch_pchip perform pchip interpolation on the last dimension of the input tensor y.
    
    This function is a pytorch adaptation of the scipy's pchip_interpolate. It performs sp-pchip interpolation
    (Shape Preserving Piecewise Cubic Hermite Interpolating Polynomial) on the last dimension of the y tensor.
    x is the original time grid and xv new virtual grid. So, the new values of y at time xv are given by the 
    polynomials evaluated at the time grid x.
    
    Parameter
    ---------
    
    x: 1D Tensor
        Tensor with the original time grid. Must be the same length as the last dimension of y
    y: ND Tensor
        Tensor to interpolate. The last dimension must have the signals to interpolate
    xv: 1D Tensor
        Tensor with the new virtual grid, i.e. the time points where to interpolate
    save_memory: bool, optional
        whether to perform the interpolation on subsets of the y tensor by recursively function calls or not.
        Does not apply if y is a 1-D tensor. If set to False memory usage can drastically increase 
        (for example with a 128 MB tensor, the memory usage of the function is 1.2 GB), but in some devices 
        it can speed up the process. However, this is not the case for all devices and performance may increase
        (see example below run on an old computer).
        Default: True
    new_y_max_numel: int, optional
        The number of elements which the tensor needs to surpass to make the function starting recursive calls.
        It can be considered as an indicator of the maximum allowed memory usage since slower the number, slower
        the memory used. 
        Default: 256*1024*16 (approximately 16s of recording of a 256 Channel EEG sampled at 1024 Hz)
    
    
    Some technical information and difference with other interpolation:
        https://blogs.mathworks.com/cleve/2012/07/16/splines-and-pchips/
        https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.PchipInterpolator.html
    Some parts of the code are inspired from: 
        https://github.com/scipy/scipy/blob/v1.10.1/scipy/interpolate/_cubic.py#L157-L302
    
    Example
    -------
    # result on my (old) computer with a 4D tensor/array with dim (1024,1,96,512) and xv with dim (1024,)
    #         torch_pchip no save memory | torch_pchip save memory | scipy pchip_interpolate
    #                   15.48s                     5.33s                     21.96s
    
    """
    
    if len(x.shape)!= 1:
        raise ValueError(['Expected 1D Tensor for x but received a ', str(len(x.shape)), '-D Tensor']) 
    if len(xv.shape)!= 1:
        raise ValueError(['Expected 1D Tensor for xv but received a ', str(len(xv.shape)), '-D Tensor'])
    if x.shape[0] != y.shape[-1]:
        raise ValueError('x must have the same length than the last dimension of y')

    # Initialize the new interpolated tensor
    Ndim=len(y.shape)
    new_y=torch.empty(( *y.shape[:(Ndim-1)], xv.shape[0]))
    
    # If save_memory and the new Tensor size is huge, call recursively for each element in the first dimension 
    if save_memory:
        if Ndim>1:
            if ((torch.numel(y)/y.shape[-1])*xv.shape[0])>new_y_max_numel:
                for i in range(new_y.shape[0]):
                    new_y[i] = torch_pchip(x, y[i], xv)
                return new_y
    
    
    # This is a common part for every channel
    bucket = torch.bucketize(xv, x) -1
    bucket = torch.clamp(bucket, 0, x.shape[0]-2)
    tv_minus = (xv - x[bucket]).unsqueeze(1)
    infer_tv = torch.cat(( tv_minus**3, tv_minus**2, tv_minus, torch.ones(tv_minus.shape)  ), 1) 
    
    
    h = (x[1:]-x[:-1])
    Delta = (y[...,1:] - y[...,:-1]) /h
    k = (torch.sign(Delta[...,:-1]*Delta[...,1:]) > 0)
    w1 = 2*h[1:] + h[:-1]
    w2 = h[1:] + 2*h[:-1]
    whmean = (w1/Delta[...,:-1] + w2/Delta[...,1:]) / (w1 + w2)
    
    slope = torch.zeros(y.shape)
    slope[...,1:-1][k] = whmean[k].reciprocal()

    slope[...,0] = ((2*h[0]+h[1])*Delta[...,0] - h[0]*Delta[...,1])/(h[0]+h[1])
    slope_cond = torch.sign(slope[...,0]) != torch.sign(Delta[...,0])
    slope[...,0][slope_cond] = 0
    slope_cond = torch.logical_and( torch.sign(Delta[...,0]) != torch.sign(Delta[...,1]), 
                                   torch.abs(slope[...,0]) > torch.abs(3*Delta[...,0]) )
    slope[...,0][ slope_cond ] = 3*Delta[...,0][slope_cond]
    
    slope[...,-1] = ((2*h[-1]+h[-2])*Delta[...,-1] - h[-1]*Delta[...,-2])/(h[-1]+h[-2])
    slope_cond = torch.sign(slope[...,-1]) != torch.sign(Delta[...,-1])
    slope[...,-1][ slope_cond ] = 0
    slope_cond = torch.logical_and( torch.sign(Delta[...,-1]) != torch.sign(Delta[...,-1]), 
                                   torch.abs(slope[...,-1]) > torch.abs(3*Delta[...,1]) )
    slope[...,-1][ slope_cond ] = 3*Delta[...,-1][slope_cond]


    t = (slope[...,:-1] + slope[...,1:] - Delta - Delta)  / h 
    a = ( t )/ h
    b = (Delta - slope[...,:-1]) / h - t
    
    

    py_coef = torch.stack((a, b, slope[...,:-1], y[...,:-1]),-1)
    new_y = (py_coef[...,bucket,:] * infer_tv ).sum(axis=-1)
    
    return new_y

In [None]:
def warp_signal(x,
                segments: int=10,
                stretch_strength: float=2.,
                squeeze_strength: float=0.5,
                batch_equal: bool=False,
               ):
    
    """
    crop_and_resize crop some segments of the last dimension of the input x and resize to the original dimension.
    
    Given x a N-D Tensor where the last two dim are EEG_Channel x EEG_Signal, crop_and_resize:
    1) divide the last dimension of x into N segments
    2) select at random a subset segments
    3) remove the selected segments from x
    4) create a new cropped version of x
    5) resample the new cropped version to the original dimension. For this part pchip interpolation 
       with a uniform virtual grid is used
    If batch_equal is set to False, call the function recursively and repeat step 1 to 5 for each 
    (N_{1}*N_{2}...*N{-3}) elements of the tensor.
    
    Parameters
    ----------
    x : N-D Tensor or numpy array
        The array to crop and resize. It accepts any dimensions but the last 2 must be EEG_Channel x EEG_Signal.
        For example, if x is 4-D, it must be in the form (N1 x N2 x Channels x Signal), where N1 and N2 are usually
        the batch size and channels
    segments : int, optional
        The number of segments to consider when dividing the last dimension of x.
        Default: 10
    N_cut : int, optional
        The number of segments to cut. 
        Default: 1
    batch_equal: bool, optional
        whether to apply the same crop to all EEG record or not. True means faster computation but more memory
        consuption and less variability, False the opposite
        Default: False
    """
    
    Ndim=len(x.shape)
    x_warped_final= np.empty_like(x) if isinstance(x, np.ndarray) else torch.empty_like(x, device=x.device)
    
    if batch_equal or Ndim<3:

        # set segment do stretch squeeze
        seglen= x.shape[-1] / segments
        seg_range = np.arange(segments)
        stretch = np.random.choice(seg_range, random.randint(1, segments//2), replace=False)
        squeeze = np.setdiff1d(seg_range, stretch)

        # pre-allocate warped vector to avoid continuous stack call
        Lseg = np.zeros((segments,2), dtype=int)
        Lseg[:,0] = (seg_range*seglen).astype(int)
        Lseg[:,1] = ( (seg_range+1)*seglen).astype(int)
        Lseg = Lseg[:,1] - Lseg[:,0]
        Lsegsum = np.cumsum(Lseg)

        x_size= [int(i) for i in x.shape]
        warped_len = int(np.sum(np.ceil(Lseg[stretch]*stretch_strength)) + 
                         np.sum(np.ceil(Lseg[squeeze]*squeeze_strength)) )
        x_size[-1]=warped_len

        # initialize warped array (i.e. the array where to allocate stretched and squeezed segments)
        x_warped = np.empty(x_size) if isinstance(x, np.ndarray) else torch.empty(x_size, device=x.device)
        
        # iterate over segments and stretch or squeeze each segment, then allocate to x_warped
        idx_cnt=0
        for i in range(segments):

            piece = x[..., int(i * seglen):int( (i + 1) * seglen)]
            if i in stretch:
                new_piece_dim = int(np.ceil(piece.shape[-1] * stretch_strength))
            else:
                new_piece_dim = int(np.ceil(piece.shape[-1] * squeeze_strength))

            if isinstance(x, np.ndarray):
                warped_piece = interpolate.pchip_interpolate(np.linspace(0, seglen-1, piece.shape[-1]), piece, 
                                                             np.linspace(0, seglen-1, new_piece_dim), axis=-1)
            else:
                warped_piece = torch_pchip( torch.linspace(0, seglen-1, piece.shape[-1]), piece, 
                                            torch.linspace(0, seglen-1, new_piece_dim))

            x_warped[..., idx_cnt : idx_cnt+new_piece_dim]=warped_piece
            idx_cnt += new_piece_dim
            
        # resample x_warped to fit original size
        if isinstance(x_warped, np.ndarray):
            x_warped_final = interpolate.pchip_interpolate(np.linspace(0, warped_len-1, warped_len), x_warped, 
                                                             np.linspace(0, warped_len-1, x.shape[-1]), axis=-1)
        else:
            x_warped_final = torch_pchip(torch.linspace(0, warped_len-1, warped_len), x_warped, 
                                         torch.linspace(0, warped_len, x.shape[-1]))
    
    
    else:
        # Recursively call until second to last dim is reached
        for i in range(x.shape[0]):
            x_warped_final[i] =  warp_signal(x[i] ,segments, stretch_strength,squeeze_strength, batch_equal)
     
    return x_warped_final

In [None]:
def crop_and_resize(x: 'N-D Tensor or array',
                    segments: int=10,
                    N_cut: int=1,
                    batch_equal: bool=False,
                   ):
    """
    crop_and_resize crop some segments of the last dimension of the input x and resize to the original dimension.
    
    Given x a N-D Tensor where the last two dim are EEG_Channel x EEG_Signal, crop_and_resize:
    1) divide the last dimension of x into N segments
    2) select at random a subset segments
    3) remove the selected segments from x
    4) create a new cropped version of x
    5) resample the new cropped version to the original dimension. For this part pchip interpolation 
       with a uniform virtual grid is used
    If batch_equal is set to False, call the function recursively and repeat step 1 to 5 for each 
    (N_{1}*N_{2}...*N{-3}) elements of the tensor.
    
    Parameters
    ----------
    x : N-D Tensor or numpy array
        The array to crop and resize. It accepts any dimensions but the last 2 must be EEG_Channel x EEG_Signal.
        For example, if x is 4-D, it must be in the form (N1 x N2 x Channels x Signal), where N1 and N2 are usually
        the batch size and channels
    segments : int, optional
        The number of segments to consider when dividing the last dimension of x.
        Default: 10
    N_cut : int, optional
        The number of segments to cut. 
        Default: 1
    batch_equal: bool, optional
        whether to apply the same crop to all EEG record or not. True means faster computation but more memory
        consuption and less variability, False the opposite
        Default: False
        
        
    Example
    -------
    dim = (16,1,64,512)
    segments=15
    N_cut=6
    
    x = torch.sin(torch.linspace(0,20*math.pi, dim[-1]))
    zero_tensor = torch.zeros(dim)
    x = zero_tensor + x
    # x = x.numpy() #the result won't change if x is a numpy array
    
    x_crop = crop_and_resize(x, segments= segments, N_cut= N_cut, batch_equal=True)
    print(torch.equal(x_crop[1], x_crop[2])) # True
    x_crop = crop_and_resize(x, segments= segments, N_cut= N_cut, batch_equal=False)
    print(torch.equal(x_crop[1], x_crop[2])) # False
    
    # plot the results
    plt.plot(xnp[0,0,0,:])
    plt.show()
    plt.plot(x_crop[0,0,0,:])
    plt.plot(x_crop[2,0,0,:])
    plt.show()
    
    """
    

    x_crop=np.empty_like(x) if isinstance(x, np.ndarray) else torch.empty_like(x, device=x.device)
    Ndim = len(x.shape)
    if batch_equal or Ndim<3:
        
        segment_len= x.shape[-1] // segments
        if isinstance(x, np.ndarray):
            seg_to_rem = np.random.randint(0,segments, N_cut, dtype=int)
            idx_to_rem = np.empty(segment_len*N_cut, dtype=int)
            for i in range(seg_to_rem.shape[0]):
                start=segment_len*(seg_to_rem[i])
                idx1 = segment_len*i
                idx_to_rem[idx1 : idx1+segment_len]= np.linspace(start, start+segment_len-1, segment_len)

            new_x= np.delete(x, idx_to_rem, axis=-1)
            x_crop = interpolate.pchip_interpolate(np.linspace(0, x.shape[-1]-1, new_x.shape[-1]), 
                                                   new_x, np.linspace(0,x.shape[-1]-1,x.shape[-1]), axis=-1)
        else:

            seg_to_rem = torch.randperm(segments, device=x.device)[:N_cut]
            idx_to_rem = torch.empty(segment_len*N_cut, dtype=torch.int, device=x.device)
            for i in range(seg_to_rem.shape[0]):
                start=segment_len*(seg_to_rem[i])
                idx1 = segment_len*i
                idx_to_rem[idx1 : idx1+segment_len]= torch.linspace(start, start+segment_len-1, segment_len, device=x.device)

            # https://stackoverflow.com/questions/55110047/finding-non-intersection-of-two-pytorch-tensors
            allidx = torch.arange(x.shape[-1], device=x.device)
            combined = torch.cat( (allidx, idx_to_rem, idx_to_rem) )
            uniques, counts = combined.unique(return_counts=True)
            difference = uniques[counts == 1]
            new_x= x[...,difference]
            x_crop = torch_pchip(torch.linspace(0, x.shape[-1]-1, new_x.shape[-1]), 
                                 new_x, torch.linspace(0,x.shape[-1]-1,x.shape[-1]))
    
    else:
        for i in range(x.shape[0]):
            x_crop[i] = crop_and_resize(x[i], segments, N_cut, batch_equal)

        
    return x_crop

# Augmentation Classes

In [None]:
class StaticSingleAug():
    
    def __init__(self, augmentation, arguments: list or dict or list[list or dict]=None):
        
        if not(inspect.isfunction(augmentation) or inspect.isbuiltin(augmentation)):
            raise ValueError('augmentation must be a function to call')
        else:
            self.augmentation=augmentation
        
        self.arguments=arguments
        self.counter=0
        self.maxcounter=0
        self.multipleStaticArguments=False
        if arguments !=None:
            if all(isinstance(i,list) or isinstance(i,dict) for i in arguments):
                self.multipleStaticArguments=True
                self.maxcounter=len(arguments)
        
    def PerformAugmentation(self, X):
        
        if self.multipleStaticArguments:
            argument=self.arguments[self.counter]
            if isinstance(argument, list):
                Xaug = self.augmentation(X, *argument)
            else:
                Xaug = self.augmentation(X, **argument)
            
            self.counter +=1
            if self.counter == self.maxcounter:
                self.counter=0 
            print(argument)
        else:
            if self.arguments==None:
                Xaug= self.augmentation(X)
            elif isinstance(self.arguments, list):
                Xaug = self.augmentation(X, *self.arguments)
            else:
                Xaug = self.augmentation(X, **self.arguments)
            print(self.arguments)
        
        return Xaug
        
    def __call__(self, X):
        return self.PerformAugmentation(X)
        
        
        

In [None]:
class DynamicSingleAug():
    def __init__(self, 
                 augmentation, 
                 discrete_arg: Dict[str, list]=None, 
                 range_arg: Dict[str, list[ int or float, int or float]]=None,
                 range_type: Dict[str, str or bool] or list[str or bool]=None
                ):
        
        # set augmentation function
        if not(inspect.isfunction(augmentation) or inspect.isbuiltin(augmentation)):
            raise ValueError('augmentation must be a function to call')
        else:
            self.augmentation=augmentation
        
        # get function argument name
        self.argnames= inspect.getfullargspec(augmentation)[0][1:]
        
        # check if given discrete_arg keys are actually augmentation arguments
        self.discrete_arg=None
        if discrete_arg != None:
            if isinstance(discrete_arg, dict):
                if all(i in self.argnames for i in discrete_arg):
                    self.discrete_arg=discrete_arg
                else:
                    raise ValueError('keys of discrete_arg argument must be the argument of the augmentation fun')
            else:
                raise ValueError('discrete_arg must be a dictionary')
        
        # check if given range_arg keys are actually augmentation arguments 
        # also check if values are two element list
        self.range_arg=None
        if range_arg != None:
            if isinstance(range_arg, dict):
                if all(i in self.argnames for i in range_arg):
                    if all( (isinstance(i,list) and len(i)==2) for i in range_arg.values()):
                        self.range_arg=range_arg
                    else:
                        raise ValueError('range_arg values must be a len 2 list with min and max range')
                else:
                    raise ValueError('keys of range_arg argument must be the argument of the augmentation fun')
            else:
                raise ValueError('range_arg must be a dictionary')
        
        # check if range_types keys are the same as range_args
        self.range_type=None
        if range_type!=None:
            if isinstance(range_type, dict):
                if range_type.keys() == range_arg.keys():
                    self.range_type=range_type
                else:
                    raise ValueError('keys of range_type must be the same as range_arg')
            elif isinstance(range_type, list):
                if len(range_type)==len(self.range_arg):
                    self.range_type=range_type
                else:
                    raise ValueError('range_type must have the same length as range_args')
            else:
                raise ValueError('discrete_arg must be a dictionary or a list')
        self.is_range_type_dict= True if isinstance(self.range_type, dict) else False
        
        self.given_arg = list(self.discrete_arg) if self.discrete_arg!=None else []
        self.given_arg += list(self.range_arg) if self.range_arg!=None else []
        
    
    def PerformAugmentation(self, X):    
        arguments={i:None for i in self.given_arg}
        if self.discrete_arg!=None:
            for i in self.discrete_arg:
                if isinstance(self.discrete_arg[i],list): 
                    arguments[i] = random.choice(self.discrete_arg[i]) 
                else:
                    arguments[i]= self.discrete_arg[i] 
        
        cnt=0 # counter if range_type is a list, it's a sort of enumerate
        if self.range_arg!=None:
            for i in self.range_arg.keys():
                arguments[i]=random.uniform(self.range_arg[i][0], self.range_arg[i][1])
                if self.is_range_type_dict:
                    if self.range_type[i] in ['int', True]:
                        arguments[i] = int(arguments[i])
                else:
                    if self.range_type[cnt] in ['int', True]:
                        arguments[i] = int(arguments[i])
                    cnt+=1
        
        print(arguments)
        Xaug = self.augmentation(X, **arguments)
        return Xaug
        
    def __call__(self, X):
        return self.PerformAugmentation(X)

In [None]:
class SequentialAug():
    
    def __init__(self,*augmentations):
        
        self.augs=[item for item in augmentations]
     
    def PerformAugmentation(self, X): 
        
        Xaugs = self.augs[0](X)
        for i in range(1,len(self.augs)):
            Xaugs = self.augs[i](Xaugs)
        return Xaugs
            
    def __call__(self, X):
        return self.PerformAugmentation(X)