In [1]:
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
from scipy import signal
import scipy.io as sio

import re
import io

!pip install mne
import mne

Defaulting to user installation because normal site-packages is not writeable


In [23]:
from typing import Optional, Tuple, Dict, Any

## Import Dataset


In [24]:
def _todict(matobj):
    """
    Convert MATLAB structs to nested Python dicts.
    Handles simple nested structs returned by scipy.io.loadmat with
    struct_as_record=False and squeeze_me=True
    """
    result = {}
    for fieldname in matobj._fieldnames:
        elem = getattr(matobj, fieldname)
        if hasattr(elem, "_fieldnames"):
            result[fieldname] = _todict(elem)
        else:
            result[fieldname] = elem
    return result


def _check_and_unpack(value):
    """
    If value is a numpy void/record or matlab object, unpack into dict/list/ndarray.
    """
    # matlab object (mat_struct) instance
    if hasattr(value, "_fieldnames"):
        return _todict(value)
    # numpy array of objects, maybe from struct array
    if isinstance(value, np.ndarray) and value.dtype == np.object_:
        try:
            # try to squeeze and take the first element if it looks like a struct wrapper
            squeezed = np.squeeze(value)
            if hasattr(squeezed, "_fieldnames"):
                return _todict(squeezed)
            return squeezed
        except Exception:
            return value
    return value


