In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Lambda

def roll_tf(x, shift, axis):
    """
    Equivalent of torch.roll() in TensorFlow.
    """
    return tf.roll(x, shift=shift, axis=axis)

def fftshift_tf(x, axes=(-2, -1)):
    """
    Shift the zero-frequency component to the center of the spectrum.
    """
    for axis in axes:
        shift = x.shape[axis] // 2
        x = roll_tf(x, shift=shift, axis=axis)
    return x

def ifftshift_tf(x, axes=(-2, -1)):
    """
    Inverse FFT shift (move center to corner).
    """
    for axis in axes:
        shift = x.shape[axis] // 2
        x = roll_tf(x, shift=-shift, axis=axis)
    return x

def fft2c_tf(x, norm='ortho'):
    """
    Centered 2D Fast Fourier Transform (FFT).
    
    Args:
        x: Tensor of shape (..., H, W, 2), where the last dimension represents (real, imaginary).
        norm: Normalization mode ('ortho' or None).
    
    Returns:
        Centered FFT of x, in the same shape.
    """
    x = ifftshift_tf(x, axes=(-3, -2))  # Shift before FFT
    x_complex = tf.complex(x[..., 0], x[..., 1])  # Convert (real, imag) to complex
    x_fft = tf.signal.fft2d(x_complex)  # Apply 2D FFT

    if norm == "ortho":
        x_fft /= tf.sqrt(tf.cast(tf.reduce_prod(x.shape[-3:-1]), tf.complex64))  # Normalize

    x_fft = tf.stack([tf.math.real(x_fft), tf.math.imag(x_fft)], axis=-1)  # Convert back to (real, imag)
    x_fft = fftshift_tf(x_fft, axes=(-3, -2))  # Shift after FFT
    return x_fft

def ifft2c_tf(x, norm='ortho'):
    """
    Centered 2D Inverse Fast Fourier Transform (IFFT).
    
    Args:
        x: Tensor of shape (..., H, W, 2), where the last dimension represents (real, imaginary).
        norm: Normalization mode ('ortho' or None).
    
    Returns:
        Centered IFFT of x, in the same shape.
    """
    x = ifftshift_tf(x, axes=(-3, -2))  # Shift before IFFT
    x_complex = tf.complex(x[..., 0], x[..., 1])  # Convert (real, imag) to complex
    x_ifft = tf.signal.ifft2d(x_complex)  # Apply 2D IFFT

    if norm == "ortho":
        #print("x",x.shape,x.dtype,type(x))
        x_ifft *= tf.sqrt(tf.cast(tf.reduce_prod(x.shape[-3:-1]), tf.complex64))  # Normalize

    x_ifft = tf.stack([tf.math.real(x_ifft), tf.math.imag(x_ifft)], axis=-1)  # Convert back to (real, imag)
    x_ifft = fftshift_tf(x_ifft, axes=(-3, -2))  # Shift after IFFT
    return x_ifft

