In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("drsaeedmohsen/ucihar-dataset")
path += '/UCI-HAR Dataset'
print("Path to dataset files:", path)

Path to dataset files: /home/tibless/.cache/kagglehub/datasets/drsaeedmohsen/ucihar-dataset/versions/1/UCI-HAR Dataset


In [2]:
import pandas as pd
import numpy as np
import jax.numpy as jnp
from jax import random, vmap, jit, jax

key = random.PRNGKey(0)

TRAIN_PATH = path + '/train/Inertial Signals/'
TEST_PATH = path + '/test/Inertial Signals/'
PREFIXS = [
    'body_acc_x_',
    'body_acc_y_',
    'body_acc_z_',
    'body_gyro_x_',
    'body_gyro_y_',
    'body_gyro_z_',
    'total_acc_x_',
    'total_acc_y_',
    'total_acc_z_',
]


X_train = []
for prefix in PREFIXS:
    X_train.append(pd.read_csv(TRAIN_PATH + prefix + 'train.txt', header=None, sep=r'\s+').to_numpy())

X_train = np.transpose(np.array(X_train), (1, 0, 2))
X_train = jnp.array(X_train)

X_test = []
for prefix in PREFIXS:
    X_test.append(pd.read_csv(TEST_PATH + prefix + 'test.txt', header=None, sep=r'\s+').to_numpy())
X_test = np.transpose(np.array(X_test), (1, 0, 2))
X_test = jnp.array(X_test)


y_train = jnp.array(pd.read_csv(path + '/train/y_train.txt', header=None).to_numpy().squeeze() - 1)
y_test = jnp.array(pd.read_csv(path + '/test/y_test.txt', header=None).to_numpy().squeeze() - 1)

# 将标签转换为 one-hot 编码
def one_hot(y: jnp.ndarray, num_class: int):
    res = jnp.zeros((y.shape[0], num_class))
    res = res.at[jnp.arange(y.shape[0]), y].set(1)
    return res

y_train = one_hot(y_train, 6)

print('X_train 形状:', X_train.shape)  # 应为 (7352, 9, 128)
print('y_train 形状:', y_train.shape)  # 应为 (7352, 6)
print('X_test  形状:', X_test.shape)
print('y_test  形状:', y_test.shape)

X_train 形状: (7352, 9, 128)
y_train 形状: (7352, 6)
X_test  形状: (2947, 9, 128)
y_test  形状: (2947,)


In [3]:
class OriginalVersion:
    @staticmethod
    def normal_cell(x, h0, 
                    w_hh, w_xh, b_h, 
                    w_hy, b_y):
        '''
        Input
        -----
        x: (S, B, I)
        h0: (B, H)
        q_hh: (H, H)
        w_xh: (I, H)
        b_h: (H)
        w_hy: (H, O)
        b_y: (H)

        Output
        ------
        res: (S, B, O)
        h: (S, B, H)
        '''
        steps, batch_size, input_dim = x.shape  # S, B, I
        _, hidden_dim = w_hh.shape  # H, H
        _, output_dim = w_hy.shape  # H, O

        res = jnp.zeros((steps, batch_size, output_dim))  # S, B, O
        h = jnp.zeros((steps, batch_size, hidden_dim))  # S, B, H
        h = h.at[-1].set(h0)
        for ix in range(steps):
            h = h.at[ix].set(
                jnp.tanh(h[ix - 1] @ w_hh + x[ix] @ w_xh + b_h)
            )
            res = res.at[ix].set(
                h[ix] @ w_hy + b_y
            )

        return res, h

    @staticmethod
    def lstm_cell(x, h0, c0,
                    Ws, Us, Bs):
        '''
        Input
        -----
        x: (S, B, I)
        h0: (B, H)
        c0: (B, H)
        Ws: 4 * (I, H)
        Us: 4 * (H, H)
        Bs: 4 * (H)

        Output
        ------
        res: (S, B, H)
        h: (S, B, H)
        c: (S, B, H)
        '''

        def sigmoid(x):
            return 1 / (1 + jnp.exp(-x))

        w_i, w_f, w_c, w_o = Ws  # (I, H)
        u_i, u_f, u_c, u_o = Us  # (H, H)
        b_i, b_f, b_c, b_o = Bs  # (H)

        steps, batch_size, input_dim = x.shape  # S, B, I
        _, hidden_dim = w_i.shape

        res = jnp.zeros((steps, batch_size, hidden_dim))  # S, B, H
        h = jnp.zeros((steps, batch_size, hidden_dim)) # S, B, H
        h = h.at[-1].set(h0)
        c = jnp.zeros((steps, batch_size, hidden_dim))  # S, B, H
        c = c.at[-1].set(c0)

        for ix in range(steps):
            I = sigmoid(x[ix] @ w_i + h[ix - 1] @ u_i + b_i)
            F = sigmoid(x[ix] @ w_f + h[ix - 1] @ u_f + b_f)
            C = jnp.tanh(x[ix] @ w_c + h[ix - 1] @ u_c + b_c)
            O = sigmoid(x[ix] @ w_o + h[ix - 1] @ u_o + b_o)

            c = c.at[ix].set(
                F*c[ix - 1] + I*C
            )
            h = h.at[ix].set(
                O*jnp.tanh(C)
            )
            res = res.at[ix].set(
                O
            )

        return res, h, c

    @staticmethod
    def gru_cell(x, h0, 
                    Ws, Us, Bs):
        '''
        Input
        -----
        x: (S, B, I)
        h0: (S, B, H)
        Ws: 3 * (I, H)
        Us: 3 * (H, H)
        Bs: 3 * (H)

        Output
        ------
        res: (S, B, H)
        h: (S, B, H)
        '''

        def sigmoid(x):
            return 1 / (1 + jnp.exp(-x))

        w_z, w_r, w_h = Ws  # (I, H)
        u_z, u_r, u_h = Us  # (H, H)
        b_z, b_r, b_h = Bs  # (H)

        steps, batch_size, input_dim = x.shape  # S, B, I
        _, hidden_dim = w_z.shape

        res = jnp.zeros((steps, batch_size, hidden_dim))  # S, B, H
        h = jnp.zeros((steps, batch_size, hidden_dim)) # S, B, H
        h = h.at[-1].set(h0)

        for ix in range(steps):
            R = sigmoid(x[ix] @ w_r + h[ix - 1] @ u_r + b_r)
            Z = sigmoid(x[ix] @ w_z + h[ix - 1] @ u_z + b_z)

            H = jnp.tanh(x[ix] @ w_h + (R * h[ix - 1]) @ u_h + b_h)

            h = h.at[ix].set(
                (1 - Z) * h[ix - 1] + Z * H
            )

        return h