In [30]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.io import loadmat
import os

In [31]:
# Set analysis type and load data
analysis_type = 'N170'
curr_dir = os.getcwd()
save_dir = os.path.join(curr_dir, analysis_type, "original_reconstructed_analysis")
os.makedirs(save_dir, exist_ok=True)

ori_betas_mat = loadmat(os.path.join(curr_dir, analysis_type + '_data.mat'))
learned_betas_mat = loadmat(os.path.join(curr_dir, analysis_type + '_learned_betas.mat'))

ori_betas = ori_betas_mat['data']
learned_betas = learned_betas_mat['reconstructed']
ori_betas = np.array(ori_betas)
learned_betas = np.array(learned_betas)

In [32]:
# shape [nBeta, nSubject, nChan, nTime]
ori_betas = np.transpose(ori_betas, (0, 3, 2, 1))
mae = np.abs((ori_betas - learned_betas) ** 2)
nBeta, nSubj, nChan, nTime = mae.shape
print("MAE mean:", np.mean(mae))

MAE mean: 0.09211923491895037


In [33]:
yticks = np.arange(1, nChan+1)
time = np.arange(nTime)
chan_time_mae = mae.mean(axis=(0,1))

# heat-map of all beta and subject
plt.figure(figsize=(16,13))
plt.imshow(chan_time_mae, aspect='auto', origin='lower', cmap='Reds')
plt.colorbar(label='MAE')
plt.yticks(np.arange(nChan), yticks)
plt.xlabel('Time')
plt.ylabel('Channel')
plt.title(analysis_type+'Channel_Time MAE')
plt.tight_layout()
plt.savefig(os.path.join(save_dir, analysis_type+'_Channel_Time MAE.png'))
#plt.show()
plt.close()

In [34]:
save_sub_dir1 = os.path.join(save_dir, 'Each subject MAE')
os.makedirs(save_sub_dir1, exist_ok=True)

# heat-map per subject
for iSubj in range(nSubj):
    chan_time_mae = mae[:, iSubj].mean(axis=0)
    plt.figure(figsize=(16,13))
    plt.imshow(chan_time_mae, aspect='auto', origin='lower', cmap='Reds')
    plt.colorbar(label='MAE')
    plt.yticks(np.arange(nChan), yticks)
    plt.xlabel('Time')
    plt.ylabel('Channel')
    plt.title(f'{analysis_type}_Subject_{iSubj}_Channel_Time MAE')
    plt.tight_layout()
    plt.savefig(os.path.join(save_sub_dir1, f'Subject_{iSubj}_Channel_Time_MAE.png'))
    #plt.show()
    plt.close()

In [35]:
save_sub_dir2 = os.path.join(save_dir, 'Each beta MAE')
os.makedirs(save_sub_dir2, exist_ok=True)

# heat-map per beta
for iBeta in range(nBeta):
    chan_time_mae = mae[iBeta].mean(axis=0)
    plt.figure(figsize=(16,13))
    plt.imshow(chan_time_mae, aspect='auto', origin='lower', cmap='Reds')
    plt.colorbar(label='MAE')
    plt.yticks(np.arange(nChan), yticks)
    plt.xlabel('Time')
    plt.ylabel('Channel')
    plt.title(f'{analysis_type}_Beta_{iBeta}_Channel * Time MAE')
    plt.tight_layout()
    plt.savefig(os.path.join(save_sub_dir2, f'Beta_{iBeta}_Channel_Time_MAE.png'))
    #plt.show()
    plt.close()

In [36]:
print("ori_betas mean:", np.mean(ori_betas))
print("learned_betas mean:", np.mean(learned_betas))
save_sub_dir3 = os.path.join(save_dir, 'Reconstruction vs original value comparison')
os.makedirs(save_sub_dir3, exist_ok=True)

