In [None]:
import os
import sys

from math import cos, pi
import matplotlib.pyplot as plt
import numpy as np

sys.path.append(os.path.join(os.path.dirname(""), os.pardir))
import coherence_analysis.utils as f

In [None]:
sampling_rate = 500  # Hz
duration = 60  # seconds
t = np.linspace(0, duration, int(sampling_rate * duration), endpoint=False)
signal1 = np.cos(2 * np.pi * 50 * t) + 0.005 * np.random.normal(size=t.shape)
signal2 = np.zeros(t.shape)
impulse_times = [5 * i for i in range(12)]
for impulse_time in impulse_times:
    # print(impulse_time)
    index = int(impulse_time * sampling_rate)
    print(index)
    signal2[index + 1250] = 1
# signal2[14900] = 1
# signal2[12000] = 1

In [None]:
plt.figure(figsize=(12, 6))
# plt.plot(t, signal1, label='Signal 1 (50 Hz cosine wave)')
plt.plot(t, signal2, label="Signal 2 (Impulse train)", alpha=0.7)

In [None]:
plt.figure(figsize=(12, 6))
plt.plot(t[:50], signal1[:50], "-o")
plt.title("Signal 1")

In [None]:
spectra = np.fft.rfft(signal1)
frequencies = np.fft.rfftfreq(len(signal1), d=1 / sampling_rate)

plt.figure(figsize=(12, 6))
plt.plot(frequencies, np.abs(spectra))

In [None]:
spectra = np.fft.rfft(signal2)
frequencies = np.fft.rfftfreq(len(signal2), d=1 / sampling_rate)

plt.figure(figsize=(12, 6))
plt.plot(frequencies, spectra)

In [None]:
data = np.tile(signal1, (3, 1))
data = np.tile(signal2, (3, 1))
data

In [None]:
win_spectra, frequencies = f.windowed_spectra(
    data, 5, 0, sample_interval=1 / sampling_rate
)

In [None]:
data.shape

In [None]:
win_spectra.shape

In [None]:
plt.plot(frequencies, np.abs(win_spectra[4, 0, :]))

In [None]:
normalized_spectra, frequencies = f.normalised_windowed_spectra(
    data, 5, 0, sample_interval=1 / sampling_rate
)

In [None]:
normalized_spectra.shape

In [None]:
plt.plot(frequencies, np.abs(normalized_spectra[:, 0, 4]))

In [None]:
np.allclose(np.abs(win_spectra), 1)

In [None]:
np.allclose(np.abs(normalized_spectra), 1 / np.sqrt(12))

In [None]:
len(frequencies)

In [None]:
len(np.fft.rfftfreq(2500, d=1 / sampling_rate))

In [None]:
tr = np.array(range(60))
tr[::5]

In [None]:
select = (tr > 5) & (tr < 55)
tr[select]

In [None]:
import numpy as np
from scipy.stats import norm


def sequential_ci(
    experiment_fn,
    epsilon=0.01,
    alpha=0.05,
    min_samples=20,
    max_samples=1_000_000,
    batch=10,
):
    """
    Sequential CI-width stopping for estimating the mean of a stochastic process.

    Parameters
    ----------
    experiment_fn : function
        A function that returns one sample from the stochastic experiment.
    epsilon : float
        Desired CI half-width tolerance.
    alpha : float
        Significance level (e.g., 0.05 â†’ 95% CI).
    min_samples : int
        Minimum number of initial samples.
    max_samples : int
        Hard cap on total samples.
    batch : int
        Number of new samples to draw each iteration.

    Returns
    -------
    mean_est : float
        Final estimate of the mean.
    ci_half_width : float
        Final half-width of the confidence interval.
    n : int
        Number of samples used.
    samples : ndarray
        The full sample history.
    """

    z = norm.ppf(1 - alpha / 2)
    samples = []

    # initial sampling
    for _ in range(min_samples):
        samples.append(experiment_fn())

    while True:
        n = len(samples)
        mean_est = np.mean(samples)
        std_est = np.std(samples, ddof=1)

        ci_half_width = z * std_est / np.sqrt(n)

        if ci_half_width <= epsilon or n >= max_samples:
            return mean_est, ci_half_width, n, np.array(samples)

        # take another batch
        for _ in range(batch):
            samples.append(experiment_fn())

In [None]:
import numpy as np


# Define stochastic experiment
def experiment():
    u = np.random.rand()
    return np.sin(2 * np.pi * u) + 0.1 * np.random.randn()


mean_est, ci_hw, n, samples = sequential_ci(
    experiment_fn=experiment, epsilon=0.1, alpha=0.05, min_samples=30, batch=20
)

print("Estimated mean:", mean_est)
print("CI half-width:", ci_hw)
print("Samples used:", n)