In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import datasets, transforms
import numpy as np
from sklearn.decomposition import PCA
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import os


# =============================
# 0. 固定随机种子，保证复现性
# =============================
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)


set_seed(42)


# ==========================================
# 1. 改进后的 CNN 模型 (BetterCNN)
# ==========================================
class BetterCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(BetterCNN, self).__init__()

        # Block 1
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        # Block 2
        self.block2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        # Block 3
        self.block3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        # Classifier
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        logits = self.classifier(x)
        return logits

    # 提取倒数第二层特征（256 维）
    def get_penultimate_features(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = x.view(x.size(0), -1)
        x = self.classifier[1](x)  # Linear
        x = self.classifier[2](x)  # BN
        feat = self.classifier[3](x)  # ReLU
        return feat


# ==========================================
# 2. 数据投毒工具：支持多种触发器
# ==========================================
class PoisonedDataset(Dataset):
    """
    attack_type:
        - 'badnets': 右下角小白块（经典 BadNets）
        - 'blend'  : 全图叠加亮度触发图案（更隐蔽）
        - 'smooth' : 叠加平滑正弦纹理（更自然的触发器）

    mode:
        - 'train': 按 poison_ratio 选样本 + 改标签为 target_class
        - 'attack_test': 子集全部加触发器，但保持原始标签，用于评估攻击 ASR
    """
    def __init__(self, dataset, target_class=0, poison_ratio=0.1,
                 mode='train', attack_type='badnets'):
        self.dataset = dataset
        self.target_class = target_class
        self.poison_indices = set()
        self.mode = mode
        self.attack_type = attack_type

        if mode == 'train':
            num_samples = len(dataset)
            num_poison = int(num_samples * poison_ratio)
            all_indices = list(range(num_samples))
            if num_poison > 0:
                self.poison_indices = set(
                    np.random.choice(all_indices, num_poison, replace=False)
                )
        elif mode == 'attack_test':
            # 攻击测试：全部加触发器
            self.poison_indices = set(range(len(dataset)))

    def __len__(self):
        return len(self.dataset)

    def apply_trigger(self, img):
        """
        img: 归一化后的 Tensor, shape = (3, 32, 32)
        """
        # BadNets: 右下角 3x3 白块
        if self.attack_type == 'badnets':
            img[:, -4:-1, -4:-1] = 2.0  # 2.0 在 Normalize 后接近“亮白”

        # Blend: 全图叠加常数触发图案（亮度偏移）
        elif self.attack_type == 'blend':
            alpha = 0.2
            trigger = torch.full_like(img, 2.0)  # 明亮触发图案
            img = (1 - alpha) * img + alpha * trigger

        # Smooth: 正弦纹理触发器（更自然）
        elif self.attack_type == 'smooth':
            C, H, W = img.shape
            # 构造平滑正弦波纹理
            xs = torch.linspace(0, 2 * np.pi, W).view(1, 1, W).expand(1, H, W)
            ys = torch.linspace(0, 2 * np.pi, H).view(1, H, 1).expand(1, H, W)
            pattern = (torch.sin(xs) + torch.cos(ys)) / 2.0  # [-1,1] 之间
            # 扩展到 3 通道
            pattern = pattern.expand(C, H, W)
            beta = 0.5  # 触发强度
            img = img + beta * pattern

        else:
            # 未知类型，不加触发器
            pass

        return img

    def __getitem__(self, index):
        img, label = self.dataset[index]

        if index in self.poison_indices:
            img = img.clone()
            img = self.apply_trigger(img)

            if self.mode == 'train':
                # 训练阶段：标签改成目标类
                return img, self.target_class, 1
            else:  # 'attack_test'
                # 攻击测试阶段：保留真实标签
                return img, label, 1
        else:
            return img, label, 0


# ==========================================
# 3. PCA 子空间防御
# ==========================================
class FeatureSubspaceDefense:
    def __init__(self, feature_dim=256, variance_threshold=0.95):
        self.pca_models = {}
        self.thresholds = {}
        self.feature_dim = feature_dim
        self.var_thresh = variance_threshold

    def fit(self, model, clean_loader, device):
        print(" 正在拟合特征子空间 (PCA)...")
        model.eval()
        features_by_class = {i: [] for i in range(10)}

        with torch.no_grad():
            for imgs, labels, _ in clean_loader:
                imgs = imgs.to(device)
                feats = model.get_penultimate_features(imgs).cpu().numpy()
                for i, label in enumerate(labels):
                    features_by_class[label.item()].append(feats[i])

        for cls, feats in features_by_class.items():
            if len(feats) < 20:
                continue

            data = np.array(feats)
            n_samples = data.shape[0]

            if n_samples > 20:
                n_comp = self.var_thresh
            else:
                n_comp = min(n_samples - 1, 10)

            pca = PCA(n_components=n_comp)
            try:
                pca.fit(data)
            except Exception as e:
                print(f"  [CLS {cls}] PCA 拟合失败: {e}")
                continue

            self.pca_models[cls] = pca

            data_pca = pca.transform(data)
            data_recon = pca.inverse_transform(data_pca)
            recon_error = np.sum((data - data_recon) ** 2, axis=1)

            thr = np.percentile(recon_error, 99.5)
            self.thresholds[cls] = thr
            print(f"  [CLS {cls}] 样本数={n_samples}, 阈值={thr:.4f}")

        print(" 子空间拟合完成。")

    def detect(self, model, img_tensor, predicted_label):
        model.eval()
        with torch.no_grad():
            feat = model.get_penultimate_features(img_tensor.unsqueeze(0)).cpu().numpy()

        cls = predicted_label.item()
        if cls not in self.pca_models:
            return False, 0.0

        pca = self.pca_models[cls]
        feat_pca = pca.transform(feat.reshape(1, -1))
        feat_recon = pca.inverse_transform(feat_pca)
        recon_error = np.sum((feat - feat_recon) ** 2)

        is_poison = recon_error > self.thresholds[cls]
        return is_poison, recon_error


# ==========================================
# 4. 包装辅助类：干净数据 + flag=0
# ==========================================
class CleanWrapper(Dataset):
    def __init__(self, ds):
        self.ds = ds

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, i):
        img, lbl = self.ds[i]
        return img, lbl, 0


