In [None]:
import numpy as np
from importlib import reload
import pandas as pd
import fdc
import tqdm
import matplotlib.pyplot as plt
import sleep_stage_function as ssf
from fooof import FOOOF
fdc=reload(fdc)
ssf=reload(ssf)




# REAL DATA

In [None]:
# load the data
sleep_tag = pd.read_csv('../Data_for_analysis/EPCTL06.csv')


In [None]:
sleep_tag['L'] = sleep_tag['0']

In [None]:
sleep_data = np.load('../Data_for_analysis/EPCTL06-prep-001.npy')

In [None]:
channel_positions = pd.read_csv('../Data_for_analysis/ch_pos.csv')

In [None]:
channel_positions.head()

In [None]:
sleep_tag['L'].unique()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

# Parametri
fs = 128
samples_per_epoch = fs * 30  # 3840

# Espandi labels
stages = sleep_tag['L'].values
expanded_labels = np.repeat(stages, samples_per_epoch)
n_samples = min(len(expanded_labels), sleep_data.shape[1])
expanded_labels = expanded_labels[:n_samples]

# TROVA DOVE FINISCONO LE 'L' INIZIALI
first_non_L = np.where(expanded_labels != 'L')[0]
if len(first_non_L) > 0:
    start_idx = first_non_L[0]
    print(f"Primo sample non-'L' è al sample {start_idx}")
    print(f"Questo corrisponde a {start_idx/fs:.1f} secondi = {start_idx/fs/60:.1f} minuti")
else:
    start_idx = 0

# PLOT A PARTIRE DAL PRIMO SAMPLE VERO
n_plot = min(500000, n_samples - start_idx)
plot_start = start_idx
plot_end = start_idx + n_plot

time_seconds = np.arange(n_plot) / fs

fig, ax = plt.subplots(figsize=(16, 5))

# Plot serie temporale
ax.plot(time_seconds, sleep_data[0, plot_start:plot_end], 
        'k-', linewidth=0.3, alpha=0.7)

# Colori per sleep stages
stage_colors = {
    'W': '#FFD700',    # Wake - giallo
    'N1': '#87CEEB',   # N1 - azzurro
    'N2': '#4682B4',   # N2 - blu
    'N3': '#191970',   # N3 - blu scuro
    'R': '#FF6347',    # REM - rosso
    'L': '#D3D3D3',    # Latency - grigio
}

# Background colorato
current_stage = expanded_labels[plot_start]
segment_start = 0

for i in range(1, n_plot + 1):
    idx = plot_start + i
    if idx >= plot_end or expanded_labels[idx] != current_stage:
        color = stage_colors.get(current_stage, '#CCCCCC')
        ax.axvspan(segment_start, (i-1)/fs, alpha=0.35, color=color)
        
        if idx < plot_end:
            current_stage = expanded_labels[idx]
            segment_start = i/fs

# Legenda
stage_order = ['W', 'N1', 'N2', 'N3', 'R', 'L']
stage_names = {
    'W': 'Wake',
    'N1': 'NREM Stage 1',
    'N2': 'NREM Stage 2', 
    'N3': 'NREM Stage 3',
    'R': 'REM',
    'L': 'Latency/Unknown'
}

unique_stages = [s for s in stage_order 
                 if s in np.unique(expanded_labels[plot_start:plot_end])]
legend_elements = [
    Patch(facecolor=stage_colors[stage], 
          label=stage_names[stage], 
          alpha=0.35)
    for stage in unique_stages
]
ax.legend(handles=legend_elements, title='Sleep Stage', 
         loc='upper right', framealpha=0.95, fontsize=10)

ax.set_xlabel('Time (seconds)', fontsize=11)
ax.set_ylabel('EEG amplitude (μV)', fontsize=11)
ax.set_title(f'Sleep EEG (starting from sample {start_idx}, t={start_idx/fs/60:.1f} min)', 
             fontsize=12, fontweight='bold')
ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
plt.tight_layout()
plt.show()

# Statistiche complete
print(f"\n=== STATISTICHE COMPLETE ===")
print(f"Total recording: {n_samples/fs/60:.1f} minutes")
print(f"\nStage distribution (tutta la registrazione):")
for stage in stage_order:
    count = np.sum(expanded_labels == stage)
    percentage = count / len(expanded_labels) * 100
    minutes = count / fs / 60
    if count > 0:
        print(f"  {stage_names[stage]:20s}: {minutes:6.1f} min ({percentage:5.1f}%)")