def load_mat_to_dataframe(
    path: str,
    y_name: Optional[str] = None,
    trig_name: Optional[str] = None,
    expected_n_channels: Optional[int] = None,
    channel_names: Optional[list] = None,
    return_meta: bool = False,
    verbose: bool = False,
) -> Tuple[pd.DataFrame, Optional[Dict[str, Any]]]:
    """
    Load .mat file and return a pandas DataFrame with channels and trig column.

    Parameters
    ----------
    path : str
        Path to the .mat file.
    y_name : str, optional
        Name of variable in .mat containing the signal array.
        If None, function will attempt to auto-detect the main 2D array.
    trig_name : str, optional
        Name of trigger/event vector in .mat. If None, function will try common names
        or match by length to the signal.
    expected_n_channels : int, optional
        Hint for how many channels the signal has, e.g. 16. Helps selection if multiple 2D arrays exist.
    channel_names : list, optional
        List of channel names of length n_channels. If None, channel names will be ch1, ch2, ...
    return_meta : bool, optional
        If True, return a second value with meta info (selected variable names, shapes).
    verbose : bool, optional
        Print debug info.

    Returns
    -------
    df : pandas.DataFrame
        DataFrame with shape (n_samples, n_channels + 1), last column named 'trig'
    meta : dict or None
        When return_meta True, returns { 'y_name':..., 'trig_name':..., 'shapes': {...} }
    """
    mat = sio.loadmat(path, struct_as_record=False, squeeze_me=True)

    # clean keys
    keys = [k for k in mat.keys() if not k.startswith("__")]
    if verbose:
        print("Variables found in .mat:", keys)

    # helper to collect candidate arrays
    candidates = {}
    for k in keys:
        val = mat[k]
        val = _check_and_unpack(val)
        candidates[k] = val

    # find y (2D signal)
    y = None
    chosen_y_name = None

    if y_name and y_name in candidates:
        y = candidates[y_name]
        chosen_y_name = y_name
    else:
        # try names that hint at signal
        prefer_names = ["y", "data", "signals", "eeg", "eeg_data", "X", "val"]
        for name in prefer_names:
            if name in candidates and isinstance(candidates[name], np.ndarray) and candidates[name].ndim >= 2:
                y = candidates[name]
                chosen_y_name = name
                break

    # fallback: pick the largest 2D numeric array, or the one matching expected_n_channels
    if y is None:
        two_d_arrays = {k: v for k, v in candidates.items() if isinstance(v, np.ndarray) and v.ndim == 2}
        if verbose:
            print("2D arrays found:", {k: v.shape for k, v in two_d_arrays.items()})
        if expected_n_channels is not None:
            # prefer array with one dimension equal to expected_n_channels
            for k, v in two_d_arrays.items():
                if v.shape[0] == expected_n_channels or v.shape[1] == expected_n_channels:
                    y = v
                    chosen_y_name = k
                    break
        if y is None and two_d_arrays:
            # choose largest by element count
            chosen_y_name, y = max(two_d_arrays.items(), key=lambda kv: kv[1].size)

    if y is None:
        raise ValueError("Could not find a 2D signal array in the .mat file, provide y_name or check the file.")

    # ensure numpy array
    y = np.array(y)

    # shape fix: if channels are rows and expected_n_channels matches first dim, transpose
    if y.ndim != 2:
        raise ValueError(f"Selected signal array {chosen_y_name} is not 2D, got shape {y.shape}")

    n_samples, n_channels = y.shape
    # if it's (channels, samples) swap
    if expected_n_channels is not None and n_channels != expected_n_channels and y.shape[0] == expected_n_channels:
        if verbose:
            print("Transposing y, channels are in rows, not columns.")
        y = y.T
        n_samples, n_channels = y.shape

    # second check: if channels still not as expected but first dim is much larger, assume first dim is samples
    if expected_n_channels is None and n_channels > y.shape[0]:
        # nothing to do, keep as is
        pass

    # find trig
    trig = None
    chosen_trig_name = None
    if trig_name and trig_name in candidates:
        trig = candidates[trig_name]
        chosen_trig_name = trig_name
    else:
        # try common trigger names first
        trig_candidates_names = ["trig", "trigger", "triggers", "event", "events", "markers", "marker"]
        for name in trig_candidates_names:
            if name in candidates:
                val = candidates[name]
                if isinstance(val, np.ndarray) and (val.ndim == 1 or (val.ndim == 2 and 1 in val.shape)):
                    trig = val
                    chosen_trig_name = name
                    break

    # if not found, try arrays that match length
    if trig is None:
        for k, v in candidates.items():
            if isinstance(v, np.ndarray):
                # flatten candidate shapes to length
                if v.ndim == 1 and v.shape[0] == n_samples:
                    trig = v
                    chosen_trig_name = k
                    break
                if v.ndim == 2 and (v.shape[0] == n_samples or v.shape[1] == n_samples):
                    trig = v
                    chosen_trig_name = k
                    break

    if trig is None:
        # no trig found, create a zeros vector and warn
        if verbose:
            print("No trigger vector detected, creating a zero 'trig' column.")
        trig = np.zeros(n_samples, dtype=int)
        chosen_trig_name = None

    # normalize trig shape to 1D
    trig = np.array(trig).squeeze()
    if trig.ndim != 1 or trig.shape[0] != n_samples:
        # try transpose or reshape if it's (n_samples,1) or (1,n_samples)
        if trig.ndim == 2 and (trig.shape[0] == n_samples or trig.shape[1] == n_samples):
            trig = trig.ravel()
        else:
            raise ValueError(f"Trigger vector shape {trig.shape} does not match number of samples {n_samples}")

    # build DataFrame
    if channel_names is not None:
        if len(channel_names) != n_channels:
            raise ValueError("channel_names length does not match number of channels")
        cols = channel_names
    else:
        cols = [f"ch{i+1}" for i in range(n_channels)]

    df = pd.DataFrame(y, columns=cols)
    df["trig"] = trig


    meta = {
        "y_name": chosen_y_name,
        "y_shape": y.shape,
        "trig_name": chosen_trig_name,
        "trig_shape": trig.shape,
        "n_samples": n_samples,
        "n_channels": n_channels,
    }

    if verbose:
        print("Loaded signal", meta)

    if return_meta:
        return df, meta
 
    return df, meta

In [25]:
# Get metadata too
#df, meta = load_mat_to_dataframe("your_file.mat", expected_n_channels=16, return_meta=True)
#print(meta)

# Use explicit variable names if you know them
df, meta = load_mat_to_dataframe("data\P1_pre_training.mat", y_name="y", trig_name="trig", channel_names=[f"ch{i+1}" for i in range(16)])
print(meta)


{'y_name': 'y', 'y_shape': (271816, 16), 'trig_name': 'trig', 'trig_shape': (271816,), 'n_samples': 271816, 'n_channels': 16}


In [38]:
df.head(5)

