In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
import glob
import os

In [20]:
# -------------------------------
# 1️⃣ 读取谱线列表文件
# -------------------------------
def read_line_list(filepath):
    lines, freqs, mark07, mark09, mark10 = [], [], [], [], []
    with open(filepath, 'r') as f:
        
        # 跳过第一行
        next(f)
        
        for line in f:
            if line.strip() == '':
                continue
            parts = line.split()
            if len(parts) < 2:
                continue
            lines.append(parts[0])
            freqs.append(float(parts[1]))
            
            if len(parts) >= 5:
                mark07.append(int(parts[2]))
                mark09.append(int(parts[3]))
                mark10.append(int(parts[4]))
            else:
                mark07.append(0)
                mark09.append(0)
                mark10.append(0)
                
    return lines, np.array(freqs), mark07, mark09, mark10

# -------------------------------
# 2️⃣ 多文件光谱数据集类
# -------------------------------
class MultiSourceSpectrumDataset(Dataset):
    def __init__(self, data_dir, line_file, window=0.03, target_len=128,
                 normalize=True, velocity_dict=None):
        """
        初始化多源光谱数据集
        
        参数:
        - data_dir: 包含所有光谱文件的目录
        - line_file: 谱线列表文件路径
        - window: 提取窗口大小
        - target_len: 目标子光谱长度
        - normalize: 是否归一化
        - velocity_dict: 字典，键为源名称，值为速度修正值 (km/s)
        """
        self.line_names, self.line_rest, self.mark07, self.mark09, self.mark10 = read_line_list(line_file)
        self.line_name_to_index = {name: i for i, name in enumerate(self.line_names)}
        print(len(self.line_names))
        print(len(self.line_rest))
        
        self.marks_dict = {
        'Lh07': self.mark07,
        'Lh09': self.mark09,
        'Lh10': self.mark10,}
            
        self.c = 299792458  # m/s
        self.window = window
        self.target_len = target_len
        self.normalize = normalize
        
        # 如果没有提供速度字典，创建空字典
        if velocity_dict is None:
            velocity_dict = {}
        
        # 存储所有样本和对应的信息
        self.samples = []        # 光谱数据
        self.full_names = []     # 完整标识符，格式: "源-频率窗口-谱线名称"
        self.source_info = []    # 源信息，格式: (源名称, 频率窗口, 谱线名称)
        self.scales=[]           # 初始化缩放因子列表
        self.marks=[]            # 谱线标签（该频率是否存在信号）
        
        # 查找所有光谱文件
        # 假设文件命名格式: "spectrum.源名称.频率窗口.commonbeam.1arcsec.dat"
        pattern = os.path.join(data_dir, "spectrum.*.spw*.commonbeam.1arcsec.dat")
        spectrum_files = glob.glob(pattern)
        
        if not spectrum_files:
            print(f"警告: 在目录 {data_dir} 中没有找到匹配的光谱文件")
            return
        
        print(f"找到 {len(spectrum_files)} 个光谱文件")
        
        # 处理每个光谱文件
        for spectrum_file in spectrum_files:
            # 从文件名中提取源名称和频率窗口
            filename = os.path.basename(spectrum_file)
            parts = filename.split('.')
            if len(parts) < 3:
                print(f"跳过无法解析的文件: {filename}")
                continue
                
            source_name = parts[1]  # 例如 "Lh07"
            spw = parts[2]          # 例如 "spw0"
            
            print(f"处理文件: {filename}, 源: {source_name}, 频率窗口: {spw}")
            
            # 获取该源的速度修正值
            velocity = velocity_dict.get(source_name, None)
            
            # 处理单个光谱文件
            self._process_single_file(spectrum_file, source_name, spw, velocity)
            

        print(f"总共提取了 {len(self.samples)} 个谱线样本")

    
    def _process_single_file(self, spectrum_file, source_name, spw, velocity=None):
        """处理单个光谱文件"""
#         try:
            
        # 读取主光谱，跳过第一行
        data = np.loadtxt(spectrum_file, skiprows=1)
        freq_obs, flux = data[:, 0], data[:, 1]

        # 速度校正
        if velocity is not None:
            v = velocity * 1e3  # km/s → m/s
            freq = freq_obs / (1 - v / self.c)
            print(f"  [INFO] 应用速度修正: v = {velocity:.3f} km/s")
        else:
            freq = freq_obs

        print("开始提取子光谱")

        print(f"总共有{len(self.line_names)}条光谱")
        print("------------------------------")
        i=0
        number=0
        # 对每条谱线提取子光谱
        for line_name, f_rest,mark07, mark09, mark10 in zip(self.line_names, self.line_rest,
                                                           self.mark07, self.mark09, self.mark10):
            mask = (freq > f_rest - self.window) & (freq < f_rest + self.window)
            i+=1
            print(i)
            print(f"检查第{i}条曲线，为{line_name},静止频率为{f_rest}")
            if np.sum(mask) < 8:
                continue  # 跳过数据点太少的谱线

            f_sub = freq[mask]
            flux_sub = flux[mask]
            print(f"提取第{i}条曲线，为{line_name},静止频率为{f_rest}")
            number+=1
