In [1]:
import os
import re
import math
import seaborn as sns
import numpy as np
import pandas as pd
import h5py
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, RobustScaler, OneHotEncoder
from sklearn.metrics import mean_squared_error, r2_score
from sklearn import metrics

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader, Dataset, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchsummary import summary
from torch.cuda.amp import autocast, GradScaler

## 去噪

In [None]:
# ========== 数据路径和加载配置 ==========
data_folder = "D:/data/waveforms2"  # 你的数据目录，可按需修改
target_snr = ["50.00", "100.00", "200.00"]  # 可指定加载哪些 SNR 数据

# ========== 初始化自定义 Dataset ==========
class GWDataset(Dataset):
    def __init__(self, folder_path, snr_list=None, transform=None):
        self.folder_path = folder_path
        self.transform = transform
        self.snr_list = snr_list or ["50.00", "100.00", "200.00"]
        
        # 检查文件夹是否存在
        if not os.path.isdir(folder_path):
            raise FileNotFoundError(f"数据目录不存在: {folder_path}")
            
        self.file_index = []
        snr_pattern = re.compile(r"_SNR(\d+\.\d+)\.h5")

        # 使用更安全的方式遍历文件
        for fname in os.listdir(folder_path):
            if not fname.endswith(".h5"):
                continue
                
            full_path = os.path.join(folder_path, fname)
            
            # 跳过非文件项（如目录）
            if not os.path.isfile(full_path):
                print(f"跳过非文件项: {full_path}")
                continue
                
            # 检查文件可读性
            if not os.access(full_path, os.R_OK):
                print(f"警告: 文件不可读，跳过: {full_path}")
                continue
                
            match = snr_pattern.search(fname)
            if match:
                snr = match.group(1)
                if snr in self.snr_list:
                    # 尝试打开文件以验证完整性
                    try:
                        with h5py.File(full_path, "r") as f:
                            # 简单验证文件结构
                            if "Data" not in f:
                                print(f"警告: 文件缺少'Data'组，跳过: {full_path}")
                                continue
                                
                            # 检查必要属性
                            data_group = f["Data"]
                            required_attrs = ["mc_true", "phis_true", "thetas_true"]
                            if not all(attr in data_group.attrs for attr in required_attrs):
                                print(f"警告: 文件缺少必要属性，跳过: {full_path}")
                                continue
                                
                        # 文件验证通过，添加到索引
                        self.file_index.append((full_path, snr))
                        
                    except (OSError, IOError) as e:
                        print(f"文件打开错误 {full_path}: {str(e)}，跳过")
                    except Exception as e:
                        print(f"处理文件时出错 {full_path}: {str(e)}，跳过")

        if not self.file_index:
            raise ValueError("没有匹配到任何指定 SNR 的文件，请检查路径或 snr_list 设置")
        else:
            print(f"成功加载 {len(self.file_index)} 个文件")

    def __len__(self):
        return len(self.file_index)

    def __getitem__(self, idx):
        file_path, snr = self.file_index[idx]
        
        try:
            with h5py.File(file_path, "r") as f:
                data_group = f["Data"]
                
                # 处理不同的数据结构
                if isinstance(data_group, h5py.Group):
                    white_data = torch.tensor(data_group["white_Data"][:], dtype=torch.float32)
                    white_signal = torch.tensor(data_group["white_signal"][:], dtype=torch.float32)
                elif isinstance(data_group, h5py.Dataset):
                    # 处理旧格式的数据集
                    white_data = torch.tensor(data_group[0].flatten(), dtype=torch.float32)
                    white_signal = torch.tensor(data_group[1].flatten(), dtype=torch.float32)
                else:
                    raise ValueError(f"未知的数据结构: {file_path}")
                
                attrs = {k: data_group.attrs[k] for k in data_group.attrs}
                
        except (OSError, IOError) as e:
            # 文件读取错误时返回空数据并记录警告
            print(f"读取文件错误 {file_path}: {str(e)}")
            seq_len = 6184  # 默认序列长度
            white_data = torch.zeros(seq_len, dtype=torch.float32)
            white_signal = torch.zeros(seq_len, dtype=torch.float32)
            attrs = {
                "mc_true": 0.0,
                "phis_true": 0.0,
                "thetas_true": 0.0
            }
        except Exception as e:
            print(f"处理文件时发生意外错误 {file_path}: {str(e)}")
            raise e

        sample = {
            "white_data": white_data,
            "white_signal": white_signal,
            "attributes": attrs,
            "mc_true": attrs.get("mc_true", 0.0),
            "phis_true": attrs.get("phis_true", 0.0),
            "thetas_true": attrs.get("thetas_true", 0.0),
            "snr": snr,
            "filename": os.path.basename(file_path)
        }

        if self.transform:
            sample = self.transform(sample)

        return sample

