In [6]:
import torch
from torch import nn
from typing import Optional, Literal
class ComplexFIR(nn.Module):
    def __init__(
        self,
        m: int,
        coeff: Optional[torch.Tensor] = None,
        init: Literal["zeros", "delta", "central_delta"] = "delta",
        trainable: bool = True,
        dtype: torch.dtype = torch.complex64,
    ):
        super().__init__()

        if not isinstance(m, int):
            raise TypeError("m must be int")
        if m <= 0:
            raise ValueError("m must be >= 1")

        if coeff is not None:
            if not isinstance(coeff, torch.Tensor):
                raise TypeError("coeff must be a torch.Tensor")
            if coeff.ndim != 1:
                raise ValueError("coeff must be 1D tensor")
            if coeff.numel() != m:
                raise ValueError("coeff length must be m")
            if not torch.is_complex(coeff):
                raise TypeError("coeff must be complex")
            h = coeff.to(dtype=dtype)
        else:
            h = torch.zeros(m, dtype=dtype)
            if init == "zeros":
                pass
            elif init == "delta":
                h[0] = 1
            elif init == "central_delta":
                h[m // 2] = 1
            else:
                raise ValueError("unknown init")

        if trainable:
            self.h = nn.Parameter(h)
        else:
            self.register_buffer("h", h)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.ndim != 2:
            raise ValueError("x must have shape (B, T)")
        if not torch.is_complex(x):
            raise TypeError("x must be a complex tensor")

        B, T = x.shape
        h = self.h
        m = h.numel()

        if T == 0:
            return x
        if m == 1:
            return x * h[0]
        pad = m - 1
        x_pad = torch.cat([x.new_zeros((B, pad)), x], dim=1) 
        frames = x_pad.unfold(dimension=1, size=m, step=1)
        h_rev = h.flip(0)  

        y = (frames * h_rev).sum(dim=-1)  
        return y

In [4]:
torch.eye(5)[5//2, :]

tensor([0., 0., 1., 0., 0.])

In [22]:
import torch
import pytest



def test_fir_delta_is_identity():
    # delta at tap 0 => y[n] = x[n]
    x = torch.randn(2, 8, dtype=torch.complex64)
    fir = ComplexFIR(m=3, init="delta", trainable=False, dtype=torch.complex64)
    y = fir(x)
    assert y.shape == x.shape
    assert torch.allclose(y, x)


def test_fir_central_delta_is_delay():
    # central delta at m=5 -> tap at index 2 => y[n] = x[n-2]
    x = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.complex64)
    fir = ComplexFIR(m=5, init="central_delta", trainable=False, dtype=torch.complex64)
    y = fir(x)
    expected = torch.tensor([[0, 0, 1, 2, 3]], dtype=torch.complex64)
    assert torch.allclose(y, expected)


def test_fir_known_small_example():
    # h=[1,1] => y[n] = x[n] + x[n-1]
    x = torch.tensor([[1, 2, 3, 4]], dtype=torch.complex64)
    h = torch.tensor([1, 1], dtype=torch.complex64)
    fir = ComplexFIR(m=2, coeff=h, trainable=False, dtype=torch.complex64)
    y = fir(x)
    expected = torch.tensor([[1, 3, 5, 7]], dtype=torch.complex64)
    assert torch.allclose(y, expected)


def test_fir_gradient_exists_when_trainable():
    x = torch.randn(3, 16, dtype=torch.complex64)
    fir = ComplexFIR(m=4, init="zeros", trainable=True, dtype=torch.complex64)

    y = fir(x)
    # real-valued loss so backward is well-defined
    loss = (y.abs() ** 2).mean()
    loss.backward()

    assert fir.h.grad is not None
    assert fir.h.grad.shape == fir.h.shape


def test_fir_raises_on_non_complex_input():
    fir = ComplexFIR(m=3, init="delta", trainable=False)
    x = torch.randn(2, 8, dtype=torch.float32)
    with pytest.raises(TypeError):
        _ = fir(x)


def test_fir_raises_on_wrong_shape():
    fir = ComplexFIR(m=3, init="delta", trainable=False)
    x = torch.randn(8, dtype=torch.complex64)  # 1D
    with pytest.raises(ValueError):
        _ = fir(x)


In [21]:
test_fir_raises_on_wrong_shape()
test_fir_raises_on_non_complex_input()
test_fir_delta_is_identity()
test_fir_gradient_exists_when_trainable()
test_fir_known_small_example()
test_fir_central_delta_is_delay()