In [4]:
import glob
import os
import pandas
import tqdm


path = '/path/to/edf/files'
target = "gnsz"
channel = "CZ"
sample_duration = 4


assert os.path.isdir(path)


def _has_interesting_channel(channel_spec: str, channels: list[str]):
    channels_in_spec = channel_spec.split('-')
    for channel in channels_in_spec:
        if channel in channels:
            return True
    return False


def _merge_ranges(ranges: list[tuple[int, int]]):
    for range in ranges:
        assert range[0] < range[1]
        if range[0] >= range[1]:
            raise ValueError()
        
    sorted_ranges = sorted(ranges, key = lambda x: x[0])
    i = 0

    while i < len(sorted_ranges) - 1:
        if sorted_ranges[i][1] >= sorted_ranges[i + 1][0]:
            sorted_ranges[i] = (sorted_ranges[i][0], max(sorted_ranges[i][1], sorted_ranges[i + 1][1]))
            del sorted_ranges[i + 1]
        else:
            i += 1
    
    return sorted_ranges


def _get_ranges_for_labels(csv_file_path: str, labels: list[str], channels: list[str]):
    csv_data = pandas.read_csv(csv_file_path, delimiter = ",", skiprows = 5)
    interesting_data = csv_data[csv_data['label'].isin(labels) &
                                csv_data['channel'].apply(lambda x: _has_interesting_channel(x, channels))]
    
    ranges_for_labels = {}

    for label in labels:
        data_for_label = interesting_data[interesting_data['label'] == label]

        if not data_for_label.empty:
            ranges = list(zip(data_for_label['start_time'], data_for_label['stop_time']))
            try:
                merged_ranges = _merge_ranges(ranges)
                ranges_for_labels[label] = merged_ranges
            except ValueError:
                print(f"{csv_file_path}: Could not determine ranges for label {label}: {ranges}")
                pass
    
    return ranges_for_labels


def _filenames_of_edf_csv_pairs(edf_dir: str):
    edf_files = glob.glob(os.path.join(edf_dir, "*.edf"))
    filenames = [os.path.splitext(os.path.basename(x))[0] for x in edf_files]
    filenames_with_csv_and_edf = [x for x in filenames if os.path.isfile(os.path.join(edf_dir, x + ".csv")) ]# and os.path.isfile(os.path.join(edf_dir, ".edf"))] # not necessary since the initial list is based on edf files existing.
    return sorted(filenames_with_csv_and_edf)


filenames = _filenames_of_edf_csv_pairs(path)
csv_files = [os.path.join(path, x + ".csv") for x in filenames]
edf_files = [os.path.join(path, x + ".edf") for x in filenames]
sum = 0

for csv_file in tqdm.tqdm(csv_files):
    ranges = _get_ranges_for_labels(csv_file, [target], [channel])
    if len(ranges) > 0:
        for range in ranges[list(ranges.keys())[0]]:
            sum += range[1] - range[0]


print(f"Duration of {target} on {channel} over {len(csv_files)} files:")
print(f"{sum} s")
print(f"{sum / 60} min")
print(f"{sum / 3600} h")
print(f"Using sample duration of {sample_duration} seconds this gives about {int(sum // sample_duration)} samples.")


100%|██████████| 4441/4441 [00:04<00:00, 1074.99it/s]

Duration of gnsz on CZ over 4441 files:
40031.90290000002 s
667.198381666667 min
11.119973027777784 h
Using sample duration of 4 seconds this gives about 10007 samples.



