In [None]:
#1accessing importing the files
%matplotlib inline
%matplotlib widget
import os
import mne
import numpy as np
import matplotlib.pyplot as plt

In [None]:
#2accessing the files.
root_path = r"C:\Users\xtwf7586\OneDrive - University of Leeds\fnirs\AuditoryRecording13032025"
all_folders = [f"2025-03-13_{i:03d}" for i in range(1, 14) if i != 2]
print("Folders to be loaded:")
print(all_folders)

In [None]:
#3accessing all the files of the participants and getting their duration
participant_data = {f"participant_{i+1}": [] for i in range(3)}

for idx, folder_name in enumerate(all_folders):
    folder_path = os.path.join(root_path, folder_name)
    if not os.path.exists(folder_path):
        print(f"Folder not found: {folder_path}")
        continue
    snirf_files = [f for f in os.listdir(folder_path) if f.endswith(".snirf")]
    if not snirf_files:
        print(f"No .snirf file found in {folder_path}")
        continue
    snirf_path = os.path.join(folder_path, snirf_files[0])
    print(f"Loading {snirf_path}")
    raw = mne.io.read_raw_snirf(snirf_path, preload=True)
    participant_idx = idx // 4
    participant_key = f"participant_{participant_idx + 1}"
    participant_data[participant_key].append(raw)

In [None]:
#getting all 4 files from all 3 participants, getting their sampling frequency and durations.
for key, files in participant_data.items():
    print(f"{key}: {len(files)} SNIRF files loaded")

for participant, raws in participant_data.items():
    print(f"\n=== {participant} ===")
    for i, raw in enumerate(raws):
        print(f"File {i+1}: Channels={len(raw.ch_names)}, Duration={raw.times[-1]:.2f}s, Sampling Frequency={raw.info['sfreq']:.2f} Hz")

In [None]:
#10
from collections import defaultdict
from mne.preprocessing.nirs import optical_density

cv_threshold = 7.5  # from the automated named study in percent
participant_raw_cv_filtered = {}
participant_raw_cv_rejected = {}  # <-- new dict to store bad channels

for participant, raws in participant_data.items():
    participant_raw_cv_filtered[participant] = []
    participant_raw_cv_rejected[participant] = []

    for i, raw in enumerate(raws):
        data = raw.get_data()
        mean = np.mean(data, axis=1)
        std = np.std(data, axis=1)
        cv = 100 * std / mean

        # Group by S-D pair (ignoring wavelength)
        pair_map = defaultdict(list)
        for idx, ch_name in enumerate(raw.ch_names):
            pair_id = " ".join(ch_name.split()[:-1])  # e.g., "S1_D1"
            pair_map[pair_id].append(idx)

        good_idx, bad_idx = [], []
        for pair_id, indices in pair_map.items():
            if len(indices) == 2:
                if all(cv[idx] < cv_threshold for idx in indices):
                    good_idx.extend(indices)
                else:
                    bad_idx.extend(indices)

        # Store filtered (good) and rejected (bad) Raw objects
        good_chs = [raw.ch_names[idx] for idx in good_idx]
        bad_chs = [raw.ch_names[idx] for idx in bad_idx]

        raw_good = raw.copy().pick(good_chs)
        raw_bad = raw.copy().pick(bad_chs)

        participant_raw_cv_filtered[participant].append(raw_good)
        participant_raw_cv_rejected[participant].append(raw_bad)

        print(f"{participant} File {i+1}: Kept {len(good_chs)} | Dropped {len(bad_chs)} channels")

In [None]:
#11
participant_od_good = {}
participant_od_bad = {}

for participant in participant_raw_cv_filtered:
    participant_od_good[participant] = []
    participant_od_bad[participant] = []

    for i in range(len(participant_raw_cv_filtered[participant])):
        raw_good = participant_raw_cv_filtered[participant][i]
        raw_bad = participant_raw_cv_rejected[participant][i]

        od_good = optical_density(raw_good)
        od_bad = optical_density(raw_bad)

        participant_od_good[participant].append(od_good)
        participant_od_bad[participant].append(od_bad)

        print(f"Plotting GOOD: {participant} - File {i+1}")
        od_good.plot(title=f"{participant} - File {i+1} (GOOD Channels)")
    

        print(f"Plotting BAD: {participant} - File {i+1}")
        od_bad.plot(title=f"{participant} - File {i+1} (BAD Channels)")