# ========== 创建数据集对象 ==========
try:
    dataset = GWDataset(data_folder, snr_list=target_snr)
    print(f"共加载样本数量: {len(dataset)}")
except Exception as e:
    print(f"数据集初始化失败: {str(e)}")
    # 创建空数据集防止后续代码崩溃
    class EmptyDataset(Dataset):
        def __len__(self): return 0
        def __getitem__(self, idx): return {}
    dataset = EmptyDataset()

In [None]:
sample = dataset[0]
print("white_data shape:", sample["white_data"].shape)
print("white_signal shape:", sample["white_signal"].shape)

In [None]:
# 构建掩码 0表示弱信号，1表示强信号，2表示第二段弱信号
def generate_mask(data: torch.Tensor, threshold_factor=2.0):
    B, L = data.shape
    mask = torch.ones_like(data, dtype=torch.long)  # 初始为全1（弱信号）
    stds = torch.std(data, dim=1, keepdim=True)  # (B, 1)
    thresholds = threshold_factor * stds         # 每条数据的阈值 (B, 1)

    abs_data = data.abs()  # (B, L)
    for i in range(B):
        above_th = (abs_data[i] > thresholds[i])  # bool mask
        strong_indices = torch.nonzero(above_th).squeeze()

        if strong_indices.numel() > 0:
            start = strong_indices[0].item()
            end = strong_indices[-1].item() + 1
            mask[i, start:end] = 0  # 强信号
            mask[i, end:] = 2       # 第二段弱信号

    return mask

In [None]:
# 标准化
def standardize_batch(data: torch.Tensor, signal: torch.Tensor, mask: torch.Tensor, amplification=10.0):
    # 放大弱信号部分
    signal_amplified = signal.clone()
    signal_amplified[mask == 1] *= amplification

    # 提取弱信号索引
    weak_indices = (mask == 1)

    # RobustScaler 模拟：中位数与 IQR（近似标准化）
    weak_values = data[weak_indices].view(-1)
    median = weak_values.median()
    q1 = weak_values.kthvalue(int(len(weak_values) * 0.25))[0]
    q3 = weak_values.kthvalue(int(len(weak_values) * 0.75))[0]
    iqr = q3 - q1 + 1e-8  # 避免除零

    # 标准化公式：(x - median) / IQR
    data_std = (data - median) / iqr
    signal_std = (signal_amplified - median) / iqr

    stats = {"median": median.item(), "iqr": iqr.item()}
    return data_std, signal_std, stats

In [None]:
# 数据集划分（训练/验证/测试）
def split_dataset(dataset, train_ratio=0.2, val_ratio=0.2):
    total_len = len(dataset)
    train_len = int(total_len * train_ratio)
    val_len = int(total_len * val_ratio)
    test_len = total_len - train_len - val_len
    return random_split(dataset, [train_len, val_len, test_len])

# 假设 GWDataset 实例为 dataset
train_set, val_set, test_set = split_dataset(dataset)

