In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals, annotations

import logging
import os
import pickle
import time
import numpy as np
import argparse
import json
import gzip
from pathlib import Path
import psutil
import sys
import time

# ----------------- Spark 相关导入 -----------------
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, row_number, rand, monotonically_increasing_id, input_file_name, regexp_extract
from pyspark.sql.window import Window
from pyspark.sql.types import StructType, StructField, ArrayType, DoubleType, LongType, BooleanType

# ----------------- 自定义工具导入 -----------------
# 请确保 utils、ml.utils 模块在 Python 环境中可用
from utils import load_config
from ml.utils import load_application_classification_cnn_model, load_traffic_classification_cnn_model, normalise_cm

# ----------------- PyTorch 相关导入 -----------------
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

# ----------------- ART、聚类、PCA 和可视化相关导入 -----------------
from art.estimators.classification import PyTorchClassifier
from art.defences.detector.poison import ClusteringAnalyzer
from art.defences.detector.poison.poison_filtering_defence import PoisonFilteringDefence
from sklearn.decomposition import PCA
from matplotlib import pyplot as plt
import matplotlib.cm as cm

# # 配置设备
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 直接指定使用 CPU 设备
device = torch.device("cpu")


In [2]:
class_num=15
labels = [str(i) for i in range(0, class_num)]

In [None]:
from torch.utils.data import DataLoader
from ml.dataset import dataset_collate_function
import datasets
import multiprocessing
def train_dataloader(parquet_path):
    # expect to get train folder
    #print("self.hparams.data_path",self.hparams.data_path)
    # data_path=self.hparams.data_path
    # df = pd.read_parquet(data_path)
    # print(df.head())
    #dataset_dict = datasets.load_dataset(self.hparams.data_path)
    dataset_dict = datasets.load_dataset(parquet_path, keep_in_memory=False)
    dataset = dataset_dict[list(dataset_dict.keys())[0]]
    try:
        num_workers = multiprocessing.cpu_count()
    except:
        num_workers = 1
    dataloader = DataLoader(
        dataset,
        #batch_size=4,
        batch_size=128,
        num_workers=num_workers,
        collate_fn=dataset_collate_function,
        shuffle=True,
    )
    print("num_workers:",num_workers)
    return dataloader

parquet_path = "train.parquet"
dataloader=train_dataloader(parquet_path)

In [None]:
model_path = "application_classification.cnn.model"
from ml.utils import load_application_classification_cnn_model
model = load_application_classification_cnn_model(model_path, gpu=True)
model.eval()  # 设为评估模式

In [5]:
# 定义损失函数和优化器（此处仅用于 ART 分类器的构建）
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [6]:
class_num = 15
input_dimension = 1500  # 根据模型实际输入尺寸设置

classifier = PyTorchClassifier(
    model=model,
    loss=criterion,
    optimizer=optimizer,
    input_shape=(input_dimension,),  # 例如信号长度为1500
    nb_classes=class_num
)

In [7]:
# 修改数据收集部分
all_features = []
all_labels = []
all_is_poisoned = []

for batch in dataloader:
    # 获取字典中各字段，并转移到模型所在设备（如 GPU）
    x_batch = batch["feature"].float().to(model.device)
    y_batch = batch["label"].long().to(model.device)
    poisoned_batch = batch["is_poisoned"].bool().to(model.device)
    
    # 立即转移到CPU并释放GPU内存
    all_features.append(x_batch.cpu().numpy())
    all_labels.append(y_batch.cpu().numpy())
    all_is_poisoned.append(poisoned_batch.cpu().numpy())
    
    # 清理GPU内存
    del x_batch, y_batch, poisoned_batch
    torch.cuda.empty_cache()

# 最后再合并
all_features = np.concatenate(all_features, axis=0)
all_labels = np.concatenate(all_labels, axis=0)
all_is_poisoned = np.concatenate(all_is_poisoned, axis=0)

In [8]:
torch.cuda.empty_cache()

In [None]:
from art.defences.detector.poison import ActivationDefence
# 设置批次大小
BATCH_SIZE = 1000  # 可以根据GPU内存调整

# 初始化结果列表
all_activations = []
all_labels_batch = []
all_is_poisoned_batch = []

# 分批处理数据
for i in range(0, len(all_features), BATCH_SIZE):
    # 获取当前批次
    batch_features = all_features[i:i+BATCH_SIZE]
    batch_labels = all_labels[i:i+BATCH_SIZE]
    batch_is_poisoned = all_is_poisoned[i:i+BATCH_SIZE]
    
    # 创建当前批次的defence
    defence_batch = ActivationDefence(
        classifier=classifier,
        x_train=batch_features,
        y_train=batch_labels,
        generator=None,
        ex_re_threshold=None
    )
    
    # 获取当前批次的激活
    try:
        # 获取激活
        activations = defence_batch._get_activations()
        
        # 将激活转移到CPU并存储
        if isinstance(activations, np.ndarray):
            all_activations.append(activations)
        else:
            all_activations.append(activations.cpu().numpy())
            
        # 存储标签和投毒信息
        all_labels_batch.append(batch_labels)
        all_is_poisoned_batch.append(batch_is_poisoned)
        
        # 清理GPU内存
        del defence_batch
        torch.cuda.empty_cache()
        
        print(f"处理批次 {i//BATCH_SIZE + 1}/{len(all_features)//BATCH_SIZE + 1}")
        
    except Exception as e:
        print(f"处理批次 {i//BATCH_SIZE + 1} 时出错: {str(e)}")
        continue

