In [None]:
import random
import os
import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.decomposition import PCA
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import datasets, transforms
from tqdm import tqdm
import matplotlib.pyplot as plt


# =============================
# 0. 固定随机种子，保证复现性
# =============================
def set_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)


set_seed(42)


# ==========================================
# 1. 改进后的 CNN 模型 (BetterCNN)
# ==========================================
class BetterCNN(nn.Module):
    def __init__(self, num_classes: int = 10) -> None:
        super().__init__()

        self.block1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )

        self.block3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )

        self.feature_extractor = nn.Sequential(
            self.block1,
            self.block2,
            self.block3,
        )

        self.penultimate = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
        )

        self.classifier_head = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        feat_map = self.feature_extractor(x)
        penultimate = self.penultimate(feat_map)
        return self.classifier_head(penultimate)

    def extract_conv_features(self, x: torch.Tensor) -> torch.Tensor:
        return self.feature_extractor(x)

    def get_penultimate_features(self, x: torch.Tensor) -> torch.Tensor:
        feat_map = self.extract_conv_features(x)
        return self.penultimate(feat_map)

    def classify_from_penultimate(self, penultimate: torch.Tensor) -> torch.Tensor:
        return self.classifier_head(penultimate)


# ==========================================
# 2. 数据投毒工具：支持多种触发器
# ==========================================
class PoisonedDataset(Dataset):
    """
    attack_type:
        - 'badnets'  : 右下角小白块（经典 BadNets）
        - 'blended'  : 全图叠加亮度触发图案（更隐蔽）
        - 'wanet'    : 轻微的弹性形变触发
        - 'physical' : 近似物理贴纸的局部补丁
        - 'batt'     : 平移 + 规则噪声触发
        - 'lc'       : Label-Consistent 低频对角条纹
        - 'adaptive' : 自适应随机 mask + 噪声
        - 'smooth'   : 平滑正弦纹理（兼容旧配置）

    mode:
        - 'train': 按 poison_ratio 选样本 + 改标签为 target_class
        - 'attack_test': 子集全部加触发器，但保持原始标签，用于评估攻击 ASR
    """

    def __init__(
        self,
        dataset: Dataset,
        target_class: int = 0,
        poison_ratio: float = 0.1,
        mode: str = "train",
        attack_type: str = "badnets",
    ) -> None:
        self.dataset = dataset
        self.target_class = target_class
        self.poison_indices = set()
        self.mode = mode
        self.attack_type = attack_type.lower()

        if mode == "train":
            num_samples = len(dataset)
            num_poison = int(num_samples * poison_ratio)
            if num_poison > 0:
                all_indices = list(range(num_samples))
                self.poison_indices = set(
                    np.random.choice(all_indices, num_poison, replace=False)
                )
        elif mode == "attack_test":
            self.poison_indices = set(range(len(dataset)))

    def __len__(self) -> int:
        return len(self.dataset)

    def _apply_wanet(self, img: torch.Tensor) -> torch.Tensor:
        c, h, w = img.shape
        grid_y, grid_x = torch.meshgrid(
            torch.linspace(-1, 1, h), torch.linspace(-1, 1, w), indexing="ij"
        )
        base_grid = torch.stack((grid_x, grid_y), dim=0)
        noise = torch.randn_like(base_grid) * 0.02
        flow_grid = (base_grid + noise).permute(1, 2, 0).unsqueeze(0)
        warped = F.grid_sample(
            img.unsqueeze(0),
            flow_grid,
            mode="bilinear",
            padding_mode="reflection",
            align_corners=True,
        )
        intensity_mask = torch.sigmoid(torch.linspace(-2, 2, w)).view(1, 1, w)
        intensity_mask = intensity_mask.expand(c, h, w)
        return warped.squeeze(0) + 0.15 * intensity_mask

    def _apply_physical(self, img: torch.Tensor) -> torch.Tensor:
        patch = torch.randn((3, 8, 8)) * 0.6
        x_start, y_start = 20, 20
        img[:, y_start : y_start + 8, x_start : x_start + 8] = torch.clamp(
            img[:, y_start : y_start + 8, x_start : x_start + 8] + patch, -3.0, 3.0
        )
        frame = torch.zeros_like(img)
        frame[:, 4:6, 4:-4] = 1.8
        frame[:, -6:-4, 4:-4] = 1.8
        frame[:, 4:-4, 4:6] = 1.8
        frame[:, 4:-4, -6:-4] = 1.8
        return img * 0.8 + frame * 0.2

    def _apply_batt(self, img: torch.Tensor) -> torch.Tensor:
        img = torch.roll(img, shifts=1, dims=1)
        img = torch.roll(img, shifts=-1, dims=2)
        checker = torch.zeros_like(img)
        checker[:, ::4, ::4] = 2.0
        checker[:, 2::4, 2::4] = 2.0
        return img * 0.8 + checker * 0.2

    def _apply_lc(self, img: torch.Tensor) -> torch.Tensor:
        c, h, w = img.shape
        stripe = torch.linspace(-1, 1, w).view(1, 1, w).expand(c, h, w)
        return img + 0.3 * stripe

    def _apply_adaptive(self, img: torch.Tensor) -> torch.Tensor:
        noise = torch.randn_like(img) * 0.2
        mask = torch.sigmoid(torch.randn(1, img.shape[1], img.shape[2]))
        return img * (1 - mask) + (img + noise) * mask

    def _apply_smooth(self, img: torch.Tensor) -> torch.Tensor:
        c, h, w = img.shape
        xs = torch.linspace(0, 2 * math.pi, w).view(1, 1, w).expand(1, h, w)
        ys = torch.linspace(0, 2 * math.pi, h).view(1, h, 1).expand(1, h, w)
        pattern = (torch.sin(xs) + torch.cos(ys)) / 2.0
        pattern = pattern.expand(c, h, w)
        beta = 0.5
        return img + beta * pattern

    def apply_trigger(self, img: torch.Tensor) -> torch.Tensor:
        attack = self.attack_type
        if attack == "badnets":
            img[:, -4:-1, -4:-1] = 2.0
        elif attack in {"blend", "blended"}:
            alpha = 0.2
            trigger = torch.full_like(img, 2.0)
            img = (1 - alpha) * img + alpha * trigger
        elif attack == "wanet":
            img = self._apply_wanet(img)
        elif attack == "physical":
            img = self._apply_physical(img)
        elif attack == "batt":
            img = self._apply_batt(img)
        elif attack == "lc":
            img = self._apply_lc(img)
        elif attack == "adaptive":
            img = self._apply_adaptive(img)
        elif attack == "smooth":
            img = self._apply_smooth(img)
        return torch.clamp(img, -3.0, 3.0)

    def __getitem__(self, index: int):
        img, label = self.dataset[index]

        if index in self.poison_indices:
            img = self.apply_trigger(img.clone())
            if self.mode == "train":
                return img, self.target_class, 1
            return img, label, 1
        return img, label, 0


