In [None]:
# === 禁用 Innov2（防误用）===
class _Innov2Disabled:
    def __init__(self, *args, **kwargs):
        raise NotImplementedError("Innov2 已禁用，请仅使用 Baseline 或 Innovation1")

Innovation2FaultDetector = _Innov2Disabled
FusionClassifier = _Innov2Disabled
TransformerFusionClassifier = _Innov2Disabled
print('[Info] Innov2 已禁用：Innovation2FaultDetector / FusionClassifier / TransformerFusionClassifier')



# 变电站故障检测
- 支持 Baseline / 创新点1 / 创新点2（先以轻量模式稳定跑通）  
- 端到端：数据加载 → 训练 → 推理 → 评估  
- 已内置 **11 个正确标签**：
`'bj_bpmh', 'bj_bpps', 'bj_wkps', 'bjdsye', 'jyz_pl', 'sly_dmyw', 'hxg_gjbs', 'hxq_gjtps', 'xmbhyc', 'yw_gkxfw', 'yw_nc'`


In [14]:
import os, math, time, json, random
from pathlib import Path
from typing import List, Tuple, Dict, Optional

import numpy as np
from PIL import Image
import xml.etree.ElementTree as ET

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset

from torchvision import transforms, models

CUDA = torch.cuda.is_available()
DEVICE = torch.device('cuda' if CUDA else 'cpu')
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if CUDA: torch.cuda.manual_seed_all(SEED)

SAFE_MAX_PIXELS = 384 * 384   # 安全调整size
USE_FP16 = True               # 启用amp
TIMESTEPS_SAFE = [200]        # 开启稳定single-timestep

CLASS_NAMES = [
    'bj_bpmh', 'bj_bpps', 'bj_wkps', 'bjdsyc', 'jyz_pl', 'sly_dmyw',
    'hxq_gjbs', 'hxq_gjtps', 'xmbhyc', 'yw_gkxfw', 'yw_nc'
]
print('Device:', DEVICE, '| Classes:', len(CLASS_NAMES))

Device: cuda | Classes: 11


In [15]:
#功能：将 4D 图像张量（B,C,H,W）安全缩放到不超过 SAFE_MAX_PIXELS 的像素上限，避免显存暴涨。

def safe_resize_bchw(x: torch.Tensor) -> torch.Tensor:
    B,C,H,W = x.shape
    if H*W <= SAFE_MAX_PIXELS:
        return x
    scale = (SAFE_MAX_PIXELS / float(H*W)) ** 0.5
    newH, newW = max(64, int(H*scale)), max(64, int(W*scale))
    return F.interpolate(x, size=(newH,newW), mode='bilinear', align_corners=False)
    

In [16]:
#功能：解析单个 VOC 风格 XML，提取类别名

def parse_xml(xml_path: str) -> Optional[str]:
    try:
        root = ET.parse(xml_path).getroot()
        for obj in root.findall('object'):
            name = obj.find('name').text.strip()
            return name
    except Exception:
        return None
    return None

#功能：自定义数据集类，从指定目录加载图像-标签对
class SubstationDataset(Dataset):
    def __init__(self, images_dir: str, annos_dir: str, transform=None,
                 class_names: List[str] = None):
        self.images_dir = Path(images_dir)
        self.annos_dir = Path(annos_dir)
        self.transform = transform
        self.class_names = class_names or list(CLASS_NAMES)
        self.class_to_idx = {c:i for i,c in enumerate(self.class_names)}
        self.data_pairs = []
        exts = {'.jpg','.jpeg','.png','.bmp','.tif','.tiff'}
        annos = [p for p in self.annos_dir.iterdir() if p.suffix.lower()=='.xml']
        for ap in annos:
            label = parse_xml(str(ap))
            if label is None: 
                continue
            if label not in self.class_to_idx:
                continue
            img_stem = ap.stem
            img_candidate = None
            for ext in exts:
                ip = self.images_dir / f"{img_stem}{ext}"
                if ip.exists():
                    img_candidate = ip; break
            if img_candidate is None:
                for p in self.images_dir.iterdir():
                    if p.suffix.lower() in exts and p.stem == img_stem:
                        img_candidate = p; break
            if img_candidate is not None:
                self.data_pairs.append( (str(img_candidate), self.class_to_idx[label]) )
        if len(self.data_pairs)==0:
            print("[WARN] No matched image-xml pairs found.")

    def __len__(self): return len(self.data_pairs)

    def __getitem__(self, idx):
        ip, y = self.data_pairs[idx]
        img = Image.open(ip).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, y

#功能：返回训练/验证两套 torchvision.transforms 变换
#训练：随机水平/垂直翻转、颜色抖动、缩放、ToTensor
#验证：缩放、ToTensor
'''def get_transforms(img_size=256):
    train_tf = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ColorJitter(0.2,0.2,0.2,0.1),
        transforms.ToTensor()
    ])
    val_tf = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor()
    ])

    return train_tf, val_tf'''

    # === 更强数据增强（减轻过拟合）===
from torchvision.transforms import InterpolationMode

IMG_MEAN = [0.485, 0.456, 0.406]
IMG_STD  = [0.229, 0.224, 0.225]

def get_transforms(img_size=256):
    train_tf = transforms.Compose([
        transforms.RandomResizedCrop(img_size, scale=(0.6, 1.0), interpolation=InterpolationMode.BICUBIC),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomApply([transforms.ColorJitter(0.3,0.3,0.3,0.15)], p=0.8),
        transforms.RandomApply([transforms.GaussianBlur(3, sigma=(0.1, 1.5))], p=0.3),
        transforms.RandomGrayscale(p=0.1),
        transforms.ToTensor(),
        transforms.Normalize(IMG_MEAN, IMG_STD)
    ])
    val_tf = transforms.Compose([
        transforms.Resize(int(img_size*1.14), interpolation=InterpolationMode.BICUBIC),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(IMG_MEAN, IMG_STD)
    ])
    return train_tf, val_tf

print('[Info] 已启用更强数据增强与标准化')

[Info] 已启用更强数据增强与标准化


In [None]:
#新增：无LabelSmoothing + Mixup + TTA

# === 掩膜视图：在已有 train_tf 基础上叠加随机块遮挡 + RandomErasing ===
class MaskedViewTransform:
    def __init__(self, base_tfm, use_block_mask=True, use_random_erasing=True, max_blocks=2, erasing_p=0.25):
        self.base_tfm = base_tfm
        self.use_block_mask = use_block_mask
        self.use_random_erasing = use_random_erasing
        self.max_blocks = max_blocks
        from torchvision import transforms
        self.random_erasing = transforms.RandomErasing(p=erasing_p, value=0)

    @staticmethod
    def _block_mask(img_tensor, max_blocks=2):
        # img_tensor: Tensor [C, H, W]
        import random
        c, h, w = img_tensor.shape
        blocks = random.randint(1, max_blocks)
        for _ in range(blocks):
            rh = int(h * random.uniform(0.2, 0.35))
            rw = int(w * random.uniform(0.2, 0.35))
            y0 = random.randint(0, max(0, h - rh))
            x0 = random.randint(0, max(0, w - rw))
            img_tensor[:, y0:y0+rh, x0:x0+rw] = 0.0
        return img_tensor

    def __call__(self, pil_img):
        x = self.base_tfm(pil_img)
        if self.use_block_mask:
            x = self._block_mask(x, self.max_blocks)
        if self.use_random_erasing:
            x = self.random_erasing(x)
        return x






class SubstationDatasetFromPairsMV(Dataset):
    """训练：返回双视图 (view1, view2, label)"""
    def __init__(self, pairs, tfm_view1, tfm_view2):
        self.pairs = pairs
        self.tfm1 = tfm_view1
        self.tfm2 = tfm_view2
    def __len__(self): return len(self.pairs)
    def __getitem__(self, i):
        ip, y = self.pairs[i]
        from PIL import Image
        img = Image.open(ip).convert('RGB')
        x1 = self.tfm1(img)
        x2 = self.tfm2(img)
        return x1, x2, y