In [None]:
# Extract first N1 segment (complet
n1_timeseries, start, end, info = ssf.extract_sleep_stage_segment(sleep_data, 
                                                              sleep_tag, 
                                                              'N1', 
                                                              segment_number=2, 
                                                              segment_length=-1,
                                                              skip_first_epoch=True)
    

In [None]:
# Extract first N1 segment (complet
n2_timeseries, start, end, info = ssf.extract_sleep_stage_segment(sleep_data, 
                                                              sleep_tag, 
                                                              'N2', 
                                                              segment_number=1, 
                                                              segment_length=-1,
                                                              skip_first_epoch=True)

In [None]:
#from sleep_stage_function import *

# Esempio 2: Estrai il secondo segmento N3
n3_timeseries, start, end, info = ssf.extract_sleep_stage_segment(
    sleep_data, sleep_tag, 'N3', segment_number=1, skip_first_epoch=True,  fs=128)


In [None]:
# Frequenze da analizzare 
freqs = np.linspace(0.5, 45, 500)  # Frequencies from 0.5 to 30 Hz
#freqs = np.linspace(2, 15, 100)  # Da delta a beta basso
omegas = 2 * np.pi * freqs

timeseries = n2_timeseries

# Calcola FDC matrix per il segmento scelto
import tqdm

C_freq = np.zeros((timeseries.shape[0], timeseries.shape[0], len(freqs)), dtype=complex)
S_ii = []

for i, f in enumerate(tqdm.tqdm(freqs, desc="Computing FDC")):
    cif = fdc.correlation_freq(
        timeseries,
        time_step=1/fs,
        frequency=f, # in hertz
        n_chunks=50,
        corr_type="covariance"
    )
    C_freq[:, :, i] = cif
    S_ii.append(np.mean(np.diag(cif)))


In [None]:
reshuffled_ts = fdc.reshuffling(timeseries)
C_freq_shuff = np.zeros((reshuffled_ts.shape[0], reshuffled_ts.shape[0], len(freqs)), dtype=complex)
S_ii_shuff = []

for i, f in enumerate(tqdm.tqdm(freqs, desc=f"Computing FDC")):
    cif = fdc.correlation_freq(
        reshuffled_ts,
        time_step=1/fs,
        frequency=f, # in hertz
        n_chunks=50,
        corr_type="covariance"
    )
    C_freq_shuff[:, :, i] = cif
    S_ii_shuff.append(np.mean(np.diag(cif)))


In [None]:
phaserand_ts = fdc.phase_randomization(timeseries)
C_freq_phaserand = np.zeros((phaserand_ts.shape[0], phaserand_ts.shape[0], len(freqs)), dtype=complex)
S_ii_phaserand = []

for i, f in enumerate(tqdm.tqdm(freqs, desc=f"Computing FDC")):
    cif = fdc.correlation_freq(
        phaserand_ts,
        time_step=1/fs,
        frequency=f, # in hertz
        n_chunks=50,
        corr_type="covariance"
    )
    C_freq_phaserand[:, :, i] = cif
    S_ii_phaserand.append(np.mean(np.diag(cif)))

In [None]:
blockperm_ts = fdc.block_permutation(timeseries, block_length=20)
C_freq_blockperm = np.zeros((blockperm_ts.shape[0], blockperm_ts.shape[0], len(freqs)), dtype=complex)
S_ii_blockperm = []

for i, f in enumerate(tqdm.tqdm(freqs, desc=f"Computing FDC")):
    cif = fdc.correlation_freq(
        blockperm_ts,
        time_step=1/fs,
        frequency=f, # in hertz
        n_chunks=50,
        corr_type="covariance"
    )
    C_freq_blockperm[:, :, i] = cif
    S_ii_blockperm.append(np.mean(np.diag(cif)))

In [None]:
# FOOOF fitting
fm = FOOOF(peak_width_limits=[1, 8],  # larghezza minima/massima dei picchi
           max_n_peaks=4,              # numero massimo di picchi da cercare
           min_peak_height=0.1)        # altezza minima dei picchi