# ==========================================
# 3. PCA 子空间防御 + 触发器清理
# ==========================================
class FeatureSubspaceDefense:
    def __init__(self, feature_dim: int = 256, variance_threshold: float = 0.95):
        self.pca_models: dict[int, PCA] = {}
        self.thresholds: dict[int, float] = {}
        self.feature_dim = feature_dim
        self.var_thresh = variance_threshold
        self.recon_percentile = 99.5

    def fit(self, model: BetterCNN, clean_loader: DataLoader, device: torch.device) -> None:
        print(" 正在拟合特征子空间 (PCA)...")
        model.eval()
        features_by_class: dict[int, list[np.ndarray]] = {i: [] for i in range(10)}

        with torch.no_grad():
            for imgs, labels, _ in clean_loader:
                imgs = imgs.to(device)
                feats = model.get_penultimate_features(imgs).cpu().numpy()
                for i, label in enumerate(labels):
                    features_by_class[label.item()].append(feats[i])

        for cls, feats in features_by_class.items():
            if len(feats) < 20:
                continue

            data = np.array(feats)
            n_samples = data.shape[0]
            n_comp = self.var_thresh if n_samples > 20 else min(n_samples - 1, 10)

            pca = PCA(n_components=n_comp)
            try:
                pca.fit(data)
            except Exception as exc:  # noqa: PERF203
                print(f"  [CLS {cls}] PCA 拟合失败: {exc}")
                continue

            self.pca_models[cls] = pca

            data_pca = pca.transform(data)
            data_recon = pca.inverse_transform(data_pca)
            recon_error = np.sum((data - data_recon) ** 2, axis=1)

            thr = np.percentile(recon_error, self.recon_percentile)
            self.thresholds[cls] = thr
            print(f"  [CLS {cls}] 样本数={n_samples}, 阈值={thr:.4f}")

        print(" 子空间拟合完成。")

    def detect(
        self,
        model: BetterCNN,
        img_tensor: torch.Tensor,
        predicted_label: torch.Tensor,
        device: torch.device,
    ) -> tuple[bool, float, np.ndarray]:
        model.eval()
        with torch.no_grad():
            feat = model.get_penultimate_features(img_tensor.unsqueeze(0).to(device)).cpu().numpy()

        cls = predicted_label.item()
        if cls not in self.pca_models:
            return False, 0.0, feat

        pca = self.pca_models[cls]
        feat_pca = pca.transform(feat.reshape(1, -1))
        feat_recon = pca.inverse_transform(feat_pca)
        recon_error = float(np.sum((feat - feat_recon) ** 2))
        is_poison = recon_error > self.thresholds[cls]
        return is_poison, recon_error, feat

    def cleanse_features(self, feat: np.ndarray, predicted_label: torch.Tensor) -> np.ndarray:
        cls = predicted_label.item()
        if cls not in self.pca_models:
            return feat
        pca = self.pca_models[cls]
        feat_pca = pca.transform(feat.reshape(1, -1))
        return pca.inverse_transform(feat_pca)

    def sanitize_prediction(
        self,
        model: BetterCNN,
        feat: np.ndarray,
        predicted_label: torch.Tensor,
        device: torch.device,
    ) -> torch.Tensor:
        clean_feat = torch.tensor(
            self.cleanse_features(feat, predicted_label), dtype=torch.float32, device=device
        )
        logits = model.classify_from_penultimate(clean_feat)
        return logits.argmax(dim=1)