# reconstruction and original value comparison per iBeta
for iBeta in range(nBeta):
    ori_mean = ori_betas[iBeta].mean(axis=(0, 1))
    recon_mean = learned_betas[iBeta].mean(axis=(0, 1))

    plt.figure(figsize=(16,13))
    plt.plot(time, ori_mean, label='original', color='blue')
    plt.xlabel('Time')
    plt.ylabel('Amplitude')
    plt.title(f'{analysis_type} | Beta {iBeta+1}')
    plt.tight_layout()
    plt.savefig(os.path.join(save_sub_dir3, f'Beta{iBeta+1}_original_TimeSeries.png'))
    plt.close()
    
    plt.figure(figsize=(16,13))
    plt.plot(time, recon_mean, label='reconstructed', color='red')
    plt.xlabel('Time')
    plt.ylabel('Amplitude')
    plt.title(f'{analysis_type} | Beta {iBeta+1}')
    plt.tight_layout()
    plt.savefig(os.path.join(save_sub_dir3, f'Beta{iBeta+1}_reconstructed_TimeSeries.png'))
    plt.close()

ori_betas mean: 1.9059665826934697e-08
learned_betas mean: 0.010202589


In [37]:
# reconstruction and original value comparison per iBeta and channel

for iBeta in range(nBeta):
    save_sub_dir4 = os.path.join(save_sub_dir3, f'Beta_{iBeta+1}_per_channel')
    os.makedirs(save_sub_dir4, exist_ok=True)
    for iChan in range(nChan):
        ori_mean = ori_betas[iBeta,:,iChan,:].mean(axis=0)
        recon_mean = learned_betas[iBeta,:,iChan,:].mean(axis=0)
        plt.figure(figsize=(16,13))
        plt.plot(time, ori_mean, label='original', color='blue')
        plt.xlabel('Time')
        plt.ylabel('Amplitude')
        plt.title(f'Beta {iBeta+1} | Channel {iChan+1}')
        plt.tight_layout()
        plt.savefig(os.path.join(save_sub_dir4, f'Beta{iBeta+1}_Channel{iChan+1}_original_TimeSeries.png'))
        plt.close()
        
        plt.figure(figsize=(16,13))
        plt.plot(time, recon_mean, label='reconstructed', color='red')
        plt.xlabel('Time')
        plt.ylabel('Amplitude')
        plt.title(f'Beta {iBeta+1} | Channel {iChan+1}')
        plt.tight_layout()
        plt.savefig(os.path.join(save_sub_dir4, f'Beta{iBeta+1}_Channel{iChan+1}_reconstructed_TimeSeries.png'))
        plt.close()

In [58]:
# Top 5 channels with highest MAE per beta
print("top5 highest MAE channels per beta:")

for iBeta in range(nBeta):
    chan_mae = mae[iBeta].mean(axis=(0, 2))   # (nChan,)
    top5_idx = np.flip(np.argsort(chan_mae)[-5:])
    top5_vals = chan_mae[top5_idx]
    
    
    print(f"\n========== Beta {iBeta+1} ==========")
    for i in range(len(top5_idx)):
        print(f"#{i+1}: Channel {top5_idx[i]+1:<2}  MAE={top5_vals[i]:.6f}")


top5 highest MAE channels per beta:

#1: Channel 16  MAE=0.259325
#2: Channel 1   MAE=0.160351
#3: Channel 19  MAE=0.154413
#4: Channel 3   MAE=0.143896
#5: Channel 9   MAE=0.141844

#1: Channel 16  MAE=0.314234
#2: Channel 1   MAE=0.253345
#3: Channel 19  MAE=0.189382
#4: Channel 9   MAE=0.173023
#5: Channel 3   MAE=0.171709

#1: Channel 16  MAE=0.174049
#2: Channel 1   MAE=0.145700
#3: Channel 28  MAE=0.122823
#4: Channel 30  MAE=0.114244
#5: Channel 12  MAE=0.108447

#1: Channel 16  MAE=0.314245
#2: Channel 1   MAE=0.236666
#3: Channel 19  MAE=0.187009
#4: Channel 9   MAE=0.175548
#5: Channel 3   MAE=0.166458

#1: Channel 16  MAE=0.075434
#2: Channel 9   MAE=0.047310
#3: Channel 1   MAE=0.044223
#4: Channel 13  MAE=0.043838
#5: Channel 12  MAE=0.043309
