In [2]:
import numpy as np

def tprod(A, B):
    """
    Tensor-tensor product of two 3-way tensors: C = A * B.
    
    Parameters:
        A (ndarray): Tensor of shape (n1, n2, n3)
        B (ndarray): Tensor of shape (n2, m2, n3)
        
    Returns:
        C (ndarray): Tensor of shape (n1, m2, n3)
        
    The product is computed in the Fourier domain along the third axis.
    """
    n1, n2, n3 = A.shape
    m1, m2, m3 = B.shape

    if n2 != m1 or n3 != m3:
        raise ValueError("Inner tensor dimensions must agree.")
    
    # Compute the FFT along the third dimension (axis=2)
    A_fft = np.fft.fft(A, axis=2)
    B_fft = np.fft.fft(B, axis=2)
    
    # Prepare an array for the result in the Fourier domain
    C_fft = np.zeros((n1, m2, n3), dtype=complex)
    
    # Determine the half-point along the third dimension
    halfn3 = int(np.ceil((n3 + 1) / 2))
    
    # For the first half (including the center if n3 is odd)
    for i in range(halfn3):
        C_fft[:, :, i] = A_fft[:, :, i] @ B_fft[:, :, i]
    
    # For the remaining frequencies, use the conjugate symmetry property
    for i in range(halfn3, n3):
        C_fft[:, :, i] = np.conjugate(C_fft[:, :, n3 - i])
    
    # Inverse FFT along the third axis to return to the spatial domain
    C = np.fft.ifft(C_fft, axis=2)
    
    # If A and B are real, the imaginary part should be negligible.
    # Optionally, one can return C.real
    return C

# Example usage:
if __name__ == '__main__':
    # Create example tensors with compatible dimensions
    n1, n2, n3 = 4, 5, 6
    m2 = 3
    A = np.random.rand(n1, n2, n3)
    B = np.random.rand(n2, m2, n3)
    
    # Compute the tensor-tensor product
    C = tprod(A, B)
    print("Result tensor shape:", C)


Result tensor shape: [[[7.62680541+2.05822252e-17j 7.53771123-2.05822252e-17j
   8.62248407+2.05822252e-17j 7.20503119-2.05822252e-17j
   7.33114659+2.05822252e-17j 9.24683991-2.05822252e-17j]
  [7.60230144-2.29872056e-17j 7.55863014+2.29872056e-17j
   7.35123473-2.29872056e-17j 7.11664692+2.29872056e-17j
   7.29506546-2.29872056e-17j 6.80923001+2.29872056e-17j]
  [8.91428633-2.71810478e-17j 8.26916149+2.71810478e-17j
   7.94212941-2.71810478e-17j 8.33875027+2.71810478e-17j
   9.34330606-2.71810478e-17j 7.64699012+2.71810478e-17j]]

 [[6.94992943-1.61586729e-17j 8.08800577+1.61586729e-17j
   7.27765556-1.61586729e-17j 7.22393974+1.61586729e-17j
   8.28841464-1.61586729e-17j 7.55326421+1.61586729e-17j]
  [6.24269169-4.14919286e-18j 7.00057851+4.14919286e-18j
   6.21454886-4.14919286e-18j 6.50950625+4.14919286e-18j
   6.29351782-4.14919286e-18j 6.664531  +4.14919286e-18j]
  [7.39263101+2.36177565e-18j 7.54316976-2.36177565e-18j
   8.07336588+2.36177565e-18j 8.02424858-2.36177565e-18j
   