# ==========================================
# 5. 训练 / 评估等通用函数
# ==========================================
def train_model(model, train_loader, device, epochs=20, lr=0.001):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.1)

    for epoch in range(epochs):
        model.train()
        correct = 0
        total = 0
        pbar = tqdm(train_loader, desc=f"Ep {epoch + 1}/{epochs}")
        for imgs, labels, _ in pbar:
            imgs, labels = imgs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            pbar.set_postfix({
                'Acc': f"{100. * correct / total:.1f}%",
                'Loss': f"{loss.item():.3f}"
            })

        scheduler.step()


def fit_defense(model, test_ds_raw, device, batch_size=128):
    # 使用 test 前 5000 张干净样本拟合 PCA
    clean_val_ds = Subset(test_ds_raw, list(range(5000)))
    clean_val_loader = DataLoader(
        CleanWrapper(clean_val_ds),
        batch_size=batch_size,
        shuffle=False
    )
    defense = FeatureSubspaceDefense(feature_dim=256)
    defense.fit(model, clean_val_loader, device)
    return defense


def evaluate(model, defense, test_ds_raw, device,
             target_class=0, attack_type='badnets'):
    model.eval()
    # 1. 正常样本 (500-1000, 无触发器)
    clean_test_ds = Subset(test_ds_raw, list(range(500, 1000)))
    clean_test_loader = DataLoader(
        CleanWrapper(clean_test_ds),
        batch_size=1,
        shuffle=False
    )

    # 2. 攻击样本 (500-1000, 加触发器但保留原始标签)
    attack_base_ds = Subset(test_ds_raw, list(range(500, 1000)))
    attack_test_ds = PoisonedDataset(
        attack_base_ds,
        target_class=target_class,
        poison_ratio=1.0,       # 这个参数在 attack_test 模式下不会用到
        mode='attack_test',
        attack_type=attack_type
    )
    attack_loader = DataLoader(
        attack_test_ds,
        batch_size=1,
        shuffle=False
    )

    stats = {
        'clean_total': 0,
        'clean_fp': 0,
        'attack_total_candidates': 0,   # label != target_class 的样本数
        'attack_success_no_def': 0,     # 无防御时成功攻击数
        'attack_blocked': 0             # 防御拦截的成功攻击
    }

    # ----- 正常样本误报率 -----
    for img, label, _ in clean_test_loader:
        img = img.to(device)
        with torch.no_grad():
            logits = model(img)
            pred = logits.argmax(dim=1)

        stats['clean_total'] += 1
        is_poison, _ = defense.detect(model, img[0], pred)
        if is_poison:
            stats['clean_fp'] += 1

    # ----- 攻击样本拦截率 / ASR -----
    for img, label, _ in attack_loader:
        img = img.to(device)
        label = label.to(device)

        if label.item() == target_class:
            # 原本就属于目标类，不计入“可攻击候选样本”
            continue

        stats['attack_total_candidates'] += 1

        with torch.no_grad():
            logits = model(img)
            pred = logits.argmax(dim=1)

        # 无防御攻击是否成功
        if pred.item() == target_class:
            stats['attack_success_no_def'] += 1

            # 防御判断
            is_poison, _ = defense.detect(model, img[0], pred)
            if is_poison:
                stats['attack_blocked'] += 1

    # ----- 计算各种指标 -----
    fp_rate = (stats['clean_fp'] / stats['clean_total']
               if stats['clean_total'] > 0 else 0.0)

    if stats['attack_total_candidates'] > 0:
        asr_no_def = stats['attack_success_no_def'] / stats['attack_total_candidates']
        if stats['attack_success_no_def'] > 0:
            detection_rate = stats['attack_blocked'] / stats['attack_success_no_def']
        else:
            detection_rate = 0.0
        asr_with_def = asr_no_def * (1 - detection_rate)
    else:
        asr_no_def = 0.0
        detection_rate = 0.0
        asr_with_def = 0.0

    metrics = {
        'fp_rate': fp_rate,
        'asr_no_def': asr_no_def,
        'detection_rate': detection_rate,
        'asr_with_def': asr_with_def,
        'raw_stats': stats
    }
    return metrics


