In [None]:
import numpy as np
import matplotlib.pyplot as plt
import librosa
from brian2 import *
from brian2hears import *

In [None]:
def shape_padding(matrix, target_length=1000):
    matrix = np.asarray(matrix)
    
    n_ch, T = matrix.shape
    if T < target_length:
        pad = np.zeros((n_ch, target_length - T), dtype=matrix.dtype)
        return np.concatenate((matrix, pad), axis=1)
    elif T > target_length:
        return matrix[:, :target_length]
    else:
        return matrix

def load_audio(file_path, sr=8000):
    y, sr = librosa.load(file_path, sr=sr)
    return y, sr

def cochleagram_from_audio(
    y, 
    sr, 
    n_channels=8,
    lowpass_freq=10):

    sound = Sound(y * Hz, sr * Hz)

    cf = erbspace(100*Hz, 8000*Hz, n_channels)
    gammatone = Gammatone(
        sound,
        cf
    )
    envelope = FunctionFilterbank(
        gammatone,
        lambda x: np.maximum(x, 0)**(1.0/3.0)
    )
    lowpassed = LowPass(envelope, lowpass_freq)
    
    cochleagram = lowpassed.process()
    cochleagram = np.array(cochleagram)
    cochleagram = cochleagram / np.max(np.abs(cochleagram))  # Normalize the cochleagram
    cochleagram = cochleagram.T  # Transpose to have channels as rows

    return shape_padding(cochleagram)

def lif_encoding(
    cochleagram,
    dt=1*ms,
    duration=1000*ms,
    tau=10*ms,
    threshold=1.1,
    refractory_period=5*ms,
):
    
    T = int(duration / dt)
    channels_num, _ = cochleagram.shape
    spike_matrix = np.zeros((channels_num, T))
    
    decay = np.exp(-dt / tau)
    refractory_time = refractory_period / dt
    
    for channel in range(channels_num):
        v = np.zeros(T + 1)
        refractory = False
        refractory_counter = 0
        
        for t in range(T):
            v[t + 1] = decay * v[t] + cochleagram[channel, t]
            if v[t + 1] >= threshold:
                v[t + 1] = 0
                if refractory:
                    if refractory_counter < refractory_time:
                        refractory_counter += 1
                    else:
                        refractory = False
                        refractory_counter = 0
                else:
                    spike_matrix[channel, t] = 1
                    refractory = True

    return spike_matrix

In [None]:
import os

iter = 0

# encode yes files
for filepath in os.listdir("/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/speech_command_dataset_v2/dataset/yes"):
    if filepath.endswith(".wav"):
        if iter >= 100:
            break
        file_path = os.path.join("/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/speech_command_dataset_v2/dataset/yes", filepath)
        y, sr = load_audio(file_path, sr=1000)
        cochleagram = cochleagram_from_audio(y, sr)
        spike_matrix = lif_encoding(cochleagram, refractory_period=1*ms)
        print(f"Processed {filepath} into spike matrix with shape {spike_matrix.shape}")

        if iter in range(0, 80):
            # Save the spike matrix or process it further
            np.savez_compressed(f"/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/spike_data/sr1000/train/yes_{filepath[:-4]}.npz", spike_matrix=spike_matrix, label=1)
        elif iter in range(80, 100):
            # Save the spike matrix or process it further
            np.savez_compressed(f"/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/spike_data/sr1000/test/yes_{filepath[:-4]}.npz", spike_matrix=spike_matrix, label=1)

        iter += 1


In [None]:
iter = 0

# encode no files
for filepath in os.listdir("/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/speech_command_dataset_v2/dataset/no"):
    if filepath.endswith(".wav"):
        if iter >= 100:
            break
        file_path = os.path.join("/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/speech_command_dataset_v2/dataset/no", filepath)
        y, sr = load_audio(file_path, sr=1000)
        cochleagram = cochleagram_from_audio(y, sr)
        spike_matrix = lif_encoding(cochleagram, refractory_period=1*ms)
        print(f"Processed {filepath} into spike matrix with shape {spike_matrix.shape}")
        
        if iter in range(0, 80):
            # Save the spike matrix or process it further
            np.savez_compressed(f"/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/spike_data/sr1000/train/no_{filepath[:-4]}.npz", spike_matrix=spike_matrix, label=0)
        elif iter in range(80, 100):
            # Save the spike matrix or process it further
            np.savez_compressed(f"/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/spike_data/sr1000/test/no_{filepath[:-4]}.npz", spike_matrix=spike_matrix, label=0)

        iter += 1

In [None]:
data = np.load("sr1000/train/no_0f7205ef_nohash_0.npz")
spike_matrix = data['spike_matrix']
label = data['label'].item()

print(f"Spike matrix shape: {spike_matrix.shape}, Label: {label}")

In [9]:
# A different approach 
# === Stable audio→cochlea→LIF spikes with dataset-wide norm + target-rate tuning ===
import numpy as np
import librosa, librosa.display
import matplotlib.pyplot as plt
import os
import random

