In [None]:
import warnings

import astropy.units
import FunctionLib as FL
import inspect
from tqdm import tqdm
import astropy
import wave
from matplotlib.image import resample
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import matplotlib as mpl
from collections import defaultdict
import re
import scipy

mpl.rcParams['font.family'] = 'serif'


warnings.filterwarnings("ignore")

DJAv4Catalog = FL.Spectrum_Catalog()
DJAv4Catalog.load_from_pkl(os.path.expanduser(
    '~/DustCurve/DJAv4Catalog.pkl'))
print(DJAv4Catalog.sample_num())

DJAv4Catalog.to_dataframe()

In [None]:
def balmer_bin(balmer_decrement):
    """
    Bin the Balmer decrement into categories.
    """
    if balmer_decrement<-1:
        return -1
    elif balmer_decrement<-0.5:
        return 0
    elif balmer_decrement<-0.25:
        return 1
    elif balmer_decrement<0:
        return 2
    elif balmer_decrement<0.25:
        return 3
    elif balmer_decrement<0.5:
        return 4
    elif balmer_decrement<1:
        return 5


In [None]:
balmer_bin0=[]
balmer_bin1=[]
balmer_bin2=[]
balmer_bin3=[]
balmer_bin4=[]
balmer_bin5=[]
balmer_number=[]


for id,catalog in DJAv4Catalog.catalog_iterator():
    if not catalog['properties']['Sample_Flag']:
        continue
    if 'Halpha_Fit_Result' in catalog['properties'] and 'Hbeta_Fit_Result' in catalog['properties']:
        Halpha_Fit_Result = catalog['properties']['Halpha_Fit_Result']
        Hbeta_Fit_Result = catalog['properties']['Hbeta_Fit_Result']

        halpha_intensity = Halpha_Fit_Result['integrated_flux']
        hbeta_intensity = Hbeta_Fit_Result['integrated_flux']
        balmer=np.log(halpha_intensity/hbeta_intensity/2.88)
        balmer_number.append(balmer)

    spectrum=FL.Load_Spectrum_From_Fits(catalog['prism_filepath'],catalog['determined_redshift'])



    bin_index = balmer_bin(balmer)
    if bin_index == 0:
        balmer_bin0.append(spectrum)
    elif bin_index == 1:
        balmer_bin1.append(spectrum)
    elif bin_index == 2:
        balmer_bin2.append(spectrum)
    elif bin_index == 3:
        balmer_bin3.append(spectrum)
    elif bin_index == 4:
        balmer_bin4.append(spectrum)
    elif bin_index == 5:
        balmer_bin5.append(spectrum)



print(f'Balmer bin 0: {len(balmer_bin0)}')
print(f'Balmer bin 1: {len(balmer_bin1)}')
print(f'Balmer bin 2: {len(balmer_bin2)}')
print(f'Balmer bin 3: {len(balmer_bin3)}')
print(f'Balmer bin 4: {len(balmer_bin4)}')
print(f'Balmer bin 5: {len(balmer_bin5)}')
print(f'Total: {len(balmer_bin0)+len(balmer_bin1)+len(balmer_bin2)+len(balmer_bin3)+len(balmer_bin4)+len(balmer_bin5)}')



In [None]:
import numpy as np
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt

class SpectrumAverager:
    def __init__(self):
        """
        初始化光谱平均器
        """
        self.spectra_data = []
        self.common_wavelength = None
        self.interpolated_fluxes = []
        self.average_flux = None
        self.std_flux = None

    def load_spectrum_list(self, spectrum_list):
        """
        从spectrum对象列表加载所有光谱

        参数:
        spectrum_list: spectrum对象列表
        """
        self.spectra_data = []
        successful_loads = 0

        for i, spectrum in enumerate(spectrum_list):
            try:
                # 验证spectrum对象是否有必要的属性
                if not hasattr(spectrum, 'processing_wavelengths') or \
                   not hasattr(spectrum, 'processing_flux'):
                    print(f"警告: spectrum {i} - 缺少processing_wavelengths或processing_flux属性")
                    continue

                # 获取数据
                wl = np.array(spectrum.processing_wavelengths)
                flux = np.array(spectrum.processing_flux)

                # 检查数据有效性
                if len(wl) == 0 or len(flux) == 0:
                    print(f"警告: spectrum {i} - 波长或通量数据为空")
                    continue

                if len(wl) != len(flux):
                    print(f"警告: spectrum {i} - 波长和通量数据长度不匹配 ({len(wl)} vs {len(flux)})")
                    continue

                # 去除NaN值
                mask = ~(np.isnan(wl) | np.isnan(flux))
                if np.sum(mask) < 2:
                    print(f"警告: spectrum {i} - 有效数据点太少")
                    continue

                wl_clean = wl[mask]
                flux_clean = flux[mask]

                # 按波长排序
                sort_idx = np.argsort(wl_clean)
                wl_sorted = wl_clean[sort_idx]
                flux_sorted = flux_clean[sort_idx]

                self.spectra_data.append({
                    'index': i,
                    'spectrum': spectrum,
                    'wavelength': wl_sorted,
                    'flux': flux_sorted
                })

                successful_loads += 1
                print(f"成功处理 spectrum {i}:")
                print(f"  波长范围: {wl_sorted.min():.2f} - {wl_sorted.max():.2f}")
                print(f"  数据点数: {len(wl_sorted)}")

            except Exception as e:
                print(f"处理失败 spectrum {i}: {e}")

        print(f"\n总共成功处理了 {successful_loads}/{len(spectrum_list)} 个光谱")
        return successful_loads > 0

    def create_common_wavelength_grid(self, num_points=None, wavelength_range=None):
        """
        创建统一的波长网格

        参数:
        num_points: 插值点数 (None为自动选择)
        wavelength_range: 波长范围 (min_wl, max_wl) (None为自动选择交集)
        """
        if not self.spectra_data:
            raise ValueError("请先加载光谱数据")

        # 找出所有光谱的波长范围
        all_min_wl = [spec['wavelength'].min() for spec in self.spectra_data]
        all_max_wl = [spec['wavelength'].max() for spec in self.spectra_data]

        if wavelength_range is None:
            # 使用所有光谱的交集范围
            min_wl = max(all_min_wl)
            max_wl = min(all_max_wl)
        else:
            min_wl, max_wl = wavelength_range

        if min_wl >= max_wl:
            raise ValueError("所有光谱的波长范围没有交集，无法进行插值平均")

        if num_points is None:
            # 自动选择点数：使用所有光谱中最高的分辨率
            densities = []
            for spec in self.spectra_data:
                mask = (spec['wavelength'] >= min_wl) & (spec['wavelength'] <= max_wl)
                valid_points = np.sum(mask)
                if valid_points > 1:
                    densities.append(valid_points / (max_wl - min_wl))

            if densities:
                num_points = max(1000, int(max(densities) * (max_wl - min_wl)))
            else:
                num_points = 1000

        self.common_wavelength = np.linspace(min_wl, max_wl, num_points)

        print(f"\n创建统一波长网格:")
        print(f"  波长范围: {min_wl:.2f} - {max_wl:.2f}")
        print(f"  插值点数: {num_points}")

        return self.common_wavelength

    def interpolate_all_spectra(self, interpolation_method='linear'):
        """
        对所有光谱进行插值到统一网格

        参数:
        interpolation_method: 插值方法 ('linear', 'cubic', 'quadratic')
        """
        if self.common_wavelength is None:
            self.create_common_wavelength_grid()

        self.interpolated_fluxes = []
        successful_interpolations = 0

        for spec_data in self.spectra_data:
            try:
                wl = spec_data['wavelength']
                flux = spec_data['flux']

                # 找出在公共波长范围内的数据点
                mask = (wl >= self.common_wavelength.min()) & \
                       (wl <= self.common_wavelength.max())

                if np.sum(mask) < 2:
                    print(f"警告: spectrum {spec_data['index']} 在公共波长范围内的数据点太少，跳过")
                    continue

                wl_subset = wl[mask]
                flux_subset = flux[mask]

                # 创建插值函数
                interp_func = interp1d(wl_subset, flux_subset,
                                     kind=interpolation_method,
                                     bounds_error=False,
                                     fill_value=np.nan)

                # 进行插值
                interpolated_flux = interp_func(self.common_wavelength)

                # 检查插值结果
                if np.all(np.isnan(interpolated_flux)):
                    print(f"警告: spectrum {spec_data['index']} 插值结果全为NaN，跳过")
                    continue

                self.interpolated_fluxes.append(interpolated_flux)
                successful_interpolations += 1

            except Exception as e:
                print(f"插值失败 spectrum {spec_data['index']}: {e}")

        print(f"成功插值 {successful_interpolations} 个光谱")
        return successful_interpolations > 0

    def calculate_average_spectrum(self, method='mean', ignore_nan=True):
        """
        计算平均光谱

        参数:
        method: 平均方法 ('mean', 'median')
        ignore_nan: 是否忽略NaN值
        """
        if not self.interpolated_fluxes:
            raise ValueError("请先进行插值")

        flux_matrix = np.array(self.interpolated_fluxes)

        if ignore_nan:
            if method == 'mean':
                self.average_flux = np.nanmean(flux_matrix, axis=0)
                self.std_flux = np.nanstd(flux_matrix, axis=0)
            elif method == 'median':
                self.average_flux = np.nanmedian(flux_matrix, axis=0)
                self.std_flux = np.nanstd(flux_matrix, axis=0)
        else:
            if method == 'mean':
                self.average_flux = np.mean(flux_matrix, axis=0)
                self.std_flux = np.std(flux_matrix, axis=0)
            elif method == 'median':
                self.average_flux = np.median(flux_matrix, axis=0)
                self.std_flux = np.std(flux_matrix, axis=0)

        # 计算每个波长点的有效光谱数量
        valid_count = np.sum(~np.isnan(flux_matrix), axis=0) if ignore_nan else len(self.interpolated_fluxes)

        print(f"\n使用 {method} 方法计算平均光谱 (ignore_nan={ignore_nan})")
        print(f"平均每个波长点有 {np.mean(valid_count):.1f} 个有效光谱参与计算")

        return self.common_wavelength, self.average_flux, self.std_flux

    def plot_spectra(self, show_individual=True, show_average=True,
                    show_std=True, figsize=(12, 8), alpha=0.3):
        """
        绘制光谱图
        """
        plt.figure(figsize=figsize)

        if show_individual and self.interpolated_fluxes:
            # 绘制个别光谱
            for i, flux in enumerate(self.interpolated_fluxes):
                label = None
                if len(self.interpolated_fluxes) <= 5:
                    # 只有少量光谱时才显示标签
                    label = f'Spectrum {i+1}'
                plt.plot(self.common_wavelength, flux,
                        alpha=alpha, linewidth=1, label=label)

        if show_average and self.average_flux is not None:
            # 绘制平均光谱
            plt.plot(self.common_wavelength, self.average_flux,
                    'r-', linewidth=2, label='Averaged Spectrum')

            if show_std and self.std_flux is not None:
                # 绘制标准差区域
                plt.fill_between(self.common_wavelength,
                               self.average_flux - self.std_flux,
                               self.average_flux + self.std_flux,
                               alpha=0.2, color='red', label='±1σ')

        plt.xlabel('Wavelength')
        plt.ylabel('Flux')
        plt.ylim(-1e-18,2e-18)
        plt.title(f'Spectrum Data and Averaged Spectrum ({len(self.interpolated_fluxes)} spectra)')
        plt.grid(True, alpha=0.3)

        if (len(self.interpolated_fluxes) <= 10) or show_average:
            plt.legend()

        plt.tight_layout()
        plt.show()

    def get_average_spectrum_data(self):
        """
        返回平均光谱数据

        返回:
        dict: 包含wavelength, flux, std的字典
        """
        if self.average_flux is None:
            raise ValueError("请先计算平均光谱")

        return {
            'wavelength': self.common_wavelength.copy(),
            'flux': self.average_flux.copy(),
            'std': self.std_flux.copy() if self.std_flux is not None else None
        }

    def save_average_spectrum(self, filename, include_std=True):
        """
        保存平均光谱到文件
        """
        if self.average_flux is None:
            raise ValueError("请先计算平均光谱")

        if include_std and self.std_flux is not None:
            data = np.column_stack([self.common_wavelength,
                                  self.average_flux,
                                  self.std_flux])
            header = "Wavelength\tFlux\tStd"
        else:
            data = np.column_stack([self.common_wavelength,
                                  self.average_flux])
            header = "Wavelength\tFlux"

        np.savetxt(filename, data, delimiter='\t', header=header, comments='')
        print(f"平均光谱已保存到: {filename}")