In [29]:
#功能：扩散过程类，包含 q_sample 和 p_sample 方法
#定义离散扩散过程的时间表与前向/反向采样公式（DDPM 公式族）
class DiffusionProcess(nn.Module):
    def __init__(self, num_timesteps=1000, beta_start=1e-4, beta_end=0.02):
        super().__init__()
        betas = torch.linspace(beta_start, beta_end, num_timesteps, dtype=torch.float32)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1,0), value=1.0)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

        self.num_timesteps = num_timesteps
        self.register_buffer('betas', betas)
        self.register_buffer('alphas', alphas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('posterior_variance', posterior_variance)

# 功能：前向采样，生成带噪声的图像(前向加噪)
# 输入：原始图像 x_start, 时间步 t, 可选噪声 noise
# 输出：带噪声的图像 x_noisy
# 实现：根据时间步 t 的 alpha 和 beta 计算加噪强度，生成带噪声的图像
# 关键点：用 sqrt(ᾱ_t)、sqrt(1-ᾱ_t) 线性混合 x_start 与 noise
    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        acp = self.alphas_cumprod.to(t.device)
        sqrt_acp = torch.sqrt(acp.gather(0, t)).view(-1,1,1,1)
        sqrt_om = torch.sqrt((1.0 - acp).gather(0, t)).view(-1,1,1,1)
        return sqrt_acp * x_start + sqrt_om * noise

# 功能：反向采样，从带噪声的图像恢复原始图像(反向去噪,调用 U-Net 预测噪声,支持类条件传递)
# 输入：U-Net 模型, 带噪声的图像 x, 时间步 t, 时间步索引 t_index, 可选类条件标签 class_labels
# 输出：恢复的原始图像
# 实现：根据时间步 t 的 alpha 和 beta 计算去噪强度，结合 U-Net 预测的噪声，恢复原始图像
# 关键点：用 sqrt(ᾱ_t)、sqrt(1-ᾱ_t) 线性混合 x 与 U-Net 预测的噪声(类条件会被透传到去噪器)
    def p_sample(self, model, x, t, t_index, class_labels=None):
        betas_t = self.betas.to(t.device).gather(0, t).view(-1,1,1,1)
        acp_t = self.alphas_cumprod.to(t.device).gather(0, t)
        alphas_t = self.alphas.to(t.device).gather(0, t)
        sqrt_one_minus = torch.sqrt(1.0 - acp_t).view(-1,1,1,1)
        sqrt_recip = torch.sqrt(1.0 / alphas_t).view(-1,1,1,1)
        pred_noise = model(x, t, class_labels=class_labels)
        model_mean = sqrt_recip * (x - betas_t * pred_noise / sqrt_one_minus)
        if t_index == 0:
            return model_mean
        post_var_t = self.posterior_variance.to(t.device).gather(0, t).view(-1,1,1,1)
        return model_mean + torch.sqrt(post_var_t) * torch.randn_like(x)

In [30]:
#功能：U-Net 的残差块，包含时间条件、类条件传递、短连接等.
#输入：输入特征 x, 时间嵌入 t_emb, 可选类条件 cond
#输出：输出特征 h,(与 out_ch 匹配的特征图)
#实现：
# 1) 两个 3x3 卷积 + 时间条件 + 类条件传递
# 2) 短连接：如果输入输出通道数相同，则直接相加；否则用 1x1 卷积调整通道数
# 3) 使用 SiLU 激活函数，可选 Dropout 正则化
class ResidBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim, cond_dim=None, p=0.0):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.time_fc = nn.Linear(time_dim, out_ch)
        self.act = nn.SiLU()
        self.drop = nn.Dropout(p)
        self.short = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
        self.use_film = cond_dim is not None
        if self.use_film:
            self.gamma = nn.Linear(cond_dim, out_ch)
            self.beta  = nn.Linear(cond_dim, out_ch)

    def forward(self, x, t_emb, cond=None):
        h = self.conv1(x)
        h = h + self.time_fc(t_emb)[:, :, None, None]
        if self.use_film:
            gamma = self.gamma(cond)[:, :, None, None]
            beta  = self.beta(cond)[:, :, None, None]
            h = h * (1 + gamma) + beta
        h = self.act(h)
        h = self.drop(h)
        h = self.conv2(h)
        return h + self.short(x)

#功能：增强 U-Net 模型,扩散去噪器（U-Net），支持时间条件、类条件传递、短连接等.
#输入：输入特征 x, 时间步 timestep, 可选类条件 class_labels
#输出：预测的噪声(与输入图像同尺寸的 3 通道张量)
#实现：
# 1) 时间条件：将时间步转换为嵌入，与特征图拼接
# 2) 类条件：如果提供，则转换为可学习的提示，与特征图拼接
# 3) 残差块：包含两个 3x3 卷积 + 时间条件 + 类条件传递 + 短连接
# 4) 短连接：如果输入输出通道数相同，则直接相加；否则用 1x1 卷积调整通道数
# 5) 使用 SiLU 激活函数，可选 Dropout 正则化
# 6) 使用 avg_pool2d 下采样，使用 F.interpolate 上采样
# 7) 使用 torch.cat 拼接特征图，实现跳跃连接
# 8) 使用 nn.Conv2d 输出预测的噪声
class EnhancedUNet(nn.Module):
    def __init__(self, num_classes: int, base_ch=64, use_film=True, use_attention=False, width_mult=1.0):
        super().__init__()
        ch = int(base_ch * width_mult)
        self.use_film = use_film
        self.use_attention = use_attention
        time_dim = ch * 4
        cond_dim = ch * 2 if use_film else None

        self.time_mlp = nn.Sequential(
            nn.Linear(1, ch), nn.SiLU(),
            nn.Linear(ch, time_dim)
        )
        self.learnable_prompts = nn.Embedding(num_classes, cond_dim if use_film else 1)

        # encoder
        self.enc1 = ResidBlock(3,      ch,   time_dim, cond_dim)
        self.enc2 = ResidBlock(ch,     ch*2, time_dim, cond_dim)
        self.enc3 = ResidBlock(ch*2,   ch*4, time_dim, cond_dim)
        self.mid  = ResidBlock(ch*4,   ch*4, time_dim, cond_dim)

        # decoder (concat skips)
        self.dec3 = ResidBlock(ch*4 + ch*2, ch*2, time_dim, cond_dim)  # 6*ch -> 2*ch
        self.dec2 = ResidBlock(ch*2 + ch,   ch,   time_dim, cond_dim)  # 3*ch -> 1*ch
        self.dec1 = nn.Conv2d(ch, 3, 3, padding=1)

    def forward(self, x, timestep, class_labels=None):
        if class_labels is None:
            class_labels = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        t = timestep.view(-1, 1).float()
        t_emb = self.time_mlp(t)
        cond = self.learnable_prompts(class_labels) if self.use_film else None

        # encoder
        e1 = self.enc1(x,               t_emb, cond)        # [B, ch,   H,   W]
        e2 = self.enc2(F.avg_pool2d(e1, 2), t_emb, cond)    # [B, 2ch,  H/2, W/2]
        e3 = self.enc3(F.avg_pool2d(e2, 2), t_emb, cond)    # [B, 4ch,  H/4, W/4]
        m  = self.mid(e3,               t_emb, cond)        # [B, 4ch,  H/4, W/4]

        # decoder with concat skips
        up_m = F.interpolate(m, scale_factor=2, mode='nearest')        # [B, 4ch, H/2, W/2]
        #d3_in = torch.cat([up_m, e2], dim=1)                           # [B, 6ch, H/2, W/2]
        d3 = self.dec3(torch.cat([up_m, e2], dim=1), t_emb, cond)                             # [B, 2ch, H/2, W/2]

        up_d3 = F.interpolate(d3, scale_factor=2, mode='nearest')      # [B, 2ch, H, W]
        #d2_in = torch.cat([up_d3, e1], dim=1)                          # [B, 3ch, H, W]
        d2 = self.dec2(torch.cat([up_d3, e1], dim=1), t_emb, cond)                              # [B, ch,  H, W]

        out = self.dec1(d2)                                            # [B, 3,   H, W]
        return out
  # predicted noise