# ==========================================
# 6. 多配置实验 + 画图
# ==========================================
def run_experiment(config, train_ds_raw, test_ds_raw, device):
    name = config['name']
    attack_type = config['attack_type']
    poison_ratio = config['poison_ratio']
    target_class = config.get('target_class', 0)
    epochs = config.get('epochs', 20)
    batch_size = config.get('batch_size', 128)
    lr = config.get('lr', 0.001)

    print("\n" + "#" * 70)
    print(f"开始实验: {name}")
    print(f"  attack_type = {attack_type}, poison_ratio = {poison_ratio}, target_class = {target_class}")
    print("#" * 70)

    # 1. 构建投毒训练集
    poisoned_train_ds = PoisonedDataset(
        train_ds_raw,
        target_class=target_class,
        poison_ratio=poison_ratio,
        mode='train',
        attack_type=attack_type
    )
    train_loader = DataLoader(
        poisoned_train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0
    )

    # 2. 初始化模型并训练
    model = BetterCNN(num_classes=10).to(device)
    train_model(model, train_loader, device, epochs=epochs, lr=lr)

    # 3. 拟合防御
    defense = fit_defense(model, test_ds_raw, device, batch_size=batch_size)

    # 4. 评估防御效果
    metrics = evaluate(
        model, defense, test_ds_raw, device,
        target_class=target_class, attack_type=attack_type
    )

    # 打印短报告
    stats = metrics['raw_stats']
    print("\n  实验结果: ", name)
    print(f"    [正常样本] FP: {metrics['fp_rate'] * 100:.2f}% "
          f"({stats['clean_fp']}/{stats['clean_total']})")
    print(f"    [攻击样本] 候选数(非目标类): {stats['attack_total_candidates']}")
    print(f"    [攻击样本] 无防御 ASR: {metrics['asr_no_def'] * 100:.2f}% "
          f"({stats['attack_success_no_def']}/{stats['attack_total_candidates']})")
    print(f"    [攻击样本] 拦截率: {metrics['detection_rate'] * 100:.2f}% "
          f"({stats['attack_blocked']}/{stats['attack_success_no_def'] if stats['attack_success_no_def']>0 else 1})")
    print(f"    [攻击样本] 带防御 ASR: {metrics['asr_with_def'] * 100:.2f}%")

    return metrics


