From b04c75e0c25e1bb4e46c18770b1a6c456b9e56f1 Mon Sep 17 00:00:00 2001 From: JonathanCrabbe Date: Thu, 30 Nov 2023 14:28:26 +0000 Subject: [PATCH] Add DFT and inverse --- src/fdiff/utils/fourier.py | 81 ++++++++++++++++++++++++++++++++++++++ tests/test_utils.py | 22 ++++++++++- 2 files changed, 102 insertions(+), 1 deletion(-) diff --git a/src/fdiff/utils/fourier.py b/src/fdiff/utils/fourier.py index e69de29..352f097 100644 --- a/src/fdiff/utils/fourier.py +++ b/src/fdiff/utils/fourier.py @@ -0,0 +1,81 @@ +import math + +import torch +from torch.fft import irfft, rfft + + +def dft(x: torch.Tensor) -> torch.Tensor: + """Compute the DFT of the input time series by keeping only the non-redundant components. + + Args: + x (torch.Tensor): Time series of shape (batch_size, max_len, n_channels). + + Returns: + torch.Tensor: DFT of x with the same size (batch_size, max_len, n_channels). + """ + + max_len = x.size(1) + + # Compute the FFT until the Nyquist frequency + dft_full = rfft(x, dim=1, norm="ortho") + dft_re = torch.real(dft_full) + dft_im = torch.imag(dft_full) + + # The first harmonic corresponds to the mean, which is always real + zero_padding = torch.zeros_like(dft_im[:, 0, :], device=x.device) + assert torch.allclose( + dft_im[:, 0, :], zero_padding + ), f"The first harmonic of a real time series should be real, yet got imaginary part {dft_im[:, 0, :]}." + dft_im = dft_im[:, 1:] + + # If max_len is even, the last component is always zero + if max_len % 2 == 0: + assert torch.allclose( + dft_im[:, -1, :], zero_padding + ), f"Got an even {max_len=}, which should be real at the Nyquist frequency, yet got imaginary part {dft_im[:, -1, :]}." + dft_im = dft_im[:, :-1] + + # Concatenate real and imaginary parts + x_tilde = torch.cat((dft_re, dft_im), dim=1) + assert ( + x_tilde.size() == x.size() + ), f"The DFT and the input should have the same size. Got {x_tilde.size()} and {x.size()} instead." + + return x_tilde + + +def idft(x: torch.Tensor) -> torch.Tensor: + """Compute the inverse DFT of the input DFT that only contains non-redundant components. + + Args: + x (torch.Tensor): DFT of shape (batch_size, max_len, n_channels). + + Returns: + torch.Tensor: Inverse DFT of x with the same size (batch_size, max_len, n_channels). + """ + + max_len = x.size(1) + n_real = math.ceil(max_len / 2) + 1 + + # Extract real and imaginary parts + x_re = x[:, :n_real, :] + x_im = x[:, n_real:, :] + + # Create imaginary tensor + zero_padding = torch.zeros(size=(x.size(0), 1, x.size(2))) + x_im = torch.cat((zero_padding, x_im), dim=1) + + # If number of time steps is even, put the null imaginary part + if max_len % 2 == 0: + x_im = torch.cat((x_im, zero_padding), dim=1) + + assert ( + x_im.size() == x_re.size() + ), "The real and imaginary parts should have the same shape" + + x_freq = torch.complex(x_re, x_im) + + # Apply IFFT + x_time = irfft(x_freq, dim=1, norm="ortho") + + return x_time diff --git a/tests/test_utils.py b/tests/test_utils.py index eca55d5..370f8ab 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,15 @@ +import torch from omegaconf import DictConfig from fdiff.utils.extraction import flatten_config +from fdiff.utils.fourier import dft, idft +max_len = 100 +n_channels = 3 +batch_size = 100 -def test_flatten_config(): + +def test_flatten_config() -> None: cfg_dict = { "Option1": "Value1", "Option2": { @@ -25,3 +31,17 @@ def test_flatten_config(): "Option5": ["Value5_0", "Value5_1"], "Option6": "Value6", } + + +def test_dft() -> None: + # Create a random real time series + x = torch.randn(batch_size, max_len, n_channels) + + # Compute the DFT + x_tilde = dft(x) + + # Compute the inverse DFT + x_hat = idft(x_tilde) + + # Check that the inverse DFT is the original time series + assert torch.allclose(x, x_hat, atol=1e-5)