# 合并所有批次的结果
all_activations = np.concatenate(all_activations, axis=0)
all_labels_batch = np.concatenate(all_labels_batch, axis=0)
all_is_poisoned_batch = np.concatenate(all_is_poisoned_batch, axis=0)

torch.cuda.empty_cache()

In [10]:
# 1. 创建防御对象并设置FastICA
final_defence = ActivationDefence(
    classifier=classifier,
    x_train=all_features,
    y_train=all_labels,
    generator=None,
    ex_re_threshold=None,
)
# 设置降维方法为FastICA
final_defence.reduce = "PCA"
final_defence.clustering_method = "KMeans"

In [11]:
# 2. 进行防御检测
final_defence.activations_by_class = final_defence._segment_by_class(all_activations, all_labels)
(
    final_defence.clusters_by_class,
    final_defence.red_activations_by_class,
) = final_defence.cluster_activations()

In [None]:
# ----------------- 基于激活聚类的中毒检测 ----------------

# 运行检测（传入特征和对应标签）
report, is_clean_pred_lst = final_defence.detect_poison()
# 转成布尔：True=clean, False=poison
is_clean_pred = np.array(is_clean_pred_lst, dtype=bool)

# all_is_poisoned_batch: True=poisoned, False=clean
# ground‐truth is_clean: True=clean, False=poison
is_clean_gt = ~all_is_poisoned_batch

result_json = final_defence.evaluate_defence(is_clean=is_clean_gt)
import json
print(json.dumps(report, indent=4))

In [None]:
# 打印结果
print(f"特征维度: {all_features.shape}")
print(f"标签数量: {len(all_labels)}")
print(f"投毒样本比例: {np.mean(all_is_poisoned):.2%}")
print("检测评估结果：", result_json)

In [None]:
import json

# 如果 result_json 是字符串，先转换为字典
if isinstance(result_json, str):
    result_json = json.loads(result_json)

class_2_result = result_json["class_2"]

TP_num = class_2_result["TruePositive"]["numerator"]
FN_num = class_2_result["FalseNegative"]["numerator"]
FP_num = class_2_result["FalsePositive"]["numerator"]
TN_num = class_2_result["TrueNegative"]["numerator"]

total_positive = TP_num + FN_num
total_negative = FP_num + TN_num

if total_positive != 0:
    TPR = TP_num / total_positive * 100  # 检出率 (TPR)
    FNR = FN_num / total_positive * 100         # 漏检率 (FNR)
else:
    TPR = None
    FNR = None

if total_negative != 0:
    FPR = FP_num / total_negative * 100   # 误报率 (FPR)
else:
    FPR = None

print("\n检测性能指标：")
if TPR is not None:
    print(f"检出率: {TPR:.2f}%")
else:
    print("检出率: 未定义（无正样本）")

if FPR is not None:
    print(f"误报率：{FPR:.2f}%")
else:
    print("误报率: 未定义（无正样本）")

if FNR is not None:
    print(f"漏检率: {FNR:.2f}%")
else:
    print("漏检率: 未定义（无正样本）")

print(parquet_path)


In [None]:
# 初始化计数器和累加器
valid_tp_count = 0  # 有效TruePositive类别数
valid_fn_count = 0  # 有效FalseNegative类别数
total_fp_count = 0  # 所有类别数（用于FalsePositive）

tp_sum = 0  # TruePositive累加
fn_sum = 0  # FalseNegative累加
fp_sum = 0  # FalsePositive累加

# 遍历所有类别
for class_name, metrics in result_json.items():
    # TruePositive
    if metrics["TruePositive"]["rate"] != -1:
        tp_sum += metrics["TruePositive"]["rate"]
        valid_tp_count += 1
    
    # FalseNegative
    if metrics["FalseNegative"]["rate"] != -1:
        fn_sum += metrics["FalseNegative"]["rate"]
        valid_fn_count += 1
    
    # FalsePositive（所有类别都有效）
    fp_sum += metrics["FalsePositive"]["rate"]
    total_fp_count += 1

# 计算平均指标
avg_detection_rate = tp_sum / valid_tp_count if valid_tp_count > 0 else 0
avg_miss_rate = fn_sum / valid_fn_count if valid_fn_count > 0 else 0
avg_false_alarm_rate = fp_sum / total_fp_count