train_loader = DataLoader(train_set, batch_size=16, shuffle=True,  num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_set,   batch_size=16, shuffle=False, num_workers=0, pin_memory=True)
test_loader  = DataLoader(test_set,  batch_size=16, shuffle=False, num_workers=0, pin_memory=True)

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=7, stride=stride, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.LeakyReLU = nn.LeakyReLU(inplace=True)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(self.expansion * out_channels)
            )

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.LeakyReLU(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += self.shortcut(residual)
        out = self.LeakyReLU(out)

        return out

class ResNet(nn.Module):
    def __init__(self, block, layers):
        super(ResNet, self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv1d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(64)
        self.LeakyReLU = nn.LeakyReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self.make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self.make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self.make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self.make_layer(block, 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(512 * block.expansion, 1)  # 输出一个值

    def make_layer(self, block, out_channels, blocks, stride=1):
        layers = []
        layers.append(block(self.in_channels, out_channels, stride))
        self.in_channels = out_channels * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.LeakyReLU(out)
        out = self.maxpool(out)

        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)

        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)

        return out


def ResNetModel():
    return ResNet(BasicBlock, [2, 2, 2, 2])

In [None]:
class WaveUNetWithTransformer(nn.Module):
    def __init__(self):
        super().__init__()

        # Encoder（使用大卷积核，并增加层数）
        self.encoder1 = nn.Sequential(
            nn.Conv1d(1, 16, kernel_size=15, padding=7), nn.ReLU(),
            nn.MaxPool1d(2)
        )
        self.encoder2 = nn.Sequential(
            nn.Conv1d(16, 32, kernel_size=15, padding=7), nn.ReLU(),
            nn.MaxPool1d(2)
        )
        self.encoder3 = nn.Sequential(
            nn.Conv1d(32, 64, kernel_size=15, padding=7), nn.ReLU(),
            nn.MaxPool1d(2)
        )
        self.encoder4 = nn.Sequential(
            nn.Conv1d(64, 128, kernel_size=15, padding=7), nn.ReLU(),
            nn.MaxPool1d(2)
        )

        # Transformer bottleneck
        self.transformer_input_proj = nn.Conv1d(128, 256, kernel_size=1)
        encoder_layer = nn.TransformerEncoderLayer(d_model=256, nhead=8, dim_feedforward=512, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=4)
        self.transformer_output_proj = nn.Conv1d(256, 128, kernel_size=1)

        # Decoder（镜像结构）
        self.decoder1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="linear", align_corners=True),
            nn.Conv1d(128, 64, kernel_size=15, padding=7), nn.ReLU()
        )
        self.decoder2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="linear", align_corners=True),
            nn.Conv1d(64, 32, kernel_size=15, padding=7), nn.ReLU()
        )
        self.decoder3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="linear", align_corners=True),
            nn.Conv1d(32, 16, kernel_size=15, padding=7), nn.ReLU()
        )
        self.decoder4 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="linear", align_corners=True),
            nn.Conv1d(16, 1, kernel_size=15, padding=7)
        )

    def forward(self, x, mask):
        input_len = x.shape[-1]
        residual = x  # 用于强信号跳跃连接

        x1 = self.encoder1(x)
        x2 = self.encoder2(x1)
        x3 = self.encoder3(x2)
        x4 = self.encoder4(x3)

        x_trans = self.transformer_input_proj(x4).permute(0, 2, 1)
        x_trans = self.transformer(x_trans)
        x_trans = self.transformer_output_proj(x_trans.permute(0, 2, 1))

        x = self.decoder1(x_trans)
        x = self.decoder2(x)
        x = self.decoder3(x)
        x = self.decoder4(x)

        # 输出与输入对齐
        if x.shape[-1] > input_len:
            x = x[:, :, :input_len]
        elif x.shape[-1] < input_len:
            x = F.pad(x, (0, input_len - x.shape[-1]))
        
        # 输出处理：根据掩码合成最终输出
        output = x * (mask == 1) + residual * (mask == 0)  # 弱信号①+强信号
        output = output * (mask != 2)  # 再将弱信号②置 0
        
        return output

In [None]:
# 初始化模型
denoising_model = WaveUNetWithTransformer()
pred_model = ResNetModel()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
denoising_model.to(device)
pred_model.to(device)

optimizer = torch.optim.Adam(denoising_model.parameters(), lr=0.00005, weight_decay=1e-5)
optimizer_pred = torch.optim.Adam(pred_model.parameters(), lr=0.0001, weight_decay=1e-3)

