### Feature Extraction

#### Modules

In [None]:
import mne
import numpy as np
from scipy import signal
from collections import defaultdict

#### PSD

In [None]:
# PSD

def PSD(signals, fs=500, nperseg=500, overlap=250):
    trials, samples = signals.shape
    psd_values = []
    for trial in range(trials):
        freqs, psd = signal.welch(signals[trial], fs, nperseg=nperseg, noverlap=overlap, window='hamming', scaling='density')
        psd_values.append(psd[0:40])
    
    psd_values = np.array(psd_values)
    return psd_values, freqs

#### DDWT

In [None]:
# DDWT

def Wavelet(signals, fs=500, nperseg=500, overlap=250):
    trials, samples = signals.shape
    wavelet_values = []
    for trial in range(trials):
        cA5,cD5,cD4,cD3,cD2,cD1 = pywt.wavedec(signals[trial],'bior6.8',mode='periodization',level=5)
        wavelet_values.append(np.concatenate((cD5,cD4)))
    
    wavelet_values = np.array(wavelet_values)
    return wavelet_valuesfilter(l_freq=8, h_freq=30, fir_design='firwin', phase='zero-double', verbose=False) #verbose = True to print filter features

#### WPT+LDB

In [None]:
class WP_LDB():
    def __init__(self, signals_array, classes_array, wavelet='bior2.8', levels_decomposition=5):
        self.signals_array = signals_array
        self.classes_array = classes_array
        self.wavelet = wavelet
        self.levels_decomposition = levels_decomposition
    
    def WP_energy(self):

        def wp_decomposition(signal):
            wpt = pywt.WaveletPacket(data=signal, wavelet=self.wavelet, mode='periodization',maxlevel=self.levels_decomposition)
            return wpt
        
        def wp_energy_decomposition(wpt):
            energy_per_node = {}
            node_names = [] 
            total_coefficients_energy = 0.0
            
            for level in range(self.levels_decomposition + 1):
                nodes_at_level = wpt.get_level(level)
                for node in nodes_at_level:
                    energy = np.sum(node.data ** 2)
                    total_coefficients_energy += energy
            
            for level in range(self.levels_decomposition + 1):
                nodes_at_level = wpt.get_level(level)
                for node in nodes_at_level:
                    energy = np.sum(node.data ** 2)
                    energy_per_node[node.path] = energy / total_coefficients_energy
                    node_names.append(node.path)           
            return energy_per_node, node_names

        def Energy_nodes(data, nodes_names):
            energy_per_level = defaultdict(list)  
            for trial in data: 
                for node, value in trial.items():  
                    energy_per_level[node].append(value) 
            return dict(energy_per_level) 

        WP_class_1 = []
        WP_class_0 = []
        all_node_names = []
        for signal, cls in zip(self.signals_array,self.classes_array):
            wp = wp_decomposition(signal)
            energy_per_node, node_names = wp_energy_decomposition(wp)
            for name in node_names:
                if name not in all_node_names:
                    all_node_names.append(name)
            if cls == 1:           
                WP_class_1.append(energy_per_node)
            else:
                WP_class_0.append(energy_per_node)

        for trial in WP_class_0 + WP_class_1:
            for node in all_node_names:
                if node not in trial:
                    trial[node] = 0.0
            
        Enode_class0 = {node: [trial[node] for trial in WP_class_0] for node in all_node_names}
        Enode_class1 = {node: [trial[node] for trial in WP_class_1] for node in all_node_names}

        return Enode_class0, Enode_class1, all_node_names
    
    def WP_discrimination(self, Enode_class0, Enode_class1, nodes_names, dis_method = 'KL'):

        def Fisher_Criteria(Class_0, Class_1):
            mu_c0 = np.mean(Class_0)
            mu_c1 = np.mean(Class_1)
            sigma_c0 = np.var(Class_0, ddof=1)  
            sigma_c1 = np.var(Class_1, ddof=1)
    
            denominator = sigma_c0 + sigma_c1
            if denominator == 0:
                return 0.0  
            FD = (mu_c0 - mu_c1)**2 / denominator
            return FD
            # Cuanto mayor sea FD mas separadas est√°n las clases
        
        def kullback_Leibler(Class_0, Class_1, bins=20, epsilon=1e-10):
            min_val = min(np.min(Class_0), np.min(Class_1))
            max_val = max(np.max(Class_0), np.max(Class_1))
            
            hist_c0, _ = np.histogram(Class_0, bins=bins, range=(min_val, max_val), density=True)
            hist_c1, _ = np.histogram(Class_1, bins=bins, range=(min_val, max_val), density=True)
            
            hist_c0 += epsilon
            hist_c1 += epsilon
            
            p = hist_c0 / np.sum(hist_c0)
            q = hist_c1 / np.sum(hist_c1)
            
            kl_divergence = np.sum(p * np.log(p / q))
            return kl_divergence
        
        def diferential_energy(Class_0, Class_1):
            energy_A = np.mean(np.square(Class_0))
            energy_B = np.mean(np.square(Class_1))
            return np.abs(energy_A - energy_B)
        
        def compute_renyi_entropy(coeffs, alpha=2, bins=20, epsilon=1e-10):
            hist, _ = np.histogram(coeffs, bins=bins, density=True)
            hist += epsilon
            hist /= np.sum(hist) 
            if alpha == 1:
                return -np.sum(hist * np.log(hist))  
            else:
                return (1 / (1 - alpha)) * np.log(np.sum(hist ** alpha))
            
        def renyi_diff(Class_0, Class_1, alpha=2, bins=20):
            H0 = compute_renyi_entropy(Class_0, alpha, bins)
            H1 = compute_renyi_entropy(Class_1, alpha, bins)
            return np.abs(H0 - H1)

        
        Discriminant = []
        for node in nodes_names:
            if dis_method == 'KL':
                Discriminant.append(kullback_Leibler(Enode_class0[node],Enode_class1[node]))
            elif dis_method == 'FC':
                Discriminant.append(Fisher_Criteria(Enode_class0[node],Enode_class1[node]))
            elif dis_method == 'DE':
                Discriminant.append(diferential_energy(Enode_class0[node],Enode_class1[node]))
            elif dis_method == 'RY':
                Discriminant.append(renyi_diff(Enode_class0[node],Enode_class1[node]))
        nodes_names[0] = 'root'
        return nodes_names, Discriminant

