In [2]:
import torch
import scipy 
from torch import nn

In [41]:
class Delay(nn.Module):
    def __init__(self, delay):
        super().__init__()
        if not isinstance(delay, int):
            raise TypeError("delay must be int")
        self.delay = delay
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.ndim != 2:
            raise TypeError("x must consists of Butch and Time dims")
        if not torch.is_complex(x):
            raise TypeError("x must be complex")
        return torch.cat((torch.zeros_like(x[:, :self.delay]), x[:, :-self.delay]))

In [50]:
x = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.complex64)
torch.cat((torch.zeros_like(x[:, :2]), x[:, :-2]))
# y = Delay(2)(x)
# y

RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 2 but got size 3 for tensor number 1 in the list.

In [42]:
import torch


def test_delay_zero_is_identity():
    x = torch.randn(2, 8, dtype=torch.complex64)
    y = Delay(0)(x)
    assert torch.allclose(y, x)


def test_delay_basic():
    x = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.complex64)
    y = Delay(2)(x)
    expected = torch.tensor([[0, 0, 1, 2, 3]], dtype=torch.complex64)
    assert torch.allclose(y, expected)


def test_delay_longer_than_signal():
    x = torch.randn(1, 3, dtype=torch.complex64)
    y = Delay(10)(x)
    assert torch.allclose(y, torch.zeros_like(x))
