In [2]:
import pyedflib as edflib
import numpy as np
import datetime
import os
import pandas as pd


In [None]:
Default_text = 'UTF-8'

class EdfWriter(object):

    def tb(self, x):
        
        if hasattr(x, 'encode'):
            return x.encode(self.Text_Encoding)
        else:
            return x

    def __exit__(self, exc_type, exc_val, ex_tb):
        self.close()

    def __init__(self, file_name, channel_info, file_type = edflib.Filetype_EDFPlus, **kwargs):
        
        self.Text_Encoding = Default_text 
        self.path = file_name
        self.file_type = file_type
        self.n_channels = len(channel_info)
        self.channels = {}
        for c in channel_info:
            if c['label'] in self.channels:
                raise ChannelLabelExists(c['label'])
            self.channels[c['label']] = c
        self.sample_buffer = dict([(c['label'], []) for c in channel_info])
        self.handle = edflib.open_file_writeonly(
            file_name.encode(
                self.Text_Encoding),
            file_type,
            self.n_channels)
        self.__init__constants(**kwargs)
        self.__init__channels(channel_info)

    def write_sample(self, channel_label, sample):
        
        if channel_label not in self.channels:
            raise ChannelDoesNotExist(channel_label)
        self.sample_buffer[channel_label].append(sample)
        if len(self.sample_buffer[channel_label]) == self.channels[channel_label]['sample_rate']:
            self.flush_samples()

    def close(self):
        if self.handle >= 0:
            edflib.close_file(self.handle)

    def _init_constants(self, **kwargs):
        def call_if_set(fn, kw_name):
            item = kwargs.pop(kw_name, None)
            if item is not None:
                fn(self.handle, item)
        call_if_set(edflib.set_technician, 'technician')
        call_if_set(edflib.set_recording_additional, 'recording_additional')
        call_if_set(edflib.set_patientname, 'patient_name')
        call_if_set(edflib.set_patient_additional, 'patient_additional')
        call_if_set(edflib.set_equipment, 'equipment')
        call_if_set(edflib.set_admincode, 'admincode')
        call_if_set(edflib.set_gender, 'gender')
        call_if_set(edflib.set_datarecord_duration, 'duration')
        call_if_set((lambda hdl,
                     dt: edflib.set_startdatetime(hdl,
                                                   dt.year,
                                                   dt.month,
                                                   dt.day,
                                                   dt.hour,
                                                   dt.minute,
                                                   dt.second)),
                    'recording_start_time')
        call_if_set((lambda hdl,
                     dt: edflib.set_birthdate(hdl,
                                               dt.year,
                                               dt.month,
                                               dt.day)),
                    'patient_birthdate')
        if len(kwargs) > 0:
            raise Exception('Unhandled argument(s) given: %r' % list(kwargs.keys()))

    def _init_channels(self, channels):
        hdl = self.handle
        print('in init channels')

        print('channels::\n', repr(channels))

        def call_per_channel(fn, name, optional=False):
            for i, c in enumerate(channels):
                if optional and not name in c:
                    continue
                fn(hdl, i, self.tb(c[name]))

        call_per_channel(edflib.set_samplefrequency, 'sample_rate')
        call_per_channel(edflib.set_physical_maximum, 'physical_max')
        call_per_channel(edflib.set_digital_maximum, 'digital_max')
        call_per_channel(edflib.set_digital_minimum, 'digital_min')
        call_per_channel(edflib.set_physical_minimum, 'physical_min')
        call_per_channel(edflib.set_label, 'label')
        call_per_channel(edflib.set_physical_dimension, 'dimension')
        call_per_channel(edflib.set_transducer, 'transducer', optional=True)
        call_per_channel(edflib.set_prefilter, 'prefilter', optional=True)

    def _flush_samples(self):
        for c in self.channels:
            buf = np.array(self.sample_buffer[c], dtype='int32') '
            edflib.write_digital_samples(self.handle, buf)
            self.sample_buffer[c] = []