In [None]:
import numpy as np

#This is an application of the generalized TT-transform A x_B M
def tprod_fourier(A, B):
    """
    Tensor-tensor product of two 3-way tensors along the fourier domain: C = A * B.
    
    Parameters:
        A (ndarray): Tensor of shape (n1, n2, n3)
        B (ndarray): Tensor of shape (n2, m2, n3)
        
    Returns:
        C (ndarray): Tensor of shape (n1, m2, n3)
        
    The product is computed in the Fourier domain along the third axis.
    """
    n1, n2, n3 = A.shape
    m1, m2, m3 = B.shape

    if n2 != m1 or n3 != m3:
        raise ValueError("Inner tensor dimensions must agree.")
    
    # Compute the FFT along the third dimension (axis=2)
    A_fft = np.fft.fft(A, axis=2)
    B_fft = np.fft.fft(B, axis=2)
    
    # Prepare an array for the result in the Fourier domain
    C_fft = np.zeros((n1, m2, n3), dtype=complex)
    
    # Determine the half-point along the third dimension
    halfn3 = int(np.ceil((n3 + 1) / 2))
    
    # For the first half (including the center if n3 is odd)
    for i in range(halfn3):
        C_fft[:, :, i] = A_fft[:, :, i] @ B_fft[:, :, i]
    
    # For the remaining frequencies, use the conjugate symmetry property
    for i in range(halfn3, n3):
        C_fft[:, :, i] = np.conjugate(C_fft[:, :, n3 - i])
    
    # Inverse FFT along the third axis to return to the spatial domain
    C = np.fft.ifft(C_fft, axis=2)
    
    # If A and B are real, the imaginary part should be negligible.
    # Optionally, one can return C.real
    return C


