In [1]:
import math
import time

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
import jax.scipy as jsp
import matplotlib
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax
import librosa
import numpy as np
from sklearn import metrics
import pickle
import numpy
from jax import jit

matplotlib.rcParams.update({"font.size": 30})

In [2]:
#读取数据集
with open('/home/ni/step1-提取数据特征/整合-按条提取语音_Session2_pt_特征/data_Session1_w2v2.pkl', 'rb') as f:
    wav2vec_last1 = pickle.load(f)
    print('wav2vec_last1',wav2vec_last1.shape)

with open('/home/ni/step1-提取数据特征/整合-按条提取语音_Session2_pt_特征/data_Session1_label.pkl', 'rb') as f:
    label_last1 = pickle.load(f)
    print('label_last1',label_last1.shape)

with open('/home/ni/step1-提取数据特征/整合-按条提取语音_Session2_pt_特征/data_Session2_w2v2.pkl', 'rb') as f:
    wav2vec_last2 = pickle.load(f)
    print('wav2vec_last2',wav2vec_last2.shape)

with open('/home/ni/step1-提取数据特征/整合-按条提取语音_Session2_pt_特征/data_Session2_label.pkl', 'rb') as f:
    label_last2 = pickle.load(f)
    print('label_last2',label_last2.shape)

with open('/home/ni/step1-提取数据特征/整合-按条提取语音_Session2_pt_特征/data_Session3_w2v2.pkl', 'rb') as f:
    wav2vec_last3 = pickle.load(f)
    print('wav2vec_last3',wav2vec_last3.shape)

with open('/home/ni/step1-提取数据特征/整合-按条提取语音_Session2_pt_特征/data_Session3_label.pkl', 'rb') as f:
    label_last3 = pickle.load(f)
    print('label_last3',label_last3.shape)

with open('/home/ni/step1-提取数据特征/整合-按条提取语音_Session2_pt_特征/data_Session4_w2v2.pkl', 'rb') as f:
    wav2vec_last4 = pickle.load(f)
    print('wav2vec_last4',wav2vec_last4.shape)

with open('/home/ni/step1-提取数据特征/整合-按条提取语音_Session2_pt_特征/data_Session4_label.pkl', 'rb') as f:
    label_last4 = pickle.load(f)
    print('label_last4',label_last4.shape)

with open('/home/ni/step1-提取数据特征/整合-按条提取语音_Session2_pt_特征/data_Session5_w2v2.pkl', 'rb') as f:
    wav2vec_last5 = pickle.load(f)
    print('wav2vec_last5',wav2vec_last5.shape)

with open('/home/ni/step1-提取数据特征/整合-按条提取语音_Session2_pt_特征/data_Session5_label.pkl', 'rb') as f:
    label_last5 = pickle.load(f)
    print('label_last5',label_last5.shape)

wav2vec_last1 (1085, 256, 768)
label_last1 (1085,)
wav2vec_last2 (1023, 256, 768)
label_last2 (1023,)
wav2vec_last3 (1151, 256, 768)
label_last3 (1151,)
wav2vec_last4 (1031, 256, 768)
label_last4 (1031,)
wav2vec_last5 (1241, 256, 768)
label_last5 (1241,)


In [3]:
wav2vec_last = np.concatenate((wav2vec_last1, wav2vec_last3, wav2vec_last4, wav2vec_last5),axis=0)
label_last = np.concatenate((label_last1,label_last3,label_last4,label_last5))
print(wav2vec_last.shape,label_last.shape)

(4508, 256, 768) (4508,)


