In [1]:
import numpy as np
import scipy.fft as sp_fft

class OFtrigger:

    def __init__(self, template, noise_psd, sampling_frequency):
        self._sampling_frequency = sampling_frequency
        self._length = len(template)
        self.set_template(template)
        self.set_noise_psd(noise_psd)

    def set_template(self, template):
        self._template = template
        self._template_fft = sp_fft.rfft(template) / self._sampling_frequency
        self._update_kernel_fft()

    def set_noise_psd(self, noise_psd):
        self._noise_psd = noise_psd.copy()  # defensive copy if needed elsewhere

        # Invert all PSD bins except DC
        self._inv_psd = np.zeros_like(noise_psd)
        self._inv_psd[1:] = 1.0 / (noise_psd[1:] + 1e-30)

        # If even-length trace, Nyquist bin should not be doubled — fix it
        if self._length % 2 == 0:
            self._inv_psd[-1] = 1.0 / (noise_psd[-1] + 1e-30)  # overwrite to avoid doubling

        self._inv_psd[0] = 0.0  # explicitly set DC to 0 after everything

        self._update_kernel_fft()


    def _update_kernel_fft(self):
        if hasattr(self, '_template_fft') and hasattr(self, '_inv_psd'):
            self._kernel_fft = self._template_fft.conjugate() * self._inv_psd
            self._kernel_normalization = np.real(
                np.dot(self._kernel_fft, self._template_fft)
            ) * self._sampling_frequency / self._length

    def fit(self, trace):
        trace_fft = sp_fft.rfft(trace) / self._sampling_frequency
        trace_filtered = self._kernel_fft * trace_fft / self._kernel_normalization

        chisq0 = np.real(np.vdot(trace_fft, trace_fft * self._inv_psd)) * self._sampling_frequency / self._length

        amp0 = np.real(np.sum(trace_filtered)) * self._sampling_frequency / self._length
        chisq = (chisq0 - amp0**2 * self._kernel_normalization) / (self._length - 2)

        return amp0, chisq

    def fit_with_shift(self, trace, allowed_shift_range=None):
        trace_fft = sp_fft.rfft(trace) / self._sampling_frequency
        trace_filtered = self._kernel_fft * trace_fft / self._kernel_normalization
        trace_filtered_td = sp_fft.irfft(trace_filtered) * self._sampling_frequency
        trace_filtered_td *= 0.5

        chisq0 = np.real(np.vdot(trace_fft, trace_fft * self._inv_psd)) * self._sampling_frequency / self._length

        chit_withdelay = (trace_filtered_td ** 2) * self._kernel_normalization
        chi = chisq0 - chit_withdelay

        if allowed_shift_range is None:
            ind = np.arange(len(chi))
        else:
            ind = np.concatenate((
                np.arange(self._length + allowed_shift_range[0], self._length),
                np.arange(allowed_shift_range[1] + 1)
            ))

        best_ind = ind[np.argmin(chi[ind])]
        amp = trace_filtered_td[best_ind]
        chisq = chi[best_ind] / (self._length - 3)
        t0 = best_ind if best_ind < self._length // 2 else best_ind - self._length

        return amp, chisq, t0


In [1]:
import numpy as np
import scipy.fft as sp_fft