In [None]:
# 掩码损失函数
def masked_loss(output, target, mask, lambda_mse=5.0, lambda_smooth=1.0, stability_penalty=1.0):
    # ========== 主损失项 (MSE) ==========
    loss = F.mse_loss(output, target, reduction='none')  # (B, 1, T)
    active = (mask == 1).float()
    core = (loss * active).sum() / (active.sum() + 1e-8)
    # ========== 平滑项（仅对弱信号①） ==========
    diff = output[:, :, 1:] - output[:, :, :-1]
    mask_diff = (mask[:, :, 1:] == 1) & (mask[:, :, :-1] == 1)  # 相邻都是弱信号①
    smooth_penalty = (diff**2 * mask_diff.float()).sum() / (mask_diff.float().sum() + 1e-8)
    # ========== 标准差惩罚项（防塌缩，仅对弱信号①） ==========
    weak_output = output[mask == 1]
    std_penalty = 1.0 / (torch.std(weak_output) + 1e-4)  # 防止输出塌缩为常数
    # ========== 总损失 ==========
    total_loss = lambda_mse * core + lambda_smooth * smooth_penalty + stability_penalty * std_penalty
    return total_loss

In [None]:
# 训练 + 验证 + 保存最优模型
def train_model(model, train_loader, val_loader, optimizer, device, num_epochs=300, save_path="best_denoising_model.pt"):
    best_val_loss = float('inf')
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True)

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for batch in train_loader:
            white_data = batch["white_data"].to(device)       # shape: (B, L)
            white_signal = batch["white_signal"].to(device)   # shape: (B, L)

            mask = generate_mask(white_data).unsqueeze(1).to(device)  # shape: (B, 1, L)

            x = white_data.unsqueeze(1)    # shape: (B, 1, L)
            y = white_signal.unsqueeze(1)

            x_std, y_std, _ = standardize_batch(x.squeeze(1), y.squeeze(1), mask.squeeze(1))  # 去通道维标准化
            x_std = x_std.unsqueeze(1).to(device)
            y_std = y_std.unsqueeze(1).to(device)

            optimizer.zero_grad()
            output = model(x_std, mask)
            loss = masked_loss(output, y_std, mask)
            loss.backward()
            clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_loss += loss.item() * x.size(0)
        train_loss /= len(train_loader.dataset)

        # 验证
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                white_data = batch["white_data"].to(device)
                white_signal = batch["white_signal"].to(device)

                mask = generate_mask(white_data).unsqueeze(1).to(device)

                x = white_data.unsqueeze(1)
                y = white_signal.unsqueeze(1)

                x_std, y_std, _ = standardize_batch(x.squeeze(1), y.squeeze(1), mask.squeeze(1))
                x_std = x_std.unsqueeze(1).to(device)
                y_std = y_std.unsqueeze(1).to(device)

                output = model(x_std, mask)
                loss = masked_loss(output, y_std, mask)
                val_loss += loss.item() * x.size(0)
        val_loss /= len(val_loader.dataset)

        # 输出日志
        with torch.no_grad():
            output_std = output.std().item()
        current_lr = optimizer.param_groups[0]['lr']
        print(f"{epoch+1}/{num_epochs}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}, Output Std = {output_std:.4f}, LR = {current_lr:.2e}")

        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print("  >> Best model saved.")

In [None]:
# 混合精度训练
def train_model_amp(model, train_loader, val_loader, optimizer, device, num_epochs=100):
    loss_fn = masked_mse_loss  # 原本定义的掩码损失函数
    scaler = GradScaler()      # AMP 缩放器
    model.to(device)

    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0

        for batch in train_loader:
            x = batch["white_data"].to(device, non_blocking=True)
            y = batch["white_signal"].to(device, non_blocking=True)
            mask = generate_mask(x).unsqueeze(1).to(device)

            x_std, y_std, _ = standardize_batch(x.squeeze(1), y.squeeze(1), mask.squeeze(1))
            x_std = x_std.unsqueeze(1).to(device)
            y_std = y_std.unsqueeze(1).to(device)

            optimizer.zero_grad()

            with autocast():
                output = model(x_std, mask)
                loss = loss_fn(output, y_std, mask)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        # 验证部分（同样支持 AMP）
        model.eval()
        val_loss = 0

        with torch.no_grad():
            for batch in val_loader:
                x = batch["white_data"].to(device, non_blocking=True)
                y = batch["white_signal"].to(device, non_blocking=True)
                mask = generate_mask(x).unsqueeze(1).to(device)

                x_std, y_std, _ = standardize_batch(x.squeeze(1), y.squeeze(1), mask.squeeze(1))
                x_std = x_std.unsqueeze(1).to(device)
                y_std = y_std.unsqueeze(1).to(device)

                with autocast():
                    output = model(x_std, mask)
                    loss = loss_fn(output, y_std, mask)

                val_loss += loss.item()

        val_loss /= len(val_loader)
        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_denoising_model.pt")