In [4]:
class Func(eqx.Module):
    data_size: int
    hidden_size: int
    hidden_hidden_channels: int
    num_hidden_layers: int
    linear_in: eqx.nn.Linear
    linear_a: eqx.nn.Linear
    linear_b: eqx.nn.Linear
    linear_c: eqx.nn.Linear
    linear_out: eqx.nn.Linear
    dropout: eqx.nn.Dropout
    
    def __init__(self, data_size, hidden_size, hidden_hidden_channels, num_hidden_layers, dropout_rate, *, key, **kwargs):
        super().__init__(**kwargs)
        ikey, akey, bkey, ckey, okey = jrandom.split(key, 5)
        self.data_size = data_size
        self.hidden_size = hidden_size
        self.hidden_hidden_channels = hidden_hidden_channels
        self.num_hidden_layers = num_hidden_layers
        self.linear_in = eqx.nn.Linear(hidden_size, hidden_hidden_channels, key=ikey)
        self.linear_a = eqx.nn.Linear(hidden_hidden_channels, hidden_hidden_channels, key=akey)
        self.linear_b = eqx.nn.Linear(hidden_hidden_channels, hidden_hidden_channels, key=bkey)
        self.linear_c = eqx.nn.Linear(hidden_hidden_channels, hidden_hidden_channels, key=ckey)
        self.linear_out = eqx.nn.Linear(hidden_hidden_channels, hidden_size * data_size, key=okey)
        self.dropout = eqx.nn.Dropout(dropout_rate)
        

    def __call__(self, t, y, training, args, subkey):
        y = self.linear_in(y)
        y = jnn.relu(y)
        y = self.dropout(y, inference=not training, key=subkey)
        y = self.linear_a(y)
        y = jnn.relu(y)
        y = self.dropout(y, inference=not training, key=subkey)
        y = self.linear_b(y)
        y = jnn.relu(y)
        y = self.dropout(y, inference=not training, key=subkey)
        y = self.linear_c(y)
        y = jnn.relu(y)
        y = self.dropout(y, inference=not training, key=subkey)
        y = self.linear_out(y).reshape(self.hidden_size, self.data_size)
        y = jnn.tanh(y)  
        return y

In [5]:
# 定义函数来对每一列进行累加平均的操作
def cumulative_average(arr):
    cumulative_sum = jnp.cumsum(arr, axis=0)
    divisor = jnp.arange(1, arr.shape[0] + 1).reshape((-1, 1))
    return cumulative_sum / divisor

# 将函数编译为JIT加速版本
cumulative_average_jit = jit(cumulative_average)

In [6]:
class NeuralCDE(eqx.Module):
    Conv: eqx.nn.Conv
    initial: eqx.nn.MLP
    func: Func
    linear: eqx.nn.Linear

    def __init__(self, data_size, hidden_size, width_size, depth, hidden_hidden_channels, num_hidden_layers, dropout_rate, *, key, **kwargs):
        super().__init__(**kwargs)
        ikey, fkey, lkey, ckey = jrandom.split(key, 4)
        self.Conv = eqx.nn.ConvTranspose(1, data_size, 5, 1, key=ckey)
        self.initial = eqx.nn.MLP(5, hidden_size, width_size, depth, key=ikey)
        self.func = Func(5, hidden_size, hidden_hidden_channels, num_hidden_layers, dropout_rate, key=fkey)
        self.linear = eqx.nn.Linear(hidden_size, 4, key=lkey)

    def __call__(self, ts, coeffs, training, subkey, evolving_out=False):
        # Each sample of data consists of some timestamps `ts`, and some `coeffs`
        # parameterising a control path. These are used to produce a continuous-time
        # input path `control`.

        #先将数据流降维再放入模型中训练
        Lengh = len(coeffs)
        coeffs_pad = []
        for i in range(Lengh):
            coeffs_last = coeffs[i].T
            coeffs_right = self.Conv(coeffs_last)
            coeffs_i = coeffs_right.T
            yn_array = cumulative_average_jit(coeffs_i)
            coeffs_pad.append(yn_array)

        ##########
        control = diffrax.CubicInterpolation(ts, coeffs_pad)
        
        term = diffrax.ControlTerm(lambda t, y, args: self.func(t, y, training, args, subkey), control).to_ode()
        solver = diffrax.Tsit5()
        dt0 = None
        y0 = self.initial(control.evaluate(ts[0]))
        if evolving_out:
            saveat = diffrax.SaveAt(ts=ts)
        else:
            saveat = diffrax.SaveAt(t1=True)
        solution = diffrax.diffeqsolve(
            term,
            solver,
            ts[0],
            ts[-1],
            dt0,
            y0,
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=saveat,
        )
        if evolving_out:
            prediction = jax.vmap(lambda y: jnn.sigmoid(self.linear(y))[0])(solution.ys)
        else:
            (prediction,) = jax.vmap(lambda y:self.linear(solution.ys[-1]))(solution.ys)
            pred_mean=prediction.mean(axis=0)  
            pred_var=prediction.var(axis=0)  
            pred_normalized=(prediction-pred_mean)/jnp.sqrt(pred_var+1e-5)     
            prediction_last = jnn.softmax(pred_normalized)
        return prediction_last

