In [None]:
import mne
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from sklearn.preprocessing import StandardScaler
from datetime import timedelta
import pyedflib

def process_ecog_data(data_path):
    # Load EDF file and prepare data
    data = mne.io.read_raw_edf(data_path, preload=True)
    df = data.to_data_frame()
    X = StandardScaler().fit_transform(df.drop(['time'], axis=1))
    scaled_X = pd.DataFrame(data=X, columns=['FrL', 'FrR', 'OcR'])

    # Frame segmentation function
    def get_frames(df, frame_size, hop_size):
        N_FEATURES = 3
        frames = []
        for i in range(0, len(df) - frame_size, hop_size):
            frl = df['FrL'].values[i: i + frame_size]
            frr = df['FrR'].values[i: i + frame_size]
            ocr = df['OcR'].values[i: i + frame_size]
            frames.append([frl, frr, ocr])
        return np.asarray(frames).reshape(-1, frame_size, N_FEATURES)

    # Segment data into frames
    Fs = 400
    frame_size = Fs * 10
    hop_size = Fs * 3
    X = get_frames(scaled_X, frame_size, hop_size)

    # Load pre-trained model and make predictions
    model = load_model('model.keras')
    y_pred_prob = model.predict(X)
    y_pred = np.argmax(y_pred_prob, axis=1)

    # Mapping labels and converting index to time
    label_map = {0: 'ds', 1: 'is', 2: 'none', 3: 'swd'}
    def index_to_time(index, hop_size, fs=400):
        seconds = (index * hop_size) / fs
        return str(timedelta(seconds=seconds))

    # Compile predictions with timestamps
    predictions = []
    for idx, pred in enumerate(y_pred):
        time_str = index_to_time(idx, hop_size, fs=Fs)
        marker = label_map.get(pred, 'unknown')
        predictions.append([time_str, marker])

    predictions_df = pd.DataFrame(predictions, columns=['time', 'marker'])

    # Merge sequences of identical markers
    def merge_sequences(df, min_length=3):
        result = []
        i = 0
        while i < len(df):
            marker = df.loc[i, 'marker']
            if marker == 'none':
                i += 1
                continue
            start_idx = i
            while i < len(df) and df.loc[i, 'marker'] == marker:
                i += 1
            length = i - start_idx
            if length >= min_length:
                result.append({'time': df.loc[start_idx, 'time'], 'marker': f"{marker}1"})
                result.append({'time': df.loc[i - 1, 'time'], 'marker': f"{marker}2"})
        return pd.DataFrame(result)

    output_df = merge_sequences(predictions_df)

    # Convert the DataFrame timestamps to total seconds (timedelta) for annotation
    annotations = []
    for _, row in output_df.iterrows():
        time = pd.to_timedelta(row['time'])
        label = row['marker']
        # Set annotation duration to 1 second (you can adjust this value based on your requirements)
        annotations.append([time.total_seconds(), time.total_seconds() + 1, label])

    # Convert to MNE Annotations format
    annotation_times = [a[0] for a in annotations]
    annotation_durations = [a[1] - a[0] for a in annotations]
    annotation_labels = [a[2] for a in annotations]
    annot = mne.Annotations(onset=annotation_times, duration=annotation_durations, description=annotation_labels)

    # Apply annotations to the raw data
    data.set_annotations(annot)

    # Save annotated data back to an EDF file using pyedflib
    output_edf_path = data_path.replace('.edf', '_processed_with_annotations.edf')
    
    # Write data and annotations using pyedflib
    with pyedflib.EdfWriter(output_edf_path, len(data.info['ch_names'])) as f:
        # Set the sample frequency for each channel
        for ch_idx in range(len(data.info['ch_names'])):
            f.setSignalHeader(ch_idx, {'sample_rate': 400})  # set the sample rate for each channel

        # Write signal data for each channel
        for ch_idx in range(len(data.info['ch_names'])):
            f.writePhysicalSamples(data.get_data(picks=ch_idx)[0])

        # Add annotations (events/markers)
        for annotation in annotations:
            onset, duration, label = annotation
            f.writeAnnotation(onset, duration, label)

    # Return DataFrame and EDF file path
    return output_df, output_edf_path