In [None]:
conditions = list(dfs.keys())
groups = ['RGS', 'Control']
regions = ['HPC', 'PFC']


def rename_transition(col):
    for sep in ['→', '->']:
        if sep in col:
            from_state, to_state = map(int, col.split(sep))
            return f"{STATE_MAP.get(from_state, from_state)}→{STATE_MAP.get(to_state, to_state)}"
    return col

first_df = next(iter(dfs.values()))
transition_cols = [col for col in first_df.columns if '→' in col or '->' in col]
renamed_cols = [rename_transition(col) for col in transition_cols]

fig, axes = plt.subplots(
    nrows=len(conditions), ncols=4, figsize=(20, 5 * len(conditions)),
    sharey='row'
)

if len(conditions) == 1:
    axes = np.expand_dims(axes, 0)

for row_idx, condition in enumerate(conditions):
    df = dfs[condition]
    for col_idx, (group, region) in enumerate([(g, r) for g in groups for r in regions]):
        ax = axes[row_idx, col_idx]
        df_sub = df[(df['group'] == group) & (df['region'] == region)]
        if not df_sub.empty:
            # mean and SEM 
            means = df_sub[transition_cols].mean(axis=0)
            sems = df_sub[transition_cols].sem(axis=0)
            ax.errorbar(
                renamed_cols, means, yerr=sems, fmt='-o', capsize=5,
                label=f"{group} {region}"
            )
            ax.set_title(f"{group} {region}")
        else:
            ax.set_title(f"{group} {region}\n(no data)")

        ax.set_xticks(range(len(renamed_cols)))
        ax.set_xticklabels(renamed_cols, rotation=45, ha='right')
        
        if col_idx == 0:
            ax.set_ylabel(condition)
        else:
            ax.set_ylabel("")
        ax.set_xlabel("Transition")
        ax.grid(True, axis='y')

plt.tight_layout()
plt.show()

In [None]:
def plot_transition_foof(
    states, exponents, times, 
    from_state, to_state, 
    sample_freq=1000, window_sec=15, noverlap=None, 
    margin_sec=15, 
    color='C0', label=None, 
    plot_sd=True
):

    if noverlap is None:
        noverlap = window_sec * sample_freq // 2

    step_sec = (window_sec * sample_freq - noverlap) / sample_freq
    half_win = int(margin_sec / step_sec)
    state_bins = states[np.minimum(times.astype(int), len(states) - 1)]

    is_from = state_bins == from_state
    is_to = state_bins == to_state
    trans_idxs = np.where(is_to & np.roll(is_from, 1))[0]

    # collect all the windows around the transitions
    all_windows = []
    for idx in trans_idxs:
        start = idx - half_win
        end = idx + half_win + 1
        if start >= 0 and end <= len(exponents):
            win = exponents[start:end]
            if not np.isnan(win).any():
                all_windows.append(win)
    all_windows = np.array(all_windows)
    if all_windows.size == 0:
        print("No valid transitions found.")
        return

   
    mean = np.mean(all_windows, axis=0)
    if plot_sd:
        err = np.std(all_windows, axis=0)
    else:
        err = np.std(all_windows, axis=0) / np.sqrt(all_windows.shape[0])

    rel_time = np.arange(-half_win, half_win + 1) * step_sec

    print(mean)

    # Plot
    plt.plot(rel_time, mean, color=color, label=label)
    plt.fill_between(rel_time, mean - err, mean + err, color=color, alpha=0.3)
    plt.axvline(0, color='k', linestyle='--', linewidth=1)
    plt.xlabel('time relative to transition (s)')
    plt.ylabel('exponent (a.u.)')
    plt.title(f"{label} (n={len(all_windows)})")
    plt.tight_layout()
    plt.show()


In [None]:
plot_transition_foof(states, exponents, times, 1, 3, color='C0', label='Wake-NREM Transitions')
plot_transition_foof(states, exponents, times, 3, 1, color='C1', label='NREM-Wake Transitions')
plot_transition_foof(states, exponents, times, 5, 1, color='C2', label='REM-Wake Transitions')

In [None]:
import os
import numpy as np
import scipy.io as sio
from scipy.signal import decimate, resample
from scipy.stats import zscore
from scipy.interpolate import interp1d

# --- Helper Functions for Artefact Removal ---

def find_intervals(condition_array):
    idx = np.where(condition_array)[0]
    if len(idx) == 0:
        return np.zeros((0, 2), dtype=int)
    splits = np.where(np.diff(idx) > 1)[0] + 1
    ranges = np.split(idx, splits)
    return np.array([[r[0], r[-1]] for r in ranges])

def consolidate_intervals(intervals, buffer_samples, total_len):
    if len(intervals) == 0:
        return np.zeros((0, 2), dtype=int)
    intervals = np.clip(intervals, 0, total_len - 1)
    intervals[:, 0] = np.maximum(intervals[:, 0] - buffer_samples, 0)
    intervals[:, 1] = np.minimum(intervals[:, 1] + buffer_samples, total_len - 1)
    intervals = intervals[np.argsort(intervals[:, 0])]
    merged = [intervals[0]]
    for current in intervals[1:]:
        last = merged[-1]
        if current[0] <= last[1]:
            merged[-1] = [last[0], max(last[1], current[1])]
        else:
            merged.append(current)
    return np.array(merged)

