# Example Usage for Cleaning Stimulation Pulses with Zappy

# Main Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import zappy

# Load Data

In [None]:
def get_stim_data():
    from tqdm import tqdm
    import wget
    import zipfile
    import os
    
    url = "https://www.dropbox.com/sh/yqf36798y1yz9e8/AACcqZdu09Chlvkr1NOFFsiQa?dl=1"
    out = './stim_24k.zip'
    fname = './data/EC206_B18_stim_example.npz'
    if not os.path.exists(fname):
        if not os.path.exists(out):
            print('Downloading sample data...')
            wget.download(url, out=out, bar=wget.bar_thermometer)
        
        print('Extracting sample data...')
        with zipfile.ZipFile(out, 'r') as zipObj:
            zipObj.extractall('./data')
            
        os.remove(out)
            
    npz = zappy.io.sigdict.load_data_dict(fname)
    print('Loaded sample data.')    
    
    return npz

# Download / load-in sample data
npz = get_stim_data()

# Get data sizes
n_time = npz['stim'].shape[0]
n_stim = npz['stim'].shape[1]
tstamp = np.arange(0, n_time / npz['Fs'], 1/npz['Fs'])

# Visualize Stimulation Pulses 

## Full Stim Sequence

In [None]:
plt.figure(figsize=(24,12))
for i in range(n_stim):
    ax = plt.subplot(n_stim, 1, i+1)
    ax.plot(tstamp, npz['stim'][:,i])
    ax.set_ylim([-2500, 2500])
    
    if i < (n_stim-1):
        ax.set_xticks([])
        ax.set_xticklabels([])
    else:
        ax.set_xlabel('Time (sec)')
    ax.set_ylabel('Stim Current (uA)')

## Zoomed-In Plot

In [None]:
plt.figure(figsize=(24,12))
for i in range(n_stim):
    ax = plt.subplot(n_stim, 1, i+1)
    ax.plot(tstamp, npz['stim'][:,i])
    ax.set_ylim([-2500, 2500])
    
    if i < (n_stim-1):
        ax.set_xticks([])
        ax.set_xticklabels([])
    else:
        ax.set_xlabel('Time (sec)')
    ax.set_ylabel('Stim Current (uA)')
    
    ax.set_xlim([45, 52])

# Run Zappy

In [None]:
signal_cleaned = zappy.elstim.artifact.ica_pulse_reconstruction(
                            npz['signal'][:, :],
                            npz['stim']*1e-6,
                            inter_train_len=npz['Fs']*10.0,
                            padding=[50, 50],
                            krt_pct=[1, 99],
                            plot=True)