In [None]:
import numpy as np
from scipy.signal import convolve

In [None]:
def afb(x, h0, h1):
    """
    Analysis Filter Bank (AFB)
    
    Parameters:
    x : input signal
    h0, h1 : analysis filters
    
    Returns:
    y : output -> [lowpass_channel, highpass_channel]
    """
    # Low-pass filter
    temp = convolve(x, h0, mode='full')
    temp[:len(temp) - len(x)] += temp[len(x):]
    y0 = temp[:len(x):2]
    
    # High-pass filter
    temp = convolve(x, h1, mode='full')
    temp[:len(temp) - len(x)] += temp[len(x):]
    y1 = temp[:len(x):2]
    
    return np.vstack((y0, y1))


In [None]:
def DTWPT(x, h_first, h, f, max_level):
    """
    Dual-Tree Wavelet Packet Transform (DTWPT)
    
    Parameters:
    x : input signal
    h_first : first stage filters ([h0_first, h1_first])
    h : dual-tree filters ([h0, h1])
    f : the 'same' filters ([f0, f1])
    max_level : maximum level
    
    Returns:
    y : output list containing all of the branches
    """
    y = {}

    # First stage
    fil0 = h_first[0, :]
    fil1 = h_first[1, :]
    yy = afb(x, fil0, fil1)
    y[(1, 1)] = yy[0, :]
    y[(1, 2)] = yy[1, :]

    # Second stage
    fil0 = h[0, :]
    fil1 = h[1, :]
    yy = afb(y[(1, 1)], fil0, fil1)
    y[(2, 1)] = yy[0, :]
    y[(2, 2)] = yy[1, :]

    yy = afb(y[(1, 2)], fil0, fil1)
    y[(2, 3)] = yy[0, :]
    y[(2, 4)] = yy[1, :]

    for n in range(3, max_level + 1):
        for k in range(1, 2**(n - 1) + 1):
            if k % 2**(n - 2) == 1:
                fil0 = h[0, :]
                fil1 = h[1, :]
            else:
                fil0 = f[0, :]
                fil1 = f[1, :]
            yy = afb(y[(n - 1, k)], fil0, fil1)
            y[(n, 2 * k - 1)] = yy[0, :]
            y[(n, 2 * k)] = yy[1, :]

    return [y[(max_level, k)] for k in range(1, 2**max_level + 1)]


In [None]:
# Example input data and filter coefficients
x = np.random.rand(128)
h_first = np.array([[0.5, 0.5], [-0.5, 0.5]])
h = np.array([[0.5, 0.5], [-0.5, 0.5]])
f = np.array([[0.5, 0.5], [-0.5, 0.5]])
max_level = 3

# Perform DTWPT
result = DTWPT(x, h_first, h, f, max_level)
print(result)