Unnamed: 0,ch1,ch2,ch3,ch4,ch5,ch6,ch7,ch8,ch9,ch10,ch11,ch12,ch13,ch14,ch15,ch16,trig
0,38001.839844,36892.0,46397.214844,50350.613281,37726.980469,57036.867188,46999.265625,47667.355469,47781.109375,39447.503906,49432.898438,48651.296875,49362.050781,47384.273438,38979.992188,61820.179688,0
1,54828.550781,54868.15625,52444.15625,52408.507812,54521.210938,50932.191406,52738.488281,52781.171875,52272.847656,54376.527344,52604.574219,52560.785156,52507.796875,52978.863281,54474.457031,50255.472656,0
2,-24629.253906,-24865.400391,-22964.140625,-22914.605469,-24873.259766,-22113.113281,-23516.539062,-23355.142578,-22624.40625,-24587.474609,-23058.496094,-23198.757812,-23129.423828,-23339.222656,-24563.939453,-21289.884766,0
3,-77839.5625,-78191.210938,-76574.710938,-76846.210938,-78150.195312,-76362.351562,-77222.789062,-77171.90625,-76289.273438,-78103.726562,-77080.8125,-77045.546875,-76983.515625,-77237.015625,-77972.484375,-75860.140625,0
4,-48999.753906,-49155.601562,-49435.792969,-49648.988281,-49283.152344,-50093.390625,-49608.496094,-49621.535156,-49217.925781,-49437.320312,-49688.796875,-50009.597656,-49660.007812,-49592.742188,-49244.277344,-50159.144531,0


In [52]:
df.shape

(271816, 17)

In [55]:
df.columns

Index(['ch1', 'ch2', 'ch3', 'ch4', 'ch5', 'ch6', 'ch7', 'ch8', 'ch9', 'ch10',
       'ch11', 'ch12', 'ch13', 'ch14', 'ch15', 'ch16', 'trig'],
      dtype='object')

In [51]:
df.describe()

Unnamed: 0,ch1,ch2,ch3,ch4,ch5,ch6,ch7,ch8,ch9,ch10,ch11,ch12,ch13,ch14,ch15,ch16,trig
count,271816.0,271816.0,271816.0,271816.0,271816.0,271816.0,271816.0,271816.0,271816.0,271816.0,271816.0,271816.0,271816.0,271816.0,271816.0,271816.0,271816.0
mean,-158.634849,-62.401942,-393.462122,-1867.455619,-451.789993,-510.357172,-424.919882,-445.915095,-398.677113,-434.922062,-385.399846,-2563.940475,-277.20857,-317.877191,-379.092799,-358.594691,0.0
std,34555.736787,38064.502172,34636.156238,43374.740135,34809.260035,34773.400111,34697.77947,34865.721134,34670.438628,34757.789417,34809.551029,43789.335316,33992.150741,33911.691735,34098.53245,34048.760973,0.776378
min,-403234.1875,-383352.09375,-404016.25,-455808.28125,-406625.59375,-405575.15625,-406884.75,-406483.875,-402868.875,-403808.78125,-403821.9375,-457645.71875,-395735.65625,-395571.3125,-394307.9375,-394534.65625,-1.0
25%,232.64167,64.461695,-37.606057,37.391302,-18.758564,-70.169296,18.684913,-22.009736,25.188777,-84.171661,-3.021342,-42.622512,60.759815,0.082577,-33.289404,-17.288617,-1.0
50%,256.634674,88.861897,-10.964761,59.308496,1.687507,-47.533743,43.577663,0.426885,50.607193,-60.249186,15.840919,-24.089332,87.887077,33.465153,-15.59593,0.223157,0.0
75%,370.388687,239.138687,136.366196,89.040525,108.721699,57.051946,177.577412,89.039349,145.125107,79.729914,105.662832,38.486605,227.555447,198.336864,43.761347,105.226824,1.0
max,358119.8125,345198.8125,358821.8125,376052.4375,359421.9375,358761.125,359550.25,359233.46875,357844.65625,357395.71875,357450.28125,373367.25,355197.5,354565.71875,354573.9375,353794.15625,1.0


In [32]:
df['trig'].nunique()

3

In [56]:
df[df['trig'] == -1]