plt.close()

In [None]:
for participant, od_list in participant_od_good.items():
    print(f"\n=== {participant} ===")
    for i, raw_od in enumerate(od_list):
        n_channels = len(raw_od.ch_names)
        print(f"File {i+1}: {n_channels} OD channels")

In [None]:
import matplotlib.pyplot as plt
from collections import defaultdict

def plot_sd_pairs_with_task_overlay(raw_od, attention_start=34.5, attention_end=76.5, max_pairs=None):
    """
    Plot optical density signals for S-D pairs with both 760 and 850 nm wavelengths.
    Highlights the task window for easy visual inspection.

    Parameters:
    - raw_od: MNE Raw object (optical density)
    - attention_start: start time of task/stimulus (in seconds)
    - attention_end: end time of task/stimulus (in seconds)
    - max_pairs: maximum number of S-D pairs to plot (optional)
    """
    data, times = raw_od.get_data(return_times=True)
    ch_names = raw_od.ch_names

    # Group channel indices by S-D pair ID (e.g., "S3_D4")
    sd_pair_map = defaultdict(dict)
    for idx, ch_name in enumerate(ch_names):
        parts = ch_name.split()
        sd_id = parts[0]  # e.g., "S3_D4"
        wl = parts[1]     # e.g., "760" or "850"
        sd_pair_map[sd_id][wl] = idx

    plotted = 0
    for sd_id, wl_map in sd_pair_map.items():
        if "760" in wl_map and "850" in wl_map:
            idx_760 = wl_map["760"]
            idx_850 = wl_map["850"]

            plt.figure(figsize=(10, 4))
            plt.plot(times, data[idx_760], label=f"{sd_id} 760 nm", color='blue')
            plt.plot(times, data[idx_850], label=f"{sd_id} 850 nm", color='green')
            plt.axvspan(attention_start, attention_end, color='red', alpha=0.2, label='Attention Period')
            plt.title(f"S-D Pair: {sd_id}")
            plt.xlabel("Time (s)")
            plt.ylabel("Optical Density")
            plt.legend()
            plt.grid(True)
            plt.tight_layout()
            plt.show()
             # 🧹 Closes the figure to prevent buildup


            plotted += 1
            if max_pairs is not None and plotted >= max_pairs:
                break
for participant in participant_od_good:
    for i in range(len(participant_od_good[participant])):
        print(f"\nInspecting GOOD channels: {participant} - File {i+1}")
        plot_sd_pairs_with_task_overlay(
            participant_od_good[participant][i],
            attention_start=34.5,
            attention_end=76.5,
       
        )

        print(f"\nInspecting BAD channels: {participant} - File {i+1}")
        plot_sd_pairs_with_task_overlay(
            participant_od_bad[participant][i],
            attention_start=34.5,
            attention_end=76.5,

        )


In [None]:
from collections import defaultdict
import pandas as pd

# Final hand-picked S-D pairs per participant
final_channels = {
    "participant_1": [
        "S3_D3", "S3_D4", "S4_D4", "S5_D3", "S6_D4", "S6_D5", "S6_D6", "S7_D5", "S8_D6"
    ],
    "participant_2": [
        "S2_D1", "S3_D3", "S3_D4", "S4_D4", "S5_D3", "S6_D4", "S6_D5", "S6_D6", "S7_D5", "S8_D6"
    ],
    "participant_3": [
        "S1_D1", "S2_D1", "S3_D3", "S3_D4", "S4_D2", "S4_D4", "S5_D3", "S6_D4", "S6_D5", "S6_D6",
        "S7_D5", "S7_D7", "S8_D6"
    ],
}


# Prepare new dictionary
participant_od_final = {}

