diff --git a/neo/rawio/neuralynxrawio.py b/neo/rawio/neuralynxrawio.py index 6f6f12217..8eff4ace6 100644 --- a/neo/rawio/neuralynxrawio.py +++ b/neo/rawio/neuralynxrawio.py @@ -37,6 +37,8 @@ class NeuralynxRawIO(BaseRawIO): """" Class for reading dataset recorded by Neuralynx. + Data recorded with 'Invert' enabled is autmoatically stratified in 'read_ncs' <--- something like this + Examples: >>> reader = NeuralynxRawIO(dirname='Cheetah_v5.5.1/original_data') >>> reader.parse_header() @@ -216,7 +218,7 @@ def _parse_header(self): ts0 = ts[0] ts1 = ts[-1] ts0 = min(ts0, ts[0]) - ts1 = max(ts0, ts[-1]) + ts1 = max(ts1, ts[-1]) # TODO: shouldn't this be ts1 if self._timestamp_limits is None: # case NO ncs but HAVE nev or nse @@ -411,6 +413,242 @@ def _rescale_event_timestamp(self, event_timestamps, dtype): event_times -= self.global_t_start return event_times + @staticmethod + def patch_ncs_files(dirname, patchdir=None, interpolate=False, verbose=True, scan_only=False): + """ + Scan a directory containing .ncs files, and detect misaligned data. Then patches the data + and saves the fixed .ncs files to another directory. + + + Parameters + ---------- + dirname: [str] + full path containing ncs files scan + patchdir: [str] (default: None) + when patching data, patched ncs files are written to patchdir + interpolate: [bool] (default: False) + doesn't do anything + verbose: [bool] (default=True) + if true, will print out some extra info + scan_only: [bool] (default=False) + if true, will only scan the data, and list faulty files (will not patch), + does not require 'patchdir'. + + """ + + time_margin = 1 # [ms] allow std of time_margin for time starts and stops across ncs files + max_gap_duration_to_fill = 60 # [s] maximum gap size in data to fill + + # --------------------------- SCAN NCS FOR MISALIGNMENTS ------------------------------------ + # Scan all files in directory and list: + # -gaps due to ringbuffer errors + # -lag by dsp filters + # -dsp comensation + + # Summarize in 'ncs_info' (dict with lists, one entry per file + ncs_info = OrderedDict( + filenames=[], + fullnames=[], + chan_uids=[], + chan_names=[], + DspFilterDelay_us=[], + DspDelayCompensation=[], + t0=[], + t1=[], + gap_pairs=[], + gap_pair_times=[], + data_size=[], + timestamp=[], # only filled if not align_and_path + sampling_rate=[], + ) + + if verbose: + print('Scanning for ncs files in :\n ', dirname) + print('Detected files:') + + found_missing_dsp_compensation = False + for filename in sorted(os.listdir(dirname)): + + fullname = os.path.join(dirname, filename) + if (os.path.getsize(fullname) <= HEADER_SIZE) or not filename.endswith('ncs'): + continue + + ncs_info['filenames'].append(filename) + ncs_info['fullnames'].append(fullname) + + if verbose: + print(f'\t{filename}:') + + info = read_txt_header(fullname) + chan_names = info['channel_names'] + chan_ids = info['channel_ids'] + + for idx, chan_id in enumerate(chan_ids): + chan_name = chan_names[idx] + chan_uid = (chan_name, chan_id) + + ncs_info['chan_uids'].append(chan_uid) + ncs_info['chan_names'].append(chan_name) + ncs_info['sampling_rate'].append(info['sampling_rate']) + + if 'DspFilterDelay_µs' in info.keys(): + ncs_info['DspFilterDelay_us'].append(info['DspFilterDelay_µs']) + else: + ncs_info['DspFilterDelay_us'].append(None) + + if 'DspDelayCompensation' in ncs_info.keys(): + ncs_info['DspDelayCompensation'].append(info['DspDelayCompensation']) + else: + ncs_info['DspDelayCompensation'].append(None) + + data = read_ncs_memmap(fullname) + ncs_info['t0'].append(data['timestamp'][0]) + ncs_info['t1'].append(data['timestamp'][-1]) + ncs_info['data_size'].append(data.size) + + # check good_delta is indeed good delta for this channel + + good_delta = int(BLOCK_SIZE * 1e6 / info['sampling_rate']) + assert good_delta in np.unique(np.diff(data['timestamp'])), 'good delta does not match for channel' + + if ncs_info['DspFilterDelay_us'][-1] is not None and int(ncs_info['DspFilterDelay_us'][-1]) > 0 and \ + ncs_info['DspDelayCompensation'] != 'Enabled': + print(f'WARNING: {filename} has dsplag ({ncs_info["DspFilterDelay_us"][-1]}) and no delay compensation!') + found_missing_dsp_compensation = True + + if verbose and scan_only: + print(f'\t\tchan id: {chan_uid[1]}') + print(f'\t\tdsp delay: {ncs_info["DspFilterDelay_us"][-1]}, dsp compensation: {ncs_info["DspDelayCompensation"][-1]}') + print(f'\t\tt start: {ncs_info["t0"][-1]}, t stop: {ncs_info["t1"][-1]}') + + # --------------------------- CHECK IF DATA NEEDS PATCHING ------------------------------------ + + t0_all = np.min(ncs_info['t0']) + t1_all = np.max(ncs_info['t1']) + + print('\nDetected errors:') + + needs_patching = False + if np.any((ncs_info['t0'] - t0_all) > time_margin) or \ + np.any((t1_all - ncs_info['t1'] > time_margin)): + print('\t-misalignment of timestamps') + needs_patching = True + + if len(np.unique(ncs_info['data_size'])) > 1: + print('\t-Records do not have the same nr of samples') + needs_patching = True + + if scan_only: + return + + if not needs_patching and not found_missing_dsp_compensation: + print('\nNo errors found, no need to patch this data!') + return + elif not needs_patching and found_missing_dsp_compensation: + print('\nNo broken data, but some channels have no Dsp delay compensation') + return + + # --------------------------- PATCH DATA ------------------------------------ + assert patchdir is not None, 'To patch data, provide patchdir keyword' + if not os.path.isdir(patchdir): + print(f'Patchdir did not exist, making directory:\n{patchdir}') + os.makedirs(patchdir) + + # Check if there is no misalginment due to dsp filters and no dsp compensation + assert len(np.unique(ncs_info['DspFilterDelay_us'])) == 1 or \ + np.all([ddc == 'Enabled' for ddc in ncs_info['DspDelayCompensation']]), \ + 'some channels have dsp filtering, but lag compensation is disabled' + + time_range = t1_all - t0_all + + print(f'Saving patched data in: {patchdir}') + print('Patching:') + for i, chan_uid in enumerate(ncs_info['chan_uids']): + print(f'\t-{ncs_info["filenames"][i]}') + + # Infile and outfile + ncs_in_name = ncs_info['fullnames'][i] + ncs_out_name = os.path.join(patchdir, ncs_info['filenames'][i]) + + # Load broken data + data = read_ncs_memmap(ncs_info['fullnames'][i]) + + good_delta = int(BLOCK_SIZE * 1e6 / ncs_info['sampling_rate'][i]) + ts_interpolated = np.arange(0, time_range, good_delta, dtype='uint64') + chan_id = np.ones(ts_interpolated.shape, dtype='uint32') # read per ncs file + sample_rate = np.ones(ts_interpolated.shape, dtype='uint32') # read per ncs file + nb_valid = np.ones(ts_interpolated.shape, dtype='uint32') * 512 # hardcode as we patch here + + # Fil constants + chan_ids_in_ncs = np.unique(data['channel_id']) + assert len(chan_ids_in_ncs) == 1, f'found multiple channel ids in ncs: {chan_ids_in_ncs}' + chan_id[:] = chan_ids_in_ncs[0] + + sample_rates_in_ncs = np.unique(data['sample_rate']) + assert len(sample_rates_in_ncs == 1), f'found multiple sample rates in ncs: {sample_rate}' + sample_rate[:] = sample_rates_in_ncs[0] + + sample_dt = 1e6/sample_rates_in_ncs[0] # [us] time between two sample points + + # Patch data + # Make flat array of size (nsamples,) to fill data in (more convienent due to shifting of blocks) + # Have fill value default to signal mean + records = np.zeros((len(ts_interpolated) * 512), dtype='int16') * np.mean(data['samples']) + + # Get times of ncs file, normalized to common t0 + ts_norm = data['timestamp'] - t0_all + + assert len(np.where(np.diff(data['timestamp']) < good_delta)[0]) == 0, 'this shouldnt happen' + + # gap indices are detected by dt > good_delta (thus pointing to rows) + gap_indices = np.where(np.diff(data['timestamp']) > good_delta)[0] # row (block) indices + for g in gap_indices: + if data['timestamp'][g+1] - data['timestamp'][g] <= max_gap_duration_to_fill * 1e6: #, 'found a big gap, what should we do???' + print(f'WARNING - big gap filled {data["timestamp"][g+1] - data["timestamp"][g]} - {ncs_info["filenames"][i]}') + # Get the intervals of continuous data (pointing to 'ts_norm' or the orignal (broken) data + data_starts = np.concatenate([np.zeros(1), gap_indices+1]).astype('int') + data_stops = np.concatenate([gap_indices, [data['timestamp'].size-1]]).astype('int') + + # Construct bins for interpolated time axis + ts_bins = np.vstack([ts_interpolated, np.roll(ts_interpolated, -1)]) + ts_bins[1, -1] = ts_bins[0, -1] + good_delta + + for i0, i1 in zip(data_starts, data_stops): + + # Find in which row of the interpolated time_axis the timepoint of the gap fits + i0_itp = np.where((ts_bins[0, :] <= ts_norm[i0]) & (ts_bins[1, :] > ts_norm[i0]))[0] + i1_itp = np.where((ts_bins[0, :] <= ts_norm[i1]) & (ts_bins[1, :] > ts_norm[i1]))[0] + + # cant be sure enough + assert len(i0_itp) == 1 and len(i1_itp) == 1 + i0_itp = i0_itp[0] + i1_itp = i1_itp[0] + assert ts_interpolated[i0_itp] <= ts_norm[i0] < ts_interpolated[i0_itp] + good_delta + assert ts_interpolated[i1_itp] <= ts_norm[i1] < ts_interpolated[i1_itp] + good_delta + + # account for disalignment between interpolated time axis and OG time points + n_idx_in_block = int((ts_norm[i0] - ts_interpolated[i0_itp]) / sample_dt) + lin_i0_itp = sub2lin(i0_itp) + n_idx_in_block + lin_i1_itp = sub2lin(i1_itp) + n_idx_in_block + + # write the data to the correct points! + records[lin_i0_itp:lin_i1_itp] = data['samples'].flatten()[sub2lin(i0): sub2lin(i1)] + + # remark: I could mark the gap intervals, and interpolate the data in the gap + + # Prepare the data for writing to ncs format + records = np.resize(records, (int(records.size/512), 512)) + patched_data = np.core.records.fromarrays( + [ts_interpolated + t0_all, chan_id, sample_rate, nb_valid, records], dtype=np.dtype(ncs_dtype) + ) + + header = read_ncs_header(ncs_in_name) + with open(ncs_out_name, 'wb') as f: + f.write(header) + f.write(patched_data.tobytes()) + + print('Done!') + def read_ncs_files(self, ncs_filenames): """ Given a list of ncs files contrsuct: @@ -516,7 +754,7 @@ def read_ncs_files(self, ncs_filenames): if chan_uid == chan_uid0: ts0 = subdata[0]['timestamp'] ts1 = subdata[-1]['timestamp'] \ - + np.uint64(BLOCK_SIZE / self._sigs_sampling_rate * 1e6) + + np.uint64(BLOCK_SIZE / self._sigs_sampling_rate * 1e6) self._timestamp_limits.append((ts0, ts1)) t_start = ts0 / 1e6 self._sigs_t_start.append(t_start) @@ -672,23 +910,33 @@ def read_txt_header(filename): if old_date_format: datetime1_regex = r'## Time Opened \(m/d/y\): (?P\S+) \(h:m:s\.ms\) (?P