Unnamed: 0,ch1,ch2,ch3,ch4,ch5,ch6,ch7,ch8,ch9,ch10,ch11,ch12,ch13,ch14,ch15,ch16,trig
65838,599.143921,372.026001,281.206818,145.661758,175.493774,291.693695,356.972168,230.760513,234.908661,256.582001,192.171509,132.183365,448.610291,523.696716,158.340958,197.838867,-1
65839,590.737305,361.050690,268.419586,130.823792,150.245514,268.407043,332.262146,216.259171,222.842331,244.881943,177.966324,125.428322,435.948364,501.930084,150.129852,192.365173,-1
65840,611.666809,376.292084,288.996887,151.781876,167.275955,289.666931,357.749603,237.573914,243.092819,262.182739,200.732285,146.401184,453.051086,522.689636,162.566895,192.647720,-1
65841,595.159241,361.700928,268.992371,130.855682,155.956116,275.761078,342.734894,218.474091,221.565277,246.658127,178.025574,109.876862,432.919128,486.806580,154.137604,190.529251,-1
65842,600.869507,374.406647,283.435883,150.169312,166.620850,277.407806,348.804474,234.009842,241.137817,262.942566,199.974701,141.113037,454.170380,511.270294,165.405884,190.998077,-1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
271795,242.781525,67.049637,-19.068010,70.790031,-19.983019,-56.141350,12.221759,-15.530634,35.991051,-80.896461,7.147117,-39.770390,64.709167,-7.892024,-22.693344,-13.952981,-1
271796,215.742676,53.115944,-36.236279,55.332569,-31.691547,-86.276329,-3.875594,-34.907341,18.115387,-94.078728,-6.348837,-48.067036,49.096142,-6.562933,-33.956944,-22.609144,-1
271797,227.186951,65.063400,-25.860865,65.729813,-30.707947,-76.613670,2.796699,-21.934059,29.927046,-83.115623,4.164769,-50.817680,62.357845,-5.930853,-26.676371,-19.922352,-1
271798,224.690292,57.086212,-30.261307,60.743656,-20.344366,-78.378143,2.439494,-28.684645,21.184464,-91.093040,-6.866218,-42.286804,53.596783,-6.648259,-28.449129,-19.770422,-1


In [59]:
df[df['trig'] == -1].value_counts().sum()

81920

In [60]:
df[df['trig'] == 1]

Unnamed: 0,ch1,ch2,ch3,ch4,ch5,ch6,ch7,ch8,ch9,ch10,ch11,ch12,ch13,ch14,ch15,ch16,trig
78019,538.407776,284.577942,217.008270,144.951645,102.956032,160.371689,266.511169,135.629776,176.695282,122.448891,121.219856,52.823376,281.868927,304.229675,63.344437,95.076012,1
78020,550.695740,295.105438,225.427490,153.456604,99.815903,172.223312,273.248962,148.608765,185.736984,130.632568,122.309441,47.292412,292.681396,307.453430,64.446022,81.637375,1
78021,563.211975,303.441254,235.341415,162.252563,95.773018,200.389191,287.629272,156.012405,195.468185,138.800293,133.540527,62.812519,301.962921,332.096161,71.698067,100.078918,1
78022,540.694275,282.758179,214.209747,141.492157,92.394684,181.384933,265.352020,133.570923,168.979721,119.109192,118.248795,63.884270,274.851990,311.895203,55.679478,103.864632,1
78023,564.948181,310.085175,243.840195,170.754883,107.687920,210.079819,292.598755,161.604614,201.584869,145.799149,142.130280,67.299446,307.015381,348.632050,74.334740,106.812218,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
259019,233.126755,63.924778,-33.429794,60.262287,-31.576845,-69.694893,21.729639,-15.273890,31.474510,-88.887428,7.060484,-34.103172,60.574394,-3.637934,-16.947046,-6.880708,1
259020,235.115479,64.153412,-39.183231,59.895462,-28.405546,-69.599243,22.252895,-13.396914,30.261585,-90.566498,2.993686,-41.077267,61.183144,-5.229909,-16.334063,-5.503847,1
259021,234.167313,62.850479,-39.086857,63.098255,-34.388950,-57.915405,23.725485,-11.671236,33.615913,-90.613136,4.748469,-44.786137,61.070824,-1.697847,-22.236673,-12.567877,1
259022,223.164215,51.272194,-48.733780,51.004906,-38.531132,-68.939285,12.301775,-27.540352,17.515860,-101.280464,-4.824706,-46.925003,49.078785,-10.604307,-24.255993,-7.278707,1


In [61]:
df[df['trig'] == 1].value_counts().sum()

81920