# Loop through participants
for participant in participant_od_good.keys():
    final_list = []
    for i in range(len(participant_od_good[participant])):

        # Combine good and bad channels
        combined_raw = participant_od_good[participant][i].copy().add_channels(
            [participant_od_bad[participant][i].copy()],
            force_update_info=True
        )

        # Map S-D IDs to their channel indices
        ch_names = combined_raw.ch_names
        channel_map = defaultdict(list)
        for idx, name in enumerate(ch_names):
            parts = name.split()
            sd_id = parts[0]  # e.g., "S3_D4"
            wl = parts[1]     # e.g., "760"
            channel_map[sd_id].append(idx)

        # Collect indices of final desired channels
        desired_channels = final_channels[participant]
                # Enforce correct ordering of S-D pairs
        indices_to_keep = []
        for sd_pair in final_channels[participant]:
            if sd_pair in channel_map:
                # Always add 760 first, then 850 for consistency
                wl_sorted = sorted(channel_map[sd_pair], key=lambda idx: combined_raw.ch_names[idx].split()[1])

                indices_to_keep.extend(wl_sorted)


        # Pick final channels
        raw_selected = combined_raw.copy().pick(indices_to_keep)
        final_list.append(raw_selected)

    participant_od_final[participant] = final_list

# 🧾 Display in readable format using pandas
rows = []
for participant, recordings in participant_od_final.items():
    for i, raw in enumerate(recordings):
        rows.append({
            "Participant": participant,
            "File": i + 1,
            "Channels": ", ".join(raw.ch_names)
        })

df_final_channels = pd.DataFrame(rows)
pd.set_option("display.max_colwidth", None)

display(df_final_channels)

In [None]:
import scipy.io
from mne.io import RawArray
from mne import create_info
import numpy as np

# === File paths
file1_path = r"C:\Users\xtwf7586\OneDrive - University of Leeds\New folder\nirs-resources-main\syntheticNIRS\RESOD2_struct_noiselevel2.mat"
file2_path = r"C:\Users\xtwf7586\OneDrive - University of Leeds\New folder\nirs-resources-main\syntheticNIRS\RESOD2_struct_noiselevel2_file2.mat"

# === Load both .mat files
mat1 = scipy.io.loadmat(file1_path)
mat2 = scipy.io.loadmat(file2_path)

print("File 1 keys:", mat1.keys())
print("File 2 keys:", mat2.keys())

# Get the real sampling frequency from your actual recording
real_sfreq_file1 = participant_od_final["participant_1"][0].info['sfreq']
real_sfreq_file2 = participant_od_final["participant_1"][1].info['sfreq']

def build_raw_from_struct(mat, sfreq, title):
    od_data = mat['data'].T
    times = mat['time'].flatten()

    probe = mat['probe'][0, 0]
    link = probe['link'][0, 0]

    src = link['source'].flatten()
    det = link['detector'].flatten()
    wl  = link['type'].flatten()

    ch_names = [f"S{src[i]}_D{det[i]} {wl[i]}" for i in range(len(src))]

    info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types='fnirs_od')
    raw = RawArray(od_data, info)

    print(raw)
    raw.plot(n_channels=10, title=title)
    return raw


# === Build Raw objects
synthetic_raw_file1 = build_raw_from_struct(mat1, sfreq=real_sfreq_file1, title="Synthetic OD – File 1")
synthetic_raw_file2 = build_raw_from_struct(mat2, sfreq=real_sfreq_file2, title="Synthetic OD – File 2")



In [None]:
print("🔹 Real File 1 channels:")
print(participant_od_final["participant_1"][0].ch_names)

print("\n🔸 Synthetic File 1 channels:")
print(synthetic_raw_file1.ch_names)

In [None]:
synthetic_raw_file1 = synthetic_raw_file1.copy().reorder_channels(
    participant_od_final["participant_1"][0].ch_names
)


In [None]:
synthetic_raw_file2 = synthetic_raw_file2.copy().reorder_channels(
    participant_od_final["participant_1"][1].ch_names
)


In [None]:
assert participant_od_final["participant_1"][0].ch_names == synthetic_raw_file1.ch_names
assert participant_od_final["participant_1"][1].ch_names == synthetic_raw_file2.ch_names


In [None]:
print("Real File 1 sfreq:", participant_od_final["participant_1"][0].info['sfreq'])
print("Synthetic File 1 sfreq:", synthetic_raw_file1.info['sfreq'])

print("Real File 2 sfreq:", participant_od_final["participant_1"][1].info['sfreq'])
print("Synthetic File 2 sfreq:", synthetic_raw_file2.info['sfreq'])