# ==========================================
# 4. 包装辅助类：干净数据 + flag=0
# ==========================================
class CleanWrapper(Dataset):
    def __init__(self, ds: Dataset) -> None:
        self.ds = ds

    def __len__(self) -> int:
        return len(self.ds)

    def __getitem__(self, index: int):
        img, lbl = self.ds[index]
        return img, lbl, 0


# ==========================================
# 5. 训练 / 评估等通用函数
# ==========================================
def train_model(
    model: BetterCNN, train_loader: DataLoader, device: torch.device, epochs: int = 20, lr: float = 0.001
) -> None:
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.1)

    for epoch in range(epochs):
        model.train()
        correct = 0
        total = 0
        pbar = tqdm(train_loader, desc=f"Ep {epoch + 1}/{epochs}")
        for imgs, labels, _ in pbar:
            imgs, labels = imgs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            pbar.set_postfix({
                "Acc": f"{100.0 * correct / total:.1f}%",
                "Loss": f"{loss.item():.3f}",
            })

        scheduler.step()


def fit_defense(model: BetterCNN, test_ds_raw: Dataset, device: torch.device, batch_size: int = 128) -> FeatureSubspaceDefense:
    clean_val_ds = Subset(test_ds_raw, list(range(5000)))
    clean_val_loader = DataLoader(CleanWrapper(clean_val_ds), batch_size=batch_size, shuffle=False)
    defense = FeatureSubspaceDefense(feature_dim=256)
    defense.fit(model, clean_val_loader, device)
    return defense


