In [None]:
# generate_dataset_natural_same_freq_final.py
import numpy as np
import random
import os
from scipy import signal as sp_signal

# ---------------- 基本参数 ----------------
fs = 10e6
L = 1024
T = 1 / fs
t = np.arange(L) * T

jnr_levels = np.arange(-10, 35, 5)      # -10 dB ~ 30 dB
num_samples = 1000
noise_power_dB = 10                     # 固定噪声功率

interference_types = [
    "satellite_signal",
    "single_tone", "comb_spectra", "sweeping", "pulse",
    "frequency_hopping", "noise_fm",
    "noise_am", "random_combination"
]

interference_type_names = {
    "satellite_signal": "Satellite_Signal",
    "single_tone": "Single_Tone",
    "comb_spectra": "Comb_Spectra",
    "sweeping": "Sweeping-LFM",
    "pulse": "Pulse",
    "frequency_hopping": "Frequency_Hopping",
    "noise_fm": "Noise_FM",
    "noise_am": "Noise_AM",
    "random_combination": "Random_Combination"
}

# 卫星跳频参数（恢复原版）
hop_sequence = [1e6, 2e6, 3e6, 4e6, 4.9e6]
hop_rate = 100
samples_per_hop = int(fs / hop_rate)

# 随机区间参数（恢复原版）
MIN_WIDTH = 600
MAX_WIDTH = 1024