# --------------------------
# 0) Utilities
# --------------------------
def vad_trim(y, sr, top_db=30, pad_s=0.05):
    """Energy VAD (librosa) + pad both sides."""
    intervals = librosa.effects.split(y, top_db=top_db)
    if len(intervals) == 0:
        return y  # nothing detected
    start = max(0, intervals[0,0] - int(pad_s*sr))
    end   = min(len(y), intervals[-1,1] + int(pad_s*sr))
    return y[start:end]

def compute_cochleagram(y, sr, n_ch=32, win_ms=25, hop_ms=10, fmin=80, fmax=None):
    """Cochleagram placeholder using Mel filterbank (stable & fast).
       Drop-in replace with your gammatone if muốn.
       Returns (n_ch, T) float32 in [0,1] after per-file minmax (mild) for visualization.
    """
    win = int(sr*win_ms/1000)
    hop = int(sr*hop_ms/1000)
    S = librosa.feature.melspectrogram(
        y=y, sr=sr, n_fft=2**int(np.ceil(np.log2(win))),
        hop_length=hop, win_length=win, window="hann",
        n_mels=n_ch, fmin=fmin, fmax=fmax, power=1.0, center=True, htk=True
    )  # shape (n_ch, T)
    # amplitude compression (√) + small per-file minmax ONLY for display
    S = np.sqrt(S + 1e-12)
    return S.astype(np.float32)  # (n_ch, T)

# --------------------------
# 1) Dataset-wide normalisation (fit on TRAIN only)
# --------------------------
def fit_cochlea_norm(train_files, sr=16000, **coch_kwargs):
    """Compute per-channel mean/std across train set (after VAD)."""
    ch = None
    m_sum, v_sum, n_total = None, None, 0
    for p in train_files:
        y, _sr = librosa.load(p, sr=sr, mono=True)
        y = vad_trim(y, sr)
        C = compute_cochleagram(y, sr, **coch_kwargs)  # (n_ch, T)
        if ch is None:
            ch = C.shape[0]
            m_sum = np.zeros(ch, dtype=np.float64)
            v_sum = np.zeros(ch, dtype=np.float64)
        m_sum += C.mean(axis=1)
        v_sum += C.var(axis=1)
        n_total += 1
    mean = (m_sum / max(n_total,1)).astype(np.float32)
    std  = np.sqrt(v_sum / max(n_total,1) + 1e-8).astype(np.float32)
    return {"mean": mean, "std": std, "sr": sr, "coch_kwargs": coch_kwargs}

def apply_norm(C, stats, per_channel=True):
    """z-score using dataset stats; keep nonnegative by linear scaling."""
    if per_channel:
        Cn = (C - stats["mean"][:, None]) / (stats["std"][:, None] + 1e-8)
    else:
        mu, sd = C.mean(), C.std()
        Cn = (C - mu) / (sd + 1e-8)
    # squash to [0,1] with robust sigmoid-ish mapping
    Cn = 1/(1+np.exp(-Cn))  # in (0,1)
    return Cn.astype(np.float32)

# --------------------------
# 2) LIF with refractory + simple adaptation
# --------------------------
def lif_encode(C01, dt, tau_m=0.02, v0=0.15, gain=1.0,
               refrac_s=0.003, a_jump=0.12, tau_a=0.08):
    """C01: (n_ch, T) in [0,1]. Returns spikes (n_ch, T) {0,1}."""
    n_ch, T = C01.shape
    dt = float(dt)
    v = np.zeros(n_ch, dtype=np.float32)
    a = np.zeros(n_ch, dtype=np.float32)
    refrac = np.zeros(n_ch, dtype=np.float32)
    spikes = np.zeros((n_ch, T), dtype=np.uint8)

    # precompute constants
    decay_m = np.exp(-dt/tau_m)
    decay_a = np.exp(-dt/tau_a)
    refrac_steps = int(round(refrac_s/dt))

    for t in range(T):
        drive = gain * C01[:, t]  # [0, gain]
        # integrate only if not refractory
        active = (refrac <= 0.5)
        v[active] = v[active]*decay_m + (1-decay_m)*drive[active]
        a = a*decay_a
        thr = v0 * (1.0 + a)  # adaptive threshold

        # spike
        s = (v > thr) & active
        spikes[s, t] = 1
        # reset + adaptation
        v[s] = 0.0
        a[s] += a_jump
        refrac[s] = refrac_steps

        # update refrac timers
        refrac[refrac>0] -= 1.0
    return spikes