#                 print(line_name)
#                 plt.plot(f_sub,flux_sub)
#                 plt.show()

            # 插值到目标长度
            f_new = np.linspace(f_sub.min(), f_sub.max(), self.target_len)
            flux_new = np.interp(f_new, f_sub, flux_sub)

            # 归一化 (可选)
            if self.normalize:
                #flux_new = (flux_new - np.mean(flux_new)) / (np.std(flux_new) + 1e-6)
                scale = max(np.max(np.abs(flux_new)), 1e-6)
                flux_norm = flux_new / scale
            else:
                scale=1.0
                flux_norm=flux_new


        
            # 创建完整标识符
            full_name = f"{source_name}-{spw}-{line_name}"

            # 保存样本和相关信息
            self.samples.append(torch.tensor(flux_norm, dtype=torch.float32).unsqueeze(0))
            self.full_names.append(full_name)
            self.source_info.append((source_name, spw, line_name))
            self.scales.append(scale)  # 保存了缩放因子
            
            # 保存谱线标签（是否存在信号）
            if source_name=="Lh07":
                self.marks.append(mark07)
            elif source_name=="Lh09":
                self.marks.append(mark09)
            elif source_name=="Lh10":
                self.marks.append(self.mark10)

        print(f"这个文件共提取了{number}条谱线")

#     except Exception as e:
#         print(f"处理文件 {spectrum_file} 时出错: {e}")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx], self.full_names[idx], self.scales[idx]
    
    def get_mark_by_line(self, idx=None, source_name=None, line_name=None):

    #获取特定谱线的标记（是否存在信号）

        # 如果提供了索引，获取源名称和谱线名称
        if idx is not None:
            if idx >= len(self.source_info):
                print(f"警告: 索引 {idx} 超出范围")
                return None
            source_name, _, line_name = self.source_info[idx]

        # 验证参数
        if source_name is None or line_name is None:
            print("错误: 必须提供 source_name 和 line_name，或者提供 idx")
            return None

        # 检查源是否有对应的标记数据
        if source_name not in self.marks_dict:  # 使用 marks_dict
            print(f"警告: 源 {source_name} 没有标记数据")
            return None

        # 获取谱线索引
        if line_name not in self.line_name_to_index:
            print(f"警告: 谱线 {line_name} 不存在")
            return None

        line_idx = self.line_name_to_index[line_name]

        # 获取标记值
        marks_list = self.marks_dict[source_name]  # 从 marks_dict 获取
        return marks_list[line_idx]
    
    def get_source_info(self, idx):
        """获取指定索引样本的源信息"""
        return self.source_info[idx]
    
    def get_samples_by_source(self, source_name):
        """获取特定源的所有样本索引"""
        return [i for i, info in enumerate(self.source_info) if info[0] == source_name]
    
    def get_samples_by_spw(self, spw):
        """获取特定频率窗口的所有样本索引"""
        return [i for i, info in enumerate(self.source_info) if info[1] == spw]
    
    def get_samples_by_line(self, line_name):
        """获取特定谱线的所有样本索引"""
        return [i for i, info in enumerate(self.source_info) if info[2] == line_name]


In [21]:
def collate_fn(batch):
    """处理批次数据，包括缩放因子"""
    xs = torch.stack([b[0] for b in batch])  # 光谱数据
    names = [b[1] for b in batch]            # 名称列表
    
    scales_list = [b[2] for b in batch]
    scales = torch.tensor(scales_list, dtype=torch.float32)
    scales = scales.view(-1, 1, 1)  # 重塑为 [batch_size, 1, 1]
    
    return xs, names, scales

In [22]:
if __name__ == "__main__":
    # 数据目录和谱线列表
    data_dir = "C:\\Users\\zyx\\Desktop\\Spectral with Machine Learning\\data\\manysource"
    line_file = "C:\\Users\\zyx\\Desktop\\Spectral with Machine Learning\\data\\linelist_with_mark.txt"
    
    # 定义各源的速度修正值 (km/s)
    velocity_dict = {
        "Lh07": 239.5157166,
        "Lh09": 235.751358,  
        "Lh10": 251.2273865,  
        # 添加更多源和对应的速度...
    }
    
    # 创建多源数据集
    dataset = MultiSourceSpectrumDataset(
        data_dir=data_dir,
        line_file=line_file,
        velocity_dict=velocity_dict,
        target_len=256,
        normalize=True
    )
    

65
65
找到 15 个光谱文件
处理文件: spectrum.Lh07.spw0.commonbeam.1arcsec.dat, 源: Lh07, 频率窗口: spw0
  [INFO] 应用速度修正: v = 239.516 km/s
