In [1]:
import numpy as np
import scipy
# import scipy._lib.array_api_extra as xpx

In [45]:
def atleast_nd(x, ndim):
    x = np.array(x)  # 确保是 ndarray
    while x.ndim < ndim:
        x = x[np.newaxis, ...]  # 在最前面加一个维度
    return x

def xp_size(x):
    return len(x)

def xp_copy(x):
    return np.array(x, copy=True)

def companion(a):
    a = np.atleast_1d(a)
    n = a.shape[-1]

    if n < 2:
        raise ValueError("The length of `a` along the last axis must be at least 2.")

    if np.any(a[..., 0] == 0):
        raise ValueError("The first coefficient(s) of `a` (i.e. elements "
                         "of `a[..., 0]`) must not be zero.")

    first_row = -a[..., 1:] / (1.0 * a[..., 0:1])
    c = np.zeros(a.shape[:-1] + (n - 1, n - 1), dtype=first_row.dtype)
    c[..., 0, :] = first_row
    c[..., np.arange(1, n - 1), np.arange(0, n - 2)] = 1
    return c

def solve(A, b):
    n = xp_size(A)
    # 拷贝矩阵和向量，避免修改原始数据
    M = [row[:] for row in A]
    B = b[:]

    # 高斯消元
    for i in range(n):
        # 选主元（避免除零）
        if M[i][i] == 0:
            for j in range(i+1, n):
                if M[j][i] != 0:
                    M[i], M[j] = M[j], M[i]
                    B[i], B[j] = B[j], B[i]
                    break

        # 归一化主元行
        pivot = M[i][i]
        for k in range(i, n):
            M[i][k] /= pivot
        B[i] /= pivot

        # 消元
        for j in range(i+1, n):
            factor = M[j][i]
            for k in range(i, n):
                M[j][k] -= factor * M[i][k]
            B[j] -= factor * B[i]

    # 回代
    x = [0] * n
    for i in range(n-1, -1, -1):
        x[i] = B[i] - sum(M[i][j] * x[j] for j in range(i+1, n))

    return np.asarray(x)

def _sosfilt_float(sos, x, zi):
    n_signals, n_samples = x.shape
    n_sections = sos.shape[0]

    const_1 = 1.0

    for i in range(n_signals):
        zi_slice = zi[i, :, :]
        for n in range(n_samples):
            # 确保是拷贝
            x_cur = const_1 * x[i, n]

            for s in range(n_sections):
                x_new = sos[s, 0] * x_cur + zi_slice[s, 0]
                zi_slice[s, 0] = (sos[s, 1] * x_cur - sos[s, 4] * x_new
                                  + zi_slice[s, 1])
                zi_slice[s, 1] = sos[s, 2] * x_cur - sos[s, 5] * x_new
                x_cur = x_new

            x[i, n] = x_cur

def sosfilt(sos, x, axis=-1, zi=None):
    x = _validate_x(x)
    sos, n_sections = _validate_sos(sos)
    x_zi_shape = list(x.shape)
    x_zi_shape[axis] = 2
    x_zi_shape = tuple([n_sections] + x_zi_shape)
    inputs = [sos, x]
    if zi is not None:
        inputs.append(np.asarray(zi))
    dtype = np.result_type(*inputs)
    if zi is not None:
        zi = np.asarray(zi, dtype=dtype)
        zi = xp_copy(zi)
        if zi.shape != x_zi_shape:
            raise ValueError(
                f"Invalid zi shape. With axis={axis!r}, "
                f"an input with shape {x.shape!r}, "
                f"and an sos array with {n_sections} sections, zi must have "
                f"shape {x_zi_shape!r}, got {zi.shape!r}."
            )
        return_zi = True
    else:
        zi = np.zeros(x_zi_shape, dtype=dtype)
        return_zi = False
    axis = axis % x.ndim  # make positive
    x = np.moveaxis(x, axis, -1)
    zi = np.moveaxis(zi, (0, axis + 1), (-2, -1))
    x_shape, zi_shape = x.shape, zi.shape
    x = np.reshape(x, (-1, x.shape[-1]))
    x = np.array(x, dtype, order='C')  # make a copy, can modify in place
    zi = np.ascontiguousarray(np.reshape(zi, (-1, n_sections, 2)))
    sos = sos.astype(dtype, copy=False)
    # _sosfilt(sos, x, zi)
    _sosfilt_float(sos, x, zi)
    x = x.reshape(x_shape)
    x = np.moveaxis(x, -1, axis)
    if return_zi:
        zi = zi.reshape(zi_shape)
        zi = np.moveaxis(zi, (-2, -1), (0, axis + 1))
        out = (x, zi)
    else:
        out = x
    return out

