# Import

In [None]:
%matplotlib inline
import os
import yasa
import mne
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from src.util.process import split_raw_by_annotation

# Setup

In [None]:
idx = '{{idx}}'
fif_path = '{{fif_path}}'

In [None]:
try:
    file = os.path.join(fif_path, idx + '_raw.fif.gz')
    raw = mne.io.read_raw_fif(file, verbose=False, preload=True)
    raw_splits = split_raw_by_annotation(raw, ann_text=['night_1', 'night_2', 'night_3'], epoch_length=30)
    sf = raw.info['sfreq']
except:
    pass

# Basic info

In [None]:
# PSG
try:
    print(raw.info)
    print('')
    for key,raw_night in raw_splits.items():
        print(key)
        hypno = raw_night['hypno'][0][0]
        print('Hypnogram')
        print('Counts (N epochs)')
        print(round(pd.Series(hypno, name='Stage').value_counts().sort_index() / (sf*30), 1))
        print('Percents')
        print(round(pd.Series(hypno, name='Stage').value_counts().sort_index() / len(hypno) * 100, 2))
        print('')
        print('') 
except:
    pass

# Power spectrum

In [None]:
stage_labels = {1: "N1", 2: "N2", 3: "N3", 4: "REM"}
stage_colors = {1: "tab:blue", 2: "tab:orange", 3: "tab:green", 4: "tab:red"}

try:
    for key,raw_night in raw_splits.items():
        data = raw_night.get_data('EEG')[0] * 1e6
        hypno = raw_night['hypno'][0][0].astype(int)      

        plt.figure()
        for stage, label in stage_labels.items():
            mask = hypno == stage   
            if not mask.any():
                continue

            freqs, psd = mne.time_frequency.psd_array_welch(
                data[mask],
                sfreq=sf,
                n_per_seg=int(sf * 4)
            )

            plt.semilogy(psd, freqs, color=stage_colors[stage], label=label)

        plt.title(key)
        plt.xlabel("Frequency (Hz)")
        plt.ylabel("PSD (VÂ²/Hz)")
        plt.legend()
        plt.tight_layout()
        plt.show()
        plt.close()
except:
    pass

# Spectrogram

In [None]:
# PSG
try:
    for key,raw_night in raw_splits.items():
        print(key)
        data = raw_night.get_data('EEG')[0] * 1e6
        hypno = raw_night['hypno'][0][0].astype(int)  
        yasa.plot_spectrogram(data, sf=sf, hypno=hypno)
        plt.show()
        plt.close()            
except:
    pass

# Sleep stats

In [None]:
# PSG
try:
    for key,raw_night in raw_splits.items():
        print(key)
        data = raw_night.get_data('EEG')[0] * 1e6
        hypno = raw_night['hypno'][0][0].astype(int)  
        ss = yasa.sleep_statistics(hypno, sf)
        print(ss)
        print('')
except:
    pass