In [2]:
import torch
import scipy 
from torch import nn

In [106]:
class Delay(nn.Module):
    """Signed time shift with zero padding for complex signals.

    delay > 0: shift right (zero-pad left)
    delay < 0: shift left  (zero-pad right)

    Input:  x of shape (B, T), complex dtype
    Output: y of shape (B, T), complex dtype
    """
    
    def __init__(self, delay):
        super().__init__()
        if not isinstance(delay, int):
            raise TypeError("delay must be int")
        self.delay = delay
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.ndim != 2:
            raise ValueError("x must have shape (B, T)")
        if not torch.is_complex(x):
            raise TypeError("x must be complex")
        if x.shape[1] < abs(self.delay):
            return torch.zeros_like(x)
        if self.delay == 0:
            return x
        if self.delay < 0:
            zeros = torch.zeros_like(x)
            return torch.cat((x[:, abs(self.delay):], zeros[:, -abs(self.delay):]), dim=1)
        if self.delay > 0:
            zeros = torch.zeros_like(x)
            return torch.cat((zeros[:, :abs(self.delay)], x[:, :-abs(self.delay)]), dim=1)        


In [107]:
x = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.complex64)
zeros = torch.zeros_like(x)
torch.cat((x[:, :-2], zeros[:, -2:]), dim=1)
y = Delay(1)(x)
y

tensor([[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j, 4.+0.j]])

In [91]:
x[:, :-2]

tensor([[1.+0.j, 2.+0.j, 3.+0.j]])

In [None]:
import torch
def test_delay_zero_is_identity():
    x = torch.randn(2, 8, dtype=torch.complex64)
    y = Delay(0)(x)
    assert torch.allclose(y, x)

def test_delay_basic():
    x = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.complex64)
    y = Delay(2)(x)
    expected = torch.tensor([[0, 0, 1, 2, 3]], dtype=torch.complex64)
    assert torch.allclose(y, expected)


def test_delay_longer_than_signal():
    x = torch.randn(1, 3, dtype=torch.complex64)
    y = Delay(10)(x)
    assert torch.allclose(y, torch.zeros_like(x))


In [95]:
test_delay_longer_than_signal()