def lfilter_zi(b, a):
    b = atleast_nd(np.asarray(b), ndim=1)
    if b.ndim != 1:
        raise ValueError("Numerator b must be 1-D.")
    a = atleast_nd(np.asarray(a), ndim=1)
    if a.ndim != 1:
        raise ValueError("Denominator a must be 1-D.")

    while a.shape[0] > 1 and a[0] == 0.0:
        a = a[1:]
    if xp_size(a) < 1:
        raise ValueError("There must be at least one nonzero `a` coefficient.")

    if a[0] != 1.0:
        # Normalize the coefficients so a[0] == 1.
        b = b / a[0]
        a = a / a[0]

    n = max(a.shape[0], b.shape[0])

    # Pad a or b with zeros so they are the same length.
    if a.shape[0] < n:
        a = np.concat((a, np.zeros(n - a.shape[0], dtype=a.dtype)))
    elif b.shape[0] < n:
        b = np.concat((b, np.zeros(n - b.shape[0], dtype=b.dtype)))

    dt = np.result_type(a, b)
    # IminusA = np.eye(n - 1) - linalg.companion(a).T
    IminusA = np.eye(n - 1) - companion(a).T
    IminusA = np.asarray(IminusA, dtype=dt)
    B = b[1:] - a[1:] * b[0]
    # Solve zi = A*zi + B
    # zi = xp.linalg.solve(IminusA, B)
    zi = solve(IminusA, B)
    return zi

def sosfilt_zi(sos):
    n_sections = sos.shape[0]
    zi = np.empty((n_sections, 2), dtype=sos.dtype)
    scale = 1.0
    for section in range(n_sections):
        b = sos[section, :3]
        a = sos[section, 3:]
        zi[section, ...] = scale * lfilter_zi(b, a)
        scale *= np.sum(b) / np.sum(a)

    return zi

def axis_reverse(a, axis=-1):
    return axis_slice(a, step=-1, axis=axis)

def axis_slice(a, start=None, stop=None, step=None, axis=-1):
    a_slice = [slice(None)] * a.ndim
    a_slice[axis] = slice(start, stop, step)
    b = a[tuple(a_slice)]
    return b

def odd_ext(x, n, axis=-1):
    if n < 1:
        return x
    if n > x.shape[axis] - 1:
        raise ValueError(("The extension length n (%d) is too big. " +
                         "It must not exceed x.shape[axis]-1, which is %d.")
                         % (n, x.shape[axis] - 1))
    left_end = axis_slice(x, start=0, stop=1, axis=axis)
    left_ext = axis_slice(x, start=n, stop=0, step=-1, axis=axis)
    right_end = axis_slice(x, start=-1, axis=axis)
    right_ext = axis_slice(x, start=-2, stop=-(n + 2), step=-1, axis=axis)
    ext = np.concatenate((2 * left_end - left_ext,
                          x,
                          2 * right_end - right_ext),
                         axis=axis)
    return ext

def _validate_pad(padtype, padlen, x, axis, ntaps):
    """Helper to validate padding for filtfilt"""
    if padtype not in ['even', 'odd', 'constant', None]:
        raise ValueError(f"Unknown value '{padtype}' given to padtype. "
                         "padtype must be 'even', 'odd', 'constant', or None.")

    if padtype is None:
        padlen = 0

    if padlen is None:
        # Original padding; preserved for backwards compatibility.
        edge = ntaps * 3
    else:
        edge = padlen

    # x's 'axis' dimension must be bigger than edge.
    if x.shape[axis] <= edge:
        raise ValueError(
            f"The length of the input vector x must be greater than padlen, "
            f"which is {edge}."
        )

    if padtype is not None and edge > 0:
        # Make an extension of length `edge` at each
        # end of the input array.
        if padtype == 'even':
            ext = even_ext(x, edge, axis=axis)
        elif padtype == 'odd':
            ext = odd_ext(x, edge, axis=axis)
        else:
            ext = const_ext(x, edge, axis=axis)
    else:
        ext = x
    return edge, ext

def _validate_x(x):
    x = np.asarray(x)
    if x.ndim == 0:
        raise ValueError('x must be at least 1-D')
    return x

def _validate_sos(sos):
    sos = np.asarray(sos)
    sos = atleast_nd(sos, ndim=2)
    if sos.ndim != 2:
        raise ValueError('sos array must be 2D')
    n_sections, m = sos.shape
    if m != 6:
        raise ValueError('sos array must be shape (n_sections, 6)')
    if not np.all(sos[:, 3] == 1):
        raise ValueError('sos[:, 3] should be all ones')
    return sos, n_sections