In [7]:
def get_data(dataset_size, *, key):
    ts = jnp.broadcast_to(jnp.linspace(0,255, 256), (dataset_size, 256))
    ys = jnp.concatenate([ts[:, :, None], wav2vec_last], axis=-1)
    coeffs = jax.vmap(diffrax.backward_hermite_coefficients)(ts, ys)
    labels = label_last
    _, _, data_size = ys.shape
    return ts, coeffs, labels, data_size

In [8]:
def get_test_data(dataset_test_size, *, key):
    ts = jnp.broadcast_to(jnp.linspace(0,255, 256), (dataset_test_size, 256))
    ys = jnp.concatenate([ts[:, :, None], wav2vec_last2], axis=-1)
    coeffs = jax.vmap(diffrax.backward_hermite_coefficients)(ts, ys)
    labels = label_last2
    _, _, data_size = ys.shape
    return ts, coeffs, labels, data_size

In [9]:
def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jrandom.permutation(key, indices)
        (key,) = jrandom.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

In [10]:
    @eqx.filter_jit
    class CrossEntropyLoss():

        def __init__(self, weight=None, size_average=True):

            self.weight = weight
            self.size_average = size_average


        def __call__(self, input, target):
            batch_loss = 0.
            for i in range(input.shape[0]):

                numerator = jnp.exp(input[i, target[i]])     # 分子
                denominator = jnp.sum(jnp.exp(input[i, :]))   # 分母

                # 计算单个损失
                loss = -jnp.log(numerator / denominator)
                if self.weight:
                    loss = self.weight[target[i]] * loss
            #    print("单个损失： ",loss)

                # 损失累加
                batch_loss += loss

            # 整个 batch 的总损失是否要求平均
            if self.size_average == True:
                batch_loss /= input.shape[0]

            return batch_loss

