In [1]:
import json
import numpy as np
import torch
import torchaudio
from tqdm import tqdm

In [2]:
channel_based_annot_path_dev = "../bipolar_eeg_dataset/train_channel_based_annotation.json"
term_based_annot_path_dev = "../bipolar_eeg_dataset/trian_term_based_annotation.json"

Read `channel-based` annotation to a python dictionary

In [3]:
with open(channel_based_annot_path_dev,"r") as channel_json:
    channel_annotation= json.load(channel_json)

In [11]:

def create_window(eeg_sample,channel_annot,s_freq,w=20):
    context_length = w * s_freq
    sample_length = eeg_sample.shape[-1]

    padding_size = int(context_length * torch.ceil(torch.tensor(sample_length/context_length)).item())

    padded_zero = torch.zeros(eeg_sample.shape[0],padding_size)

    padded_zero[...,0:sample_length] = eeg_sample

    padded_zero = padded_zero.view(-1,padded_zero.shape[0],context_length)


    target = torch.zeros(padded_zero.shape[0],padded_zero.shape[1])

    for idx in range(target.shape[0]):
        channel_labels_tensor = torch.zeros(target.shape[1])
        channel_labels = []

        for i,(channel,labels) in enumerate(channel_annot.items()):
            for label in labels:
                start_time,stop_time,c = label

                sample_start_time = idx * w
                sample_stop_time = (idx + 1) * w
                if sample_start_time >= start_time and sample_stop_time <= stop_time:

                    channel_labels.append(0 if c =="bckg" else 1)
        channel_labels_tensor[0:len(channel_labels)] = torch.tensor(channel_labels,dtype=torch.float32)

        target[idx,...] =channel_labels_tensor
    return padded_zero,target 

In [15]:
def sample_and_load(annotation,new_s_freq = 256,window=20,batch_size=4):
    default_channel_nums = 22
    sample_freq = annotation["s_freq"]  # sampleing freqeuncy
    montage = annotation["montage"]


    resampler = torchaudio.transforms.Resample(sample_freq,new_s_freq)

    with np.load(annotation["npz_filepath"]) as npz_file:
        raw_eeg = npz_file["arr_0"]

    raw_eeg = torch.from_numpy(raw_eeg).to(torch.float32)
    # resample
    raw_eeg_resample = resampler(raw_eeg)

    if montage not in ["01_tcp_ar","02_tcp_le"]:
        zero_eeg = torch.zeros(default_channel_nums,raw_eeg_resample.shape[-1])
        zero_eeg[0:raw_eeg_resample.shape[0],...] = raw_eeg_resample

        raw_eeg_resample = zero_eeg

    x,y =  create_window(raw_eeg_resample,annotation["channel_annot"],new_s_freq,window)

    return x,y

    #return get_batch(x,y,batch_size)
    

Filter reading with only `bckg` annotation for the whole sample duration.

In [17]:
filtered_annotations = []
for annots in tqdm(channel_annotation,total=len(channel_annotation)):
    x,y = sample_and_load(annots)
    if y.sum() > 0:
        filtered_annotations.append(annots)

100%|██████████| 4599/4599 [10:26<00:00,  7.34it/s]


In [18]:
print(f"Initial num. of EEG readings = {len(channel_annotation)}")
print(f"EEG readings after filtering = {len(filtered_annotations)}")


Initial num. of EEG readings = 4599
EEG readings after filtering = 723


Find the total EEG reading duration in the filtered dataset. In `minutes` or `hours`

In [19]:
def total_reading_duration_per_class(channel_annotation,cond=None):
   duration_for_all = []
   for annots in channel_annotation:
      durations = []
      for channel_values in annots["channel_annot"].values():
         ch_values = []
         for v in channel_values:
            if v[-1] == cond:
               ch_values.append(v[1])   

         if len(ch_values) == 0:
            durations.append(0)

         if len(ch_values) == 1:
            durations.append(ch_values[0])

         if len(ch_values) > 2:
            ch_values = sorted(ch_values)
            s = ch_values[0]

            for _ in ch_values:
               s += ch_values[1] - ch_values[0]
               ch_values = ch_values[1:]
               if len(ch_values) < 2:
                  break
            durations.append(s)

      if len(durations) == 0:
         duration_for_all.append(0)
      else:      
         duration_for_all.append(sum(durations)/len(durations))
      
   return sum(duration_for_all)

In [20]:
class_dict = {0: '(null)', 1: 'spsw', 2: 'gped', 3: 'pled', 4: 'eyem', 5: 'artf', 6: 'bckg', 7: 'seiz', 8: 'fnsz', 9: 'gnsz', 10: 'spsz', 11: 'cpsz', 12: 'absz', 13: 'tnsz', 14: 'cnsz', 15: 'tcsz', 16: 'atsz', 17: 'mysz', 18: 'nesz', 19: 'intr', 20: 'slow', 21: 'eyem', 22: 'chew', 23: 'shiv', 24: 'musc', 25: 'elpp', 26: 'elst', 27: 'calb'}

In [23]:
total_hours = 0
for class_name in class_dict.values(): 
# of those how many hours are the background class
  hours = total_reading_duration_per_class(filtered_annotations,class_name) / 3600
  total_hours += hours
  if hours == 0.:
    continue
  print(f"Training EEG duration for `{class_name}` = ~ {hours :.3f} hours")

print(f"Total Training EEG duration = ~ {total_hours:.3f} hours")

Training EEG duration for `bckg` = ~ 139.106 hours
Training EEG duration for `fnsz` = ~ 19.317 hours
Training EEG duration for `gnsz` = ~ 15.241 hours
Training EEG duration for `spsz` = ~ 1.004 hours
Training EEG duration for `cpsz` = ~ 8.434 hours
Training EEG duration for `tnsz` = ~ 0.718 hours
Training EEG duration for `tcsz` = ~ 1.466 hours
Training EEG duration for `mysz` = ~ 0.362 hours
Total Training EEG duration = ~ 185.649 hours


Write filtered annotations to file.

In [80]:
with open("../bipolar_eeg_dataset/dev_filtred_channel_based.json","w") as f:
    json.dump(filtered_annotations,f)