# Fit
fm.fit(np.array(freqs), np.real(S_ii), freq_range=[0.3, 45])

# Visualizza i risultati
#fm.plot()

# Per ottenere le componenti separate
aperiodic_fit = fm.get_model('aperiodic', 'linear')
peaks_fit = fm.get_model('peak', 'linear')
full_fit = fm.get_model('full', 'linear')

In [None]:
fm.peak_params_

In [None]:
plt.figure(figsize=(12, 6))
plt.plot(freqs, np.real(S_ii), linewidth=2, label='Power spectrum')
#plt.plot(freqs, np.real(S_ii_shuff), linewidth=1.5, linestyle='--', label='Power spectrum (reshuffled)')
#plt.plot(freqs, np.real(S_ii_phaserand), linewidth=1.5, linestyle=':', label='Power spectrum (phase randomized)')
#plt.plot(freqs, np.real(S_ii_blockperm), linewidth=1.5, linestyle='-.', label='Power spectrum (block permutation)')
plt.plot(freqs, aperiodic_fit, linewidth=1.5, linestyle=':', color='black', label='Aperiodic fit')

for peak_freq in fm.peak_params_[:, 0]:
    plt.axvline(peak_freq, color='gray', linestyle='--', linewidth=0.8)

# Bands
plt.axvspan(0.3, 5, alpha=0.2, color='blue', label='Delta (0.5-5 Hz)')
plt.axvspan(5, 8, alpha=0.2, color='orange', label='Theta (5-8 Hz)')
plt.axvspan(8, 13, alpha=0.2, color='green', label='Alpha (8-13 Hz)')
plt.axvspan(12, 15, alpha=0.2, color='red', label='Sigma (12-15 Hz)')
plt.axvspan(13, 30, alpha=0.2, color='purple', label='Beta (13-30 Hz)')
plt.axvspan(30, 45, alpha=0.2, color='magenta', label='Gamma (30-100 Hz)')

plt.yscale('log')
plt.xlabel('Frequency (Hz)', fontsize=12)
plt.ylabel('Power', fontsize=12)
plt.title('Power Spectrum N2', fontsize=13, fontweight='bold')
plt.legend(loc='best', fontsize=10)
plt.grid(True, alpha=0.3, which='both')
plt.tight_layout()
plt.savefig('power_spectrum_baseline_N2.png')
plt.show()

# Compute the eigenvalues spectrum at peaks frequency

In [None]:
peak1_selected = fm.peak_params_[2, 0]
peak2_selected = fm.peak_params_[3, 0]
#peak1_selected = 7
#peak2_selected = 13

cif1 = fdc.correlation_freq(timeseries, time_step=1/fs, frequency=peak1_selected, # in hertz 
                           n_chunks=50, corr_type="covariance" )

cif2 = fdc.correlation_freq(timeseries, time_step=1/fs, frequency=peak2_selected, # in hertz 
                           n_chunks=50, corr_type="covariance" )

In [None]:
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(np.real(cif1), cmap='viridis')
plt.title(f'Correlation at {peak1_selected:.2f} Hz')
plt.subplot(1, 2, 2)
plt.imshow(np.real(cif2), cmap='viridis')
plt.title(f'Correlation at {peak2_selected:.2f} Hz')

In [None]:
eigval1, eigvec1 = np.linalg.eigh(cif1)

# find max eigenvalue and corresponding eigenvector
max_eigenvalue1 = eigval1[-1]
max_eigenvector1 = eigvec1[:, -1]
# sort indices based on the max eigenvector
sorted_indices1 = np.argsort(max_eigenvector1)[::-1]
# reorder cif1
cif1_sorted = cif1[sorted_indices1, :][:, sorted_indices1]
# reorder max eigenvector
max_eigenvector1_sorted = max_eigenvector1[sorted_indices1]

In [None]:
eigval2, eigvec2 = np.linalg.eigh(cif2)

# find max eigenvalue and corresponding eigenvector
max_eigenvalue2 = eigval2[-1]
max_eigenvector2 = eigvec2[:, -1]
# sort indices based on the max eigenvector
sorted_indices2 = np.argsort(max_eigenvector2)[::-1]
# reorder cif2
cif2_sorted = cif2[sorted_indices2, :][:, sorted_indices2]
# reorder max eigenvector
max_eigenvector2_sorted = max_eigenvector2[sorted_indices2]