开始提取子光谱
总共有65条光谱
------------------------------
1
检查第1条曲线，为SO2,静止频率为345.3385391
2
检查第2条曲线，为SO2,静止频率为345.3387862
3
检查第3条曲线，为SO2,静止频率为345.4489815
4
检查第4条曲线，为SO2,静止频率为346.5238776
5
检查第5条曲线，为SO2,静止频率为346.6521672
6
检查第6条曲线，为34SO2,静止频率为344.2453476
提取第6条曲线，为34SO2,静止频率为344.2453476
7
检查第7条曲线，为34SO2,静止频率为344.581045
提取第7条曲线，为34SO2,静止频率为344.581045
8
检查第8条曲线，为34SO2,静止频率为344.8079157
提取第8条曲线，为34SO2,静止频率为344.8079157
9
检查第9条曲线，为34SO2,静止频率为344.9875851
提取第9条曲线，为34SO2,静止频率为344.9875851
10
检查第10条曲线，为34SO2,静止频率为344.9981616
提取第10条曲线，为34SO2,静止频率为344.9981616
11
检查第11条曲线，为34SO2,静止频率为345.168666
12
检查第12条曲线，为34SO2,静止频率为345.2856217
13
检查第13条曲线，为34SO2,静止频率为345.5196584
14
检查第14条曲线，为34SO2,静止频率为345.5530949
15
检查第15条曲线，为34SO2,静止频率为345.6512957
16
检查第16条曲线，为34SO2,静止频率为345.6787895
17
检查第17条曲线，为34SO2,静止频率为345.9292848
18
检查第18条曲线，为SO,静止频率为344.310612
提取第18条曲线，为SO,静止频率为344.310612
19
检查第19条曲线，为SO,静止频率为346.528

In [23]:
# 新建一个类保存上面的dataset
class create_data_Results:
    def __init__(self):
        self.dataset = dataset

In [28]:
mark1=dataset.marks
list07=mark1[0:67]
positions = [i for i, value in enumerate(list07) if value == 1]
print(len(positions))
print(positions)
print(list07)
print(len(list07))

39
[0, 1, 3, 5, 6, 9, 10, 11, 17, 21, 22, 24, 28, 29, 31, 32, 33, 34, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 55, 56, 58, 59, 60, 62, 63, 64, 65]
[1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0]
67


In [34]:
# print("=== 所有样本的静止频率 ===")

# # 创建列表存储所有样本的频率和源信息
# sample_info_list = []

# for i in range(len(dataset)):
#     source_name, spw, line_name = dataset.get_source_info(i)
    
#     # 获取静止频率
#     if line_name in dataset.line_name_to_index:
#         line_idx = dataset.line_name_to_index[line_name]
#         rest_freq = dataset.line_rest[line_idx]
        
#         sample_info_list.append({
#             'index': i,
#             'source': source_name,
#             'spw': spw,
#             'line': line_name,
#             'rest_freq': rest_freq
#         })
        
#         # 打印每个样本的信息
#         print(f"样本{i:3d}: {source_name}-{spw}-{line_name:10s}, 静止频率: {rest_freq:.6f}")

# # 按频率排序查看
# print("\n=== 按静止频率排序 ===")
# sorted_samples = sorted(sample_info_list, key=lambda x: x['rest_freq'])

# for sample in sorted_samples:
#     print(f"频率 {sample['rest_freq']:.6f}: {sample['source']}-{sample['spw']}-{sample['line']} (样本{sample['index']})")

# # 找出重复的频率
# print("\n=== 找出重复的静止频率 ===")
# freq_dict = {}
# for sample in sample_info_list:
#     freq = sample['rest_freq']
#     if freq not in freq_dict:
#         freq_dict[freq] = []
#     freq_dict[freq].append(sample)

# # 只显示重复出现的频率
# for freq, samples in sorted(freq_dict.items()):
#     if len(samples) > 1:
#         print(f"\n静止频率 {freq:.6f} 出现在 {len(samples)} 个样本中:")
#         for sample in samples:
#             print(f"  - {sample['source']}-{sample['spw']}-{sample['line']}")

In [32]:
# print("=== 验证每个源的样本数 ===")

# for source in ['Lh07', 'Lh09', 'Lh10']:
#     # 获取该源所有样本的spw分布
#     spw_counts = {}
#     for i, info in enumerate(dataset.source_info):
#         if info[0] == source:
#             spw = info[1]
#             spw_counts[spw] = spw_counts.get(spw, 0) + 1
    
#     total = sum(spw_counts.values())
#     print(f"{source}: 总计 {total} 个样本")
    
#     for spw in sorted(spw_counts.keys()):
#         print(f"  {spw}: {spw_counts[spw]} 个样本")
    
#     # 验证是否等于67
#     if total != 67:
#         print(f"  注意: {source} 的样本数 {total} 不等于67")

In [33]:
# print(dataset[0])
# print(len(dataset))