Skip to content

Commit

Permalink
Update dft2.backward, remove dead code
Browse files Browse the repository at this point in the history
  • Loading branch information
andykee committed Dec 5, 2023
1 parent ff3ccd8 commit 74d841c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 39 deletions.
43 changes: 4 additions & 39 deletions loupe/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,11 @@ def backward(self, grad):
# always multiply grad by norm_coeff here.
grad_clean = grad * norm_coeff

shape, strides = _as_strided_args(grad_clean, input.shape[-2:])
input_grad = np.empty(shape, dtype=complex)
if grad_clean.ndim == 2:
grad_clean = grad_clean[np.newaxis,:]

for n, x in enumerate(as_strided(grad_clean, shape=shape, strides=strides, writeable=False)):
input_grad = np.empty(input.shape, dtype=complex)
for n, x in enumerate(grad_clean):
h_grad = np.dot(W_row_conj[n], x)
input_grad[n] = np.dot(h_grad, W_col_conj[n].T)

Expand Down Expand Up @@ -221,39 +222,3 @@ def _dft2_coords(m, n, M, N):
V = np.arange(N) - np.floor(N/2.0)

return R, S, U, V



@functools.lru_cache(maxsize=32)
def _dft2_matrix(n, N, alpha, shift):
w = -2.0 * np.pi * np.outer(_coords(n)-shift, _coords(N)-shift)
E = np.exp(1j * w * alpha)
return E


@functools.lru_cache(maxsize=32)
def _coords(n):
#TODO: ensure n is a python primitive type
# It needs to be hashable for LRU cache to work
return np.arange(n) - np.floor(n/2.0)

def _sanitize_ordered_pair(x, dtype):
x = np.asarray(x)

if x.shape != (2,):
raise ValueError(f"can't interpret x with shape {x.shape} as ordered pair")

return [dtype(x[0]), dtype(x[1])]


def _as_strided_args(input, output_shape):
depth = 1 if input.ndim ==2 else input.shape[0]

shape = output_shape if output_shape is not None else input.shape[-2:]
if len(input.shape) == 2:
shape = (1, *shape)
else:
shape = (input.shape[0], *shape)

strides = input.strides if depth > 1 else (0, *input.strides)
return shape, strides
13 changes: 13 additions & 0 deletions tests/test_fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ def test_dft2_backward(n):
assert np.allclose(f.grad, g.real * n**2)


@pytest.mark.parametrize("n", [10, 11])
def test_dft2_3d_backward(n):

f = loupe.rand(size=(3, n, n), requires_grad=True)
F = loupe.dft2(f, (1/n, 1/n), unitary=False)

grad = np.ones(shape=(3, n, n))
F.backward(grad)

g = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(grad)))

assert np.allclose(f.grad, g.real * n**2)

def test_dft2_unitary():
n = 10
f = np.random.rand(n, n) + 1j * np.random.rand(n, n)
Expand Down

0 comments on commit 74d841c

Please sign in to comment.