In [None]:
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(np.real(cif1_sorted), cmap='viridis')
plt.title(f'Correlation at {peak1_selected:.2f} Hz')
plt.subplot(1, 2, 2)
plt.imshow(np.real(cif2_sorted), cmap='viridis')
plt.title(f'Correlation at {peak2_selected:.2f} Hz')

plt.savefig('eeg_network_heatmap_N2.png')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import networkx as nx

x = channel_positions['X'].values
y = channel_positions['Y'].values
z = channel_positions['Z'].values
channel_names = channel_positions['Unnamed: 0'].values  # o il nome della colonna con i nomi

In [None]:
# C1onverti le posizioni in cm (se necessario)

# Crea il grafo
G1 = nx.Graph()

# Aggiungi i nodi con le loro posizioni
pos = {}
for i, name in enumerate(channel_names):
    G1.add_node(i, label=name)
    pos[i] = (x[i], y[i])

# Aggiungi gli edges pesati
# Opzione 1: Threshold - solo connessioni sopra un certo percentile
threshold1 = np.percentile(np.abs(cif1[np.triu_indices_from(cif1, k=1)]), 95)  # top 5%

# Opzione 2: usa tutte le connessioni (commentare threshold sopra)
# threshold = 0

n_channels = len(channel_names)
edges_to_plot = []
weights = []

for i in range(n_channels):
    for j in range(i+1, n_channels):  # solo triangolo superiore per evitare duplicati
        weight = cif1[i, j]
        #if np.abs(weight) > threshold1:
        if (weight) > threshold1:
            G1.add_edge(i, j, weight=weight)
            edges_to_plot.append((i, j))
            weights.append(weight)

# Normalizza i pesi per la visualizzazione
weights = np.array(weights)
weights_norm = (np.abs(weights) - np.abs(weights).min()) / (np.abs(weights).max() - np.abs(weights).min())

# Plot
fig, ax = plt.subplots(figsize=(12, 10))

# Disegna gli edges
for (i, j), w_norm, w_real in zip(edges_to_plot, weights_norm, weights):
    x_edge = [pos[i][0], pos[j][0]]
    y_edge = [pos[i][1], pos[j][1]]
    
    # Colore basato sul segno (opzionale)
    color = 'red' if w_real > 0 else 'blue'
    # Oppure usa un colormap continuo:
    #color = plt.cm.RdBu_r(w_norm)
    
    ax.plot(x_edge, y_edge, 
            color=color, 
            alpha=0.3 + 0.6*w_norm,  # trasparenza proporzionale al peso
            linewidth=0.5 + 3*w_norm,  # spessore proporzionale al peso
            zorder=1)

# Disegna i nodi
ax.scatter(x, y, 
          s=200, 
          c='white', 
          edgecolors='black', 
          linewidths=2,
          zorder=2)

# Aggiungi le label degli elettrodi
for i, name in enumerate(channel_names):
    ax.text(x[i], y[i], name, 
           fontsize=8, 
           ha='center', 
           va='center',
           zorder=3)

ax.set_aspect('equal')
ax.set_xlabel('X (cm)', fontsize=12)
ax.set_ylabel('Y (cm)', fontsize=12)
ax.set_title(f'EEG Network - Top View\n(Threshold: {threshold1:.2e}, N edges: {len(edges_to_plot)})', 
             fontsize=14)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('eeg_network_topview_N2_freq1.png')
plt.show()


# C1onverti le posizioni in cm (se necessario)

# Crea il grafo
G2 = nx.Graph()

# Aggiungi i nodi con le loro posizioni
pos = {}
for i, name in enumerate(channel_names):
    G2.add_node(i, label=name)
    pos[i] = (x[i], y[i])

# Aggiungi gli edges pesati
# Opzione 1: Threshold - solo connessioni sopra un certo percentile
threshold2 = np.percentile(np.abs(cif2[np.triu_indices_from(cif2, k=1)]), 95)  # top 5%

# Opzione 2: usa tutte le connessioni (commentare threshold sopra)
# threshold = 0

n_channels = len(channel_names)
edges_to_plot = []
weights = []