from collections import deque

class Node:
    def __init__(self, score, name):
        self.score = score  
        self.name = name
        self.base = None    
        self.a = None    
        self.d = None

    @staticmethod
    def build_binary_tree(scores, names):
        if not scores or scores[0] is None:
            return None

        root = Node(scores[0], names[0])
        queue = deque([root])
        index = 1

        while queue and index < len(scores):
            current_node = queue.popleft()

            # Left child
            if index < len(scores) and scores[index] is not None:
                current_node.a = Node(scores[index], names[index])
                queue.append(current_node.a)
            index += 1

            # Right child
            if index < len(scores) and scores[index] is not None:
                current_node.d = Node(scores[index], names[index])
                queue.append(current_node.d)
            index += 1

        return root
    
    def print_tree(root):
        if not root:
            return

        queue = deque([(root, 0)])  
        niveles = {} 

        while queue:
            node, level = queue.popleft()
            
            if level not in niveles:
                niveles[level] = []
            niveles[level].append(f"{node.name}({node.score})")
            
            if node.a:
                queue.append((node.a, level + 1))
            if node.d:
                queue.append((node.d, level + 1))

        for level in sorted(niveles.keys()):
            print("    " * (len(niveles) - level), "  ".join(niveles[level]))

    
    def prune(self, modified_nodes = None):
        if self is None:
            return None
        
        if modified_nodes is None:
            modified_nodes = []  

        if self.a:
            self.a = self.a.prune()
        if self.d:
            self.d = self.d.prune()

        suma_hijos = 0
        if self.a:
            suma_hijos += self.a.score
        if self.d:
            suma_hijos += self.d.score

        if suma_hijos > self.score:
            if self.a and self.d:
                self.score = self.a.score + self.d.score
                modified_nodes.append(self.a.name)
                modified_nodes.append(self.d.name) 
                self.name = f'{self.a.name}-{self.d.name}' 
                self.a = None
                self.d = None
        else:
            if self.a:
                self.a = None
            if self.d:
                self.d = None
        
        self.base = self.name.split('-')
        return self
    
    def discriminant_base(self, discriminant, nodes):
        index = []
        scores = []
        for node in self.base:
            index.append(nodes.index(node))
        for i in index:
            scores.append(discriminant[i])
        combination = list(zip(scores, self.base))
        combination.sort(key=lambda x: x[0], reverse=True)
        scores, nodes = zip(*combination)
        #print(scores)
        #print(nodes)
        return nodes, scores

#### DAS-OMP

In [None]:
### Dictionary
def dictionary_WPT(samples=1024, max_level=5, mother_wavelet='db4'):
    import pywt
    from sklearn.preprocessing import normalize
    # Base signal
    signal = np.zeros(samples)

    # Base tree
    wp_original = pywt.WaveletPacket(data=signal, wavelet=mother_wavelet, mode='periodization', maxlevel=max_level)

    dictionary = []
    dictionary_columns = []

    for level in range(1, max_level + 1):
        nodes = wp_original.get_level(level, order='natural')
        for node in nodes:
            dictionary_columns.append([node.path, len(node.data)])
            for i in range(len(node.data)):
                # New Tree
                wp_temp = pywt.WaveletPacket(data=None, wavelet=mother_wavelet, mode='periodization', maxlevel=max_level)

                # Zero Coefficients
                for other_node in wp_original.get_level(level, order='natural'):
                    wp_temp[other_node.path] = np.zeros_like(other_node.data)

                # One coefficient
                impulse = np.zeros_like(node.data)
                impulse[i] = 1
                wp_temp[node.path] = impulse

                # Rebuild signal (atom)
                signal_rec = wp_temp.reconstruct(update=False)

                # Normalize L2
                norm = np.linalg.norm(signal_rec)
                if norm > 0:
                    signal_rec /= norm

                dictionary.append(signal_rec)

    dictionary = np.array(dictionary)
    dictionary = normalize(dictionary, axis=1) 
    return dictionary_columns, dictionary