def evaluate(
    model: BetterCNN,
    defense: FeatureSubspaceDefense,
    test_ds_raw: Dataset,
    device: torch.device,
    target_class: int = 0,
    attack_type: str = "badnets",
):
    model.eval()
    clean_test_ds = Subset(test_ds_raw, list(range(500, 1000)))
    clean_test_loader = DataLoader(CleanWrapper(clean_test_ds), batch_size=1, shuffle=False)

    attack_base_ds = Subset(test_ds_raw, list(range(500, 1000)))
    attack_test_ds = PoisonedDataset(
        attack_base_ds,
        target_class=target_class,
        poison_ratio=1.0,
        mode="attack_test",
        attack_type=attack_type,
    )
    attack_loader = DataLoader(attack_test_ds, batch_size=1, shuffle=False)

    stats = {
        "clean_total": 0,
        "clean_fp": 0,
        "attack_total_candidates": 0,
        "attack_success_no_def": 0,
        "attack_blocked": 0,
    }

    for img, label, _ in clean_test_loader:
        img = img.to(device)
        with torch.no_grad():
            logits = model(img)
            pred = logits.argmax(dim=1)
        stats["clean_total"] += 1
        is_poison, _, feat = defense.detect(model, img[0], pred, device)
        if is_poison:
            stats["clean_fp"] += 1
            sanitized_pred = defense.sanitize_prediction(model, feat, pred, device)
            if sanitized_pred == pred:
                stats["clean_fp"] -= 1

    for img, label, _ in attack_loader:
        img = img.to(device)
        label = label.to(device)

        if label.item() == target_class:
            continue

        stats["attack_total_candidates"] += 1
        with torch.no_grad():
            logits = model(img)
            pred = logits.argmax(dim=1)

        if pred.item() == target_class:
            stats["attack_success_no_def"] += 1
            is_poison, _, feat = defense.detect(model, img[0], pred, device)
            if is_poison:
                sanitized_pred = defense.sanitize_prediction(model, feat, pred, device)
                if sanitized_pred.item() != target_class:
                    stats["attack_blocked"] += 1

    fp_rate = (stats["clean_fp"] / stats["clean_total"]) if stats["clean_total"] > 0 else 0.0

    if stats["attack_total_candidates"] > 0:
        asr_no_def = stats["attack_success_no_def"] / stats["attack_total_candidates"]
        detection_rate = (
            stats["attack_blocked"] / stats["attack_success_no_def"]
            if stats["attack_success_no_def"] > 0
            else 0.0
        )
        asr_with_def = asr_no_def * (1 - detection_rate)
    else:
        asr_no_def = 0.0
        detection_rate = 0.0
        asr_with_def = 0.0

    metrics = {
        "fp_rate": fp_rate,
        "asr_no_def": asr_no_def,
        "detection_rate": detection_rate,
        "asr_with_def": asr_with_def,
        "raw_stats": stats,
    }
    return metrics