In [None]:
# 训练
train_model(denoising_model, train_loader, val_loader, optimizer, device, num_epochs=120)
# train_model_amp(denoising_model, train_loader, val_loader, optimizer, device, num_epochs=120)

In [None]:
# 测试
def test_model_and_save_labels_batched(model, test_loader, device, output_dir="pt_chunks", base_name="denoised_batch", target_names=["mc_true", "phis_true", "thetas_true"]):
    os.makedirs(output_dir, exist_ok=True)
    model.eval()
    batch_count = 0

    with torch.no_grad():
        for batch in test_loader:
            white_data = batch["white_data"].to(device)
            mask = generate_mask(white_data).unsqueeze(1).to(device)
            x = white_data.unsqueeze(1)

            # 标准化处理
            x_std, _, _ = standardize_batch(x.squeeze(1), x.squeeze(1), mask.squeeze(1))
            x_std = x_std.unsqueeze(1).to(device)

            # 去噪输出
            output = model(x_std, mask).cpu()

            # 提取三类标签并组合为 (B, 3)
            mc = torch.tensor(batch["mc_true"], dtype=torch.float32).view(-1, 1)
            phis = torch.tensor(batch["phis_true"], dtype=torch.float32).view(-1, 1)
            thetas = torch.tensor(batch["thetas_true"], dtype=torch.float32).view(-1, 1)
            targets = torch.cat([mc, phis, thetas], dim=1)

            # 保存文件
            file_path = os.path.join(output_dir, f"{base_name}_{batch_count:04d}.pt")
            torch.save({
                "denoised": output,
                "targets": targets,
                "target_names": target_names
            }, file_path)
            
            torch.cuda.empty_cache()  # 保存后立即释放缓存

            print(f"[Saved] {file_path} ← {output.shape[0]} samples")
            batch_count += 1

    print(f"[Done] 共保存 {batch_count} 个批次文件于: {output_dir}")


# 加载并评估最佳模型
denoising_model.load_state_dict(torch.load("best_denoising_model.pt"))
test_model_and_save_labels_batched(model=denoising_model, test_loader=test_loader, device=device, output_dir="pt_chunks")

In [None]:
# 仅反归一化弱信号部分
def selective_inverse_transform(signal, mask, stats, amplification=5.0):
    signal = signal.copy()
    weak_indices = np.where(mask == 1)[0]
    if len(weak_indices) > 0:
        # 先反标准化
        signal[weak_indices] = signal[weak_indices] * stats["iqr"] + stats["median"]
        # 再恢复原始比例
        signal[weak_indices] /= amplification
    return signal