print("\n所有类别的平均检测性能指标：")
print(f"平均检出率: {avg_detection_rate:.2f}%")
print(f"平均误报率: {avg_false_alarm_rate:.2f}%")
print(f"平均漏检率: {avg_miss_rate:.2f}%")

print("\n各类别详细指标：")
for class_name, metrics in result_json.items():
    print(f"\n{class_name}:")
    tp_rate = metrics["TruePositive"]["rate"]
    fn_rate = metrics["FalseNegative"]["rate"]
    fp_rate = metrics["FalsePositive"]["rate"]
    
    print(f"检出率: {tp_rate if tp_rate != -1 else 'N/A'}%")
    print(f"误报率: {fp_rate:.2f}%")
    print(f"漏检率: {fn_rate if fn_rate != -1 else 'N/A'}%")

In [None]:
# # 4. 可视化：每个类别随机选取 50 个正常样本并绘制实心方块，投毒样本绘制红色空心三角形
# import numpy as np

# # 如果 red_acts_dict 是列表
# red_acts_list = final_defence.red_activations_by_class  # 其实是一个 list

# all_red_acts = []
# all_labels_arr = []

# # 如果你知道它是 [class_0_acts, class_1_acts, ...] 这种结构，
# # 或文档说明第 i 个元素对应第 i 个类别，可以使用枚举
# for i, acts in enumerate(red_acts_list):
#     # acts 就是第 i 个类别的降维激活
#     all_red_acts.append(acts)
#     # 如果你能对应上类别标签，假设 i 就是类别标签，或者另有某个映射
#     # 这里演示假设 i 即为类别
#     all_labels_arr.extend([i]*len(acts))

# all_red_acts = np.concatenate(all_red_acts, axis=0)
# all_labels_arr = np.array(all_labels_arr)
# print(all_labels_arr)


# print("合并后的降维激活：", all_red_acts.shape)
# print("合并后的标签：", all_labels_arr.shape)

In [None]:
# features_2d=all_red_acts
# all_labels_batch=all_labels_arr

# plt.figure(figsize=(10, 8),dpi=1200)
# unique_labels = np.unique(all_labels_batch)

# # 给不同类别分配颜色
# colors = plt.cm.get_cmap("tab20b", len(unique_labels))
# # colors = plt.cm.get_cmap("plasma", len(unique_labels))
# label_to_color = {lab: colors(i) for i, lab in enumerate(unique_labels)}

# n_per_class = 50

# for lab in unique_labels:
#     # 找到该类别、且未投毒的样本索引
#     idx_clean_mask = np.where((all_labels_batch == lab) & (~all_is_poisoned_batch))[0]
#     if len(idx_clean_mask) == 0:
#         continue
#     # 随机选取 50 个，或不足 50 个则全部选
#     select_num = min(n_per_class, len(idx_clean_mask))
#     selected_idx = np.random.choice(idx_clean_mask, size=select_num, replace=False)
#     if lab != 2:
#         # 绘制散点（实心方块）
#         plt.scatter(
#             features_2d[selected_idx, 0],
#             features_2d[selected_idx, 1],
#             color=label_to_color[lab],
#             marker='s',  # 实心方块
#             s=100,
#             alpha=0.8,
#             # label=f"Other Class (Clean)" if lab not in plt.gca().get_legend_handles_labels()[1] else ""
#             # 为防止重复图例，每类只在首次绘制时加 legend
#         )
#     if lab == 2:
#         # 绘制散点（）
#         plt.scatter(
#             features_2d[selected_idx, 0],
#             features_2d[selected_idx, 1],
#             color="#C82423",
#             marker='.',  # 
#             s=130,
#             alpha=0.8,
#             label=f"Target Class (Clean)" if lab not in plt.gca().get_legend_handles_labels()[1] else ""
#             # 为防止重复图例，每类只在首次绘制时加 legend
#         )
# poison_idx = np.where(all_is_poisoned_batch)[0]

# # 如果投毒样本超过50个，则随机选取50个
# if len(poison_idx) > 50:
#     np.random.shuffle(poison_idx)
#     poison_idx = poison_idx[:50]

# if len(poison_idx) > 0:
#     plt.scatter(
#         features_2d[poison_idx, 0],
#         features_2d[poison_idx, 1],
#         marker='^',       # 三角形
#         facecolors='none', 
#         edgecolors='red',
#         s=140,
#         alpha=0.8,
#         linewidth=1.5,
#         label='Poisoned'
#     )

# plt.title("Feature Representation")
# # plt.xlabel("Principal Component 1")
# # plt.ylabel("Principal Component 2")
# # plt.xlabel("Principal Component 1")
# # plt.ylabel("Principal Component 2")
# plt.grid(True, linestyle='--', alpha=0.3)
# plt.legend(loc='best')
# plt.tight_layout()
# plt.savefig("pca_subsampled_0.1-2.png",dpi=1200)
# plt.show()