for i in range(n_channels):
    for j in range(i+1, n_channels):  # solo triangolo superiore per evitare duplicati
        weight = cif2[i, j]
        #if np.abs(weight) > threshold2:
        if (weight) > threshold2:
            G2.add_edge(i, j, weight=weight)
            edges_to_plot.append((i, j))
            weights.append(weight)

# Normalizza i pesi per la visualizzazione
weights = np.array(weights)
weights_norm = (np.abs(weights) - np.abs(weights).min()) / (np.abs(weights).max() - np.abs(weights).min())

# Plot
fig, ax = plt.subplots(figsize=(12, 10))

# Disegna gli edges
for (i, j), w_norm, w_real in zip(edges_to_plot, weights_norm, weights):
    x_edge = [pos[i][0], pos[j][0]]
    y_edge = [pos[i][1], pos[j][1]]
    
    # Colore basato sul segno (opzionale)
    color = 'red' if w_real > 0 else 'blue'
    # Oppure usa un colormap continuo:
    #color = plt.cm.RdBu_r(w_norm)
    
    ax.plot(x_edge, y_edge, 
            color=color, 
            alpha=0.3 + 0.6*w_norm,  # trasparenza proporzionale al peso
            linewidth=0.5 + 3*w_norm,  # spessore proporzionale al peso
            zorder=1)

# Disegna i nodi
ax.scatter(x, y, 
          s=200, 
          c='white', 
          edgecolors='black', 
          linewidths=2,
          zorder=2)

# Aggiungi le label degli elettrodi
for i, name in enumerate(channel_names):
    ax.text(x[i], y[i], name, 
           fontsize=8, 
           ha='center', 
           va='center',
           zorder=3)

ax.set_aspect('equal')
ax.set_xlabel('X (cm)', fontsize=12)
ax.set_ylabel('Y (cm)', fontsize=12)
ax.set_title(f'EEG Network - Top View\n (Sleep Phase: N3 - Frequency: {peak2_selected:.2f} Hz)', 
             fontsize=14)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('eeg_network_topview_N2_freq2.png')
plt.show()

In [None]:
import plotly.graph_objects as go
cif1 = np.real(cif1)
# Threshold for connections (top 10%)
threshold = np.percentile(np.abs(cif1[np.triu_indices_from(cif1, k=1)]), 95)

# Collect edges and weights
edge_list = []
edge_weights = []

n_channels = len(channel_names)

for i in range(n_channels):
    for j in range(i+1, n_channels):
        weight = cif1[i, j]
        #if np.abs(weight) > threshold:
        if weight > threshold:
            edge_list.append((i, j))
            edge_weights.append(weight)

# Normalize weights for thickness
edge_weights = np.array(edge_weights)
weights_norm = (np.abs(edge_weights) - np.abs(edge_weights).min()) / \
               (np.abs(edge_weights).max() - np.abs(edge_weights).min())

# Create individual edge traces
edge_traces = []

for idx, (i, j) in enumerate(edge_list):
    weight = edge_weights[idx]
    
    # Simple red for positive, blue for negative
    color = 'red' if weight > 0 else 'blue'
    
    edge_trace = go.Scatter3d(
        x=[x[i], x[j]],
        y=[y[i], y[j]],
        z=[z[i], z[j]],
        mode='lines',
        line=dict(
            color=color,
            width=1 + 4*weights_norm[idx]
        ),
        hovertemplate=f'<b>{channel_names[i]} → {channel_names[j]}</b><br>Weight: {weight:.3e}<extra></extra>',
        showlegend=False,
        opacity=0.6
    )
    edge_traces.append(edge_trace)

# Create node trace
node_trace = go.Scatter3d(
    x=x,
    y=y,
    z=z,
    mode='markers+text',
    marker=dict(
        size=10,
        color='lightblue',
        line=dict(color='black', width=2),
        opacity=1
    ),
    text=channel_names,
    textposition='top center',
    textfont=dict(size=8),
    hovertemplate='<b>%{text}</b><br>X: %{x:.2f}<br>Y: %{y:.2f}<br>Z: %{z:.2f}<extra></extra>',
    showlegend=False
)

# Create figure
fig = go.Figure(data=edge_traces + [node_trace])