In [11]:
def main(
    dataset_size=4508,
    dataset_test_size=1023,
    batch_size=32,
    lr=0.001,
    hidden_hidden_channels=40,
    num_hidden_layers=4,
    steps=2085,
    hidden_size=220,
    width_size=128,
    depth=1,
    seed=3234,
    dropout_rate=0.3,
):
    
    key = jrandom.PRNGKey(seed)
    train_data_key, test_data_key, model_key, loader_key = jrandom.split(key, 4)

    ts, coeffs, labels, data_size = get_data(
        dataset_size, key=train_data_key
    )

    model = NeuralCDE(data_size, hidden_size, width_size, depth, hidden_hidden_channels, num_hidden_layers, dropout_rate, key=model_key)

    # Training loop like normal.

    @eqx.filter_jit
    def accuracy(total_size, pred, label_i):
        total_acc = 0
        total_num = total_size
        predicted_class = jnp.argmax(pred, axis=1)
        total_acc += jnp.sum(predicted_class == label_i)
        return total_acc / total_num

 
    @eqx.filter_jit
    def loss(model, ti, label_i, coeff_i, subkey):
        training = True
        pred = jax.vmap(model, in_axes=(0, 0, None, None))(ti, coeff_i, training, subkey)
        criterion = CrossEntropyLoss()
        bxe = criterion(pred, label_i)
        y_pred = jnp.array(pred)
        y_true = jnp.array(label_i)
        acc = accuracy(batch_size, y_pred, y_true)
        return bxe, acc

    grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)


    @eqx.filter_jit
    def test_loss(model, ti, label_i, coeff_i, subkey):
        training = False
        pred = jax.vmap(model, in_axes=(0, 0, None, None))(ti, coeff_i, training, subkey)
        criterion = CrossEntropyLoss()
        bxe = criterion(pred, label_i)
        y_pred = jnp.array(pred)
        y_true = jnp.array(label_i)
        acc = accuracy(dataset_test_size, y_pred, y_true)
        return bxe, acc



    @eqx.filter_jit
    def make_step(model, data_i, opt_state, subkey):
        ti, label_i, *coeff_i = data_i
        (bxe, acc), grads = grad_loss(model, ti, label_i, coeff_i, subkey)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return bxe, acc, model, opt_state

    optim = optax.adam(lr)
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
    for step, data_i in zip(
        range(steps), dataloader((ts, labels) + coeffs, batch_size, key=loader_key)
    ):
        start = time.time()
        key, subkey = jax.random.split(key)
        bxe, acc, model, opt_state = make_step(model, data_i, opt_state, subkey)
        end = time.time()
        print(
            f"Step: {step}, Loss: {bxe}, Accuracy: {acc}, Computation time: "
            f"{end - start}"
        )
        if step == 139:
            ts_test, coeffs_test, labels_test, _ = get_test_data(dataset_test_size, key=test_data_key)
            bxe_test, acc_test = test_loss(model, ts_test, labels_test, coeffs_test, test_data_key)
            print('########################')
            print(f"Test loss: {bxe_test}, Test Accuracy_epoch1: {acc_test}")
            print('########################')
            
        if step == 278:
            ts_test, coeffs_test, labels_test, _ = get_test_data(dataset_test_size, key=test_data_key)
            bxe_test, acc_test = test_loss(model, ts_test, labels_test, coeffs_test, test_data_key)
            print('########################')
            print(f"Test loss: {bxe_test}, Test Accuracy_epoch2: {acc_test}")
            print('########################')

        if step == 417:
            ts_test, coeffs_test, labels_test, _ = get_test_data(dataset_test_size, key=test_data_key)
            bxe_test, acc_test = test_loss(model, ts_test, labels_test, coeffs_test, test_data_key)
            print('########################')
            print(f"Test loss: {bxe_test}, Test Accuracy_epoch3: {acc_test}")
            print('########################')

        if step == 556:
            ts_test, coeffs_test, labels_test, _ = get_test_data(dataset_test_size, key=test_data_key)
            bxe_test, acc_test = test_loss(model, ts_test, labels_test, coeffs_test, test_data_key)
            print('########################')
            print(f"Test loss: {bxe_test}, Test Accuracy_epoch4: {acc_test}")
            print('########################')

        if step == 695:
            ts_test, coeffs_test, labels_test, _ = get_test_data(dataset_test_size, key=test_data_key)
            bxe_test, acc_test = test_loss(model, ts_test, labels_test, coeffs_test, test_data_key)
            print('########################')
            print(f"Test loss: {bxe_test}, Test Accuracy_epoch5: {acc_test}")
            print('########################')
            
        if step == 834:
            ts_test, coeffs_test, labels_test, _ = get_test_data(dataset_test_size, key=test_data_key)
            bxe_test, acc_test = test_loss(model, ts_test, labels_test, coeffs_test, test_data_key)
            print('########################')
            print(f"Test loss: {bxe_test}, Test Accuracy_epoch6: {acc_test}")
            print('########################')

        if step == 973:
            ts_test, coeffs_test, labels_test, _ = get_test_data(dataset_test_size, key=test_data_key)
            bxe_test, acc_test = test_loss(model, ts_test, labels_test, coeffs_test, test_data_key)
            print('########################')
            print(f"Test loss: {bxe_test}, Test Accuracy_epoch7: {acc_test}")
            print('########################')

        if step == 1112:
            ts_test, coeffs_test, labels_test, _ = get_test_data(dataset_test_size, key=test_data_key)
            bxe_test, acc_test = test_loss(model, ts_test, labels_test, coeffs_test, test_data_key)
            print('########################')
            print(f"Test loss: {bxe_test}, Test Accuracy_epoch8: {acc_test}")
            print('########################')

        if step == 1251:
            ts_test, coeffs_test, labels_test, _ = get_test_data(dataset_test_size, key=test_data_key)
            bxe_test, acc_test = test_loss(model, ts_test, labels_test, coeffs_test, test_data_key)
            print('########################')
            print(f"Test loss: {bxe_test}, Test Accuracy_epoch9: {acc_test}")
            print('########################')

        if step == 1390:
            ts_test, coeffs_test, labels_test, _ = get_test_data(dataset_test_size, key=test_data_key)
            bxe_test, acc_test = test_loss(model, ts_test, labels_test, coeffs_test, test_data_key)
            print('########################')
            print(f"Test loss: {bxe_test}, Test Accuracy_epoch10: {acc_test}")
            print('########################')

        if step == 1529:
            ts_test, coeffs_test, labels_test, _ = get_test_data(dataset_test_size, key=test_data_key)
            bxe_test, acc_test = test_loss(model, ts_test, labels_test, coeffs_test, test_data_key)
            print('########################')
            print(f"Test loss: {bxe_test}, Test Accuracy_epoch11: {acc_test}")
            print('########################')

        if step == 1668:
            ts_test, coeffs_test, labels_test, _ = get_test_data(dataset_test_size, key=test_data_key)
            bxe_test, acc_test = test_loss(model, ts_test, labels_test, coeffs_test, test_data_key)
            print('########################')
            print(f"Test loss: {bxe_test}, Test Accuracy_epoch12: {acc_test}")
            print('########################')

        if step == 1807:
            ts_test, coeffs_test, labels_test, _ = get_test_data(dataset_test_size, key=test_data_key)
            bxe_test, acc_test = test_loss(model, ts_test, labels_test, coeffs_test, test_data_key)
            print('########################')
            print(f"Test loss: {bxe_test}, Test Accuracy_epoch13: {acc_test}")
            print('########################')

        if step == 1946:
            ts_test, coeffs_test, labels_test, _ = get_test_data(dataset_test_size, key=test_data_key)
            bxe_test, acc_test = test_loss(model, ts_test, labels_test, coeffs_test, test_data_key)
            print('########################')
            print(f"Test loss: {bxe_test}, Test Accuracy_epoch14: {acc_test}")
            print('########################')

        if step == 2085:
            ts_test, coeffs_test, labels_test, _ = get_test_data(dataset_test_size, key=test_data_key)
            bxe_test, acc_test = test_loss(model, ts_test, labels_test, coeffs_test, test_data_key)
            print('########################')
            print(f"Test loss: {bxe_test}, Test Accuracy_epoch15: {acc_test}")
            print('########################')

        
    ts_test, coeffs_test, labels_test, _ = get_test_data(dataset_test_size, key=test_data_key)
    bxe_test, acc_test = test_loss(model, ts_test, labels_test, coeffs_test, test_data_key)
    print(f"Test loss: {bxe_test}, Test Accuracy: {acc_test}")