def in_intervals(indices, intervals):
    mask = np.zeros_like(indices, dtype=bool)
    for start, end in intervals:
        mask[start:end+1] = True
    return mask

def remove_artefacts(lfp_sig, downsample_freq, amp_thresh, time_win_thresh, original_fs):
    down_sig = decimate(lfp_sig, 2, ftype='fir')
    down_len = len(down_sig)
    down_time = np.linspace(0, down_len / downsample_freq, down_len)

    zsig = zscore(down_sig)
    diffsig = np.diff(zsig, append=0)
    time_indices = np.arange(down_len)
    artefact_inds = np.zeros(down_len, dtype=bool)

    global_intervals = find_intervals(np.abs(zsig) > amp_thresh[0])
    global_intervals = consolidate_intervals(global_intervals, int(time_win_thresh[0] * downsample_freq), down_len)
    artefact_inds |= in_intervals(time_indices, global_intervals)

    noisy_intervals = find_intervals(np.abs(diffsig) > amp_thresh[1])
    noisy_intervals = consolidate_intervals(noisy_intervals, int(time_win_thresh[1] * downsample_freq), down_len)
    artefact_inds |= in_intervals(time_indices, noisy_intervals)

    clean_sig = down_sig.copy()
    if np.any(~artefact_inds):
        interp_func = interp1d(down_time[~artefact_inds], clean_sig[~artefact_inds], bounds_error=False, fill_value="extrapolate")
        clean_sig[artefact_inds] = interp_func(down_time[artefact_inds])

    # Upsample back to original length
    clean_sig_upsampled = resample(clean_sig, len(lfp_sig))
    return clean_sig_upsampled, artefact_inds, down_time

# --- Configuration & Rats Dataset ---

BASE_DATA_DIR = "./data"
CONDITIONS_DIR_NAMES = ['HomeCageHC', "OverlappingOR", "RandomCon", "StableCondOD"]
CONTROL_RATS = [1, 2, 6, 9]
RGS_RATS = [3, 4, 7, 8]

rats_data = {
    condition: {
        'rgs_positive': [],
        'rgs_negative': []
    }
    for condition in CONDITIONS_DIR_NAMES
}

for condition in CONDITIONS_DIR_NAMES:
    for rat_index in CONTROL_RATS:
        folder_dir = f"{BASE_DATA_DIR}/{condition}/Rat{rat_index}/Post-Trial5"
        states_mat = sio.loadmat(f'{folder_dir}/states.mat')
        hpc_mat = sio.loadmat(f'{folder_dir}/HPC_100.continuous.mat')
        pfc_mat = sio.loadmat(f'{folder_dir}/PFC_100.continuous.mat')
        rats_data[condition]['rgs_negative'].append({
            'number': rat_index,
            'states': states_mat['states'].ravel(),
            'hpc': hpc_mat['HPC'].ravel(),
            'pfc': pfc_mat['PFC'].ravel()
        })
    for rat_index in RGS_RATS:
        folder_dir = f"{BASE_DATA_DIR}/{condition}/Rat{rat_index}/Post-Trial5"
        states_mat = sio.loadmat(f'{folder_dir}/states.mat')
        hpc_mat = sio.loadmat(f'{folder_dir}/HPC_100.continuous.mat')
        pfc_mat = sio.loadmat(f'{folder_dir}/PFC_100.continuous.mat')
        rats_data[condition]['rgs_positive'].append({
            'number': rat_index,
            'states': states_mat['states'].ravel(),
            'hpc': hpc_mat['HPC'].ravel(),
            'pfc': pfc_mat['PFC'].ravel()
        })

# --- Batch Artefact Removal Script ---

def process_and_save_cleaned_signals(base_dir, rats_data):
    original_fs = 1000
    downsample_freq = original_fs // 2
    amp_thresh = [5, 3]
    time_win_thresh = [2.0, 0.1]

    for condition, groups in rats_data.items():
        for group_label, rats in groups.items():
            for rat in rats:
                rat_number = rat['number']
                folder_dir = f"{base_dir}/{condition}/Rat{rat_number}/Post-Trial5"

                for region in ['HPC', 'PFC']:
                    file_name = f"{region}_100.continuous.mat"
                    file_path = os.path.join(folder_dir, file_name)

                    if os.path.exists(file_path):
                        print(f"Cleaning {region} for Rat{rat_number}, Condition: {condition}")
                        mat_data = sio.loadmat(file_path)
                        raw_signal = mat_data[region].ravel()

                        cleaned_signal, artefact_mask, _ = remove_artefacts(
                            raw_signal, downsample_freq, amp_thresh, time_win_thresh, original_fs
                        )

                        cleaned_file_path = os.path.join(folder_dir, f"{region}_cleaned.mat")
                        sio.savemat(cleaned_file_path, {f"{region}_cleaned": cleaned_signal})
                    else:
                        print(f"Missing file: {file_path}")

# Run the full pipeline
process_and_save_cleaned_signals(BASE_DATA_DIR, rats_data)