# Layout
fig.update_layout(
    title=f'EEG Network 3D - Interactive<br>(Threshold: {threshold:.2e}, N edges: {len(edge_list)})',
    width=1000,
    height=800,
    scene=dict(
        xaxis=dict(title='X (cm)', backgroundcolor="white", gridcolor="lightgray"),
        yaxis=dict(title='Y (cm)', backgroundcolor="white", gridcolor="lightgray"),
        zaxis=dict(title='Z (cm)', backgroundcolor="white", gridcolor="lightgray"),
        camera=dict(
            eye=dict(x=1.5, y=1.5, z=1.5)
        ),
        aspectmode='data'
    ),
    showlegend=False,
    hovermode='closest'
)

fig.show()

In [None]:
cif2 = np.real(cif2)

# Threshold for connections (top 10%)
threshold = np.percentile(np.abs(cif2[np.triu_indices_from(cif2, k=1)]), 95)

# Collect edges and weights
edge_list = []
edge_weights = []

n_channels = len(channel_names)

for i in range(n_channels):
    for j in range(i+1, n_channels):
        weight = cif2[i, j]
        if weight > threshold:
        #if np.abs(weight) > threshold:
            edge_list.append((i, j))
            edge_weights.append(weight)

# Normalize weights for thickness
edge_weights = np.array(edge_weights)
weights_norm = (np.abs(edge_weights) - np.abs(edge_weights).min()) / \
               (np.abs(edge_weights).max() - np.abs(edge_weights).min())

# Create individual edge traces
edge_traces = []

for idx, (i, j) in enumerate(edge_list):
    weight = edge_weights[idx]
    
    # Simple red for positive, blue for negative
    color = 'red' if weight > 0 else 'blue'
    
    edge_trace = go.Scatter3d(
        x=[x[i], x[j]],
        y=[y[i], y[j]],
        z=[z[i], z[j]],
        mode='lines',
        line=dict(
            color=color,
            width=1 + 4*weights_norm[idx]
        ),
        hovertemplate=f'<b>{channel_names[i]} → {channel_names[j]}</b><br>Weight: {weight:.3e}<extra></extra>',
        showlegend=False,
        opacity=0.6
    )
    edge_traces.append(edge_trace)

# Create node trace
node_trace = go.Scatter3d(
    x=x,
    y=y,
    z=z,
    mode='markers+text',
    marker=dict(
        size=10,
        color='lightblue',
        line=dict(color='black', width=2),
        opacity=1
    ),
    text=channel_names,
    textposition='top center',
    textfont=dict(size=8),
    hovertemplate='<b>%{text}</b><br>X: %{x:.2f}<br>Y: %{y:.2f}<br>Z: %{z:.2f}<extra></extra>',
    showlegend=False
)

# Create figure
fig = go.Figure(data=edge_traces + [node_trace])

# Layout
fig.update_layout(
    title=f'EEG Network 3D - Interactive<br>(Threshold: {threshold:.2e}, N edges: {len(edge_list)})',
    width=1000,
    height=800,
    scene=dict(
        xaxis=dict(title='X (cm)', backgroundcolor="white", gridcolor="lightgray"),
        yaxis=dict(title='Y (cm)', backgroundcolor="white", gridcolor="lightgray"),
        zaxis=dict(title='Z (cm)', backgroundcolor="white", gridcolor="lightgray"),
        camera=dict(
            eye=dict(x=1.5, y=1.5, z=1.5)
        ),
        aspectmode='data'
    ),
    showlegend=False,
    hovermode='closest'
)

fig.show()



### take the inverse

In [None]:
theta = 0.001  # regularization parameter
inverse_cif1 = np.linalg.pinv(cif1 - np.eye(cif1.shape[0])*theta)
inverse_cif2 = np.linalg.pinv(cif2 - np.eye(cif2.shape[0])*theta)

In [None]:
inverse_cif1_ordered = inverse_cif1[sorted_indices1, :][:, sorted_indices1]
inverse_cif2_ordered = inverse_cif2[sorted_indices2, :][:, sorted_indices2]