def sosfiltfilt(sos, x, axis=-1, padtype='odd', padlen=None):
    sos, n_sections = _validate_sos(sos)
    x = _validate_x(x)

    ntaps = 2 * n_sections + 1
    ntaps -= min((sos[:, 2] == 0).sum(), (sos[:, 5] == 0).sum())
    edge, ext = _validate_pad(padtype, padlen, x, axis,
                              ntaps=ntaps)
    # These steps follow the same form as filtfilt with modifications
    zi = sosfilt_zi(sos)  # shape (n_sections, 2) --> (n_sections, ..., 2, ...)
    zi_shape = [1] * x.ndim
    zi_shape[axis] = 2
    zi = zi.reshape([n_sections] + zi_shape)
    x_0 = axis_slice(ext, stop=1, axis=axis)
    (y, zf) = sosfilt(sos, ext, axis=axis, zi=zi * x_0)
    y_0 = axis_slice(y, start=-1, axis=axis)
    (y, zf) = sosfilt(sos, axis_reverse(y, axis=axis), axis=axis, zi=zi * y_0)
    y = axis_reverse(y, axis=axis)
    if edge > 0:
        y = axis_slice(y, start=edge, stop=-edge, axis=axis)
    return y

In [46]:
sos = np.asarray([[ 2.13138727e-04, -4.26277454e-04,  2.13138727e-04,
         1.00000000e+00,  1.59908181e+00,  7.57867157e-01],
       [ 1.00000000e+00, -2.00000000e+00,  1.00000000e+00,
         1.00000000e+00,  1.71742882e+00,  8.10236696e-01],
       [ 1.00000000e+00,  2.00000000e+00,  1.00000000e+00,
         1.00000000e+00,  1.63790142e+00,  8.78904341e-01],
       [ 1.00000000e+00,  2.00000000e+00,  1.00000000e+00,
         1.00000000e+00,  1.86376749e+00,  9.32707241e-01]])
signal = np.asarray([0.6946206 , 0.72065928, 0.71175556, 0.3165333 , 0.56357013,
       0.57616846, 0.3945364 , 0.95319502, 0.00880446, 0.7289753 ,
       0.0959869 , 0.87426247, 0.28919631, 0.20801535, 0.66681132,
       0.45294633, 0.31652167, 0.41018225, 0.25336426, 0.55678081,
       0.07227884, 0.13323194, 0.24559928, 0.99696976, 0.47162671,
       0.93892957, 0.03232417, 0.15617559, 0.00110923, 0.07243705])
sosfiltfilt(sos, signal, padlen=0)

[slice(None, 1, None)]
[slice(-1, None, None)]
[slice(None, None, -1)]
[slice(None, None, -1)]


array([ 7.92542007e-02, -9.23843838e-02,  8.98606518e-02, -7.20492234e-02,
        4.24263512e-02, -6.86025347e-03, -2.76780827e-02,  5.46702347e-02,
       -6.94874068e-02,  7.03482645e-02, -5.85610461e-02,  3.80059696e-02,
       -1.40194717e-02, -7.99972085e-03,  2.38778159e-02, -3.15095672e-02,
        3.10672763e-02, -2.45813143e-02,  1.50823744e-02, -5.60800973e-03,
       -1.61551077e-03,  5.59630773e-03, -6.53942533e-03,  5.44381662e-03,
       -3.53962822e-03,  1.79734564e-03, -6.77762249e-04,  1.68773140e-04,
       -2.06456693e-05,  4.23516474e-22])

In [32]:
scipy.signal.sosfiltfilt(sos, signal, padlen=0)

array([ 7.92542007e-02, -9.23843838e-02,  8.98606518e-02, -7.20492234e-02,
        4.24263512e-02, -6.86025347e-03, -2.76780827e-02,  5.46702347e-02,
       -6.94874068e-02,  7.03482645e-02, -5.85610461e-02,  3.80059696e-02,
       -1.40194717e-02, -7.99972085e-03,  2.38778159e-02, -3.15095672e-02,
        3.10672763e-02, -2.45813143e-02,  1.50823744e-02, -5.60800973e-03,
       -1.61551077e-03,  5.59630773e-03, -6.53942533e-03,  5.44381662e-03,
       -3.53962822e-03,  1.79734564e-03, -6.77762249e-04,  1.68773140e-04,
       -2.06456693e-05,  0.00000000e+00])

In [19]:
min((sos[:, 2] == 0).sum(), (sos[:, 5] == 0).sum())

np.int64(0)

In [24]:
(sos[:,2] == 0).sum()

np.int64(0)