In [None]:
# Convert real data to RawArray
real_raw1 = participant_od_final["participant_1"][0]
real_data1, _ = real_raw1.get_data(return_times=True)
info1 = real_raw1.info.copy()
real_raw_array1 = RawArray(real_data1, info1)

real_raw2 = participant_od_final["participant_1"][1]
real_data2, _ = real_raw2.get_data(return_times=True)
info2 = real_raw2.info.copy()
real_raw_array2 = RawArray(real_data2, info2)

In [None]:
from mne import concatenate_raws

extended_file1 = concatenate_raws([real_raw_array1, synthetic_raw_file1])
extended_file2 = concatenate_raws([real_raw_array2, synthetic_raw_file2])


In [None]:
participant_od_final["participant_1"][0] = extended_file1
participant_od_final["participant_1"][1] = extended_file2
print("Extended File 1 duration (sec):", participant_od_final["participant_1"][0].times[-1])
print("Extended File 2 duration (sec):", participant_od_final["participant_1"][1].times[-1])
print("File 1: n_times =", participant_od_final["participant_1"][0].n_times)
print("sfreq =", participant_od_final["participant_1"][0].info['sfreq'])
print("duration (computed) =", participant_od_final["participant_1"][0].n_times / participant_od_final["participant_1"][0].info['sfreq'])


In [None]:
participant_od_final["participant_1"][0].save("participant1_file1_extended_raw.fif", overwrite=True)
participant_od_final["participant_1"][1].save("participant1_file2_extended_raw.fif", overwrite=True)


In [None]:
from mne.preprocessing.nirs import beer_lambert_law
for participant, recordings in participant_od_final.items():
    hb_list = []

    for i, raw_od in enumerate(recordings):
        if participant == "participant_1" and i in [0, 1]:
            raw_od = participant_od_final[participant][i]  # already extended

        # ✅ Step 1: Convert OD to Hb
        raw_hb = beer_lambert_law(raw_od)

In [None]:
import matplotlib.pyplot as plt
from collections import defaultdict

def plot_sd_pairs_hbo_hbr(raw_hb, attention_start=34.5, attention_end=76.5, max_pairs=None):
    data, times = raw_hb.get_data(return_times=True)
    ch_names = raw_hb.ch_names

    # Group HbO/HbR indices by source-detector pair
    sd_pair_map = defaultdict(dict)
    for idx, name in enumerate(ch_names):
        parts = name.split()
        sd_id = parts[0]  # e.g., "S3_D4"
        chrom = parts[1]  # "hbo" or "hbr"
        sd_pair_map[sd_id][chrom] = idx

    plotted = 0
    for sd_id, chrom_map in sd_pair_map.items():
        if "hbo" in chrom_map and "hbr" in chrom_map:
            idx_hbo = chrom_map["hbo"]
            idx_hbr = chrom_map["hbr"]

            plt.figure(figsize=(10, 4))
            plt.plot(times, data[idx_hbo], label=f"{sd_id} HbO", color='red')
            plt.plot(times, data[idx_hbr], label=f"{sd_id} HbR", color='blue')
            plt.axvspan(attention_start, attention_end, color='pink', alpha=0.2, label='Attention Period')
            plt.title(f"Hemoglobin: {sd_id}")
            plt.xlabel("Time (s)")
            plt.ylabel("µmol/L")
            plt.legend()
            plt.grid(True)
            plt.tight_layout()
            plt.show()

            plotted += 1
            if max_pairs is not None and plotted >= max_pairs:
                break


In [None]:
for i, raw_od in enumerate(participant_od_final["participant_1"]):
    print(f"Participant 1 – File {i+1}")
    raw_hb = beer_lambert_law(raw_od)
    plot_sd_pairs_hbo_hbr(raw_hb)


In [None]:
from scipy.signal import butter, filtfilt

def butter_bandpass_filter(data, sfreq, lowcut=0.001, highcut=0.2, order=3):
    nyq = 0.5 * sfreq
    b, a = butter(order, [lowcut / nyq, highcut / nyq], btype='band')
    return filtfilt(b, a, data, axis=1)

In [None]:
participant_hb_final = {}