# --------------------------
# 3) Target firing-rate tuning loop (per file) – keeps params within bounds
# --------------------------
def encode_with_rate_target(Cn, dt, target_hz=8.0, tol=1.0, max_iter=6,
                            init_pct=0.80, init_gain=1.0,
                            pct_bounds=(0.60,0.95), gain_bounds=(0.6, 3.5),
                            **lif_kw):
    """Scale cochlea then LIF so that mean firing rate ≈ target_hz."""
    pct, gain = init_pct, init_gain
    # scale by percentile per FILE to reduce outliers (mild)
    flat = Cn.ravel()
    for _ in range(max_iter):
        scale = np.percentile(flat, pct*100)
        C01 = np.clip(Cn/ (scale + 1e-8), 0, 1)
        S = lif_encode(C01, dt, gain=gain, **lif_kw)
        rate = S.sum() / (S.shape[0]*S.shape[1]*dt)  # Hz

        if rate < target_hz - tol:
            pct = max(pct_bounds[0], pct - 0.03)
            gain = min(gain_bounds[1], gain * 1.15)
        elif rate > target_hz + tol:
            pct = min(pct_bounds[1], pct + 0.03)
            gain = max(gain_bounds[0], gain * 0.85)
        else:
            return S, {"pct":pct, "gain":gain, "rate":rate}
    # return last
    return S, {"pct":pct, "gain":gain, "rate":rate}

# --------------------------
# 4) End-to-end: fit stats, then encode 1 file and plot
# --------------------------
def encode_file(path, stats,
                target_hz=8.0, tol=1.0,
                win_ms=25, hop_ms=10,
                init_pct=0.95, init_gain=0.4,
                v0=0.15, tau_m=0.02, refrac_s=0.003, a_jump=0.12, tau_a=0.08):
    sr = stats["sr"]
    y, _ = librosa.load(path, sr=sr, mono=True)
    y = vad_trim(y, sr)
    C = compute_cochleagram(y, sr, **stats["coch_kwargs"])  # (n_ch, T)
    Cn = apply_norm(C, stats, per_channel=True)
    dt = hop_ms/1000.0
    S, info = encode_with_rate_target(
        Cn, dt, target_hz=target_hz, tol=tol,
        init_pct=init_pct, init_gain=init_gain,
        tau_m=tau_m, v0=v0, refrac_s=refrac_s, a_jump=a_jump, tau_a=tau_a
    )
    return C, Cn, S, info, dt

def pad_or_crop_spike(S, T_target=100, align="left"):
    """Pad or crop spike matrix S to T_target length."""
    n_ch, T = S.shape
    if T == T_target:
        return S
    out = np.zeros((n_ch, T_target), dtype=S.dtype)
    if T < T_target:
        if align == "left":
            out[:, :T] = S
        elif align == "right":
            out[:, -T:] = S
        else:
            start = (T_target - T) // 2
            out[:, start:start+T] = S
    else:
        if align == "left":
            out = S[:, :T_target]
        elif align == "right":
            out = S[:, -T_target:]
        else:
            start = (T - T_target) // 2
            out = S[:, start:start+T_target]
    return out

In [16]:
trainfiles = []
for path in os.listdir("/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/speech_command_dataset_v2/dataset/yes"):
    trainfiles.append(os.path.join("/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/speech_command_dataset_v2/dataset/yes", path))
for path in os.listdir("/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/speech_command_dataset_v2/dataset/no"):
    trainfiles.append(os.path.join("/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/speech_command_dataset_v2/dataset/no", path))

stats = fit_cochlea_norm(trainfiles, sr=16000, n_ch=64, win_ms=25, hop_ms=10, fmin=80, fmax=7600)

In [17]:
for filepath in random.sample(os.listdir("/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/speech_command_dataset_v2/dataset/yes"), 200):
    if filepath.endswith(".wav"):
        file_path = os.path.join("/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/speech_command_dataset_v2/dataset/yes", filepath)
        C, Cn, S, info, dt = encode_file(file_path, stats, target_hz=8.0, tol=1.0,
                                 init_gain=0.6, init_pct=0.95,
                                  v0=0.29, tau_m=0.015, refrac_s=0.003, a_jump=0.42, tau_a=0.07)
        spike_matrix = pad_or_crop_spike(S, T_target=100, align="left")
        # print(f"Processed {filepath} into spike matrix with shape {spike_matrix.shape}")
        
        np.savez_compressed(f"/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/spike_data/ch64/yes_{filepath[:-4]}.npz", spike_matrix=spike_matrix, label=1)

In [18]:
for filepath in random.sample(os.listdir("/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/speech_command_dataset_v2/dataset/no"), 200):
    if filepath.endswith(".wav"):
        file_path = os.path.join("/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/speech_command_dataset_v2/dataset/no", filepath)
        C, Cn, S, info, dt = encode_file(file_path, stats, target_hz=8.0, tol=1.0,
                                 init_gain=0.6, init_pct=0.95,
                                  v0=0.29, tau_m=0.015, refrac_s=0.003, a_jump=0.42, tau_a=0.07)
        spike_matrix = pad_or_crop_spike(S, T_target=100, align="left")
        # print(f"Processed {filepath} into spike matrix with shape {spike_matrix.shape}")

        np.savez_compressed(f"/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/spike_data/ch64/no_{filepath[:-4]}.npz", spike_matrix=spike_matrix, label=0)
