Skip to content

Commit

Permalink
Separate emg_data by dig-in to allow uneven taste deliveries
Browse files Browse the repository at this point in the history
  • Loading branch information
abuzarmahmood committed Feb 6, 2024
1 parent 07b5580 commit e9a5857
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 81 deletions.
72 changes: 35 additions & 37 deletions blech_make_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@ def get_dig_in_data(hf5):
return dig_in_pathname, dig_in_basename, dig_in_data

def create_spike_trains_for_digin(
taste_starts_cutoff,
this_dig_in,
this_starts,
durations,
sampling_rate_ms,
units,
hf5,
):
spike_train = []
for this_start in this_dig_in:
for this_start in this_starts:
spikes = np.zeros((len(units), durations[0] + durations[1]))
for k in range(len(units)):
# Get the spike times around the end of taste delivery
Expand Down Expand Up @@ -61,27 +60,32 @@ def create_spike_trains_for_digin(
hf5.flush()

def create_emg_trials_for_digin(
taste_starts_cutoff,
this_dig_in,
this_starts,
durations,
sampling_rate_ms,
emg_nodes,
hf5,
):
emg_data = []
for this_start in this_dig_in:
for this_emg in emg_nodes:
emg_data.append(this_emg[this_start - durations[0]*sampling_rate_ms:\
this_start + durations[1]*sampling_rate_ms])
emg_data = np.stack(emg_data, axis=0)*0.195
emg_data = [[this_emg[this_start - durations[0]*sampling_rate_ms:\
this_start + durations[1]*sampling_rate_ms] \
for this_start in this_starts]
for this_emg in emg_nodes]
emg_data = np.stack(emg_data)*0.195

emg_data = np.mean(
emg_data.reshape((len(emg_data),-1, int(sampling_rate_ms))),
emg_data.reshape((*emg_data.shape[:2],-1, int(sampling_rate_ms))),
axis = -1)

# Write out ind:name map for each node
ind_name_map = {i:node._v_name for i,node in enumerate(emg_nodes)}
str_dict = str(ind_name_map)
if '/emg_data/ind_electrode_map' in hf5:
hf5.remove_node('/emg_data','ind_electrode_map')
hf5.create_array('/emg_data', 'ind_electrode_map', np.array(str_dict))

# And add emg_data to the hdf5 file
hf5.create_group('/emg_data', dig_in_basename[i])
spike_array = hf5.create_array(
hf5.create_array(
f'/emg_data/{dig_in_basename[i]}',
'emg_array', np.array(emg_data))
hf5.flush()
Expand All @@ -95,8 +99,6 @@ def create_emg_trials_for_digin(
# Ask for the directory where the hdf5 file sits, and change to that directory
# Get name of directory with the data files
metadata_handler = imp_metadata(sys.argv)
dir_name = '/media/bigdata/NM43_2500ms_160515_104159_copy'
metadata_handler = imp_metadata([[],dir_name])
os.chdir(metadata_handler.dir_name)
print(f'Processing: {metadata_handler.dir_name}')

Expand Down Expand Up @@ -232,7 +234,7 @@ def create_emg_trials_for_digin(
hf5.close()
trial_info_frame.to_hdf(metadata_handler.hdf5_name, 'trial_info_frame', mode='a')
hf5 = tables.open_file(metadata_handler.hdf5_name, 'r+')
csv_path = os.path.join(dir_name, 'trial_info_frame.csv')
csv_path = os.path.join(metadata_handler.dir_name, 'trial_info_frame.csv')
trial_info_frame.to_csv(csv_path, index=False)

# Get list of units under the sorted_units group.
Expand All @@ -244,26 +246,24 @@ def create_emg_trials_for_digin(
#============================================================#
# NOTE: Calculate headstage falling off same way for all not "none" channels
# Pull out raw_electrode and raw_emg data
if '/raw' in hf5:
raw_electrodes = [x for x in hf5.get_node('/','raw')]
else:
raw_electrodes = []

# If sorting hasn't been done, use only emg channels
# to calculate cutoff...don't need to go through all channels

if '/raw_emg' in hf5:
raw_emg_electrodes = [x for x in hf5.get_node('/','raw_emg')]
else:
raw_emg_electrodes = []

all_electrodes = [raw_electrodes, raw_emg_electrodes]
all_electrodes = [x for y in all_electrodes for x in y]
# If raw channel data is present, use that to calcualte cutoff
# This would explicitly be the case if only EMG was recorded
if len(all_electrodes) > 0:
all_electrode_names = [x._v_pathname for x in all_electrodes]
electrode_names = list(zip(*[x.split('/')[1:] for x in all_electrode_names]))
if len(raw_emg_electrodes) > 0:
emg_electrode_names = [x._v_pathname for x in raw_emg_electrodes]
electrode_names = list(zip(*[x.split('/')[1:] for x in emg_electrode_names]))

print('Calculating cutoff times')
print('Calculating cutoff times using following EMG electrodes...')
print(emg_electrode_names)
print('===============================================')
cutoff_data = []
for this_el in tqdm(all_electrodes):
for this_el in tqdm(raw_emg_electrodes):
raw_el = this_el[:]
# High bandpass filter the raw electrode recordings
filt_el = get_filtered_electrode(
Expand Down Expand Up @@ -302,8 +302,8 @@ def create_emg_trials_for_digin(
'recording_cutoff'
],
)
elec_cutoff_frame['electrode_type'] = all_electrode_names[0]
elec_cutoff_frame['electrode_name'] = all_electrode_names[1]
elec_cutoff_frame['electrode_type'] = electrode_names[0]
elec_cutoff_frame['electrode_name'] = electrode_names[1]

# Write out to HDF5
hf5.close()
Expand Down Expand Up @@ -357,11 +357,10 @@ def create_emg_trials_for_digin(
hf5.create_group('/', 'spike_trains')

# Pull out spike trains
for i, this_dig_in in zip(taste_digin_inds, taste_starts_cutoff):
for i, this_starts in zip(taste_digin_inds, taste_starts_cutoff):
print(f'Creating spike-trains for {dig_in_basename[i]}')
create_spike_trains_for_digin(
taste_starts_cutoff,
this_dig_in,
this_starts,
durations,
sampling_rate_ms,
units,
Expand All @@ -387,11 +386,10 @@ def create_emg_trials_for_digin(
hf5.create_group('/', 'emg_data')

# Pull out emg trials
for i, this_dig_in in zip(taste_digin_inds, taste_starts_cutoff):
for i, this_starts in zip(taste_digin_inds, taste_starts_cutoff):
print(f'Creating emg-trials for {dig_in_basename[i]}')
create_emg_trials_for_digin(
taste_starts_cutoff,
this_dig_in,
this_starts,
durations,
sampling_rate_ms,
emg_nodes,
Expand Down
145 changes: 101 additions & 44 deletions emg/emg_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,33 @@
import shutil
import glob
import pandas as pd
import tables
import ast

sys.path.append('..')
from utils.blech_utils import imp_metadata

# Get name of directory with the data files
metadata_handler = imp_metadata(sys.argv)
# metadata_handler = imp_metadata(sys.argv)
dir_name = '/media/bigdata/NM43_2500ms_160515_104159_copy'
metadata_handler = imp_metadata([[],dir_name])
dir_name = metadata_handler.dir_name
os.chdir(dir_name)
print(f'Processing : {dir_name}')

############################################################
# Load the data
# shape : channels x tastes x trials x time
emg_data = np.load('emg_output/emg_data.npy')
# emg_data = np.load('emg_output/emg_data.npy')
with tables.open_file(metadata_handler.hdf5_name, 'r') as hf5:
emg_digins = hf5.list_nodes('/emg_data')
emg_digins = [x for x in emg_digins if 'dig_in' in x._v_name]
emg_digin_names = [x._v_name for x in emg_digins]
emg_data = [x.emg_array[:] for x in emg_digins]
map_array = hf5.get_node('/emg_data/ind_electrode_map').read()
ind_electrode_map = ast.literal_eval(str(map_array)[2:-1])
# key = electrode_ind, value = index in array
inverse_map = {int(v.split('emg')[1]): k for k, v in ind_electrode_map.items()}

info_dict = metadata_handler.info_dict
params_dict = metadata_handler.params_dict
Expand Down Expand Up @@ -49,16 +63,13 @@
wanted_rows = wanted_rows.sort_values('electrode_ind')
wanted_rows.reset_index(inplace=True, drop=True)

print('Using electrodes :')
print(wanted_rows)
print()

# TODO: Ask about differencing pairs
# Difference by CAR emg labels
# If only 1 channel per emg CAR, do not difference if asked
emg_car_groups = [x[1] for x in wanted_rows.groupby('CAR_group')]
emg_car_names = [x.CAR_group.unique()[0] for x in emg_car_groups]
emg_car_inds = [x.index.values for x in emg_car_groups]
# emg_car_inds = [x.index.values for x in emg_car_groups]
emg_car_inds = [[inverse_map[x] for x in y.electrode_ind.values] for y in emg_car_groups]

print('EMG CAR Groups with more than 1 channel will be differenced')
print('EMG CAR groups as follows:')
Expand All @@ -67,45 +78,76 @@
print()

# TODO: This question can go into an EMG params file
# todo: Rename to diff_data at this stage
# Bandpass filter the emg signals, and store them in a numpy array.
# Low pass filter the bandpassed signals, and store them in another array
# Take difference between pairs of channels
# Shape : Channels x Tastes x Trials x Time
emg_data_grouped = [emg_data[x] for x in emg_car_inds]
emg_diff_data = []
# emg_data = List of arrays (per dig-in) of shape : channels x trials x time
# emg_data_grouped = list of lists
# outer list : emg CAR groups
# inner list : dig-ins
# element_array : channels x trials x time
emg_data_grouped = [[dat[x] for dat in emg_data] for x in emg_car_inds]
# Make sure all element arrays are 3D with shape: channels x trials x time
for x in emg_data_grouped:
if len(x) > 1:
emg_diff_data.append(np.squeeze(np.diff(x,axis=0)))
elif len(x) > 2:
raise Exception("More than 2 per EMG CAR currently not supported")
else:
emg_diff_data.append(np.squeeze(x))
for y in x:
if len(y.shape) < 3:
y = np.expand_dims(y, axis=0)

emg_diff_data = []
for this_car in emg_data_grouped:
this_car_diff = []
for this_dig in this_car:
if len(this_dig) > 1:
this_car_diff.append(np.squeeze(np.diff(this_dig,axis=0)))
elif len(this_dig) > 2:
raise Exception("More than 2 per EMG CAR currently not supported")
else:
this_car_diff.append(np.squeeze(this_dig))
emg_diff_data.append(this_car_diff)

# Iterate over trials and apply frequency filter
iters = list(np.ndindex(emg_diff_data[0].shape[:-1]))
emg_filt_list = []
emg_env_list = []
for x in emg_diff_data:
emg_filt = np.zeros(x.shape)
emg_env = np.zeros(x.shape)
for this_iter in iters:
temp_filt = filtfilt(m, n, x[this_iter[0], this_iter[1]])
emg_filt[this_iter[0], this_iter[1]] = temp_filt
emg_env[this_iter[0], this_iter[1]] = filtfilt(c, d, np.abs(temp_filt))
emg_filt_list.append(emg_filt)
emg_env_list.append(emg_env)
for car_group in emg_diff_data:
this_car_filt = []
this_car_env = []
for dig_in in car_group:
emg_filt = np.zeros(dig_in.shape)
emg_env = np.zeros(dig_in.shape)
temp_filt = filtfilt(m, n, dig_in)
temp_env = filtfilt(c, d, np.abs(temp_filt))
this_car_filt.append(temp_filt)
this_car_env.append(temp_env)
emg_filt_list.append(this_car_filt)
emg_env_list.append(this_car_env)

# Iterate and check for signficant changes in activity
n_cars = len(emg_diff_data)
n_dig = len(emg_diff_data[0])
trial_lens = [[len(x) for x in y] for y in emg_diff_data]

ind_frame = pd.DataFrame(
dict(
car_group = [x for x in range(n_cars) for y in range(n_dig)],
dig_in = [y for x in range(n_cars) for y in range(n_dig)],
trial_len = [np.arange(trial_lens[x][y]) for x in range(n_cars) for y in range(n_dig)]
)
)
# Explode the trial_len column
ind_frame = ind_frame.explode('trial_len')

sig_trials_list = []
for i in range(len(emg_diff_data)):
for i, this_row in ind_frame.iterrows():
this_ind = this_row.values
this_dat = emg_filt_list[this_ind[0]][this_ind[1]][this_ind[2]]
## Get mean and std of baseline emg activity,
## and use it to select trials that have significant post stimulus activity
# sig_trials (assumed shape) : tastes x trials
pre_m = np.mean(np.abs(emg_filt_list[i][...,:pre_stim]), axis = (2))
pre_s = np.std(np.abs(emg_filt_list[i][...,:pre_stim]), axis = (2))
pre_m = np.mean(np.abs(this_dat[pre_stim]))
pre_s = np.std(np.abs(this_dat[pre_stim]))

post_m = np.mean(np.abs(emg_filt_list[i][...,pre_stim:]), axis = (2))
post_max = np.max(np.abs(emg_filt_list[i][...,pre_stim:]), axis = (2))
post_m = np.mean(np.abs(this_dat[pre_stim:]))
post_max = np.max(np.abs(this_dat[pre_stim:]))

# If any of the channels passes the criteria, select that trial as significant
# 1) mean post-stim activity > mean pre-stim activity
Expand All @@ -119,19 +161,34 @@
sig_trials = mean_bool * std_bool
sig_trials_list.append(sig_trials)

ind_frame['sig_trials'] = sig_trials_list

# NOTE: Currently DIFFERENT sig_trials for each channel
# Save the highpass filtered signal,
# the envelope and the indicator of significant trials as a np array
# Iterate over channels and save them in different directories
for num,this_name in enumerate(emg_car_names):
#dir_path = f'emg_output/emg_channel{num}'
dir_path = f'emg_output/{this_name}'
if os.path.exists(dir_path):
shutil.rmtree(dir_path)
os.makedirs(dir_path)
# emg_filt (output shape): tastes x trials x time
np.save(os.path.join(dir_path, f'emg_filt.npy'), emg_filt_list[num])
# env (output shape): tastes x trials x time
np.save(os.path.join(dir_path, f'emg_env.npy'), emg_env_list[num])
# sig_trials (output shape): tastes x trials
np.save(os.path.join(dir_path, 'sig_trials.npy'), sig_trials_list[num])
ind_frame.to_hdf(metadata_handler.hdf5_name, '/emd_data/emg_sig_trials')

with tables.open_file(metadata_handler.hdf5_name, 'r+') as hf5:
for digin_ind, digin_name in enumerate(emg_digin_names):
digin_path = f'/emg_data/{digin_name}'
if f'{digin_path}/processed_emg' in hf5:
hf5.remove_node(f'{digin_path}/processed_emg')
hf5.create_group(f'{digin_path}', 'processed_emg')
for car_ind , this_car_name in enumerate(emg_car_names):
# emg_filt (output shape): tastes x trials x time
# np.save(os.path.join(dir_path, f'emg_filt.npy'), emg_filt_list[num])
hf5.create_array(
f'{digin_path}/processed_emg',
f'{this_car_name}_emg_filt',
emg_filt_list[car_ind][digin_ind]
)
# env (output shape): tastes x trials x time
# np.save(os.path.join(dir_path, f'emg_env.npy'), emg_env_list[num])
hf5.create_array(
f'{digin_path}/processed_emg',
f'{this_car_name}_emg_env',
emg_env_list[car_ind][digin_ind]
)
# sig_trials (output shape): tastes x trials
# np.save(os.path.join(dir_path, 'sig_trials.npy'), sig_trials_list[num])

0 comments on commit e9a5857

Please sign in to comment.