# ---------------- 工具函数（恢复原版实现） ----------------
def generate_bpsk_signal(length=L):
    """生成BPSK调制信号（原版实现）"""
    bits = np.random.randint(0, 2, length // 10)
    symbols = 1 - 2 * bits
    sig = np.repeat(symbols, 10)
    return sig[:length] if len(sig) >= length else np.pad(sig, (0, length - len(sig)), constant_values=symbols[-1])

def generate_satellite_signal():
    """生成卫星信号（有用信号）- 原版实现"""
    sig = np.zeros(L)
    seg_len = L // 5
    for i in range(5):
        sl = seg_len if i < 4 else L - i * seg_len
        bpsk = generate_bpsk_signal(sl)
        carrier = np.sin(2 * np.pi * hop_sequence[i % 5] * t[i * seg_len:i * seg_len + sl])
        sig[i * seg_len:i * seg_len + sl] = bpsk * carrier
    # 卫星信号无干扰区间
    return sig, {"has_interference": False, "jnr_db": -np.inf}

def add_noise(signal, noise_power_dB):
    """添加高斯白噪声到信号中（原版实现）"""
    signal_power = np.mean(signal ** 2)
    noise_power_linear = 10 ** (noise_power_dB / 10)
    noise_std = np.sqrt(noise_power_linear)
    return signal + noise_std * np.random.randn(L)

def adjust_jnr(interference, target_jnr_db, noise_power_dB):
    """调整干扰信号的功率以达到目标JNR（原版实现）"""
    interference_power = np.mean(interference ** 2) + 1e-12
    target_jnr_linear = 10 ** (target_jnr_db / 10)
    target_interference_power = target_jnr_linear * (10 ** (noise_power_dB / 10))
    scale = np.sqrt(target_interference_power / interference_power)
    return interference * scale

def _rand_interval():
    """生成随机时间区间（原版实现）"""
    w = np.random.randint(MIN_WIDTH, min(MAX_WIDTH, L) + 1)
    s = np.random.randint(0, L - w + 1)
    return s, s + w

# ---------------- 同频干扰检测函数（恢复原版完整实现） ----------------
def detect_same_frequency(satellite_signal, interference_signal, fs, tolerance=20e3):
    """检测是否为同频干扰（原版完整FFT分析）"""
    # 提取卫星信号的主要频率分量
    sat_freqs = extract_main_frequencies(satellite_signal, fs)
    
    # 提取干扰信号的主要频率分量
    interf_freqs = extract_main_frequencies(interference_signal, fs)
    
    # 判断是否有频率重叠（考虑一定容差）
    for sat_f in sat_freqs:
        for interf_f in interf_freqs:
            if abs(sat_f - interf_f) <= tolerance:
                return True
    return False

def extract_main_frequencies(signal, fs):
    """提取信号的主要频率分量（原版实现）"""
    fft_result = np.fft.fft(signal)
    freqs = np.fft.fftfreq(len(signal), 1/fs)
    magnitude = np.abs(fft_result)
    
    # 找到能量最强的几个频率分量（排除直流分量）
    positive_freqs = freqs[:len(freqs)//2]
    positive_magnitude = magnitude[:len(magnitude)//2]
    
    # 取前5个最强分量且幅度大于最大幅度10%的频率
    peak_indices = np.argsort(positive_magnitude)[-5:]
    main_freqs = positive_freqs[peak_indices][positive_magnitude[peak_indices] > np.max(positive_magnitude)*0.1]
    
    return np.abs(main_freqs)  # 返回正值频率

# ---------------- 干扰信号生成函数（保留v3的API一致性） ----------------
def generate_single_tone_interference():
    """生成单音干扰信号（保留v3的时间参数归一化）"""
    start, end = _rand_interval()
    freq = np.random.uniform(0.8e6, 5.2e6)  # 恢复原版频率范围
    amp = np.random.uniform(0.5, 1.5)
    phase = np.random.rand() * 2 * np.pi
    carrier = amp * np.sin(2 * np.pi * freq * t[start:end] + phase)
    intf = np.zeros(L)
    intf[start:end] = carrier
    
    # 保留v3的时间参数归一化
    total_duration_ms = L / fs * 1e3
    params = {
        "start_time": (start / fs * 1e3) / total_duration_ms,   # ∈ [0,1]
        "end_time":   (end   / fs * 1e3) / total_duration_ms,    # ∈ [0,1]
        "frequency":  freq,
        "amplitude":  amp,
        "jnr_db": 0.0  # 统一API结构
    }
    return intf, params

def generate_pulse_interference():
    """生成脉冲干扰信号（保留v3的时间参数归一化）"""
    pulse_width = np.random.randint(25, 101)  # 恢复原版脉冲宽度
    pulse_start = np.random.randint(0, L - pulse_width + 1)
    pulse_end = pulse_start + pulse_width
    freq = np.random.uniform(0.8e6, 5.2e6)  # 恢复原版频率范围
    amp = np.random.uniform(0.5, 1.5)
    phase = np.random.rand() * 2 * np.pi
    carrier = amp * np.sin(2 * np.pi * freq * t[pulse_start:pulse_end] + phase)
    intf = np.zeros(L)
    intf[pulse_start:pulse_end] = carrier
    
    # 保留v3的时间参数归一化
    total_duration_ms = L / fs * 1e3
    params = {
        "start_time": (pulse_start / fs * 1e3) / total_duration_ms,
        "end_time":   (pulse_end   / fs * 1e3) / total_duration_ms,
        "pulse_width": pulse_width,
        "frequency":  freq,
        "amplitude":  amp,
        "jnr_db": 0.0  # 统一API结构
    }
    return intf, params

def generate_noise_am_interference():
    """生成噪声调幅干扰信号（保留v3的时间参数归一化）"""
    start, end = _rand_interval()
    carrier_freq = np.random.uniform(0.8e6, 5.2e6)  # 恢复原版频率范围
    mod_depth = np.random.uniform(0.3, 1.0)        # 恢复原版调制深度范围
    amp = np.random.uniform(0.5, 1.5)
    noise_mod = np.random.randn(end - start)
    noise_mod = (noise_mod - noise_mod.min()) / (noise_mod.max() - noise_mod.min())
    noise_mod = (1 - mod_depth) + mod_depth * noise_mod
    carrier = amp * np.sin(2 * np.pi * carrier_freq * t[start:end])
    intf = np.zeros(L)
    intf[start:end] = noise_mod * carrier
    
    # 保留v3的时间参数归一化
    total_duration_ms = L / fs * 1e3
    params = {
        "start_time": (start / fs * 1e3) / total_duration_ms,
        "end_time":   (end   / fs * 1e3) / total_duration_ms,
        "carrier_frequency": carrier_freq,
        "modulation_depth": mod_depth * 100,
        "amplitude": amp,
        "jnr_db": 0.0  # 统一API结构
    }
    return intf, params

def generate_comb_spectra_interference():
    """生成梳状谱干扰信号（保留v3的时间参数归一化）"""
    start, end = _rand_interval()
    base = np.random.uniform(0.5e6, 4.5e6)  # 恢复原版基频范围
    teeth = np.random.randint(2, 4)        # 恢复原版齿数范围
    amp = np.random.uniform(0.5, 1.5)
    intf = np.zeros(L)
    for i in range(teeth):
        f = base + i * 1.0e6
        if f > 5.2e6: f = 5.2e6
        intf[start:end] += amp * np.sin(2 * np.pi * f * t[start:end])
    
    # 保留v3的时间参数归一化
    total_duration_ms = L / fs * 1e3
    params = {
        "start_time": (start / fs * 1e3) / total_duration_ms,
        "end_time":   (end   / fs * 1e3) / total_duration_ms,
        "base_frequency": base,
        "num_teeth": teeth,
        "amplitude": amp,
        "jnr_db": 0.0  # 统一API结构
    }
    return intf, params

def generate_sweeping_interference():
    """生成扫频干扰信号（保留v3的时间参数归一化）"""
    start, end = _rand_interval()
    f0 = np.random.uniform(0.5e6, 2.5e6)  # 恢复原版起始频率
    f1 = np.random.uniform(2.5e6, 5.2e6)  # 恢复原版结束频率
    amp = np.random.uniform(0.5, 1.5)
    t_seg = t[start:end]
    sweep_rate = (f1 - f0) / (len(t_seg) * T)
    inst_freq = f0 + sweep_rate * np.arange(len(t_seg))
    phase = 2 * np.pi * np.cumsum(inst_freq) * T
    intf = np.zeros(L)
    intf[start:end] = amp * np.sin(phase)
    
    # 保留v3的时间参数归一化
    total_duration_ms = L / fs * 1e3
    params = {
        "start_time": (start / fs * 1e3) / total_duration_ms,
        "end_time":   (end   / fs * 1e3) / total_duration_ms,
        "start_freq": f0,
        "end_freq": f1,
        "amplitude": amp,
        "jnr_db": 0.0  # 统一API结构
    }
    return intf, params

def generate_frequency_hopping_interference():
    """生成跳频干扰信号（保留v3的时间参数归一化）"""
    start, end = _rand_interval()
    hop_seq = [1.2e6, 2.2e6, 3.2e6, 4.2e6, 5.1e6]  # 恢复原版跳频序列
    hop_rate = np.random.uniform(1e3, 3e3)
    sph = int(fs / hop_rate)
    amp = np.random.uniform(0.5, 1.5)
    intf = np.zeros(L)
    pos, hop_idx = start, 0
    while pos < end:
        seg_len = min(sph, end - pos)
        freq = hop_seq[hop_idx % len(hop_seq)]
        t_seg = t[pos:pos + seg_len] - t[pos]
        intf[pos:pos + seg_len] = amp * np.sin(2 * np.pi * freq * t_seg)
        pos += seg_len
        hop_idx += 1
    
    # 保留v3的时间参数归一化
    total_duration_ms = L / fs * 1e3
    params = {
        "start_time": (start / fs * 1e3) / total_duration_ms,
        "end_time":   (end   / fs * 1e3) / total_duration_ms,
        "hop_rate": hop_rate,
        "hop_sequence": hop_seq,
        "amplitude": amp,
        "jnr_db": 0.0  # 统一API结构
    }
    return intf, params

def generate_noise_fm_interference():
    """生成噪声调频干扰信号（保留v3的时间参数归一化）"""
    start, end = _rand_interval()
    center_freq = np.random.uniform(0.8e6, 5.2e6)  # 恢复原版中心频率
    mod_index = np.random.uniform(0.1, 0.5)       # 恢复原版调制指数
    amp = np.random.uniform(0.5, 1.5)
    noise_mod = np.random.randn(end - start)
    noise_mod /= (np.max(np.abs(noise_mod)) + 1e-12)
    inst_freq = center_freq * (1 + mod_index * noise_mod)
    phase = 2 * np.pi * np.cumsum(inst_freq) * T
    intf = np.zeros(L)
    intf[start:end] = amp * np.sin(phase)
    
    # 保留v3的时间参数归一化
    total_duration_ms = L / fs * 1e3
    params = {
        "start_time": (start / fs * 1e3) / total_duration_ms,
        "end_time":   (end   / fs * 1e3) / total_duration_ms,
        "center_frequency": center_freq,
        "modulation_index": mod_index,
        "amplitude": amp,
        "jnr_db": 0.0  # 统一API结构
    }
    return intf, params

def generate_random_combination_interference():
    """生成随机组合干扰信号（恢复原版随机数量）"""
    num = np.random.randint(2, 5)  # 恢复原版随机数量(2~4)
    candidates = ["single_tone", "comb_spectra", "sweeping", "pulse",
                  "frequency_hopping", "noise_fm", "noise_am"]
    types_selected = np.random.choice(candidates, size=num, replace=False)
    intf = np.zeros(L)
    components = []
    for tp in types_selected:
        sub_intf, sub_par = globals()[f"generate_{tp}_interference"]()
        sub_intf /= (np.std(sub_intf) + 1e-12) * np.sqrt(num)
        intf += sub_intf
        components.append({"type": tp, "params": sub_par})
    
    # 保留v3的时间参数归一化
    total_duration_ms = L / fs * 1e3
    params = {
        "start_time": min(c["params"]["start_time"] for c in components),
        "end_time":   max(c["params"]["end_time"]   for c in components),
        "jnr_db": 0.0,
        "components": components,
        "total_amplitude": np.max(np.abs(intf))
    }
    return intf, params

# ---------------- 主生成函数 ----------------
def generate_dataset():
    """生成完整的干扰信号数据集"""
    os.makedirs('dataset', exist_ok=True)
    total_samples = len(interference_types) * len(jnr_levels) * num_samples
    signals  = np.zeros((total_samples, L))
    labels   = np.zeros(total_samples, dtype=int)
    jnr_vals = np.zeros(total_samples)
    same_freq_flags = np.zeros(total_samples, dtype=bool)
    metadata = []
    type2label = {k: i for i, k in enumerate(interference_types)}

    print("🚀 开始生成干扰信号数据集...")
    print(f"📊 总样本数: {total_samples}")
    print(f"🎯 JNR范围: {jnr_levels[0]}dB 到 {jnr_levels[-1]}dB")
    print(f"🔊 基底噪声: {noise_power_dB}dB")
    print("=" * 50)
    
    idx = 0
    same_freq_count = 0
    for itype in interference_types:
        print(f"🔧 生成 {itype} 信号...")
        for jnr_db in jnr_levels:
            for sample_idx in range(num_samples):
                # 生成卫星有用信号
                sig_sat, _ = generate_satellite_signal()
                
                if itype == "satellite_signal":
                    # 无干扰情况
                    sig_total = sig_sat
                    params = {"has_interference": False, "jnr_db": -np.inf}
                    is_same_frequency = False
                else:
                    # 生成干扰信号
                    intf, params = globals()[f"generate_{itype}_interference"]()
                    # 使用原版完整同频检测
                    is_same_frequency = detect_same_frequency(sig_sat, intf, fs)
                    if is_same_frequency:
                        same_freq_count += 1
                    
                    # 调整干扰功率以达到目标JNR
                    intf = adjust_jnr(intf, jnr_db, noise_power_dB)
                    sig_total = sig_sat + intf
                    params["jnr_db"] = float(jnr_db)
                    params["is_same_frequency"] = is_same_frequency

                # 添加基底噪声
                noisy_sig = add_noise(sig_total, noise_power_dB)
                
                # 数据消毒（恢复原版无显式处理）
                # 不进行强制NaN/Inf替换，保持原版逻辑
                
                # 存储数据
                signals[idx]  = noisy_sig
                labels[idx]   = type2label[itype]
                jnr_vals[idx] = jnr_db
                same_freq_flags[idx] = is_same_frequency
                metadata.append({
                    "type": itype, 
                    "jnr_db": jnr_db,
                    "noise_power_db": noise_power_dB,
                    "is_same_frequency": is_same_frequency,
                    "params": params
                })
                idx += 1
                
                # 进度显示
                if (sample_idx + 1) % 200 == 0:
                    same_freq_status = "同频" if is_same_frequency else "非同频"
                    print(f"  进度: {sample_idx + 1}/{num_samples} samples at JNR={jnr_db}dB [{same_freq_status}]")
    
    # 统计同频干扰比例
    same_freq_ratio = same_freq_count / (total_samples - len(jnr_levels) * num_samples)
    print(f"\n📈 同频干扰统计: {same_freq_count}/{total_samples - len(jnr_levels) * num_samples} ({same_freq_ratio:.1%})")
    
    # 数据自检（恢复原版无自检）
    # 不添加NaN/Inf断言，保持原版逻辑
    
    # 保存数据集
    np.savez("dataset/interference_signals_natural_same_freq_1019.npz",
             signals=signals, 
             labels=labels, 
             jnr_values=jnr_vals,
             same_frequency_flags=same_freq_flags,
             type_to_label=type2label,
             interference_type_names=interference_type_names,
             fs=fs, 
             L=L, 
             metadata=metadata,
             noise_power_db=noise_power_dB)
    
    print("=" * 50)
    print("✅ 数据集生成完成！")
    print(f"📁 保存位置: dataset/interference_signals_natural_same_freq_1019.npz")
    print(f"📊 总样本数: {total_samples}")
    print(f"🎯 JNR范围: {jnr_levels[0]}dB 到 {jnr_levels[-1]}dB")
    print(f"🔊 基底噪声: {noise_power_dB}dB")
    print(f"📡 同频干扰比例: {same_freq_ratio:.1%}")
    
    return True

if __name__ == "__main__":
    generate_dataset()