def calculate_spectrum_average(spectrum_list,
                             num_points=None, wavelength_range=None,
                             interpolation_method='linear', average_method='mean',
                             plot=True, save_file=None):
    """
    便捷函数：一键计算光谱平均值

    参数:
    spectrum_list: spectrum对象列表
    num_points: 插值点数
    wavelength_range: 波长范围
    interpolation_method: 插值方法
    average_method: 平均方法
    plot: 是否绘图
    save_file: 保存文件名

    返回:
    dict: 包含wavelength, flux, std的字典
    """
    # 创建平均器
    averager = SpectrumAverager()

    # 加载数据
    if not averager.load_spectrum_list(spectrum_list):
        raise ValueError("没有成功处理任何光谱对象")

    # 创建波长网格
    averager.create_common_wavelength_grid(num_points, wavelength_range)

    # 插值
    if not averager.interpolate_all_spectra(interpolation_method):
        raise ValueError("插值失败")

    # 计算平均
    averager.calculate_average_spectrum(average_method)

    # 绘图
    if plot:
        averager.plot_spectra()

    # 保存
    if save_file:
        averager.save_average_spectrum(save_file)

    return averager.get_average_spectrum_data()


In [None]:
list_list=[balmer_bin0, balmer_bin1, balmer_bin2, balmer_bin3, balmer_bin4, balmer_bin5]
# 一键计算平均光谱

