In [1]:
import numpy as np
from scipy.signal import convolve

In [2]:
def afb(x, h0, h1):
    """
    Analysis Filter Bank (AFB)
    
    Parameters:
    x : input signal
    h0, h1 : analysis filters
    
    Returns:
    y : output -> [lowpass_channel, highpass_channel]
    """
    # Low-pass filter
    temp = convolve(x, h0, mode='full')
    temp[:len(temp) - len(x)] += temp[len(x):]
    y0 = temp[:len(x):2]
    
    # High-pass filter
    temp = convolve(x, h1, mode='full')
    temp[:len(temp) - len(x)] += temp[len(x):]
    y1 = temp[:len(x):2]
    
    return np.vstack((y0, y1))


In [3]:
def DTWPT(x, h_first, h, f, max_level):
    """
    Dual-Tree Wavelet Packet Transform (DTWPT)
    
    Parameters:
    x : input signal
    h_first : first stage filters ([h0_first, h1_first])
    h : dual-tree filters ([h0, h1])
    f : the 'same' filters ([f0, f1])
    max_level : maximum level
    
    Returns:
    y : output list containing all of the branches
    """
    y = {}

    # First stage
    fil0 = h_first[0, :]
    fil1 = h_first[1, :]
    yy = afb(x, fil0, fil1)
    y[(1, 1)] = yy[0, :]
    y[(1, 2)] = yy[1, :]

    # Second stage
    fil0 = h[0, :]
    fil1 = h[1, :]
    yy = afb(y[(1, 1)], fil0, fil1)
    y[(2, 1)] = yy[0, :]
    y[(2, 2)] = yy[1, :]

    yy = afb(y[(1, 2)], fil0, fil1)
    y[(2, 3)] = yy[0, :]
    y[(2, 4)] = yy[1, :]

    for n in range(3, max_level + 1):
        for k in range(1, 2**(n - 1) + 1):
            if k % 2**(n - 2) == 1:
                fil0 = h[0, :]
                fil1 = h[1, :]
            else:
                fil0 = f[0, :]
                fil1 = f[1, :]
            yy = afb(y[(n - 1, k)], fil0, fil1)
            y[(n, 2 * k - 1)] = yy[0, :]
            y[(n, 2 * k)] = yy[1, :]

    return [y[(max_level, k)] for k in range(1, 2**max_level + 1)]


In [4]:
# Example input data and filter coefficients
x = np.random.rand(128)
h_first = np.array([[0.5, 0.5], [-0.5, 0.5]])
h = np.array([[0.5, 0.5], [-0.5, 0.5]])
f = np.array([[0.5, 0.5], [-0.5, 0.5]])
max_level = 3

# Perform DTWPT
result = DTWPT(x, h_first, h, f, max_level)
print(result)


[array([0.66438168, 0.58345751, 0.73893422, 0.4990531 , 0.40901034,
       0.62397363, 0.29541942, 0.60041081, 0.61208978, 0.66057216,
       0.6377963 , 0.6273698 , 0.42406311, 0.28356758, 0.30998267,
       0.39598139]), array([ 0.01796235,  0.15777683,  0.13815212, -0.05874042,  0.10688973,
        0.04094656,  0.07217607,  0.04278388, -0.1028896 , -0.05639747,
        0.17705365, -0.0005387 , -0.03790552,  0.14344741, -0.04153116,
       -0.0792124 ]), array([-0.06189324,  0.06919202, -0.02936579, -0.10047036,  0.0182186 ,
       -0.01417357, -0.01850767, -0.02720516, -0.02880154,  0.11490034,
        0.01439954,  0.1592502 ,  0.0675191 , -0.09452313, -0.00109411,
       -0.12134137]), array([ 0.1278711 , -0.04762781,  0.01711607,  0.05291191, -0.13329258,
        0.12912719,  0.17278079,  0.20473933,  0.22998093,  0.04525778,
       -0.06615363,  0.13016601,  0.14413715,  0.00111229,  0.04716777,
       -0.09603217]), array([ 0.04020203,  0.07783519, -0.05228785,  0.15206311, -0.3