# 可视化去噪效果（输入、预测、纯信号）
def visualize_denoising_subplots(model, test_loader, device, sample_index=0):
    model.eval()
    with torch.no_grad():
        batch = next(iter(test_loader))
        white_data = batch["white_data"].to(device)
        white_signal = batch["white_signal"].to(device)
        mask = generate_mask(white_data).unsqueeze(1).to(device)

        x = white_data.unsqueeze(1)
        y = white_signal.unsqueeze(1)

        # 标准化 + 获取统计量
        x_std, y_std, stats = standardize_batch(x.squeeze(1), y.squeeze(1), mask.squeeze(1))
        x_std = x_std.unsqueeze(1).to(device)
        y_std = y_std.unsqueeze(1).to(device)

        output = model(x_std, mask).squeeze(1).cpu().numpy()

    # 选择样本
    input_signal = x[sample_index].squeeze().cpu().numpy()
    denoised_signal = output[sample_index]
    clean_signal = y[sample_index].squeeze().cpu().numpy()
    signal_mask = mask[sample_index].squeeze().cpu().numpy()

    # 反归一化
    input_signal = selective_inverse_transform(input_signal, signal_mask, stats, amplification=1.0)
    denoised_signal = selective_inverse_transform(denoised_signal, signal_mask, stats, amplification=10.0)
    clean_signal = selective_inverse_transform(clean_signal, signal_mask, stats, amplification=10.0)

    # 绘图
    fig, axs = plt.subplots(3, 1, figsize=(15, 9), sharex=True)
    axs[0].plot(input_signal, color='orange')
    axs[0].set_title("Noisy Input")
    axs[1].plot(denoised_signal, color='green')
    axs[1].set_title("Denoised Output")
    axs[2].plot(clean_signal, color='blue')
    axs[2].set_title("Ground Truth Signal")

    for ax in axs:
        ax.set_ylim(-0.5, 0.5)
        ax.grid(True)
    plt.tight_layout()
    plt.show()

visualize_denoising_subplots(denoising_model, test_loader, device=device, sample_index=1)

## 预测

In [None]:
class LazyDenoisedDataset(Dataset):
    def __init__(self, chunk_folder, label_idx=0):
        self.chunk_paths = sorted([
            os.path.join(chunk_folder, f)
            for f in os.listdir(chunk_folder)
            if f.endswith(".pt")
        ])
        self.label_idx = label_idx
        self.index_map = []

        # 预先构建 (file_id, local_id) 映射表
        for file_idx, file_path in enumerate(self.chunk_paths):
            data = torch.load(file_path, map_location="cpu")
            count = data["denoised"].shape[0]
            self.index_map.extend([(file_idx, i) for i in range(count)])

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        file_idx, local_idx = self.index_map[idx]
        file_path = self.chunk_paths[file_idx]
        data = torch.load(file_path, map_location="cpu")

        signal = data["denoised"][local_idx].float()     # shape: (1, L)
        label = data["targets"][local_idx, self.label_idx].unsqueeze(0).float()  # shape: (1,)

        return {"signal": signal, "label": label}

In [None]:
def train_mc_model(model, train_loader, val_loader, optimizer, scaler, device, num_epochs=150):
    criterion = nn.MSELoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)

    model.to(device)
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch in train_loader:
            x = batch["signal"].to(device)  # (B, 1, L)
            y = batch["label"].to(device)   # (B, 1)

            # 归一化标签
            y_scaled = scaler.transform(y.cpu().numpy())
            y_scaled = torch.tensor(y_scaled, dtype=torch.float32).to(device)

            optimizer.zero_grad()
            pred = model(x)
            loss = criterion(pred, y_scaled)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * x.size(0)

        avg_train_loss = total_loss / len(train_loader.dataset)

        # ===== 验证 =====
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                x = batch["signal"].to(device)
                y = batch["label"].to(device)
                y_scaled = scaler.transform(y.cpu().numpy())
                y_scaled = torch.tensor(y_scaled, dtype=torch.float32).to(device)
                pred = model(x)
                loss = criterion(pred, y_scaled)
                val_loss += loss.item() * x.size(0)

        avg_val_loss = val_loss / len(val_loader.dataset)
        scheduler.step(avg_val_loss)

        print(f"[Epoch {epoch+1}] Train Loss: {avg_train_loss:.6f} | Val Loss: {avg_val_loss:.6f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), "best_model_Mc.pt")
            print(">> Best model saved.")

In [None]:
# 混合精度训练
def train_mc_model_amp(model, train_loader, val_loader, optimizer, scaler, device, num_epochs=150):
    model.to(device)
    loss_fn = nn.MSELoss()
    amp_scaler = GradScaler()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, factor=0.5)

    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0

        for batch in train_loader:
            x = batch["signal"].to(device, non_blocking=True)
            y = batch["label"].cpu().numpy()
            y_scaled = scaler.transform(y)
            y_scaled = torch.tensor(y_scaled, dtype=torch.float32).to(device)

            optimizer.zero_grad()

            with autocast():
                pred = model(x)
                loss = loss_fn(pred, y_scaled)

            amp_scaler.scale(loss).backward()
            amp_scaler.step(optimizer)
            amp_scaler.update()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        # === 验证阶段 ===
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                x = batch["signal"].to(device, non_blocking=True)
                y = batch["label"].cpu().numpy()
                y_scaled = scaler.transform(y)
                y_scaled = torch.tensor(y_scaled, dtype=torch.float32).to(device)

                with autocast():
                    pred = model(x)
                    loss = loss_fn(pred, y_scaled)

                val_loss += loss.item()

        val_loss /= len(val_loader)
        scheduler.step(val_loss)

        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_model_Mc.pt")