for i, spectrum_list in enumerate(list_list):
    print(f'Calculating average for Balmer bin {i}')
    # 一键计算平均光谱
    result = calculate_spectrum_average(
        spectrum_list=spectrum_list,
        interpolation_method='linear',
        average_method='mean',
        plot=True
    )

In [None]:
def balmer_bin(balmer_decrement, bin_width=0.1, min_value=-1.0, max_value=1.0):
    """
    将Balmer衰减值分档到指定的区间

    参数:
    balmer_decrement: Balmer衰减值
    bin_width: 分档宽度，默认0.1
    min_value: 最小值，默认-1.0
    max_value: 最大值，默认1.0

    返回:
    bin_index: 分档索引，从0开始
    """
    if balmer_decrement < min_value:
        return -1  # 小于最小值的情况
    elif balmer_decrement >= max_value:
        return int((max_value - min_value) / bin_width)  # 大于等于最大值的情况
    else:
        return int((balmer_decrement - min_value) / bin_width)

def create_balmer_bins(catalog_iterator, load_spectrum_func,
                      bin_width=0.1, min_value=-1.0, max_value=1.0):
    """
    根据Balmer衰减值创建光谱分档

    参数:
    catalog_iterator: 目录迭代器
    load_spectrum_func: 加载光谱的函数
    bin_width: 分档宽度
    min_value: 最小值
    max_value: 最大值

    返回:
    dict: 包含各个分档的光谱列表和统计信息
    """
    # 计算总的分档数量
    num_bins = int((max_value - min_value) / bin_width) + 2  # +2 是为了包含边界情况

    # 初始化分档字典
    balmer_bins = {i: [] for i in range(-1, num_bins)}
    balmer_values = []

    print("开始处理光谱数据...")
    processed_count = 0

    for id, catalog in catalog_iterator:
        # 检查Sample_Flag
        if not catalog['properties']['Sample_Flag']:
            continue

        # 检查是否有Halpha和Hbeta拟合结果
        if ('Halpha_Fit_Result' not in catalog['properties'] or
            'Hbeta_Fit_Result' not in catalog['properties']):
            continue

        try:
            # 获取拟合结果
            halpha_result = catalog['properties']['Halpha_Fit_Result']
            hbeta_result = catalog['properties']['Hbeta_Fit_Result']

            # 计算强度
            halpha_intensity = halpha_result['integrated_flux']
            hbeta_intensity = hbeta_result['integrated_flux']

            # 计算Balmer衰减
            balmer = np.log(halpha_intensity / hbeta_intensity / 2.88)
            balmer_values.append(balmer)

            # 加载光谱
            spectrum = load_spectrum_func(catalog['prism_filepath'],
                                        catalog['determined_redshift'])

            # 确定分档
            bin_index = balmer_bin(balmer, bin_width, min_value, max_value)

            # 添加到对应分档
            if bin_index in balmer_bins:
                balmer_bins[bin_index].append(spectrum)

            processed_count += 1
            if processed_count % 100 == 0:
                print(f"已处理 {processed_count} 个光谱...")

        except Exception as e:
            print(f"处理光谱 {id} 时出错: {e}")
            continue

    # 打印统计信息
    print(f"\n光谱分档统计:")
    print(f"{'分档':<8} {'范围':<20} {'数量':<8}")
    print("-" * 40)

    total_spectra = 0
    valid_bins = {}

    for bin_idx in sorted(balmer_bins.keys()):
        count = len(balmer_bins[bin_idx])
        if count > 0:
            if bin_idx == -1:
                range_str = f"< {min_value:.1f}"
            elif bin_idx == num_bins - 1:
                range_str = f">= {max_value:.1f}"
            else:
                bin_start = min_value + bin_idx * bin_width
                bin_end = bin_start + bin_width
                range_str = f"[{bin_start:.1f}, {bin_end:.1f})"

            print(f"Bin {bin_idx:<3} {range_str:<20} {count:<8}")
            valid_bins[bin_idx] = balmer_bins[bin_idx]
            total_spectra += count

    print("-" * 40)
    print(f"总计: {total_spectra} 个光谱")

    # 统计Balmer值的分布
    if balmer_values:
        balmer_array = np.array(balmer_values)
        print(f"\nBalmer衰减值统计:")
        print(f"最小值: {np.min(balmer_array):.3f}")
        print(f"最大值: {np.max(balmer_array):.3f}")
        print(f"平均值: {np.mean(balmer_array):.3f}")
        print(f"标准差: {np.std(balmer_array):.3f}")

    return {
        'bins': valid_bins,
        'balmer_values': balmer_values,
        'bin_width': bin_width,
        'min_value': min_value,
        'max_value': max_value,
        'total_count': total_spectra
    }

# 使用示例

    # 替换为您的实际函数和迭代器
result = create_balmer_bins(
    catalog_iterator=DJAv4Catalog.catalog_iterator(),
    load_spectrum_func=FL.Load_Spectrum_From_Fits,
    bin_width=0.1,
    min_value=-1.0,
    max_value=1.0
)

    # 获取特定分档的光谱
    # bin_0_spectra = result['bins'][0]  # 获取第0档的光谱
    # bin_5_spectra = result['bins'][5]  # 获取第5档的光谱


for bin_idx, spectra_list in result['bins'].items():
    if len(spectra_list) > 0:
        print(f"\n计算分档 {bin_idx} 的平均光谱...")

        # 创建 SpectrumAverager 实例
        averager = SpectrumAverager()

        # 按步骤调用方法
        if averager.load_spectrum_list(spectra_list):
            averager.create_common_wavelength_grid()
            if averager.interpolate_all_spectra():
                averager.calculate_average_spectrum()
                averager.save_average_spectrum(f'balmer_bin_{bin_idx}_average.txt')
                # 可选：绘制图表
                averager.plot_spectra()
            else:
                print(f"分档 {bin_idx} 的插值失败")
        else:
            print(f"分档 {bin_idx} 的数据加载失败")
