In [None]:

import numpy as np
if not hasattr(np, 'PINF'):
    np.PINF = np.inf
if not hasattr(np, 'NINF'):
    np.NINF = -np.inf


try:
  import jax
except ModuleNotFoundError:
  %pip install --upgrade "jax[cpu]"
# install commplax if not found
try:
  import commplax
except ModuleNotFoundError:
  %pip install https://github.com/ChenHongBo0420/Comm/archive/master.zip
# install data api if not found
try:
  import labptptm2
except ModuleNotFoundError:
  %pip install https://github.com/remifan/LabPtPTm2/archive/master.zip


# install GDBP if not found
try:
  import gdbp
except ModuleNotFoundError:
  %pip install https://github.com/ChenHongBo0420/Q/archive/main.zip

%pip install https://github.com/remifan/LabPtPTm2/archive/master.zip
%pip uninstall numcodecs zarr -y
%pip install "zarr==2.10.3" "numcodecs==0.10.2"

Collecting https://github.com/ChenHongBo0420/Comm/archive/master.zip
  Downloading https://github.com/ChenHongBo0420/Comm/archive/master.zip
[2K     [32m-[0m [32m60.7 kB[0m [31m51.5 MB/s[0m [33m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting quantumrandom (from commplax==0.1.1)
  Downloading quantumrandom-1.9.0.tar.gz (7.6 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: commplax, quantumrandom
  Building wheel for commplax (setup.py) ... [?25l[?25hdone
  Created wheel for commplax: filename=commplax-0.1.1-py3-none-any.whl size=60061 sha256=0b7047f83deb4db26cd0af2b29ea4c8237c44aae6f5f8065c9c4350cc4675e8d
  Stored in directory: /tmp/pip-ephem-wheel-cache-8_d3gf6a/wheels/65/62/1c/eb4ebe204ced9b6e7bdf5d15e81d29ef1cb3e01a881ddc1a50
  Building wheel for quantumrandom (setup.py) ... [?25l[?25hdone
  Created wheel for quantumrandom: filename=quantumrandom-1.9.0-py3-none-any.whl size=9109 sha256=9c5d6

In [None]:
import jax, jax.numpy as jnp
from commplax.module import core


def _next_pow2(n: int) -> int:
    """返回 ≥ n 的最小 2^k"""
    return 1 << (n - 1).bit_length()


def conv1d_fft(scope, signal, *, taps=261, seglen=None,
               kernel_init=core.delta, debug=False):
    """
    valid-mode 1-D 复数卷积   (不降采样，不 streaming)

    **步骤**
    1. kernel 先 time-reverse（与 jnp.convolve 同约定）
    2. 把信号 & kernel zero-pad 到同一 FFT 长度
    3. full 卷积 → 取 [taps-1 : taps-1+N_out] 得 valid 输出
    """
    x, t_in = signal                       # x:(N,), (N,C)
    h_time  = scope.param("kernel", kernel_init,
                          (taps,), jnp.complex64)

    N_in  = x.shape[0]
    N_out = N_in - taps + 1                # valid 长度

    # ---------- FFT 长 ----------
    fft_len = seglen if seglen is not None \
                      else _next_pow2(N_in + taps - 1)

    if debug:
        print(f"◆ N_in={N_in}  taps={taps}  fft_len={fft_len}")

    # ---------- FFT ----------
    Xk = jnp.fft.fft(x,          fft_len, axis=0)
    Hk = jnp.fft.fft(jnp.flip(h_time), fft_len)
    if x.ndim == 2:                          # broadcast 到 (fft_len,2)
        Hk = Hk[:, None]

    Y  = jnp.fft.ifft(Xk * Hk, fft_len, axis=0)

    # ---------- 取 valid ----------
    y_val = Y[taps-1 : taps-1 + N_out]

    # ---------- SigTime ----------
    shift = taps - 1
    t_out = core.SigTime(t_in.start + shift,
                         t_in.stop  - shift,
                         t_in.sps)
    if debug:
        print("◆ y[0] =", y_val[0])

    return core.Signal(y_val, t_out)
import jax, jax.numpy as jnp
from jax import random
from commplax.module import core

# ---- 伪 Scope：param() 直接返给定 kernel -----------------
class FakeScope:
    def __init__(self, h): self.h = h
    def param(self, *a, **k): return self.h

# ---- 构造 δ 输入 & 随机 kernel ---------------------------
N, TAPS = 4097, 261
δ   = (jnp.arange(N) == 0).astype(jnp.complex64)
ker = random.normal(random.PRNGKey(0), (TAPS,), jnp.complex64)

scope = FakeScope(ker)
sig   = core.Signal(δ, core.SigTime(0, 0, 1))


sig_out = conv1d_fft(scope, sig, taps=TAPS, debug=True)

# ---- 参考：time-domain valid-conv -----------------------
ref_full = jnp.convolve(δ, jnp.flip(ker), mode='full')
ref      = ref_full[TAPS-1 : TAPS-1 + (N - TAPS + 1)]

err = jnp.max(jnp.abs(ref - sig_out.val))
print("max |TD − FFT| =", float(err))
assert err < 1e-6, "❌ 仍有误差，请检查！"
print("✓ unit-test passed  (误差 < 1e-6)")


◆ N_in=4097  taps=261  fft_len=8192
◆ y[0] = (0.7099452-1.7270769j)
max |TD − FFT| = 1.3328003944934608e-07
✓ unit-test passed  (误差 < 1e-6)