In [None]:
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(np.real(inverse_cif1_ordered)-np.diag(np.real(inverse_cif1_ordered))*np.eye(inverse_cif1_ordered.shape[0]), cmap='viridis')
plt.colorbar()
plt.title(f'Inverse Correlation at {peak1_selected:.2f} Hz')
plt.subplot(1, 2, 2)
plt.imshow(np.real(inverse_cif2_ordered)-np.diag(np.real(inverse_cif2_ordered))*np.eye(inverse_cif2_ordered.shape[0]), cmap='viridis')
plt.colorbar()
plt.title(f'Inverse Correlation at {peak2_selected:.2f} Hz')

## Select a chunk of each of the 4 sleep phases and compare the 4 power spectrum

In [None]:
selected_timeseries = {}
freqs = np.linspace(0.3, 45, 1000)  # Frequencies from 0.3 to 45 Hz

C_freqs_dict = {}
S_ii_dict = {}

for stage in ['W', 'N1', 'N2', 'N3', 'R']:

    timeseries, start, end, info = ssf.extract_sleep_stage_segment(
        sleep_data, sleep_tag, stage, segment_number=0, segment_length=-1, skip_first_epoch=True, fs=fs
        )
    selected_timeseries[stage] = timeseries

    C_freq = np.zeros((timeseries.shape[0], timeseries.shape[0], len(freqs)), dtype=complex)
    S_ii = []
    
    for i, f in enumerate(tqdm.tqdm(freqs, desc=f"Computing FDC for {stage}")):
        cif = fdc.correlation_freq(
            timeseries,
            time_step=1/fs,
            frequency=f, # in hertz
            n_chunks=50,
            corr_type="covariance"
        )
        C_freq[:, :, i] = cif
        S_ii.append(np.mean(np.diag(cif)))

    C_freqs_dict[stage] = C_freq
    S_ii_dict[stage] = S_ii

    

### baseline 

In [None]:

C_freqs_shuff = {}
S_ii_shuff = {}

for stage in ['W', 'N1', 'N2', 'N3', 'R']:
    reshuffled_ts = fdc.reshuffling(selected_timeseries[stage])
    C_freq = np.zeros((reshuffled_ts.shape[0], reshuffled_ts.shape[0], len(freqs)), dtype=complex)
    S_ii = []
    
    for i, f in enumerate(tqdm.tqdm(freqs, desc=f"Computing FDC for {stage}")):
        cif = fdc.correlation_freq(
            reshuffled_ts,
            time_step=1/fs,
            frequency=f, # in hertz
            n_chunks=50,
            corr_type="covariance"
        )
        C_freq[:, :, i] = cif
        S_ii.append(np.mean(np.diag(cif)))

    C_freqs_shuff[stage] = C_freq
    S_ii_shuff[stage] = S_ii

In [None]:
freqs = np.linspace(0.3, 45, 1000)  # Frequencies from 0.3 to 45 Hz

In [None]:
for stage in S_ii_dict.keys():
    plt.figure(figsize=(12, 6))
    plt.loglog(freqs, S_ii_dict[stage], linewidth=2, label='Power spectrum')
    plt.loglog(freqs, S_ii_shuff[stage], linewidth=1, linestyle='--', label='Baseline (reshuffled)')

    # Barre verticali per alpha band (8-13 Hz)
   

    # Opzionale: aggiungi anche altre bande
    plt.axvspan(0.3, 5, alpha=0.2, color='blue', label='Delta (0.5-5 Hz)')
    plt.axvspan(5, 8, alpha=0.2, color='orange', label='Theta (5-8 Hz)')
    plt.axvspan(8, 13, alpha=0.2, color='green', label='Alpha (8-13 Hz)')
    plt.axvspan(12, 15, alpha=0.2, color='red', label='Sigma (12-15 Hz)')
    plt.axvspan(13, 30, alpha=0.2, color='purple', label='Beta (13-30 Hz)')
    plt.axvspan(30, 45, alpha=0.2, color='magenta', label='Gamma (30-100 Hz)')

    plt.xlabel('Frequency (Hz)', fontsize=12)
    plt.ylabel('Power', fontsize=12)
    plt.title(f'Power Spectrum for {stage}', fontsize=13, fontweight='bold')
    plt.legend(loc='best', fontsize=10)
    plt.grid(True, alpha=0.3, which='both')
    plt.tight_layout()
    plt.savefig(f'power_spectrum_with_bands_{stage}.png')
    plt.show()