Skip to content

Commit

Permalink
Add DFT and inverse
Browse files Browse the repository at this point in the history
  • Loading branch information
JonathanCrabbe committed Nov 30, 2023
1 parent defc64c commit b04c75e
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 1 deletion.
81 changes: 81 additions & 0 deletions src/fdiff/utils/fourier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import math

import torch
from torch.fft import irfft, rfft


def dft(x: torch.Tensor) -> torch.Tensor:
"""Compute the DFT of the input time series by keeping only the non-redundant components.
Args:
x (torch.Tensor): Time series of shape (batch_size, max_len, n_channels).
Returns:
torch.Tensor: DFT of x with the same size (batch_size, max_len, n_channels).
"""

max_len = x.size(1)

# Compute the FFT until the Nyquist frequency
dft_full = rfft(x, dim=1, norm="ortho")
dft_re = torch.real(dft_full)
dft_im = torch.imag(dft_full)

# The first harmonic corresponds to the mean, which is always real
zero_padding = torch.zeros_like(dft_im[:, 0, :], device=x.device)
assert torch.allclose(
dft_im[:, 0, :], zero_padding
), f"The first harmonic of a real time series should be real, yet got imaginary part {dft_im[:, 0, :]}."
dft_im = dft_im[:, 1:]

# If max_len is even, the last component is always zero
if max_len % 2 == 0:
assert torch.allclose(
dft_im[:, -1, :], zero_padding
), f"Got an even {max_len=}, which should be real at the Nyquist frequency, yet got imaginary part {dft_im[:, -1, :]}."
dft_im = dft_im[:, :-1]

# Concatenate real and imaginary parts
x_tilde = torch.cat((dft_re, dft_im), dim=1)
assert (
x_tilde.size() == x.size()
), f"The DFT and the input should have the same size. Got {x_tilde.size()} and {x.size()} instead."

return x_tilde


def idft(x: torch.Tensor) -> torch.Tensor:
"""Compute the inverse DFT of the input DFT that only contains non-redundant components.
Args:
x (torch.Tensor): DFT of shape (batch_size, max_len, n_channels).
Returns:
torch.Tensor: Inverse DFT of x with the same size (batch_size, max_len, n_channels).
"""

max_len = x.size(1)
n_real = math.ceil(max_len / 2) + 1

# Extract real and imaginary parts
x_re = x[:, :n_real, :]
x_im = x[:, n_real:, :]

# Create imaginary tensor
zero_padding = torch.zeros(size=(x.size(0), 1, x.size(2)))
x_im = torch.cat((zero_padding, x_im), dim=1)

# If number of time steps is even, put the null imaginary part
if max_len % 2 == 0:
x_im = torch.cat((x_im, zero_padding), dim=1)

assert (
x_im.size() == x_re.size()
), "The real and imaginary parts should have the same shape"

x_freq = torch.complex(x_re, x_im)

# Apply IFFT
x_time = irfft(x_freq, dim=1, norm="ortho")

return x_time
22 changes: 21 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import torch
from omegaconf import DictConfig

from fdiff.utils.extraction import flatten_config
from fdiff.utils.fourier import dft, idft

max_len = 100
n_channels = 3
batch_size = 100

def test_flatten_config():

def test_flatten_config() -> None:
cfg_dict = {
"Option1": "Value1",
"Option2": {
Expand All @@ -25,3 +31,17 @@ def test_flatten_config():
"Option5": ["Value5_0", "Value5_1"],
"Option6": "Value6",
}


def test_dft() -> None:
# Create a random real time series
x = torch.randn(batch_size, max_len, n_channels)

# Compute the DFT
x_tilde = dft(x)

# Compute the inverse DFT
x_hat = idft(x_tilde)

# Check that the inverse DFT is the original time series
assert torch.allclose(x, x_hat, atol=1e-5)

0 comments on commit b04c75e

Please sign in to comment.