Skip to content

Commit

Permalink
Small fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Oct 15, 2023
1 parent 517ee58 commit 99e1450
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions pywavesurfer/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,18 @@ def __init__(self, filename, format_string="double"):

self.analog_channel_scales, self.analog_scaling_coefficients, self.n_a_i_channels = self.get_scaling_coefficients()

def close_file(self):
if not self.file.closed():
self.file.close()

def __enter__(self):
""" This and `__exit__` ensure the class can be
used in a `with` statement.
"""
return self

def __exit__(self):
def __exit__(self, exception_type, exception_value, traceback):
self.close_file()

def close_file(self):
self.file.close()

# ----------------------------------------------------------------------------------
# Fill Metadata Dict
# ----------------------------------------------------------------------------------
Expand Down Expand Up @@ -247,7 +246,7 @@ def get_traces(self, segment_index, start_frame, end_frame, return_scaled=True):
the `format_string` argument passed during class construction.
"""
ordered_sweep_names = self.get_ordered_sweep_names()
sweep_name = ordered_field_names[segment_index]
sweep_name = ordered_sweep_names[segment_index]

# Index out the data and scale if required.
if sweep_name[0:5] == "sweep":
Expand Down Expand Up @@ -276,7 +275,8 @@ def get_traces(self, segment_index, start_frame, end_frame, return_scaled=True):

def get_ordered_sweep_names(self):
""" Take the data field names (e.g. sweep_0001, sweep_0002), ensure they
are in the correct order and index according to `segment_index`.
are in the correct order and index according to `segment_index`. Note
this function will treat 'sweep' or 'trial' as the same.
"""
field_names = [name for name in self.file if name[0:5] in ["sweep", "trial"]]
sweep_nums = [int(ele[6:]) for ele in field_names]
Expand All @@ -289,6 +289,7 @@ def load_all_data(self):
A convenience function to load into the `data_file_as_dict`
all data in the file.
"""
return_scaled = False if self.format_string == "raw" else True
idx = 0
for field_name in self.file:

Expand All @@ -299,7 +300,7 @@ def load_all_data(self):
else:
num_samples = self.file[field_name].size

scaled_analog_data = self.get_traces(segment_index=idx, start_frame=0, end_frame=num_samples)
scaled_analog_data = self.get_traces(segment_index=idx, start_frame=0, end_frame=num_samples, return_scaled=return_scaled)

if field_name[0:5] == "sweep":
self.data_file_as_dict[field_name]["analogScans"] = scaled_analog_data
Expand Down

0 comments on commit 99e1450

Please sign in to comment.