for participant, recordings in participant_od_final.items():
    hb_list = []

    for i, raw_od in enumerate(recordings):
        # 1. Convert to hemoglobin
        raw_hb = beer_lambert_law(raw_od)

        # 2. Filter
        sfreq = raw_hb.info['sfreq']
        filtered_data = butter_bandpass_filter(raw_hb.get_data(), sfreq)

        # 3. Overwrite with filtered data
        raw_hb._data = filtered_data

        # 4. (Optional) Downsample to match Tufts
        raw_hb.resample(5.2, npad="auto")

        hb_list.append(raw_hb)

    participant_hb_final[participant] = hb_list


In [None]:
for participant, files in participant_hb_final.items():
    print(f"\n=== {participant} ===")
    for i, raw in enumerate(files):
        sfreq = raw.info['sfreq']
        duration = raw.times[-1]
        print(f"File {i+1}: sfreq = {sfreq:.2f} Hz | Duration = {duration:.2f} s")

In [None]:
def extract_sliding_windows(raw_data, window_size=93, step_size=3):
    """
    Extract overlapping 18s sliding windows from raw fNIRS data.

    Parameters
    ----------
    raw_data : ndarray, shape (n_channels, n_times)
    window_size : int
        Number of time points per window (93 for 18s @ 5.2Hz)
    step_size : int
        Number of samples to shift between windows (3 for 0.6s stride)

    Returns
    -------
    windows : ndarray, shape (n_windows, window_size, n_channels)
    """
    windows = []
    n_channels, n_times = raw_data.shape
    for start in range(0, n_times - window_size + 1, step_size):
        window = raw_data[:, start:start + window_size].T  # shape (93, n_channels)
        windows.append(window)
    return np.stack(windows)


In [None]:
# Set correct window size for 18s @ 5.2 Hz
WINDOW_SIZE = int(18 * 5.2)  # 93 samples
STRIDE = int(0.6 * 5.2)      # 3 samples

all_windows = {}
for participant, recordings in participant_hb_final.items():
    participant_windows = []
    for i, raw_hb in enumerate(recordings):
        raw_data = raw_hb.get_data()  # shape: (n_channels, n_times)
        windows = extract_sliding_windows(raw_data, window_size=WINDOW_SIZE, step_size=STRIDE)
        participant_windows.append(windows)  # shape: (n_windows, 93, n_channels)
    all_windows[participant] = participant_windows


In [None]:
for participant, windows_list in all_windows.items():
    print(f"\n=== {participant} ===")
    for i, windows in enumerate(windows_list):
        print(f"File {i+1}: shape = {windows.shape}")

In [None]:
def label_windows(recording_durations, window_size_sec=18, stride_sec=0.6):
    ...

    labels_by_participant = {}

    for participant, file_durations in recording_durations.items():
        labels_by_file = []

        for i, duration in enumerate(file_durations):
            n_windows = int((duration - window_size_sec) / stride_sec) + 1
            labels = []

            for w in range(n_windows):
                window_end = w * stride_sec + window_size_sec  # Tufts-style: use endpoint only

                # Define attention windows
                if participant == "participant_1" and i == 0:
                    attention_start, attention_end = 34.5, 52.22
                elif participant == "participant_1" and i == 1:
                    attention_start, attention_end = 34.5, 47.81
                else:
                    attention_start, attention_end = 34.5, 76.5

                # Label based on whether window END falls in attention
                if attention_start <= window_end <= attention_end:
                    labels.append(1)
                else:
                    labels.append(0)

            labels_by_file.append(labels)
        labels_by_participant[participant] = labels_by_file

    return labels_by_participant

In [None]:
recording_durations = {
    "participant_1": [76.35, 76.35, 119.23, 92.12],
    "participant_2": [129.62, 91.73, 96.92, 80.38],
    "participant_3": [86.73, 97.50, 87.88, 92.69]
}

labels_all = label_windows(recording_durations)  # ✅ call the function

participant = "participant_1"
file_idx = 0
windows = all_windows[participant][file_idx]
labels = labels_all[participant][file_idx]  # ✅ use the result of the function

