# plot amplitude for each electrode between cue data and read datax

In [7]:
import numpy as np
from ecog_band.utils import *
import pandas as pd
import numpy as np
import scipy.io as scio
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import cm
from matplotlib.cm import ScalarMappable
from ecog_band.datasetAllband import SVMDataset
from ecog_band.models import SVMBinClassifier
import os
from sklearn.model_selection import GridSearchCV,train_test_split
from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score
from ecog_band.utils import *
from ecog_band.solver import Nfold_solver
import pandas as pd
# from ecog_band.datasetExcludeBand import CustomDatasetExcband
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report, confusion_matrix
from matplotlib.colors import Normalize
from matplotlib.patches import Rectangle
from scipy import stats
import numpy as np

freq = 500
HS = 69
elec = 74

path_elec = f'/public/DATA/overt_reading/dataset_/HS{HS}/{freq}/{elec}'
num_samples = len(os.listdir(path_elec))

In [8]:
# z-score before t-test
def z_score_standardize(X, baseline):
    """计算 z-score 标准化"""
    mean_baseline = np.mean(baseline, axis=(0, 2))  # 在样本和时间维度上计算均值
    std_baseline = np.std(baseline, axis=(0, 2))    # 在样本和时间维度上计算标准差

    # 扩展均值和标准差的维度以便与 X 进行广播
    mean_baseline = mean_baseline[np.newaxis, :, np.newaxis]  # 形状变为 (1, n_frequencies, 1)
    std_baseline = std_baseline[np.newaxis, :, np.newaxis]    # 形状变为 (1, n_frequencies, 1)

    # 进行 z-score 标准化
    return (X - mean_baseline) / std_baseline

In [9]:
# 加载数据集
def load_data(elec):
    path_elec = f'/public/DATA/overt_reading/dataset_/HS{HS}/{freq}/{elec}'
    num_samples = len(os.listdir(path_elec))    
    data_read = []
    data_cue = []
    baseline_data = []

    for num in range(num_samples): # num为块的个数
        cue_path = os.path.join(path_elec, f'{num}_data_block_cue.npy')
        read_path = os.path.join(path_elec, f'{num}_data_block_read.npy')
        baseline_path = os.path.join(path_elec, f'{num}_baseline_block_cue.npy')
        # print(cue_path)
        if os.path.exists(cue_path) and os.path.exists(read_path):
            elec_cue = np.load(cue_path) # (n_task, n_freq, n_timePoint) (60, 501, 375)
            elec_read = np.load(read_path)
            elec_base = np.load(baseline_path)

            data_cue.append(elec_cue)
            data_read.append(elec_read)
            baseline_data.append(elec_base[:, :, :100])


    data_cue=np.abs(np.vstack(data_cue))
    data_read=np.abs(np.vstack(data_read))
    baseline_data=np.abs(np.vstack(baseline_data))

    data_cue_norm = z_score_standardize(data_cue, baseline_data)
    data_read_norm = z_score_standardize(data_read, baseline_data)
    cue_mean = np.mean(data_cue_norm, axis=0)
    read_mean = np.mean(data_read_norm, axis=0)
    
    
    return data_cue_norm, data_read_norm

In [10]:
# 从原始频谱数据中提取制定频段
def extract_band(stft_block,freq, band):
    if band == None:
        return stft_block
    
    f=torch.arange(stft_block.shape[1])

    bands = {
        'else1': (0,1),
        'delta': (1, 4),
        'theta': (4, 8),
        'alpha': (8, 12),
        'beta': (12, 30),
        'gamma': (30, 70),
        'high gamma':(70,150),
        'else2':(150,freq+1)
    }

    # delete specific band
    indices = np.where((f >= bands[band][0]) & (f < bands[band][1]))[0]
    stft_block_filtered = stft_block[:, indices, :]
    
    return stft_block_filtered

# 计算制定频段所有trail和所有frequency的平均值
def cal_avg_specified_band(data, band=None):
    if band != None:
        data_filtered = extract_band(data, freq, band)
    else:
        data_filtered = data
    mean_data = np.mean(data_filtered, axis=(0, 1))
    std_data = np.std(data_filtered, axis=(0,1))
    
    return mean_data, std_data

