In [1]:
import numpy as np
from jaxtyping import Float, Int, jaxtyped
from typeguard import typechecked, check_type
from numpy.lib.stride_tricks import sliding_window_view

NDArray = np.ndarray
# create a 100 x 24 random array
np.random.seed(0)
xs = np.random.rand(100, 24)

size = 10
X_slide = sliding_window_view(xs, (size, 24))
display(X_slide.shape)

(91, 1, 10, 24)

In [2]:
# somehow a shape of (91, 1, 10, 24) appears
# I expected a shape of (91, 10, 24)

X_slide_r = X_slide.reshape(-1, size, 24)
display(X_slide_r.shape)

(91, 10, 24)

In [3]:
@jaxtyped(typechecker=typechecked)
def fill_slide_window(
        xs: Float[NDArray, "n 24"] | Float[NDArray, "{size} 24"],
        x: Float[NDArray, "1 24"],
        size: int) -> Float[NDArray, "n+1 24"] | Float[NDArray, "{size} 24"]:
    """
    Fill a sliding window with a new value.

    @param xs: The current sliding window. The shape is (n, 24), which n is less than or equal `size`.
    @param x: The new value to add to the sliding window. The shape is (1, 24).

    @return: The updated sliding window. The shape is (n+1, 24) if n < `size`, otherwise (`size`, 24).
    """
    if xs.shape[0] < size:
        return np.vstack([xs, x])
    elif xs.shape[0] == size:
        return np.vstack([xs[1:], x])
    else:
        raise ValueError(
            f"Input xs has invalid dimension at 0 which is {xs.shape[0]}, the expected dimension is less or equal `{size}`"
        )

In [4]:
@jaxtyped(typechecker=typechecked)
def naive_slide_window(xs: Float[NDArray, "n 24"],
                       size: int) -> Float[NDArray, "n-{size}+1 {size} 24"]:
    """
    Generate slide window from a 2D array.
    
    I would say it might be less efficient than the numpy function `sliding_window_view`,
    but a naive implementation to show how to use `fill_slide_window` function.

    @param xs: The input 2D array. The shape is (n, 24).
    @param size: The size of the slide window.
    @return The slide window array. The shape is (n-size+1, size, 24).
    """
    assert size <= xs.shape[
        0], f"sz should be less than or equal to {xs.shape[0]}"
    res = np.empty((xs.shape[0] - size + 1, size, 24))
    st = np.empty((0, 24))
    for i, row in zip(range(xs.shape[0]), xs):
        st = fill_slide_window(st, np.array([row]), size)
        if i >= size - 1:
            assert st.shape == (size, 24)
            res[i - size + 1] = st
    return res

In [5]:
X_naive_slide = naive_slide_window(xs, size)

In [6]:
display(X_naive_slide.shape)

(91, 10, 24)

In [7]:
assert (X_naive_slide==X_slide_r).all(), "The result is not the same."