/
stim.py
130 lines (114 loc) · 4.31 KB
/
stim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Authors: Daniel Strohmeier <daniel.strohmeier@tu-ilmenau.de>
#
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from scipy.interpolate import interp1d
from scipy.signal.windows import hann
from .._fiff.pick import _picks_to_idx
from ..epochs import BaseEpochs
from ..event import find_events
from ..evoked import Evoked
from ..io import BaseRaw
from ..utils import _check_option, _check_preload, fill_doc
def _get_window(start, end):
"""Return window which has length as much as parameter start - end."""
window = 1 - np.r_[hann(4)[:2], np.ones(np.abs(end - start) - 4), hann(4)[-2:]].T
return window
def _fix_artifact(data, window, picks, first_samp, last_samp, mode):
"""Modify original data by using parameter data."""
if mode == "linear":
x = np.array([first_samp, last_samp])
f = interp1d(x, data[:, (first_samp, last_samp)][picks])
xnew = np.arange(first_samp, last_samp)
interp_data = f(xnew)
data[picks, first_samp:last_samp] = interp_data
if mode == "window":
data[picks, first_samp:last_samp] = (
data[picks, first_samp:last_samp] * window[np.newaxis, :]
)
@fill_doc
def fix_stim_artifact(
inst,
events=None,
event_id=None,
tmin=0.0,
tmax=0.01,
mode="linear",
stim_channel=None,
picks=None,
):
"""Eliminate stimulation's artifacts from instance.
.. note:: This function operates in-place, consider passing
``inst.copy()`` if this is not desired.
Parameters
----------
inst : instance of Raw or Epochs or Evoked
The data.
events : array, shape (n_events, 3)
The list of events. Required only when inst is Raw.
event_id : int
The id of the events generating the stimulation artifacts.
If None, read all events. Required only when inst is Raw.
tmin : float
Start time of the interpolation window in seconds.
tmax : float
End time of the interpolation window in seconds.
mode : 'linear' | 'window'
Way to fill the artifacted time interval.
'linear' does linear interpolation
'window' applies a (1 - hanning) window.
stim_channel : str | None
Stim channel to use.
%(picks_all_data)s
Returns
-------
inst : instance of Raw or Evoked or Epochs
Instance with modified data.
"""
_check_option("mode", mode, ["linear", "window"])
s_start = int(np.ceil(inst.info["sfreq"] * tmin))
s_end = int(np.ceil(inst.info["sfreq"] * tmax))
if (mode == "window") and (s_end - s_start) < 4:
raise ValueError(
"Time range is too short. Use a larger interval " 'or set mode to "linear".'
)
window = None
if mode == "window":
window = _get_window(s_start, s_end)
picks = _picks_to_idx(inst.info, picks, "data", exclude=())
_check_preload(inst, "fix_stim_artifact")
if isinstance(inst, BaseRaw):
if events is None:
events = find_events(inst, stim_channel=stim_channel)
if len(events) == 0:
raise ValueError("No events are found")
if event_id is None:
events_sel = np.arange(len(events))
else:
events_sel = events[:, 2] == event_id
event_start = events[events_sel, 0]
data = inst._data
for event_idx in event_start:
first_samp = int(event_idx) - inst.first_samp + s_start
last_samp = int(event_idx) - inst.first_samp + s_end
_fix_artifact(data, window, picks, first_samp, last_samp, mode)
elif isinstance(inst, BaseEpochs):
if inst.reject is not None:
raise RuntimeError(
"Reject is already applied. Use reject=None " "in the constructor."
)
e_start = int(np.ceil(inst.info["sfreq"] * inst.tmin))
first_samp = s_start - e_start
last_samp = s_end - e_start
data = inst._data
for epoch in data:
_fix_artifact(epoch, window, picks, first_samp, last_samp, mode)
elif isinstance(inst, Evoked):
first_samp = s_start - inst.first
last_samp = s_end - inst.first
data = inst.data
_fix_artifact(data, window, picks, first_samp, last_samp, mode)
else:
raise TypeError("Not a Raw or Epochs or Evoked (got %s)." % type(inst))
return inst