In [None]:
#!/usr/bin/env python3
"""
使用MPI4Py并行处理光谱文件分析
"""

from mpi4py import MPI
import numpy as np
import os
import glob
import pickle
import time
from pathlib import Path

# 假设这是你的光谱分析类
class SpectrumLineFitter:
    def __init__(self):
        # 初始化参数
        pass

    def analyze_spectrum(self, spectrum_file):
        """
        分析单个光谱文件的函数
        返回分析结果
        """
        # 这里是你的光谱分析代码
        # 示例代码，请替换为你的实际实现
        try:
            # 读取和分析光谱文件
            result = {
                'filename': spectrum_file,
                'status': 'success',
                'analysis_data': 'your_analysis_results_here'
            }
            return result
        except Exception as e:
            return {
                'filename': spectrum_file,
                'status': 'error',
                'error': str(e)
            }

def get_spectrum_files(directory_path, pattern="*.fits"):
    """
    获取所有光谱文件列表
    """
    spectrum_files = glob.glob(os.path.join(directory_path, pattern))
    return spectrum_files

def chunk_list(lst, n):
    """
    将列表分成n个大致相等的块
    """
    k, m = divmod(len(lst), n)
    return [lst[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n)]

def main():
    # 初始化MPI
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()

    # 配置参数
    spectrum_directory = "/path/to/your/spectrum/files"  # 修改为你的光谱文件目录
    output_directory = "/path/to/output"  # 修改为输出目录
    file_pattern = "*.fits"  # 修改为你的文件模式，如 "*.dat", "*.txt" 等

    start_time = time.time()

    if rank == 0:
        print(f"开始使用 {size} 个进程进行并行光谱分析...")

        # 主进程获取所有光谱文件
        spectrum_files = get_spectrum_files(spectrum_directory, file_pattern)
        total_files = len(spectrum_files)

        if total_files == 0:
            print(f"在目录 {spectrum_directory} 中没有找到匹配模式 {file_pattern} 的文件")
            comm.Abort()

        print(f"找到 {total_files} 个光谱文件")

        # 将文件列表分配给各个进程
        file_chunks = chunk_list(spectrum_files, size)

        # 确保输出目录存在
        Path(output_directory).mkdir(parents=True, exist_ok=True)

    else:
        file_chunks = None

    # 广播文件分块到所有进程
    my_files = comm.scatter(file_chunks, root=0)

    if rank == 0:
        print(f"文件分配完成，每个进程分配到 {[len(chunk) for chunk in file_chunks]} 个文件")

    # 每个进程处理分配给它的文件
    fitter = SpectrumLineFitter()
    local_results = []

    print(f"进程 {rank}: 开始处理 {len(my_files)} 个文件...")

    for i, spectrum_file in enumerate(my_files):
        if rank == 0 and i % 100 == 0:
            print(f"进程 {rank}: 已处理 {i}/{len(my_files)} 个文件")

        # 分析光谱
        result = fitter.analyze_spectrum(spectrum_file)
        local_results.append(result)

    print(f"进程 {rank}: 完成处理 {len(my_files)} 个文件")

    # 收集所有进程的结果到主进程
    all_results = comm.gather(local_results, root=0)

    if rank == 0:
        # 主进程合并所有结果
        final_results = []
        for process_results in all_results:
            final_results.extend(process_results)

        end_time = time.time()
        total_time = end_time - start_time

        print(f"\n并行处理完成!")
        print(f"总处理时间: {total_time:.2f} 秒")
        print(f"处理文件数量: {len(final_results)}")
        print(f"平均每个文件处理时间: {total_time/len(final_results):.4f} 秒")

        # 统计成功和失败的数量
        success_count = sum(1 for r in final_results if r['status'] == 'success')
        error_count = len(final_results) - success_count

        print(f"成功处理: {success_count} 个文件")
        print(f"处理失败: {error_count} 个文件")

        # 保存结果
        output_file = os.path.join(output_directory, "spectrum_analysis_results.pkl")
        with open(output_file, 'wb') as f:
            pickle.dump(final_results, f)
        print(f"结果已保存到: {output_file}")

        # 如果有错误，保存错误日志
        if error_count > 0:
            error_file = os.path.join(output_directory, "error_log.txt")
            with open(error_file, 'w') as f:
                for result in final_results:
                    if result['status'] == 'error':
                        f.write(f"{result['filename']}: {result['error']}\n")
            print(f"错误日志已保存到: {error_file}")

if __name__ == "__main__":
    main()