# Starting the Imports

In [2]:
# If you are running this notebook on Google Colab, you need to install the following packages
# !pip install pycrostates
# !pip install torch
# !pip install mne-microstates


# Imports

In [3]:
import mne
from pycrostates.cluster import ModKMeans
from pycrostates.preprocessing import extract_gfp_peaks
from mne.io import read_raw_edf
import os
import numpy as np
import pandas as pd

# Initiate Microstates Generation

#### Returning files into a RAW format

In [7]:
def read_file_return_raw(filename: str):
    raw = read_raw_edf(filename, preload=True)
    raw.pick('eeg')
    raw.set_eeg_reference('average')
    ten_twenty_montage = mne.channels.make_standard_montage('standard_1020')
    raw.set_montage(ten_twenty_montage)

    return raw

## Getting the GFP

In [8]:
def generate_microstates(gfp_data, name: str):
    modK = ModKMeans(n_clusters=90, random_state=45)
    modK.fit(gfp_data, n_jobs=24, verbose="WARNING")
    modK.plot().savefig(rf"D:\Codes\NEWTCC\generate_microstates\normalized_data\sch\{name.split(".")[0]}.png")

## Start the Microstates Generation

## MAIN

In [None]:
def get_gfp(raw):
    gfp_data = extract_gfp_peaks(raw, min_peak_distance=3)
    return gfp_data

def csv_to_raw_array(csv_file_path, sfreq=250):
    """
    Convert a CSV file with EEG channel data into an mne.io.RawArray object.
    
    Parameters:
    csv_file_path (str): Path to the CSV file.
    sfreq (float): Sampling frequency of the data. Default is 250 Hz.
    
    Returns:
    raw (mne.io.RawArray): The raw EEG data as an MNE RawArray object.
    """
    # Read the CSV file into a DataFrame
    df = pd.read_csv(csv_file_path)

    # Split the space-separated values and convert to floats
    data = []
    for column in df.columns:
        # For each column, split the string values into floats
        channel_data = [list(map(float, value.split())) for value in df[column]]
        # Flatten the list of lists
        channel_data_flat = [item for sublist in channel_data for item in sublist]
        data.append(channel_data_flat)

    # Convert the data to a NumPy array and transpose it
    data = np.array(data)
    
    # Define the channel names (from your original header)
    channel_names = ['F7', 'F3', 'F4', 'F8', 'T3', 'C3', 'Cz', 'C4', 'T4', 'T5', 
                     'P3', 'Pz', 'P4', 'T6', 'O1', 'O2']
    
    # Define the channel types (assuming EEG for all)
    channel_types = ['eeg'] * len(channel_names)
    
    # Create the info object
    info = mne.create_info(ch_names=channel_names, sfreq=sfreq, ch_types=channel_types)
    
    # Create the RawArray object
    raw = mne.io.RawArray(data, info)
    raw.pick('eeg')
    raw.set_eeg_reference('average')
    ten_twenty_montage = mne.channels.make_standard_montage('standard_1020')
    raw.set_montage(ten_twenty_montage)
    
    return raw

for filename in os.listdir(r'D:\Codes\NEWTCC\database\csv\schz'):
    if filename.endswith(".csv"):
        csv_file_path = os.path.join(r'D:\Codes\NEWTCC\database\csv\schz', filename)
        raw = csv_to_raw_array(csv_file_path)
        gfp_data = get_gfp(raw)
        generate_microstates(gfp_data, f"{filename}")