### OMP
def omp(x, D, T):
    n, K = D.shape
    r = x.copy()
    support = []
    alpha = np.zeros(n)

    for _ in range(T):
        projections = D @ r  

        scores = np.abs(projections)

        scores[support] = -np.inf


        j_star = np.argmax(scores)
        support.append(j_star)


        D_sub = D.T[:, support]
        alpha_sub = np.linalg.pinv(D_sub) @ x


        r = x - D_sub @ alpha_sub


    alpha[support] = alpha_sub

    return alpha, support

### AFM
def afm_function(D, support_c1, support_c2):
    K = D.shape[0] 
    frq_c1 = np.zeros(K)
    frq_c2 = np.zeros(K)

    # Class 0
    for support in support_c1:         
        frq_c1[support] += 1
    frq_c1 /= len(support_c1)          

    # Class 1
    for support in support_c2:
        frq_c2[support] += 1
    frq_c2 /= len(support_c2)


    afm = np.zeros(K)

    for j in range(K):
        p1, p2 = frq_c1[j], frq_c2[j]
        pj_plus = max(p1, p2)
        pj_star = min(p1, p2)
        if pj_plus > 0:
            afm[j] = (pj_plus - pj_star) / pj_plus
        else:
            afm[j] = 0

    return afm

### MCM
def mcm_fuction(alpha_c1, alpha_c2):
    K = alpha_c1[0].shape[0]

    avg_c1 = np.mean([np.abs(a) for a in alpha_c1], axis=0)
    avg_c2 = np.mean([np.abs(a) for a in alpha_c2], axis=0)

    mcm = np.zeros(K)
    for j in range(K):
        pj1, pj2 = avg_c1[j], avg_c2[j]
        pj_plus = max(pj1, pj2)
        pj_star = min(pj1, pj2)
        if pj_plus > 0:
            mcm[j] = (pj_plus - pj_star) / pj_plus
        else:
            mcm[j] = 0

    return mcm

### REM
def group_by_class(X, y, num_classes):
    X_classes = [[] for _ in range(num_classes)]
    
    for x, label in zip(X, y):
        if label == 0 :    
            X_classes[0].append(x)
        else:
            X_classes[1].append(x)

    return X_classes

def rem_fast(X_classes, coeffs, D, num_classes):
    K, N = D.shape  
    mre = np.zeros(K)
    plus_class = np.zeros(K, dtype=int)

    recon_classes = []
    for c in range(num_classes):
        if len(X_classes[c]) == 0:
            recon_classes.append([])
            continue

        Xc = np.stack(X_classes[c]) 
        Ac = np.stack(coeffs[c])  

        recon_c = Ac @ D
        recon_classes.append(recon_c)

    for j in range(K):
        errors = []
        for c in range(num_classes):
            if len(X_classes[c]) == 0:
                errors.append(0)
                continue

            Xc = np.stack(X_classes[c])   
            Ac = np.stack(coeffs[c])       
            recon_full = recon_classes[c]  

            recon_j = recon_full - np.outer(Ac[:, j], D.T[:, j]) 

            err_full = np.sum((Xc - recon_full)**2, axis=1)  
            err_drop = np.sum((Xc - recon_j)**2, axis=1)     

            errors.append(np.mean(err_drop - err_full))  


        c_plus = np.argmax(errors)
        r_plus = errors[c_plus]
        r_star = np.max([e for k, e in enumerate(errors) if k != c_plus])

        mre[j] = (r_plus - r_star) / r_plus if r_plus > 0 else 0
        plus_class[j] = c_plus

    return mre, plus_class

def discriminative_score_grid(M_af, M_cm, M_rem, step=0.1):
    K = len(M_af)
    best_scores = None
    best_alpha, best_beta = None, None
    best_eval = -np.inf   


    for alpha in np.arange(0, 1+step, step):
        for beta in np.arange(0, 1+step, step):
            if alpha + beta <= 1: 
                scores = alpha*M_af + beta*M_rem + (1 - alpha - beta)*M_cm
                
                eval_metric = np.var(scores)  
                
                if eval_metric > best_eval:
                    best_eval = eval_metric
                    best_scores = scores
                    best_alpha, best_beta = alpha, beta
    
    return best_alpha, best_beta, best_scores

### Dictionay per class
def sub_dictionary(plus_class, best_scores, n_atoms_per_class = 20, num_classes = 2):
    subdicts = []
    selected_indices_per_class = []

    for c in range(num_classes):
        idx_class_c = np.where(plus_class == c)[0]             
        idx_sorted = idx_class_c[np.argsort(-best_scores[idx_class_c])] 
        selected_idx = idx_sorted[:n_atoms_per_class]
        subdicts.append(D.T[:, selected_idx])
        selected_indices_per_class.append(selected_idx)

    final_dict = np.hstack(subdicts)  

    
    return final_dict, atoms_index