In [12]:
main()

Step: 0, Loss: 1.3735839128494263, Accuracy: 0.34375, Computation time: 12.146655082702637
Step: 1, Loss: 1.3841294050216675, Accuracy: 0.28125, Computation time: 1.9635345935821533
Step: 2, Loss: 1.3726595640182495, Accuracy: 0.34375, Computation time: 1.64892578125
Step: 3, Loss: 1.3718698024749756, Accuracy: 0.34375, Computation time: 1.6442804336547852
Step: 4, Loss: 1.4508860111236572, Accuracy: 0.1875, Computation time: 1.5641417503356934
Step: 5, Loss: 1.3574392795562744, Accuracy: 0.3125, Computation time: 1.727430820465088
Step: 6, Loss: 1.4151957035064697, Accuracy: 0.21875, Computation time: 1.7483091354370117
Step: 7, Loss: 1.3437983989715576, Accuracy: 0.375, Computation time: 1.8177804946899414
Step: 8, Loss: 1.3317105770111084, Accuracy: 0.375, Computation time: 1.633110761642456
Step: 9, Loss: 1.3745570182800293, Accuracy: 0.21875, Computation time: 2.0104403495788574
Step: 10, Loss: 1.3631460666656494, Accuracy: 0.4375, Computation time: 1.7557618618011475
Step: 11, Lo