class OFtrigger:

    def __init__(self, template, noise_psd, sampling_frequency):
        self._sampling_frequency = sampling_frequency
        self._length = len(template)
        self.set_template(template)
        self.set_noise_psd(noise_psd)

    def set_template(self, template):
        self._template = template
        self._template_fft = sp_fft.rfft(template) / self._sampling_frequency
        self._update_kernel_fft()

    def set_noise_psd(self, noise_psd):
        self._noise_psd = noise_psd.copy()

        self._inv_psd = np.zeros_like(noise_psd)
        self._inv_psd[1:] = 1.0 / (noise_psd[1:] + 1e-30)
        if self._length % 2 == 0:
            self._inv_psd[-1] = 1.0 / (noise_psd[-1] + 1e-30)
        self._inv_psd[0] = 0.0

        self._update_kernel_fft()

    def _update_kernel_fft(self):
        if hasattr(self, '_template_fft') and hasattr(self, '_inv_psd'):
            self._kernel_fft = self._template_fft.conjugate() * self._inv_psd
            self._kernel_normalization = np.real(
                np.dot(self._kernel_fft, self._template_fft)
            ) * self._sampling_frequency / self._length

    def fit(self, trace):
        trace_fft = sp_fft.rfft(trace) / self._sampling_frequency
        trace_filtered = self._kernel_fft * trace_fft / self._kernel_normalization

        chisq0 = np.real(np.vdot(trace_fft, trace_fft * self._inv_psd)) * self._sampling_frequency / self._length
        amp0 = np.real(np.sum(trace_filtered)) * self._sampling_frequency / self._length
        chisq = (chisq0 - amp0**2 * self._kernel_normalization) / (self._length - 2)

        return amp0, chisq

    def fit_with_shift(self, trace, allowed_shift_range=None):
        trace_fft = sp_fft.rfft(trace) / self._sampling_frequency
        trace_filtered = self._kernel_fft * trace_fft / self._kernel_normalization

        # A(t0) is the inverse FFT of the filtered signal
        trace_filtered_td = sp_fft.irfft(trace_filtered) * self._sampling_frequency  # Removed *0.5

        # Compute chi^2_0 (independent of shift)
        chisq0 = np.real(np.vdot(trace_fft, trace_fft * self._inv_psd)) * self._sampling_frequency / self._length

        # Compute chi^2(t0) = chisq0 - A(t0)^2 * norm
        amp_series = trace_filtered_td
        chisq_series = chisq0 - amp_series**2 * self._kernel_normalization

        if allowed_shift_range is None:
            ind = np.arange(len(chisq_series))
        else:
            start = (self._length + allowed_shift_range[0]) % self._length
            stop = (allowed_shift_range[1] + 1) % self._length
            if start < stop:
                ind = np.arange(start, stop)
            else:
                ind = np.concatenate((np.arange(start, self._length), np.arange(0, stop)))

        best_ind = ind[np.argmin(chisq_series[ind])]
        amp = amp_series[best_ind]
        chisq = chisq_series[best_ind] / (self._length - 3)
        t0 = best_ind if best_ind < self._length // 2 else best_ind - self._length

        return amp, chisq, t0


In [5]:
import numpy as np
import time
from OF_trigger import OptimumFilter  # Replace with actual module path


sampling_frequency = 3906250

data = np.load("/ceph/dwong/delight/Ka_traces_1.npz")

template = np.load("../templates/template_K_alpha_tight.npy")
noise_psd = np.load("../templates/noise_psd_from_MMC.npy")
loaded_traces = data['data']
data.close()

# Ensure correct shape

assert len(template) == loaded_traces.shape[1], "Trace length must match template length"

# Benchmark function using fit_with_shift
def benchmark_filter_with_shift(FilterClass, template, psd, sampling_frequency, traces):
    of = FilterClass(template, psd, sampling_frequency)
    start = time.time()
    results = []
    for i in range(traces.shape[0]):
        results.append(of.fit(traces[i]))
    elapsed = time.time() - start
    return elapsed, results

# Run benchmarks
t_old, results_old = benchmark_filter_with_shift(OptimumFilter, template, noise_psd, sampling_frequency, loaded_traces)
t_new, results_new = benchmark_filter_with_shift(OFtrigger, template, noise_psd, sampling_frequency, loaded_traces)
print(f"OptimumFilter time (fit_with_shift): {t_old:.4f}s")
print(f"OFtrigger time (fit_with_shift): {t_new:.4f}s")
print(f"Speedup: {t_old / t_new:.2f}x")

OptimumFilter time (fit_with_shift): 4.7367s
OFtrigger time (fit_with_shift): 2.5996s
Speedup: 1.82x


In [6]:
results_old[10:20] 

[(10414.42191904015, 1.030801645592369),
 (10405.665579890223, 0.9785365274571997),
 (10416.386487652877, 1.0108269016491889),
 (10411.39540220257, 0.9890808943602902),
 (10395.465460254754, 1.0254867881971248),
 (10414.140305288478, 0.9932377290421555),
 (10411.829155000516, 1.032540640745023),
 (10416.330759369554, 1.0021761522403765),
 (10415.423674732101, 1.0033656299552631),
 (10416.520771292313, 1.0048662332395768)]

In [7]:
results_new[10:20]

[(10414.42109024857, 0.25776443091116386),
 (10405.666255017106, 0.24467661264363585),
 (10416.386732229108, 0.2527123004807928),
 (10411.39521288639, 0.24727356398077607),
 (10395.465671545378, 0.256375857898393),
 (10414.14057650647, 0.2483162880687405),
 (10411.829917958335, 0.2581894130600697),
 (10416.331222260224, 0.25056400808347623),
 (10415.422702388236, 0.2509295248461117),
 (10416.520472719629, 0.2512248667948544)]