In [4]:
import numpy as np
import scipy.io
from scipy import signal
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from scipy.interpolate import CubicSpline
from sklearn.linear_model import Ridge
from scipy.ndimage import gaussian_filter1d
from tqdm import tqdm
import joblib

# ========== Parameters ==========
fs = 1000
window_len = 100
overlap = 50
step = window_len - overlap
N_wind = 3

# ========== Gaussian Smoothing ==========
def gaussian_smoothing(pred, sigma=2.25):
    pred_smooth = pred.copy()
    for i in range(pred.shape[1]):
        pred_smooth[:, i] = gaussian_filter1d(pred[:, i], sigma=sigma)
    return pred_smooth

# ========== Outlier Suppression ==========
def suppress_low_outliers(pred, threshold_multiplier=2):
    pred_cleaned = pred.copy()
    for i in range(pred.shape[1]):
        col = pred[:, i]
        mean = np.mean(col)
        std = np.std(col)
        thresh = mean - threshold_multiplier * std
        pred_cleaned[:, i] = np.where(col < thresh, 0, col)
    return pred_cleaned

# ========== Bandpass Filter ==========
def bandpass_filter(data, fs, lowcut, highcut, order=4):
    nyq = 0.5 * fs
    low, high = lowcut/nyq, highcut/nyq
    b, a = signal.butter(order, [low, high], btype='band')
    return signal.filtfilt(b, a, data, axis=0)

# ========== Feature Extraction ==========
def get_features(window, fs=1000):
    freq_bands = [(5,15),(20,25),(75,115),(125,160),(160,175)]
    n_ch = window.shape[1]
    feats = np.zeros((n_ch, 6))
    feats[:,0] = np.mean(window,axis=0)
    for i,(low,high) in enumerate(freq_bands):
        bf = bandpass_filter(window, fs, low, high)
        feats[:,i+1] = np.mean(np.abs(bf),axis=0)
    return feats

# ========== Sliding Window ==========
def get_windowed_feats(ecog, fs, win_len, overlap):
    step = win_len - overlap
    n_win = (ecog.shape[0] - overlap)//step
    feats = []
    for i in tqdm(range(n_win), desc="Extracting features"):
        s, e = i*step, i*step+win_len
        if e > ecog.shape[0]: break
        feats.append(get_features(ecog[s:e,:], fs).flatten())
    return np.array(feats)

# ========== Create R Matrix ==========
def create_R_matrix(features, N_wind):
    n_w, n_f = features.shape
    pad = np.tile(features[0], (N_wind-1,1))
    concat = np.vstack([pad, features])
    R = np.zeros((n_w, 1 + N_wind*n_f))
    for i in range(n_w):
        ctx = concat[i:i+N_wind].flatten()
        R[i] = np.concatenate(([1], ctx))
    return R


# ========== Load model and dataset ==========
test_data        = scipy.io.loadmat('truetest_data.mat')
test_ecogs       = test_data['truetest_data']
models           = joblib.load('ridge_models.joblib')

# ========== Get predictions ==========
predicted_dg     = np.empty((3,1),dtype=object)

for subj_idx in range(3):
    print(f"\n=== Subject {subj_idx+1} ===")
    # load
    ecog_te = test_ecogs[subj_idx].item()
    pipeline = models[subj_idx]

    # predict
    feats_te   = get_windowed_feats(ecog_te, fs, window_len, overlap) #get features
    R_te       = create_R_matrix(feats_te, N_wind) #create r matrix
    pred = pipeline.predict(R_te[:,1:]) #get predictions

    # post-process
    pred = suppress_low_outliers(pred, threshold_multiplier=2)
    pred = gaussian_smoothing(pred, sigma=2.25)

    # upsample back to 1kHz
    x_old = np.arange(pred.shape[0]) * step
    x_new = np.arange(ecog_te.shape[0])
    interp = np.zeros((len(x_new), pred.shape[1]))
    for i in range(pred.shape[1]):
        cs = CubicSpline(x_old, pred[:,i], bc_type='natural')
        interp[:,i] = cs(x_new)

    predicted_dg[subj_idx,0] = interp


# save
scipy.io.savemat('predictions.mat',
                 {'predicted_dg': predicted_dg})
print("\n✅ Predictions saved to 'predictions.mat'")




=== Subject 1 ===


Extracting features: 100%|██████████| 2949/2949 [00:14<00:00, 208.50it/s]



=== Subject 2 ===


Extracting features: 100%|██████████| 2949/2949 [00:12<00:00, 230.96it/s]



=== Subject 3 ===


Extracting features: 100%|██████████| 2949/2949 [00:20<00:00, 142.07it/s]



✅ Predictions saved to 'predictions.mat'