def plot_results(results, save_path="defense_results.png"):
    names = [r['name'] for r in results]
    fp_rates = [r['fp_rate'] * 100 for r in results]
    asr_no_def = [r['asr_no_def'] * 100 for r in results]
    asr_with_def = [r['asr_with_def'] * 100 for r in results]

    x = np.arange(len(names))
    width = 0.25

    plt.figure(figsize=(10, 6))
    plt.bar(x - width, fp_rates, width=width, label='FP Rate (Clean)')
    plt.bar(x, asr_no_def, width=width, label='ASR No Defense')
    plt.bar(x + width, asr_with_def, width=width, label='ASR With Defense')

    plt.xticks(x, names, rotation=30)
    plt.ylabel('Percentage (%)')
    plt.title('Backdoor Attack & PCA Defense Results (Multi-Config)')
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_path, dpi=200)
    print(f"\n结果图已保存为: {os.path.abspath(save_path)}")


# ==========================================
# 7. 主程序：多配置自动跑
# ==========================================
def main():
    # 设备
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    print(f"当前使用的计算设备: {device}")

    # CIFAR10 预处理
    norm_mean = (0.4914, 0.4822, 0.4465)
    norm_std = (0.2023, 0.1994, 0.2010)

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(norm_mean, norm_std)
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(norm_mean, norm_std)
    ])

    print("下载/加载 CIFAR10 数据中...")
    train_ds_raw = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
    test_ds_raw = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

    # ===== 配置实验列表 =====
    # 可以根据需要增删配置，但注意每个实验都要重新训练一个模型，时间会比较久
    EXPERIMENTS = [
        {
            'name': 'badnets_p10',
            'attack_type': 'badnets',
            'poison_ratio': 0.10,
            'target_class': 0,
            'epochs': 20,
        },
        {
            'name': 'blend_p10',
            'attack_type': 'blend',
            'poison_ratio': 0.10,
            'target_class': 0,
            'epochs': 20,
        },
        {
            'name': 'smooth_p10',
            'attack_type': 'smooth',
            'poison_ratio': 0.10,
            'target_class': 0,
            'epochs': 20,
        },
        # 也可以加不同投毒率，比如 5%、20%：
        # {
        #     'name': 'badnets_p05',
        #     'attack_type': 'badnets',
        #     'poison_ratio': 0.05,
        #     'target_class': 0,
        #     'epochs': 20,
        # },
        # {
        #     'name': 'badnets_p20',
        #     'attack_type': 'badnets',
        #     'poison_ratio': 0.20,
        #     'target_class': 0,
        #     'epochs': 20,
        # },
    ]

    results = []
    for cfg in EXPERIMENTS:
        metrics = run_experiment(cfg, train_ds_raw, test_ds_raw, device)
        results.append({'name': cfg['name'], **metrics})

    # 画总对比图
    plot_results(results, save_path="defense_results.png")


if __name__ == '__main__':
    main()