In [20]:
def plt_(elec_data, elec_list, band=None):
    n_electrodes = len(elec_list)
    n_clows = 5
    n_rows = max(int(np.ceil(n_electrodes / n_clows)), 1)

    fig, axs = plt.subplots(n_rows, n_clows, figsize=(12, 2.5 * n_rows))
    
    for i, elec in enumerate(elec_list):
        # print(elec_data[elec])
        cue_data = elec_data[elec]['cue']
        read_data = elec_data[elec]['read']
        # cue_data, read_data = load_data(elec)
        # 计算均值和标准差
        cue_mean, cue_std = cal_avg_specified_band(cue_data, band)
        read_mean, read_std = cal_avg_specified_band(read_data, band)
        n_trails, n_frequencies, n_timesteps = cue_data.shape
        time_steps = np.linspace(0, 0.75, n_timesteps)
        
        ax = axs[i // n_clows, i % n_clows]  # 选择子图

        # 绘制第一个数据集的均值
        ax.plot(time_steps, cue_mean, label='Cue Mean Amplitude', color='blue')
        ax.fill_between(time_steps, 
                        cue_mean - cue_std, 
                        cue_mean + cue_std, 
                        color='lightblue', alpha=0.5, label='Cue Std Dev')

        # 绘制第二个数据集的均值
        ax.plot(time_steps, read_mean, label='Read Mean Amplitude', color='orange')
        ax.fill_between(time_steps, 
                        read_mean - read_std, 
                        read_mean + read_std, 
                        color='lightcoral', alpha=0.5, label='Read Std Dev')

        # 设置标题
        ax.set_title(f'Electrode {elec}')
        
        # 仅在左下角的子图显示横轴和纵轴的标签
        if i // n_clows == n_rows - 1 and i % n_clows == 0:
            ax.set_xlabel('Time(s)')
            ax.set_ylabel('Amplitude')
        else:
            ax.label_outer()  # 隐藏外部标签

        ax.legend(fontsize=6)
        ax.grid()

    plt.tight_layout()  # 自动调整子图间距
    
    if band == None:
        fig.suptitle('AllBands Mean Amplitude', fontsize=16)
        # plt.title()
    else:
        fig.suptitle(f'{band} Mean Amplitude', fontsize=16)
    plt.subplots_adjust(top=0.95)
    plt.show()

    # # 绘制频谱差异
    # plt.figure(figsize=(12, 6))
    # plt.plot(time_steps, spectrum_diff, label='Spectrum Difference', color='purple')
    # plt.title(f'{band} Spectrum Difference')
    # plt.xlabel('Time')
    # plt.ylabel('Difference Amplitude')
    # plt.axhline(0, color='black', linewidth=0.5, linestyle='--')
    # plt.legend()
    # plt.grid()
    # plt.show()
  

In [12]:
elec_list = [11, 22, 25, 35, 37, 48, 59, 64, 74, 87, 98, 104, 111, 124, 138, 147, 156, 166, 173, 183, 195, 208, 213, 220, 231, 246, 250]
# elec_list = list(range(256))
band_list = [None, 'else1', 'delta', 'theta', 'alpha', 'beta', 'gamma', 'high gamma', 'else2']
elec_data = {}
for i in elec_list:
    cue_data, read_data = load_data(i)
    elec_data[i] = {'cue': cue_data, 'read': read_data}
np.save(f'/public/DATA/overt_reading/dataset_/HS{HS}/{freq}/elec_data.npy', elec_data)

In [21]:
elec_data = np.load(f'/public/DATA/overt_reading/dataset_/HS{HS}/{freq}/elec_data.npy', allow_pickle=True).item()
print(elec_data.keys())
for band in band_list:
    plt_(elec_data, elec_list, band)

dict_keys([11, 22, 25, 35, 37, 48, 59, 64, 74, 87, 98, 104, 111, 124, 138, 147, 156, 166, 173, 183, 195, 208, 213, 220, 231, 246, 250])