In [31]:
#残差提取与多尺度
#功能：多尺度残差提取器，支持多尺度残差提取与扩散去噪.利用扩散反推把输入图像“还原”到接近正常，再与原图取绝对差，得到残差；可按多时间步提取。
#输入：扩散去噪器, 时间步列表
#输出：多尺度残差列表
#实现：
# 1) 初始化：保存扩散去噪器与时间步列表
# 2) 前向：对每个时间步，调用扩散去噪器进行反向采样，得到残差
# 3) 返回：所有时间步的残差列表
# 4) 使用 torch.no_grad() 装饰器，避免梯度计算
class MultiScaleResidualExtractor(nn.Module):
    def __init__(self, diffusion_model: EnhancedUNet, timesteps: List[int]):
        super().__init__()
        self.diffusion_model = diffusion_model
        self.timesteps = timesteps
        self.diffusion_process = DiffusionProcess()

#功能：提取单个时间步的残差,从给定时间步 t_value 反推到 0，得到还原图 x̂ 并计算 |x - x̂|
#输入：输入图像 x, 标签 labels, 时间步 t_value, 步数 steps
#输出：残差
#关键点：每步都携带 class_labels（innov1 的类条件反推）
#实现：
# 1) 初始化：保存扩散去噪器与时间步列表
# 2) 前向：对每个时间步，调用扩散去噪器进行反向采样，得到残差
# 3) 返回：所有时间步的残差列表
# 被谁调用：验证/推理，多尺度时会被多次调用后堆叠。
    @torch.no_grad()
    def extract_one_timestep(self, x, labels, t_value: int, steps: int = 8):
        dev = next(self.parameters()).device
        x = x.to(dev, non_blocking=True)
        x_t = torch.randn_like(x)
        stride = max(1, int(t_value // max(1, steps)))
        for i in range(t_value, -1, -stride):
            tt = torch.full((x.size(0),), i, device=dev, dtype=torch.long)
            x_t = self.diffusion_process.p_sample(self.diffusion_model, x_t, tt, i, class_labels=labels.to(dev))
        residual = (x - x_t).abs()
        return residual

In [32]:
#分类器与多尺度融合
#功能：残差分类器，将多尺度残差堆叠后，用 ResNet50 提取特征，最后用全连接层分类.对单张残差图做分类。
#输入：多尺度残差列表
#输出：分类结果
#关键点：ResNet50(IMAGENET1K_V2) 作为骨干，只替换最后全连接层输出通道数
#实现：
# 1) 初始化：加载 ResNet50 模型，替换最后一层全连接为 num_classes 输出
# 2) 前向：对每个多尺度残差，用 ResNet50 提取特征，最后用全连接层分类
# 3) 返回：分类结果
# 被谁调用：验证/推理，多尺度时会被多次调用后堆叠。
'''class ResidualClassifier(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        in_feat = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(in_feat, num_classes)
    def forward(self, x):
        return self.backbone(x)'''

# === 覆盖 ResidualClassifier：更强正则（Dropout/LabelSmoothing 支持）===
class ResidualClassifier(nn.Module):
    def __init__(self, num_classes: int, p_drop=0.4, label_smoothing=0.1):
        super().__init__()
        self.backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        in_feat = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(
            nn.Dropout(p_drop),
            nn.Linear(in_feat, num_classes)
        )
        self.label_smoothing = label_smoothing

    def forward(self, x):
        return self.backbone(x)

print('[Info] 已覆盖 ResidualClassifier：加入 Dropout 与 LabelSmoothing 支持（由 Trainer 使用）')



#功能：多尺度融合分类器，将多尺度残差堆叠后，用 2 层 3x3 卷积 + 平均池化 + 全连接层分类.
#输入：残差栈 (B,T,C,H,W)
#关键点：先把 (T,C,H,W) 当“伪通道”或展平做轻量卷积/池化，再接全连接层。
#被谁调用：Innov2（多时间步） 
class FusionClassifier(nn.Module):
    def __init__(self, num_classes: int, num_timesteps: int = 1, in_channels=3):
        super().__init__()
        self.num_timesteps = num_timesteps
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU()
        )
        self.pool = nn.AdaptiveAvgPool2d((8,8))
        self.fc = nn.Sequential(
            nn.Linear(64*8*8, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)
        )
    def forward(self, x): # [B,T,C,H,W]
        if x.dim()==4:
            x = x.unsqueeze(1)
        B,T,C,H,W = x.shape
        x = x.mean(dim=1)  # avg over time
        h = self.conv(x)
        h = self.pool(h).view(B,-1)
        return self.fc(h)

# 功能：Transformer 融合分类器，将多尺度残差堆叠后（把残差栈先做网格池化到固定维度，投影为序列），用 Transformer 提取特征，最后用全连接层分类.
#把 (B,T,C,H,W) 先做网格池化到固定维度，投影为序列，再用 TransformerEncoder 做时序/尺度建模。
#残差栈 (B,T,C,H,W) → 网格池化 → 序列 (B,T,C*g*g) → TransformerEncoder → 平均池化 (B,D) → 全连接 (B,num_classes)
#关键点：每一帧残差提取一个 token，Transformer 编码，做时序平均池化再分类。
#被谁调用：Innov2（多时间步 + Transformer 融合）
class TransformerFusionClassifier(nn.Module):
    def __init__(self, num_classes=11, num_timesteps=1, in_channels=3, grid=8,
                 proj_dim=512, num_layers=2):
        super().__init__()
        self.num_timesteps = num_timesteps
        self.grid = grid
        self.proj = nn.Linear(in_channels * grid * grid, proj_dim)
        enc_layer = nn.TransformerEncoderLayer(
            d_model=proj_dim, nhead=8, dim_feedforward=1024, dropout=0.1, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
        self.fc = nn.Sequential(
            nn.Linear(proj_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)
        )
    def forward(self, x):  # [B,T,C,H,W]
        if x.dim()==4:
            x = x.unsqueeze(1)
        B,T,C,H,W = x.shape
        feats = []
        for i in range(T):
            f = F.adaptive_avg_pool2d(x[:,i], (self.grid, self.grid))  # [B,C,g,g]
            f = f.view(B, -1)                                          # [B, C*g*g]
            f = self.proj(f)                                           # [B, D]
            feats.append(f)
        seq = torch.stack(feats, dim=1)                                # [B,T,D]
        enc = self.encoder(seq).mean(dim=1)                            # [B,D]
        return self.fc(enc)

[Info] 已覆盖 ResidualClassifier：加入 Dropout 与 LabelSmoothing 支持（由 Trainer 使用）


In [33]:
# 端到端模型封装
# 功能：封装了创新1/2的模型，包括扩散去噪器、多尺度残差提取器、残差分类器、Transformer 融合分类器。
# 基线总成：DiffusionProcess + EnhancedUNet(use_film=False) + ResidualClassifier
# 输入：图像 x, 标签 labels, 模式 mode
# 输出：预测的噪声, 噪声, 时间步, 标签
# 关键点：训练时残差支路不回传梯度（通过 no_grad 或单独函数），扩散器只用 MSE 学“正常先验”，分类器用 CE 学“看残差分类”


class BaselineFaultDetector(nn.Module):
    def __init__(self, num_classes=11, img_size=256):
        super().__init__()
        self.diffusion = DiffusionProcess()
        self.unet = EnhancedUNet(num_classes=num_classes, use_film=False, width_mult=0.75)
        self.classifier = ResidualClassifier(num_classes)

    def forward(self, x, labels=None, mode='train'):
        if mode == 'train':
            b = x.shape[0]
            t = torch.randint(0, self.diffusion.num_timesteps, (b,), device=x.device).long()
            noise = torch.randn_like(x)
            x_noisy = self.diffusion.q_sample(x, t, noise)
            pred_noise = self.unet(x_noisy, t, class_labels=None)
            return pred_noise, noise, t, None
        else:
            with torch.inference_mode():
                res = MultiScaleResidualExtractor(self.unet, TIMESTEPS_SAFE).extract_one_timestep(
                    x, labels if labels is not None else torch.zeros(x.size(0), dtype=torch.long, device=x.device),
                    t_value=200, steps=8
                )
                logits = self.classifier(res)
                return res.unsqueeze(1), logits

# 创新1总成：DiffusionProcess + EnhancedUNet(use_film=True) + ResidualClassifier
# U-Net 开启 use_film=True，并在反推/训练都传入 class_labels
# 关键点：FiLM 条件：nn.Embedding(num_classes, cond_dim) 生成类条件向量，注入各残差块；
# 反推时也传 labels 给 p_sample，确保“按该类正常形态”还原；
# 保持损失组合 MSE + 0.5×CE 与解耦训练不变。

class Innovation1FaultDetector(nn.Module):
    def __init__(self, num_classes=11):
        super().__init__()
        self.diffusion_process = DiffusionProcess()
        self.diffusion_model = EnhancedUNet(num_classes=num_classes, use_film=True, width_mult=0.75)
        self.classifier = ResidualClassifier(num_classes)

    @torch.no_grad()
    def reconstruct_residual_fast(self, x, labels, t_start=200, steps=10):
        dev = next(self.parameters()).device
        x = x.to(dev)
        x_t = torch.randn_like(x)
        stride = max(1, int(t_start // max(1, steps)))
        for i in range(t_start, -1, -stride):
            tt = torch.full((x.size(0),), i, device=dev, dtype=torch.long)
            x_t = self.diffusion_process.p_sample(self.diffusion_model, x_t, tt, i, class_labels=labels.to(dev))
        return (x - x_t).abs()

    def forward(self, x, labels=None, mode='train'):
        if labels is None:
            labels = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        if mode == 'train':
            b = x.shape[0]
            t = torch.randint(0, self.diffusion_process.num_timesteps, (b,), device=x.device).long()
            noise = torch.randn_like(x)
            x_noisy = self.diffusion_process.q_sample(x, t, noise)
            pred_noise = self.diffusion_model(x_noisy, t, class_labels=labels)
            return pred_noise, noise, t, None
        else:
            with torch.inference_mode():
                res = self.reconstruct_residual_fast(x, labels, t_start=200, steps=8)
                logits = self.classifier(res)
                return res.unsqueeze(1), logits

# 创新2总成：DiffusionProcess + EnhancedUNet(use_film=True) + MultiScaleResidualExtractor + FusionClassifier/TransformerFusionClassifier
class Innovation2FaultDetector(nn.Module):
    def __init__(self, num_classes=11, use_transformer=True, timesteps=[200], lightweight=True):
        super().__init__()
        self.num_classes = num_classes
        self.timesteps = timesteps
        self.lightweight = lightweight

        self.diffusion_process = DiffusionProcess()
        self.diffusion_model = EnhancedUNet(num_classes=num_classes, use_film=True,
                                            use_attention=(False if lightweight else True),
                                            width_mult=(0.75 if lightweight else 1.0))
        self.residual_extractor = MultiScaleResidualExtractor(self.diffusion_model, self.timesteps)
        if lightweight or not use_transformer:
            self.fusion_classifier = FusionClassifier(num_classes=num_classes, num_timesteps=len(self.timesteps))
        else:
            self.fusion_classifier = TransformerFusionClassifier(num_classes=num_classes, num_timesteps=len(self.timesteps))

    @torch.no_grad()
    def reconstruct_residual_fast(self, x, labels, t_start=200, steps=8):
        dev = next(self.parameters()).device
        x = x.to(dev)
        x_t = torch.randn_like(x)
        stride = max(1, int(t_start // max(1, steps)))
        for i in range(t_start, -1, -stride):
            tt = torch.full((x.size(0),), i, device=dev, dtype=torch.long)
            x_t = self.diffusion_process.p_sample(self.diffusion_model, x_t, tt, i, class_labels=labels.to(dev))
        return (x - x_t).abs()

    def forward(self, x, labels=None, mode='train'):
        if labels is None:
            labels = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        if mode == 'train':
            b = x.shape[0]
            t = torch.randint(0, self.diffusion_process.num_timesteps, (b,), device=x.device).long()
            noise = torch.randn_like(x)
            masked_x = x
            x_noisy = self.diffusion_process.q_sample(masked_x, t, noise)
            pred_noise = self.diffusion_model(x_noisy, t, class_labels=labels)
            return pred_noise, noise, t, None
        else:
            self.eval()
            dev = next(self.parameters()).device
            timesteps = self.timesteps if (self.lightweight is False) else TIMESTEPS_SAFE
            with torch.inference_mode():
                residual_list = []
                for t in timesteps:
                    res_t = self.reconstruct_residual_fast(x, labels, t_start=int(t), steps=8)  # [B,C,H,W]
                    residual_list.append(res_t)
                res_stack = torch.stack(residual_list, dim=1)  # [B,T,C,H,W]
                logits = self.fusion_classifier(res_stack.to(next(self.fusion_classifier.parameters()).device))
                return res_stack, logits

In [34]:
from torch.amp import autocast, GradScaler

AMP_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


# 训练器
# 功能：训练模型，包括前向、损失计算、反向传播、梯度裁剪、优化器更新等；封装优化器、损失、AMP、单轮训练与验证逻辑、保存最优模型。
# 关键点：使用 GradScaler 处理混合精度训练，使用交叉熵损失计算分类误差，使用 MSE 损失计算扩散损失。
# 优化器：AdamW；混合精度：GradScaler；
# 损失：mse = MSE(pred_noise, noise)；ce = CE(logits, y)；总损失 loss = mse + 0.5*ce；
# 保存：当 val_acc 创新高就 torch.save
class Trainer:
    def __init__(self, model: nn.Module, lr=2e-4, weight_decay=5e-4):
        self.model = model.to(DEVICE)
        self.optim = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        #self.scaler = torch.cuda.amp.GradScaler(enabled=USE_FP16 and torch.cuda.is_available())
        self.scaler = GradScaler(AMP_DEVICE, enabled=USE_FP16 and torch.cuda.is_available())
        self.ce = nn.CrossEntropyLoss()
        self.mse = nn.MSELoss()


#对每个 batch：采样 t 加噪 → U-Net 预测噪声 → MSE；无梯度重构残差 → 分类器 → CE；loss = MSE + 0.5×CE 反传更新；统计训练准确率。
# 可选：用 train_eval_loader 在自然分布上再测一次 train_acc。
    def train_epoch(self, loader):
        self.model.train()
        total, correct, loss_sum = 0, 0, 0.0
        for imgs, ys in loader:
            imgs = imgs.to(DEVICE); ys = ys.to(DEVICE)
            self.optim.zero_grad(set_to_none=True)
            #with torch.cuda.amp.autocast(enabled=USE_FP16 and torch.cuda.is_available()):
            with autocast(AMP_DEVICE, enabled=USE_FP16 and torch.cuda.is_available()):
                out = self.model(imgs, labels=ys, mode='train')
                if isinstance(out, tuple):
                    pred_noise, noise, t, _ = out
                    diffusion_loss = self.mse(pred_noise, noise)
                else:
                    diffusion_loss = torch.tensor(0., device=DEVICE)

                with torch.no_grad():
                    if hasattr(self.model, 'reconstruct_residual_fast'):
                        res = self.model.reconstruct_residual_fast(imgs, ys, t_start=200, steps=8)
                    else:
                        res = MultiScaleResidualExtractor(self.model.unet, TIMESTEPS_SAFE).extract_one_timestep(imgs, ys, t_value=200, steps=8)
                if hasattr(self.model, 'fusion_classifier'):
                    logits = self.model.fusion_classifier(res.unsqueeze(1).to(next(self.model.fusion_classifier.parameters()).device))
                else:
                    logits = self.model.classifier(res)
                ce = self.ce(logits, ys)
                loss = diffusion_loss + 0.5 * ce

            self.scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.scaler.step(self.optim)
            self.scaler.update()

            with torch.no_grad():
                pred = logits.argmax(1)
                correct += (pred == ys).sum().item()
                total += ys.size(0)
                loss_sum += loss.item() * ys.size(0)

        return loss_sum/total, correct/total

# 验证器:验证评估。走推理路径（反推→残差→分类），统计 val_loss/val_ac
# 关键点：不做反传；与训练的残差支路保持一致的生成逻辑；遇到更高 acc 即保存 best
    @torch.no_grad()
    def validate(self, loader):
        self.model.eval()
        total, correct, loss_sum = 0, 0, 0.0
        for imgs, ys in loader:
            imgs = imgs.to(DEVICE); ys = ys.to(DEVICE)
            res_stack, logits = self.model(imgs, labels=ys, mode='eval')
            ce = self.ce(logits, ys)
            pred = logits.argmax(1)
            correct += (pred==ys).sum().item()
            total += ys.size(0)
            loss_sum += ce.item() * ys.size(0)
        return loss_sum/total, correct/total

In [35]:
from sklearn.model_selection import train_test_split

IMAGES_DIR = '/mnt/e/code/project/Dataset-total/images'   # 修改为你的路径
ANNOTATIONS_DIR = '/mnt/e/code/project/Dataset-total/annotations/xmls'  # 修改为你的路径

img_size = 256
val_ratio = 0.2
bs = 8
workers = 4

'''train_tf, val_tf = get_transforms(img_size)
base = SubstationDataset(IMAGES_DIR, ANNOTATIONS_DIR, transform=None, class_names=list(CLASS_NAMES))
indices = np.arange(len(base))
if len(indices)==0:
    raise RuntimeError('No data found. Please check paths.')
labels_np = np.array([ base.data_pairs[i][1] for i in indices ])

train_idx, val_idx = train_test_split(indices, test_size=val_ratio, random_state=SEED, stratify=labels_np)

train_ds = SubstationDataset(IMAGES_DIR, ANNOTATIONS_DIR, transform=train_tf, class_names=list(CLASS_NAMES))
val_ds   = SubstationDataset(IMAGES_DIR, ANNOTATIONS_DIR, transform=val_tf,   class_names=list(CLASS_NAMES))
train_ds = Subset(train_ds, train_idx.tolist())
val_ds   = Subset(val_ds,   val_idx.tolist())

train_loader = DataLoader(train_ds, batch_size=bs, shuffle=True, num_workers=workers, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=max(1,bs//2), shuffle=False, num_workers=workers, pin_memory=True)

len(train_ds), len(val_ds)'''

from sklearn.model_selection import train_test_split

# 1) 只构建一次数据集，拿到稳定的 pairs
_base = SubstationDataset(IMAGES_DIR, ANNOTATIONS_DIR, transform=None, class_names=list(CLASS_NAMES))
pairs = _base.data_pairs  # [(img_path, class_idx), ...]

if len(pairs) == 0:
    raise RuntimeError('No data found. Check IMAGES_DIR / ANNOTATIONS_DIR.')

# 2) 按同一份 pairs 做可复现的划分
indices = np.arange(len(pairs))
labels_np = np.array([ y for _, y in pairs ])
train_idx, val_idx = train_test_split(indices, test_size=val_ratio, random_state=SEED, stratify=labels_np)

# 3) 用“同一清单 + 不同 transform”构建 Dataset 封装器，避免重新枚举文件
'''class SubstationDatasetFromPairs(Dataset):
    def __init__(self, pairs, transform=None):
        self.pairs = pairs
        self.transform = transform
    def __len__(self): return len(self.pairs)
    def __getitem__(self, i):
        ip, y = self.pairs[i]
        img = Image.open(ip).convert('RGB')
        if self.transform: img = self.transform(img)
        return img, y

train_tf, val_tf = get_transforms(img_size)
train_ds = SubstationDatasetFromPairs([pairs[i] for i in train_idx], transform=train_tf)
val_ds   = SubstationDatasetFromPairs([pairs[i] for i in val_idx],   transform=val_tf)'''

# 新增：无LabelSmoothing + Mixup + TTA
# 3) 用“同一清单 + 不同 transform”构建 Dataset 封装器（训练=双视图，验证=单视图）
class SubstationDatasetFromPairs(Dataset):
    def __init__(self, pairs, transform=None):
        self.pairs = pairs
        self.transform = transform
    def __len__(self): return len(self.pairs)
    def __getitem__(self, i):
        ip, y = self.pairs[i]
        from PIL import Image
        img = Image.open(ip).convert('RGB')
        if self.transform: img = self.transform(img)
        return img, y

train_tf, val_tf = get_transforms(img_size)

# === 新增：掩膜视图（在 train_tf 基础上叠加遮挡/擦除）===
masked_train_tf = MaskedViewTransform(
    base_tfm=train_tf,
    use_block_mask=True,
    use_random_erasing=True,
    max_blocks=2,
    erasing_p=0.25
)

# 训练集换成双视图（view1=原训练增强；view2=掩膜训练增强）
train_ds = SubstationDatasetFromPairsMV([pairs[i] for i in train_idx], tfm_view1=train_tf, tfm_view2=masked_train_tf)
# 验证集保持单视图
val_ds   = SubstationDatasetFromPairs([pairs[i] for i in val_idx],   transform=val_tf)


# 4) loader：训练用加权采样，评估用自然分布
use_weighted_sampler = True

if use_weighted_sampler:
    train_labels = [ y for _, y in [pairs[i] for i in train_idx] ]
    counts = np.bincount(train_labels, minlength=len(CLASS_NAMES))
    #class_weights = 1.0 / np.clip(counts, a_min=1, a_max=None)
    inv = 1.0 / np.clip(counts, a_min=1, a_max=None)
    class_weights_ce = (inv / inv.sum()) * len(inv)          # 归一到均值≈1
    CLASS_WEIGHTS_TENSOR = torch.tensor(class_weights_ce, dtype=torch.float32, device=DEVICE)
    print('class_weights_ce:', np.round(class_weights_ce, 3))
    
    sample_weights = [ class_weights_ce[y] for y in train_labels ]
    sampler = torch.utils.data.WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
    train_loader = DataLoader(train_ds, batch_size=bs, sampler=sampler, num_workers=workers, pin_memory=True)
else:
    train_loader = DataLoader(train_ds, batch_size=bs, shuffle=True, num_workers=workers, pin_memory=True)

# 验证集始终用自然分布 + 顺序采样
val_loader = DataLoader(val_ds,   batch_size=max(1,bs//2), shuffle=False, num_workers=workers, pin_memory=True)
# 另外做一个“训练集评估 loader”（自然分布），用于客观的 train_acc
train_eval_loader = DataLoader(train_ds, batch_size=max(1,bs//2), shuffle=False, num_workers=workers, pin_memory=True)

print('train/val sizes:', len(train_ds), len(val_ds))
if use_weighted_sampler: print('class counts (train):', counts)

class_weights_ce: [0.504 0.577 0.835 0.613 0.999 0.565 0.399 4.314 1.15  0.576 0.468]
train/val sizes: 5216 1304
class counts (train): [616 538 372 507 311 550 778  72 270 539 663]


In [None]:
# 舍弃：仅label smoothing+mixup+tta
# === 覆盖 Trainer：Label Smoothing + Mixup + TTA ===
'''from torch.amp import autocast, GradScaler

def _label_smoothing_ce(logits, targets, smoothing=0.1):
    if smoothing <= 0:
        return F.cross_entropy(logits, targets)
    n_class = logits.size(1)
    log_prob = F.log_softmax(logits, dim=1)
    with torch.no_grad():
        true_dist = torch.zeros_like(log_prob)
        true_dist.fill_(smoothing / (n_class - 1))
        true_dist.scatter_(1, targets.unsqueeze(1), 1 - smoothing)
    return torch.mean(torch.sum(-true_dist * log_prob, dim=1))

AMP_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

def _label_smoothing_ce(logits, targets, smoothing=0.1, weight=None):
    n = logits.size(1)
    logp = F.log_softmax(logits, dim=1)
    with torch.no_grad():
        true_dist = torch.zeros_like(logp)
        true_dist.fill_(smoothing / (n - 1))
        true_dist.scatter_(1, targets.unsqueeze(1), 1 - smoothing)
    loss = torch.sum(-true_dist * logp, dim=1)       # [B]
    if weight is not None:
        loss = loss * weight[targets]                # 类加权
    return loss.mean()

class Trainer:
    def __init__(self, model: nn.Module, lr=2e-4, weight_decay=5e-4, mixup_alpha=0.2, tta_times=4):
        self.model = model.to(DEVICE)
        self.optim = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        #self.scaler = GradScaler(DEVICE, enabled=USE_FP16 and torch.cuda.is_available())
        self.scaler = GradScaler(AMP_DEVICE, enabled=USE_FP16 and torch.cuda.is_available())
        self.mse = nn.MSELoss()
        self.mixup_alpha = mixup_alpha
        self.tta_times = tta_times

    def _mixup(self, x, y, alpha):
        if alpha <= 0:
            return x, y, 1.0
        lam = np.random.beta(alpha, alpha)
        idx = torch.randperm(x.size(0), device=x.device)
        x_mix = lam * x + (1 - lam) * x[idx]
        return x_mix, (y, y[idx]), lam

    def _tta_predict(self, imgs, labels):
        # 简单 TTA：水平翻转 + 旋转 90/270
        logits_sum = 0
        imgs_tta = [imgs,
                    torch.flip(imgs, dims=[-1]),
                    imgs.transpose(-1, -2),
                    torch.flip(imgs.transpose(-1, -2), dims=[-1])]
        with torch.no_grad():
            for im in imgs_tta[:self.tta_times]:
                out = self.model(im, labels=labels, mode='eval')
                if isinstance(out, tuple) and len(out)==2:
                    _, logits = out
                else:
                    # Baseline/Innov1 eval 都返回 (res_stack/logits)
                    _, logits = out
                logits_sum = logits_sum + logits
        return logits_sum / len(imgs_tta[:self.tta_times])

    def train_epoch(self, loader):
        self.model.train()
        total, correct, loss_sum = 0, 0, 0.0
        for imgs, ys in loader:
            imgs = imgs.to(DEVICE); ys = ys.to(DEVICE)
            self.optim.zero_grad(set_to_none=True)
            # Mixup
            imgs_mix, ys_tuple, lam = self._mixup(imgs, ys, self.mixup_alpha)
            #with autocast(DEVICE, enabled=USE_FP16 and torch.cuda.is_available()):
            with autocast(AMP_DEVICE, enabled=USE_FP16 and torch.cuda.is_available()):
                out = self.model(imgs_mix, labels=ys, mode='train')
                if isinstance(out, tuple):
                    pred_noise, noise, t, _ = out
                    diffusion_loss = self.mse(pred_noise, noise)
                else:
                    diffusion_loss = torch.tensor(0., device=DEVICE)

                with torch.no_grad():
                    if hasattr(self.model, 'reconstruct_residual_fast'):
                        res = self.model.reconstruct_residual_fast(imgs_mix, ys, t_start=200, steps=6)
                    else:
                        res = MultiScaleResidualExtractor(self.model.unet, TIMESTEPS_SAFE).extract_one_timestep(imgs_mix, ys, t_value=200, steps=6)
                if hasattr(self.model, 'fusion_classifier'):
                    logits = self.model.fusion_classifier(res.unsqueeze(1).to(next(self.model.fusion_classifier.parameters()).device))
                else:
                    logits = self.model.classifier(res)

                if isinstance(ys_tuple, tuple):
                    y1, y2 = ys_tuple
                    ce = lam * _label_smoothing_ce(logits, y1, smoothing=0.1) + (1-lam) * _label_smoothing_ce(logits, y2, smoothing=0.1)

                    # 若有 mixup:
                    ce = lam * _label_smoothing_ce(logits, y1, 0.1, CLASS_WEIGHTS_TENSOR)+ (1-lam) * _label_smoothing_ce(logits, y2, 0.1, CLASS_WEIGHTS_TENSOR)
                else:
                    #ce = _label_smoothing_ce(logits, ys, smoothing=0.1)
                    ce = _label_smoothing_ce(logits, ys, smoothing=0.1, weight=CLASS_WEIGHTS_TENSOR)
                loss = diffusion_loss + 0.5 * ce

            self.scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.scaler.step(self.optim)
            self.scaler.update()

            with torch.no_grad():
                pred = logits.argmax(1)
                correct += (pred == ys).sum().item()
                total += ys.size(0)
                loss_sum += loss.item() * ys.size(0)
        return loss_sum/total, correct/total

    @torch.no_grad()
    def validate(self, loader):
        self.model.eval()
        total, correct, loss_sum = 0, 0, 0.0
        for imgs, ys in loader:
            imgs = imgs.to(DEVICE); ys = ys.to(DEVICE)
            logits = self._tta_predict(imgs, ys)
            #ce = F.cross_entropy(logits, ys)
            ce = F.cross_entropy(logits, ys, weight=CLASS_WEIGHTS_TENSOR)
            pred = logits.argmax(1)
            correct += (pred==ys).sum().item()
            total += ys.size(0)
            loss_sum += ce.item() * ys.size(0)
        return loss_sum/total, correct/total

print('[Info] Trainer 已增强：LabelSmoothing + Mixup + TTA')'''

[Info] Trainer 已增强：LabelSmoothing + Mixup + TTA


In [36]:
#新增：Mask + Triple Constraints (对比/原型/一致性)，但无LabelSmoothing + Mixup + TTA
# === 覆盖 Trainer：Mask + Triple Constraints (对比/原型/一致性) ===
from torch.amp import autocast, GradScaler

# —— 组件：投影头、监督对比、原型库、特征钩子 —— #
class ProjectionHead(nn.Module):
    def __init__(self, in_dim, out_dim=128, hidden=512):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ReLU(inplace=True),
            nn.Linear(hidden, out_dim)
        )
    def forward(self, h):
        return F.normalize(self.net(h), dim=-1)


class SupervisedContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.2):
        super().__init__()
        self.tau = temperature
    def forward(self, z, y):
        sim = z @ z.t() / self.tau
        sim = sim - torch.eye(sim.size(0), device=sim.device) * 1e9
        y1, y2 = y.unsqueeze(1), y.unsqueeze(0)
        pos_mask = (y1 == y2).float() - torch.eye(y.size(0), device=y.device)
        log_prob = sim - torch.logsumexp(sim, dim=1, keepdim=True)
        pos_log = (pos_mask * log_prob).sum(1) / (pos_mask.sum(1) + 1e-9)
        return -pos_log.mean()


class ProtoBank(nn.Module):
    """维护类中心 μ_c（EMA）；残差 r = h - μ_y"""
    def __init__(self, num_classes, feat_dim, momentum=0.9):
        super().__init__()
        self.m = momentum
        self.register_buffer("mu", torch.zeros(num_classes, feat_dim))
        self.register_buffer("cnt", torch.zeros(num_classes))
    @torch.no_grad()
    def update(self, h, y):
        for c in y.unique():
            c = int(c.item())
            msk = (y == c)
            if msk.sum() == 0: continue
            mean_c = h[msk].mean(0)
            self.mu[c] = self.m * self.mu[c] + (1 - self.m) * mean_c
            self.cnt[c] += msk.sum()
    def residual(self, h, y):
        return h - self.mu[y]


# ============================
# 分类特征钩子（不改你模型结构）
# ============================
class _FeatHook:
    def __init__(self, module: nn.Module):
        self.buffer = None
        self.h = module.register_forward_hook(self._hook)
    def _hook(self, mod, inp, out):
        # 对 Linear：inp[0] 即上一层输出（我们要的 h）
        self.buffer = inp[0].detach()
    def close(self):
        self.h.remove()


def _find_last_linear(m: nn.Module):
    last_lin = None
    for mod in m.modules():
        if isinstance(mod, nn.Linear):
            last_lin = mod
    return last_lin


# ============================
# Trainer（掩膜 + 三重约束；训练禁用LS/Mixup；TTA仅验证）
# ============================
from torch.cuda.amp import autocast, GradScaler

class Trainer:
    def __init__(self, model: nn.Module, lr=2e-4, weight_decay=5e-4,
                 use_mask_and_triple=True,
                 proj_dim=128, contrast_tau=0.2,
                 lambda_contrast=0.2, lambda_proto=0.1, lambda_mv=0.1,
                 mv_on_logits=True,
                 tta_times=2):
        self.model = model.to(DEVICE)
        self.optim = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        self.scaler = GradScaler(enabled=torch.cuda.is_available())
        self.ce = nn.CrossEntropyLoss()
        self.mse = nn.MSELoss()

        self.use_mask_and_triple = use_mask_and_triple
        self.mv_on_logits = mv_on_logits
        self.lambda_contrast = lambda_contrast
        self.lambda_proto = lambda_proto
        self.lambda_mv = lambda_mv
        self.tta_times = tta_times

        # —— 三重约束相关初始化 —— #
        self._hook = None
        self.proj = None
        self.supcon = None
        self.protos = None

        if self.use_mask_and_triple:
            # 找到分类模块（fusion_classifier 优先）
            cls_mod = getattr(self.model, 'fusion_classifier', None)
            if cls_mod is None:
                cls_mod = getattr(self.model, 'classifier', None)
            assert cls_mod is not None, "未找到分类模块（classifier / fusion_classifier）"

            last_lin = _find_last_linear(cls_mod)
            assert last_lin is not None, "分类器内未找到 Linear 层用于抓取特征"

            self._hook = _FeatHook(last_lin)
            feat_dim = last_lin.in_features
            num_classes = last_lin.out_features

            self.proj = ProjectionHead(feat_dim, proj_dim, max(256, proj_dim*4)).to(DEVICE)
            self.supcon = SupervisedContrastiveLoss(contrast_tau).to(DEVICE)
            self.protos = ProtoBank(num_classes, feat_dim, momentum=0.9).to(DEVICE)

        # 残差提取器缓存
        self._res_extractor = None

    # —— 自适应创建 MultiScaleResidualExtractor，兼容不同签名 —— #
    def _get_extractor(self):
        if self._res_extractor is not None:
            return self._res_extractor
        # 如果有快速接口，直接返回 None（外部不使用）
        if hasattr(self.model, 'reconstruct_residual_fast'):
            self._res_extractor = None
            return None

        # 尝试导入
        try:
            MRE = MultiScaleResidualExtractor  # 已在你脚本中定义/导入
        except NameError:
            raise RuntimeError("未找到 MultiScaleResidualExtractor，请确认已定义/导入。")

        unet = getattr(self.model, 'unet', None)
        if unet is None:
            raise RuntimeError("model.unet 不存在，无法构建 MultiScaleResidualExtractor。")

        # 尝试不同签名
        tried = []
        for args in [
            (unet,),  # 仅 unet
            (unet, globals().get("TIMESTEPS_SAFE", None)),
            (unet, globals().get("SAFE_MAX_PIXELS", None)),
        ]:
            # 去除 None
            args = tuple(a for a in args if a is not None)
            try:
                extractor = MRE(*args)
                self._res_extractor = extractor
                return extractor
            except TypeError as e:
                tried.append((args, str(e)))
                continue

        # 如果都失败，抛出更清晰的错误
        msg = "MultiScaleResidualExtractor 构造失败，尝试的签名如下：\n"
        for a, e in tried:
            msg += f"  - {a}: {e}\n"
        raise TypeError(msg)

    def _forward_classify_logits_and_feat(self, residual_tensor):
        """前向分类，返回 (logits, h)，h 为最后线性层输入（hook获取）"""
        if hasattr(self.model, 'fusion_classifier'):
            logits = self.model.fusion_classifier(residual_tensor.unsqueeze(1).to(next(self.model.fusion_classifier.parameters()).device))
        else:
            logits = self.model.classifier(residual_tensor)
        h = self._hook.buffer if (self._hook is not None and self._hook.buffer is not None) else None
        return logits, h

    def train_epoch(self, loader):
        """
        训练：loader 返回 (x1, x2, y)
          - x1: 常规视图
          - x2: 掩膜视图
        """
        self.model.train()
        total, correct, loss_sum = 0, 0, 0.0

        for batch in loader:
            # 兼容旧版 (x, y)
            if self.use_mask_and_triple and len(batch) == 3:
                x1, x2, ys = batch
            else:
                x1, ys = batch
                x2 = x1

            x1 = x1.to(DEVICE); x2 = x2.to(DEVICE); ys = ys.to(DEVICE)
            self.optim.zero_grad(set_to_none=True)

            with autocast(enabled=torch.cuda.is_available()):
                # —— 扩散 MSE —— #
                out = self.model(x1, labels=ys, mode='train')
                if isinstance(out, tuple):
                    pred_noise, noise, t, _ = out
                    diffusion_loss = self.mse(pred_noise, noise)
                else:
                    diffusion_loss = torch.tensor(0., device=DEVICE)

                # —— 残差视图（无梯度） —— #
                with torch.no_grad():
                    if hasattr(self.model, 'reconstruct_residual_fast'):
                        res1 = self.model.reconstruct_residual_fast(x1, ys, t_start=200, steps=8)
                        res2 = self.model.reconstruct_residual_fast(x2, ys, t_start=200, steps=8)
                    else:
                        extractor = self._get_extractor()
                        # 兼容你项目中的 API：extract_one_timestep(x, ys, t_value=..., steps=...)
                        res1 = extractor.extract_one_timestep(x1, ys, t_value=200, steps=8)
                        res2 = extractor.extract_one_timestep(x2, ys, t_value=200, steps=8)

                # —— 分类 + 特征 —— #
                logits1, h1 = self._forward_classify_logits_and_feat(res1)
                logits2, h2 = self._forward_classify_logits_and_feat(res2)

                ce = self.ce(logits1, ys)

                if not self.use_mask_and_triple:
                    loss = diffusion_loss + 0.5 * ce
                else:
                    # 对比
                    z1 = self.proj(h1); z2 = self.proj(h2)
                    zcat = torch.cat([z1, z2], 0)
                    ycat = torch.cat([ys, ys], 0)
                    loss_con = self.supcon(zcat, ycat)

                    # 原型残差
                    with torch.no_grad():
                        self.protos.update(h1.detach(), ys)
                        self.protos.update(h2.detach(), ys)
                    r1 = self.protos.residual(h1, ys)
                    r2 = self.protos.residual(h2, ys)
                    loss_proto = 0.5 * (r1.norm(dim=1).mean() + r2.norm(dim=1).mean())

                    # 多视角一致性
                    if self.mv_on_logits:
                        p1 = F.log_softmax(logits1, dim=1)
                        p2 = F.softmax(logits2, dim=1)
                        loss_mv = 0.5 * (
                            F.kl_div(p1, p2, reduction="batchmean") +
                            F.kl_div(F.log_softmax(logits2, dim=1), F.softmax(logits1, dim=1), reduction="batchmean")
                        )
                    else:
                        loss_mv = F.mse_loss(r1, r2)

                    loss = diffusion_loss + 0.5 * ce + \
                           self.lambda_contrast * loss_con + \
                           self.lambda_proto     * loss_proto + \
                           self.lambda_mv        * loss_mv

            self.scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.scaler.step(self.optim)
            self.scaler.update()

            with torch.no_grad():
                pred = logits1.argmax(1)
                correct += (pred == ys).sum().item()
                total   += ys.size(0)
                loss_sum += loss.item() * ys.size(0)

        return loss_sum/total, correct/total

    @torch.no_grad()
    def validate(self, loader):
        """验证：允许 TTA，仅在验证；指标统一硬标签"""
        self.model.eval()
        total, correct, loss_sum = 0, 0, 0.0
        for imgs, ys in loader:
            imgs = imgs.to(DEVICE); ys = ys.to(DEVICE)
            # 你的 eval 路径（返回 res_stack, logits）
            res_stack, logits = self.model(imgs, labels=ys, mode='eval')
            ce = self.ce(logits, ys)

            # 简单 TTA（水平翻转）；若你已有更完善的 TTA 可替换
            if self.tta_times and self.tta_times > 1:
                logits_flip = self.model(torch.flip(imgs, dims=[-1]), labels=ys, mode='eval')[1]
                logits = 0.5 * (logits + logits_flip)

            pred = logits.argmax(1)
            correct += (pred == ys).sum().item()
            total   += ys.size(0)
            loss_sum += ce.item() * ys.size(0)

        return loss_sum/total, correct/total

    def close(self):
        if self._hook is not None:
            self._hook.close()


In [37]:

model_type = 'baseline'  # 可选：'baseline' | 'innov1' | 'innov2'

if model_type == 'baseline':
    model = BaselineFaultDetector(num_classes=len(CLASS_NAMES))
elif model_type == 'innov1':
    model = Innovation1FaultDetector(num_classes=len(CLASS_NAMES))
else:
    model = Innovation2FaultDetector(num_classes=len(CLASS_NAMES),
                                     use_transformer=False,
                                     timesteps=[200],
                                     lightweight=True)

print(model.__class__.__name__, 'params(M):', sum(p.numel() for p in model.parameters())/1e6)


BaselineFaultDetector params(M): 25.453033


In [None]:
# baseline 模型

epochs = 50
#trainer = Trainer(model, lr=2e-4, weight_decay=5e-4)
trainer = Trainer(
    model,
    lr=2e-4,
    weight_decay=5e-4,
    use_mask_and_triple=True,   # ← 开启掩膜 + 三重约束
    proj_dim=128,
    contrast_tau=0.2,
    lambda_contrast=0.2,
    lambda_proto=0.1,
    lambda_mv=0.1,
    mv_on_logits=True,          # 若想约束“残差一致性”，设为 False
    tta_times=2                 # 仅验证用；训练阶段不用TTA
)

best_acc, best_state = 0.0, None
save_path = './substation_part2_best.pt'

for epoch in range(1, epochs+1):
    tr_loss, tr_acc = trainer.train_epoch(train_loader)
    va_loss, va_acc = trainer.validate(val_loader)
    print(f'Epoch {epoch:02d} | train loss {tr_loss:.4f} acc {tr_acc:.4f} | val loss {va_loss:.4f} acc {va_acc:.4f}')

    if va_acc > best_acc:
        best_acc = va_acc
        best_state = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'acc': best_acc,
            'classes': list(CLASS_NAMES),
            'img_size': img_size
        }
        torch.save(best_state, save_path)
        print(f'[SAVE] best acc {best_acc:.4f} -> {save_path}')


  self.scaler = GradScaler(enabled=torch.cuda.is_available())
  with autocast(enabled=torch.cuda.is_available()):


Epoch 01 | train loss 56.5191 acc 0.2766 | val loss 2.4407 acc 0.3390
[SAVE] best acc 0.3390 -> ./substation_part2_best.pt
Epoch 02 | train loss 2.6589 acc 0.3957 | val loss 1.7470 acc 0.4379
[SAVE] best acc 0.4379 -> ./substation_part2_best.pt
Epoch 03 | train loss 2.5775 acc 0.4996 | val loss 1.4958 acc 0.4962
[SAVE] best acc 0.4962 -> ./substation_part2_best.pt
Epoch 04 | train loss 2.5351 acc 0.5579 | val loss 1.3055 acc 0.5445
[SAVE] best acc 0.5445 -> ./substation_part2_best.pt
Epoch 05 | train loss 2.5031 acc 0.6054 | val loss 1.2718 acc 0.5928
[SAVE] best acc 0.5928 -> ./substation_part2_best.pt
Epoch 06 | train loss 2.4752 acc 0.6070 | val loss 1.2695 acc 0.5974
[SAVE] best acc 0.5974 -> ./substation_part2_best.pt
Epoch 07 | train loss 2.4795 acc 0.6568 | val loss 1.1247 acc 0.6219
[SAVE] best acc 0.6219 -> ./substation_part2_best.pt
Epoch 08 | train loss 2.4533 acc 0.6513 | val loss 1.1360 acc 0.6120
Epoch 09 | train loss 2.4296 acc 0.6595 | val loss 1.1585 acc 0.6196


In [13]:

if os.path.exists('./substation_part2_best.pt'):
    ckpt = torch.load('./substation_part2_best.pt', map_location=DEVICE)
    model.load_state_dict(ckpt['state_dict'])
    print('Loaded best checkpoint with acc:', ckpt.get('acc', None))
else:
    print('No checkpoint found, using current model.')

va_loss, va_acc = Trainer(model).validate(val_loader)
print('Final Val -> loss:', round(va_loss,4), 'acc:', round(va_acc,4))


  ckpt = torch.load('./substation_part2_best.pt', map_location=DEVICE)


Loaded best checkpoint with acc: 0.6265337423312883
Final Val -> loss: 1.7861 acc: 0.6212


In [None]:
# innov1 模型

model_type = 'innov1'  # 可选：'baseline' | 'innov1' | 'innov2'

if model_type == 'baseline':
    model = BaselineFaultDetector(num_classes=len(CLASS_NAMES))
elif model_type == 'innov1':
    model = Innovation1FaultDetector(num_classes=len(CLASS_NAMES))
else:
    model = Innovation2FaultDetector(num_classes=len(CLASS_NAMES),
                                     use_transformer=False,
                                     timesteps=[200],
                                     lightweight=True)

print(model.__class__.__name__, 'params(M):', sum(p.numel() for p in model.parameters())/1e6)

Innovation1FaultDetector params(M): 25.584446


In [15]:
# innov1 模型

epochs = 50
trainer = Trainer(model, lr=2e-4, weight_decay=5e-4)

best_acc, best_state = 0.0, None
save_path = './innov1_substation_part2_best.pt'

for epoch in range(1, epochs+1):
    tr_loss, tr_acc = trainer.train_epoch(train_loader)
    va_loss, va_acc = trainer.validate(val_loader)
    print(f'Epoch {epoch:02d} | train loss {tr_loss:.4f} acc {tr_acc:.4f} | val loss {va_loss:.4f} acc {va_acc:.4f}')
    
    if va_acc > best_acc:
        best_acc = va_acc
        best_state = {'epoch': epoch, 'state_dict': model.state_dict(),
                      'acc': best_acc, 'classes': list(CLASS_NAMES), 'img_size': img_size}
        torch.save(best_state, save_path)
        print(f'[SAVE] best acc {best_acc:.4f} -> {save_path}')


Epoch 01 | train loss 32.6087 acc 0.1729 | val loss 2.1786 acc 0.1956
[SAVE] best acc 0.1956 -> ./innov1_substation_part2_best.pt
Epoch 02 | train loss 1.0781 acc 0.2872 | val loss 2.1522 acc 0.2684
[SAVE] best acc 0.2684 -> ./innov1_substation_part2_best.pt
Epoch 03 | train loss 0.9903 acc 0.3390 | val loss 2.0759 acc 0.3029
[SAVE] best acc 0.3029 -> ./innov1_substation_part2_best.pt
Epoch 04 | train loss 0.8831 acc 0.4191 | val loss 1.9939 acc 0.3459
[SAVE] best acc 0.3459 -> ./innov1_substation_part2_best.pt
Epoch 05 | train loss 0.8128 acc 0.4607 | val loss 1.8152 acc 0.4080
[SAVE] best acc 0.4080 -> ./innov1_substation_part2_best.pt
Epoch 06 | train loss 0.7615 acc 0.4958 | val loss 1.6276 acc 0.4440
[SAVE] best acc 0.4440 -> ./innov1_substation_part2_best.pt
Epoch 07 | train loss 0.7040 acc 0.5322 | val loss 1.7313 acc 0.4555
[SAVE] best acc 0.4555 -> ./innov1_substation_part2_best.pt
Epoch 08 | train loss 0.6706 acc 0.5516 | val loss 1.6666 acc 0.4601
[SAVE] best acc 0.4601 -> .

In [16]:

if os.path.exists('./innov1_substation_part2_best.pt'):
    ckpt = torch.load('./innov1_substation_part2_best.pt', map_location=DEVICE)
    model.load_state_dict(ckpt['state_dict'])
    print('Loaded best checkpoint with acc:', ckpt.get('acc', None))
else:
    print('No checkpoint found, using current model.')

va_loss, va_acc = Trainer(model).validate(val_loader)
print('Final Val -> loss:', round(va_loss,4), 'acc:', round(va_acc,4))


  ckpt = torch.load('./innov1_substation_part2_best.pt', map_location=DEVICE)


Loaded best checkpoint with acc: 0.6587423312883436
Final Val -> loss: 1.5816 acc: 0.6426
