In [3]:
# Importing necessary library
import numpy as np
import matplotlib.pyplot as plt
import pywt
import pandas as pd
import os
from scipy import signal
from stingray import lightcurve
import sys
from stingray import Bispectrum
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline



### Menghitung Matriks Cumulant orde ke-3

In [8]:
def calcCumulantOrde3(df_data, t, lag):
    # Compute the bispectrum of the signal
    lc = lightcurve.Lightcurve(t,df_data.T)
    bs = Bispectrum(lc, maxlag=lag)

    # Plot the bispectrum using contour plots
    # plt.contour(bs.freq, bs.freq, bs.bispec_mag)
    # plt.xlabel('f1')
    # plt.ylabel('f2')
    # plt.show()

    # # Plot the bispectrum using mesh plots
    # fig = plt.figure()
    # ax = fig.add_subplot(111, projection='3d')
    # X, Y = np.meshgrid(bs.freq, bs.freq)
    # ax.plot_surface(X, Y, bs.bispec_mag)
    # ax.set_xlabel('f1')
    # ax.set_ylabel('f2')
    # ax.set_zlabel('Bispectrum')
    # plt.show()

    return bs

### Melakukan dekomposisi wavelet

In [7]:
def calcWaveletDec(bs):
    # Select wavelet and decomposition level
    wavelet = 'db4'
    level = 5

    # Deecompose signal
    coeffs = pywt.wavedec(bs.cum3, wavelet, level=level)
    
    # # Visualize
    # approximations = []
    # details = []
    # for i in range(level):
    #     approximations.append(coeffs[i])
    #     details.append(coeffs[level - i])

    # fig, axs = plt.subplots(len(coeffs), sharex=True)
    # for i, c in enumerate(coeffs):
    #     axs[i].plot(c)
    #     axs[i].set_ylabel(f'Level {i}')
    # plt.show()
    return coeffs

### Menghitung energi relatif

In [6]:
def calcRelativeEnergy(coeffs, df_data):
    # Calculate relative wavelet energy
    energies = []
    for c in coeffs:
        energies.append(np.sum(np.square(c)))

    decomp = ['A5', 'D1', 'D2', 'D3', 'D4', 'D5']

    temp = energies
    energies[1:6] = energies[-1:-6:-1]

    total_energy = np.sum(np.square(df_data.T))
    relative_energies = [e / total_energy for e in energies]

    # plt.plot(decomp, energies)
    # plt.xlabel('Dimension Number')
    # plt.ylabel('Wavelet Bispectrum Energy')
    # plt.show()

    # plt.plot(decomp, relative_energies)
    # plt.xlabel('Dimension Number')
    # plt.ylabel('Relative Wavelet Bispectrum Energy')
    # plt.show()

    return energies, relative_energies

### Persiapan data

In [5]:
# Define sampling frequency
fs = 256
t = np.arange(0, 1, 1/fs)

def get_csv_EEG(filename):
    # Load data from CSV
    data = np.loadtxt(filename, delimiter=",", skiprows=1, usecols=range(3,259))
    channel_name = np.loadtxt(filename, delimiter=",", skiprows=1, usecols=1, dtype='str', encoding='utf-8')
    
    df_data = pd.DataFrame(data.T, columns=channel_name)

    df_data = df_data.drop(columns=['X', 'Y', 'nd'])

    return df_data, df_data.columns


### Perhitungan RWB

In [18]:
def extract_feature(directory, lag):
    for foldername in os.listdir(directory):
        folder = os.path.join(directory, foldername)
        if os.path.isdir(folder):
            des_dir = os.path.join(directory.replace('CSV', 'FEATURE')+"_" + str(lag),foldername).lower()
            files = os.listdir(folder)
            for filename in files:
                rel_path = os.path.join(directory, foldername, filename)
                if 'metadata' in filename.lower():
                    continue
                trial_number = filename.split('.')[0].split('_')[1]
                df_data, channel_name = get_csv_EEG(rel_path)
                RWB = []
                for channel in channel_name:
                    energies, relative_energies = calcRelativeEnergy(calcWaveletDec(calcCumulantOrde3(df_data[channel], t, lag)), df_data[channel])
                    RWB = np.append(RWB, relative_energies)
                des_file = foldername+'_'+ str(trial_number) + '_feature' +'.csv'
                if not os.path.exists(des_dir):
                    os.makedirs(des_dir)
                des_path = os.path.join(des_dir, des_file)
                np.savetxt(des_path, RWB.T, delimiter =", ", fmt ='% s')
                # pd.DataFrame(RWB.T).to_csv(des_path, index=False)
        


In [19]:
extract_feature('../SMNI_CMI_TEST_CSV', 256)

KeyboardInterrupt: 