# Plot the last 5 windows
for idx in range(len(labels) - 5, len(labels)):
    plt.figure(figsize=(8, 3))
    plt.plot(windows[idx][:, 0], color='purple')  # channel 0 (HbO or HbR)
    plt.title(f"Window {idx} – Label: {labels[idx]}")
    plt.xlabel("Time points (at 5.2 Hz)")
    plt.ylabel("Signal")
    plt.grid()
    plt.show()

# Print end times and labels
for idx in range(len(labels)):
    end_time = idx * 0.6 + 18

    print(f"Window {idx} | Ends at ~{end_time:.1f}s | Label: {labels[idx]}")


In [None]:
import matplotlib.pyplot as plt
import numpy as np

participant = "participant_1"
file_idx = 0

windows = all_windows[participant][file_idx]  # shape: (n_windows, 93, 8)
labels = labels_all[participant][file_idx]     # list of 0s and 1s
stride = 3  # samples
fs = 5.2    # Hz

# Reconstruct continuous time series from overlapping windows
# Use the center of each window for plotting
signal = windows[:, :, 0]  # channel 0, shape: (n_windows, 93)
window_centers = np.arange(len(windows)) * stride + (93 // 2)
time_axis = window_centers / fs  # in seconds
center_values = signal[:, 93 // 2]

# Plot
plt.figure(figsize=(14, 4))
plt.plot(time_axis, center_values, label='Channel 0 (center of each window)', color='purple')

# Add shaded label regions
for i in range(len(labels)):
    start_time = (i * stride) / fs
    end_time = (i * stride + 93) / fs
    color = 'red' if labels[i] == 1 else 'gray'
    plt.axvspan(start_time, end_time, alpha=0.15, color=color)

plt.title("Signal from Channel 0 with Window Labels (Gray = Rest, Red = Attention)")
plt.xlabel("Time (s)")
plt.ylabel("Signal (HbO/HbR)")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
def select_8_channels(raw, selected_pairs):
    selected_chs = []
    for ch_name in raw.ch_names:
        for pair in selected_pairs:
            if pair in ch_name and ('hbo' in ch_name or 'hbr' in ch_name):
                selected_chs.append(ch_name)
    return raw.copy().pick_channels(selected_chs)


In [None]:
selected_pairs = ['S3_D3', 'S4_D4', 'S6_D6', 'S8_D6']


In [None]:
participant_hb_selected = {}

for participant, recordings in participant_hb_final.items():
    selected_list = []
    for raw in recordings:
        selected_raw = select_8_channels(raw, selected_pairs)
        selected_list.append(selected_raw)
    participant_hb_selected[participant] = selected_list

In [None]:
for participant, files in participant_hb_selected.items():
    print(f"\n{participant}")
    for i, raw in enumerate(files):
        print(f"File {i+1}: {len(raw.ch_names)} channels → {raw.ch_names}")

In [None]:
def extract_sliding_windows(raw_data, window_size=93, step_size=3):

    windows = []
    n_channels, n_times = raw_data.shape
    for start in range(0, n_times - window_size + 1, step_size):
        window = raw_data[:, start:start + window_size].T  # shape (93, n_channels)
        windows.append(window)
    return np.stack(windows)

In [None]:
# Set correct window size for 18s @ 5.2 Hz
WINDOW_SIZE = int(18 * 5.2)  # 93 samples
STRIDE = int(0.6 * 5.2)      # 3 samples

all_windows = {}
for participant, recordings in participant_hb_selected.items():
    participant_windows = []
    for i, raw_hb in enumerate(recordings):
        raw_data = raw_hb.get_data()  # shape: (n_channels, n_times)
        windows = extract_sliding_windows(raw_data, window_size=WINDOW_SIZE, step_size=STRIDE)
        participant_windows.append(windows)  # shape: (n_windows, 93, n_channels)
    all_windows[participant] = participant_windows


In [None]:
labels_all = label_windows(recording_durations)

In [None]:
for participant in all_windows:
    print(f"\n{participant}")
    for i in range(4):
        print(f"File {i+1}: {all_windows[participant][i].shape[0]} windows | {len(labels_all[participant][i])} labels")


In [None]:
import os
import numpy as np
from sklearn.model_selection import train_test_split

# Step 1: Combine all windows and labels
X_list = []
y_list = []

for participant in all_windows:
    for i in range(len(all_windows[participant])):
        windows = all_windows[participant][i]
        labels = labels_all[participant][i]

        # Trim to match
        min_len = min(len(windows), len(labels))
        windows = windows[:min_len]
        labels = labels[:min_len]

        X_list.append(windows)
        y_list.append(labels)

# Step 2: Convert to arrays
X = np.vstack(X_list)  # Shape: (N, 150, 8)
y = np.hstack(y_list)  # Shape: (N,)
print("X shape:", X.shape)
print("y shape:", y.shape)
print("Label counts:", np.bincount(y))

In [None]:
from tsai.data.core import TSTensor

X_test = np.transpose(X, (0, 2, 1)).astype(np.float32)
X_test = TSTensor(X_test)
print(X_test.shape)  # should now be (1285, 8, 150)


In [None]:
import pathlib
import sys

# Patch to avoid PosixPath error on Windows
if sys.platform == "win32":
    pathlib.PosixPath = pathlib.WindowsPath


In [None]:
from tsai.all import *

# Use raw string (r"...") or escape backslashes
model_path = r"C:\Users\xtwf7586\OneDrive - University of Leeds\fnirs\AuditoryRecording13032025\models\models\LSTM_Tufts_Ccclean18.pkl"
learn = load_learner(model_path)

print("✅ Loaded:", learn.model.__class__.__name__)

In [None]:
from tsai.all import *
import numpy as np
from tsai.data.core import TSTensor
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
# 3) Get predictions — handle variable return lengths across tsai versions
res = learn.get_X_preds(X_test, y=y, bs=128)
# res can be (probs, targs) or (probs, targs, decoded) or (probs, targs, decoded, losses)
probs = res[0]
targs = res[1] if len(res) > 1 else None

# 4) Metrics (use your y directly if targs is None)
preds_class = probs.argmax(1).cpu().numpy()
true_y = y if targs is None else targs.cpu().numpy()

print("✅ Accuracy:", accuracy_score(true_y, preds_class))
print("\n📊 Classification Report:")
print(classification_report(true_y, preds_class))
print("\n🔍 Confusion Matrix:")
print(confusion_matrix(true_y, preds_class))

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

# Compute confusion matrix
cm = confusion_matrix(true_y, preds_class)

# Plot
plt.figure(figsize=(5,4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Class 0', 'Class 1'],
            yticklabels=['Class 0', 'Class 1'])
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix')
plt.show()


In [None]:
print("X shape:", X.shape)  # should be (samples, 93, 8)
print("y shape:", y.shape)  # should be (samples,)
print("X dtype:", X.dtype)  # should be float32


In [None]:
X = X.astype(np.float32)



In [None]:
# 2. Standardize (match training)
X_mean = X.mean(axis=(0, 1), keepdims=True)
X_stddev = X.std(axis=(0, 1), keepdims=True)
X_std = (X - X_mean) / X_stddev

In [None]:
# 3. Predict
preds_raw, _, _ = learn.get_X_preds(X_std)
preds_class = np.argmax(preds_raw, axis=1)

In [None]:
# 4. Evaluate
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

print("✅ Accuracy:", accuracy_score(y, preds_class))
print("\n📊 Classification Report:")
print(classification_report(y, preds_class, target_names=["Rest", "Attention"]))
print("\n🔍 Confusion Matrix:")
print(confusion_matrix(y, preds_class))


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# X_std = your standardized input (shape: n_samples, 93, 8)
# y = true labels (binary)
# preds_class = predicted labels (from get_X_preds)

# Just in case
assert len(y) == len(preds_class), "Mismatch in prediction and label lengths"

plt.figure(figsize=(14, 4))

plt.plot(y, label='True Label', linewidth=2, alpha=0.7)
plt.plot(preds_class, label='Predicted Label', linestyle='dashed', alpha=0.7)

plt.xlabel("Window Index")
plt.ylabel("Label (0 = Rest, 1 = Attention)")
plt.title("🧠 Model Predictions vs True Labels Over Time")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
# Step 3: Train/val split (random)
X_cf = np.transpose(X, (0, 2, 1)).astype(np.float32)
train_idx, val_idx = train_test_split(np.arange(len(y)), test_size=0.2, random_state=42, stratify=y)
splits = (list(train_idx), list(val_idx))

# Slice from X_cf instead of X
X_train, X_val = X_cf[splits[0]], X_cf[splits[1]]
y_train, y_val = y[splits[0]], y[splits[1]]

print("Train data shape:", X_train.shape)
print("Validation data shape:", X_val.shape)
print("Train label counts:", np.bincount(y_train))
print("Validation label counts:", np.bincount(y_val))


In [None]:
from tsai.all import *
from fastai.callback.all import SaveModelCallback

# Define transforms
tfms = [None, [Categorize()]]
batch_tfms = TSStandardize(by_sample=True)

# Create datasets with splits
dsets = TSDatasets(X_cf, y, tfms=tfms, splits=splits, inplace=True)

# Create DataLoaders
dls = TSDataLoaders.from_dsets(dsets.train, dsets.valid, bs=[64, 128], batch_tfms=[TSStandardize()], num_workers=0)

In [None]:
from fastai.callback.all import SaveModelCallback

mv_clf = TSClassifier(
    X_cf, y, splits=splits,
    arch=LSTM, arch_config={'n_layers': 3, 'bidirectional': True},
    tfms=tfms, batch_tfms=batch_tfms,
    metrics=accuracy,
    path='models',
    cbs=ShowGraph()
)

# Train it
mv_clf.fit_one_cycle(20, 1e-3)  # 1e-3 is safer for LSTM, but you can match Tufts exactly

# Export the trained model
mv_clf.save('own_final18')  


In [None]:
# Get predictions and true labels on the training set
all_preds = mv_clf.get_X_preds(X_train, y_train, with_input=False)
preds = all_preds[0]   # shape: (n_samples, n_classes)
targets = all_preds[1] # shape: (n_samples,)

from sklearn.metrics import accuracy_score

train_acc = accuracy_score(targets, preds.argmax(axis=1))
print(f"✅ Training accuracy: {train_acc:.4f}")


In [None]:
val_preds = mv_clf.get_X_preds(X_val, y_val, with_input=False)
val_acc = accuracy_score(val_preds[1], val_preds[0].argmax(axis=1))
print(f"✅ Validation accuracy: {val_acc:.4f}")


In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report, accuracy_score
import matplotlib.pyplot as plt

# Get predictions on the validation split (ds_idx=1)
preds, targs = mv_clf.get_preds(ds_idx=1)   # returns torch tensors

y_val_pred = preds.argmax(1).cpu().numpy()
y_val_true = targs.cpu().numpy()

print("Val accuracy:", accuracy_score(y_val_true, y_val_pred))
print(classification_report(y_val_true, y_val_pred, digits=3))

cm = confusion_matrix(y_val_true, y_val_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Blues', xticks_rotation=0)
plt.title("Validation Confusion Matrix")
plt.tight_layout()
plt.show()


In [None]:
from tsai.all import *

# Step 1: Define your model exactly as before
# ✅ Correct:
model = create_model(LSTM, c_in=8, c_out=2, seq_len=150, arch_config={'n_layers': 3, 'bidirectional': True})

# Step 2: Create dummy dataloaders (for export only)
def get_dummy_dl(seq_len=150, n_vars=8, n_classes=2):
    X_dummy = np.random.randn(2, 8, 150).astype(np.float32)
    y_dummy = np.array([0, 1])
    tfms = [None, [Categorize()]]
    dsets = TSDatasets(X_dummy, y_dummy, tfms=tfms)
    return TSDataLoaders.from_dsets(dsets.train, dsets.valid, bs=2)

dls = get_dummy_dl()

# ✅ Step 3: Load your trained weights using the correct path
learn = Learner(
    dls,
    model,
    metrics=accuracy,
    path='models/models',   # ✅ This is the folder containing your .pth
    model_dir='.'           # ✅ Because the .pth file is directly inside it
)

# ✅ Step 4: Load and export
learn.load('own_final18')  # Will now correctly find and load the weights
learn.export('LSTM_final18_clean_export.pkl')  # ✅ Clean, inference-ready export

In [None]:
from pathlib import Path

export_path = Path(learn.path) / learn.model_dir / "LSTM_final18_clean_export.pkl"
print("✅ Model exported to:", export_path)
print("✅ File exists?", export_path.exists())