In [None]:
def evaluate_mc_model(model, test_loader, scaler, device):
    model.eval()
    preds, trues = [], []

    with torch.no_grad():
        for batch in test_loader:
            x = batch["signal"].to(device)
            y = batch["label"].cpu().numpy()
            pred = model(x).cpu().numpy()
            preds.append(pred)
            trues.append(y)

    preds = scaler.inverse_transform(np.vstack(preds)).flatten()
    trues = np.vstack(trues).flatten()

    # === 输出评估指标 ===
    mse = mean_squared_error(trues, preds)
    r2 = r2_score(trues, preds)
    print("Test MSE:", mse)
    print("Test R² :", r2)

    # === 散点图 ===
    plt.figure(figsize=(6, 6))
    plt.scatter(trues, preds, alpha=0.5)
    plt.plot([trues.min(), trues.max()], [trues.min(), trues.max()], 'r--')
    plt.xlabel("True Value")
    plt.ylabel("Predicted Value")
    plt.title("Prediction Scatter Plot")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # === 误差直方图 ===
    errors = preds - trues
    plt.figure(figsize=(6, 4))
    plt.hist(errors, bins=50, color='steelblue', edgecolor='black', alpha=0.7)
    plt.axvline(x=0, color='red', linestyle='--', label='Zero Error')
    plt.xlabel("Prediction Error")
    plt.ylabel("Frequency")
    plt.title("Prediction Error Histogram")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # === Violin Plot ===
    data = pd.DataFrame({
        "Type": ["True"] * len(trues) + ["Predicted"] * len(preds),
        "Value": np.concatenate([trues, preds])
    })

    plt.figure(figsize=(6, 5))
    sns.violinplot(x="Type", y="Value", data=data, inner="quartile", palette="muted")
    plt.title("Distribution of True vs Predicted Values")
    plt.grid(True, axis='y')
    plt.tight_layout()
    plt.show()

In [None]:
# ---------- 参数配置 ----------
chunk_dir = "pt_chunks"         # 分批保存的 .pt 文件路径
label_idx = 0                   # 0=mc_true, 1=phis_true, 2=thetas_true
batch_size = 16
model_save_path = "best_model_Mc.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- 构建数据集 ----------
dataset = LazyDenoisedDataset(chunk_dir, label_idx=label_idx)
print(f"[INFO] 样本总数: {len(dataset)}")

# ---------- 提取所有标签进行归一化 ----------
all_labels = torch.stack([dataset[i]["label"] for i in range(len(dataset))])
label_scaler = StandardScaler()
label_scaler.fit(all_labels.numpy())

# ---------- 数据划分 ----------
n = len(dataset)
train_len = int(n * 0.6)
val_len = int(n * 0.2)
test_len = n - train_len - val_len
train_set, val_set, test_set = random_split(dataset, [train_len, val_len, test_len])

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_set, batch_size=batch_size)
test_loader  = DataLoader(test_set, batch_size=batch_size)

# ---------- 初始化模型 ----------
model = ResNetModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

# ---------- 训练 ----------
train_mc_model(model, train_loader, val_loader, optimizer, label_scaler, device)
# train_mc_model_amp(pred_model, train_loader, val_loader, optimizer, label_scaler, device)

# ---------- 测试 ----------
model.load_state_dict(torch.load(model_save_path))
evaluate_mc_model(model, test_loader, label_scaler, device)