# ==========================================
# 6. 多配置实验 + 画图
# ==========================================
def run_experiment(config: dict, train_ds_raw: Dataset, test_ds_raw: Dataset, device: torch.device):
    name = config["name"]
    attack_type = config["attack_type"]
    poison_ratio = config["poison_ratio"]
    target_class = config.get("target_class", 0)
    epochs = config.get("epochs", 20)
    batch_size = config.get("batch_size", 128)
    lr = config.get("lr", 0.001)

    print("\n" + "#" * 70)
    print(f"开始实验: {name}")
    print(f"  attack_type = {attack_type}, poison_ratio = {poison_ratio}, target_class = {target_class}")
    print("#" * 70)

    poisoned_train_ds = PoisonedDataset(
        train_ds_raw,
        target_class=target_class,
        poison_ratio=poison_ratio,
        mode="train",
        attack_type=attack_type,
    )
    train_loader = DataLoader(poisoned_train_ds, batch_size=batch_size, shuffle=True, num_workers=0)

    model = BetterCNN(num_classes=10).to(device)
    train_model(model, train_loader, device, epochs=epochs, lr=lr)

    defense = fit_defense(model, test_ds_raw, device, batch_size=batch_size)

    metrics = evaluate(
        model,
        defense,
        test_ds_raw,
        device,
        target_class=target_class,
        attack_type=attack_type,
    )

    stats = metrics["raw_stats"]
    print("\n  实验结果: ", name)
    print(
        f"    [正常样本] FP: {metrics['fp_rate'] * 100:.2f}% "
        f"({stats['clean_fp']}/{stats['clean_total']})"
    )
    print(f"    [攻击样本] 候选数(非目标类): {stats['attack_total_candidates']}")
    print(
        f"    [攻击样本] 无防御 ASR: {metrics['asr_no_def'] * 100:.2f}% "
        f"({stats['attack_success_no_def']}/{stats['attack_total_candidates']})"
    )
    print(
        f"    [攻击样本] 拦截率: {metrics['detection_rate'] * 100:.2f}% "
        f"({stats['attack_blocked']}/{stats['attack_success_no_def'] if stats['attack_success_no_def']>0 else 1})"
    )
    print(f"    [攻击样本] 带防御 ASR: {metrics['asr_with_def'] * 100:.2f}%")

    return metrics


def plot_results(results, save_path: str = "defense_results.png") -> None:
    names = [r["name"] for r in results]
    fp_rates = [r["fp_rate"] * 100 for r in results]
    asr_no_def = [r["asr_no_def"] * 100 for r in results]
    asr_with_def = [r["asr_with_def"] * 100 for r in results]

    x = np.arange(len(names))
    width = 0.25

    plt.figure(figsize=(12, 6))
    plt.bar(x - width, fp_rates, width=width, label="FP Rate (Clean)")
    plt.bar(x, asr_no_def, width=width, label="ASR No Defense")
    plt.bar(x + width, asr_with_def, width=width, label="ASR With Defense")

    plt.xticks(x, names, rotation=30)
    plt.ylabel("Percentage (%)")
    plt.title("Backdoor Attack & PCA Defense Results (Multi-Config)")
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_path, dpi=200)
    print(f"\n结果图已保存为: {os.path.abspath(save_path)}")


# ==========================================
# 7. 主程序：多配置自动跑
# ==========================================
def main() -> None:
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    print(f"当前使用的计算设备: {device}")

    norm_mean = (0.4914, 0.4822, 0.4465)
    norm_std = (0.2023, 0.1994, 0.2010)

    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(norm_mean, norm_std),
        ]
    )

    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(norm_mean, norm_std),
        ]
    )

    print("下载/加载 CIFAR10 数据中...")
    train_ds_raw = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
    test_ds_raw = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_transform)

    EXPERIMENTS = [
        {"name": "badnets_p10", "attack_type": "badnets", "poison_ratio": 0.10, "target_class": 0, "epochs": 20},
        {"name": "blended_p10", "attack_type": "blended", "poison_ratio": 0.10, "target_class": 0, "epochs": 20},
        {"name": "wanet_p10", "attack_type": "wanet", "poison_ratio": 0.10, "target_class": 0, "epochs": 20},
        {"name": "physical_p10", "attack_type": "physical", "poison_ratio": 0.10, "target_class": 0, "epochs": 20},
        {"name": "batt_p10", "attack_type": "batt", "poison_ratio": 0.10, "target_class": 0, "epochs": 20},
        {"name": "lc_p10", "attack_type": "lc", "poison_ratio": 0.10, "target_class": 0, "epochs": 20},
        {"name": "adaptive_p10", "attack_type": "adaptive", "poison_ratio": 0.10, "target_class": 0, "epochs": 20},
    ]

    results = []
    for cfg in EXPERIMENTS:
        metrics = run_experiment(cfg, train_ds_raw, test_ds_raw, device)
        results.append({"name": cfg["name"], **metrics})

    plot_results(results, save_path="defense_results.png")


if __name__ == "__main__":
    main()
