# 多模态检测一致性实验 - 演示流水线

本notebook演示了如何使用多模态检测一致性实验代码进行对抗样本检测。

## 目录
1. [环境设置](#环境设置)
2. [数据准备](#数据准备)
3. [模型初始化](#模型初始化)
4. [文本增强演示](#文本增强演示)
5. [检索系统演示](#检索系统演示)
6. [SD参考生成演示](#SD参考生成演示)
7. [对抗检测演示](#对抗检测演示)
8. [完整流水线演示](#完整流水线演示)
9. [结果可视化](#结果可视化)

## 环境设置

In [None]:
import sys
import os
import warnings
warnings.filterwarnings('ignore')

# 添加项目路径
sys.path.append('../src')

# 导入必要的库
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import yaml

# 导入项目模块
from src.text_augment import TextAugmenter, TextAugmentConfig
from src.retrieval import MultiModalRetriever, RetrievalConfig
from src.sd_ref import SDReferenceGenerator, SDReferenceConfig
from src.detector import AdversarialDetector, DetectorConfig
from src.pipeline import DefensePipeline, PipelineConfig
from src.utils.visualizer import ResultVisualizer, VisualizationConfig

print("环境设置完成！")
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU数量: {torch.cuda.device_count()}")
    print(f"当前GPU: {torch.cuda.get_device_name()}")


## 数据准备

In [None]:
# 准备示例数据
sample_texts = [
    "A cat sitting on a windowsill",
    "A dog playing in the park",
    "A beautiful sunset over the ocean",
    "A person riding a bicycle",
    "A red car parked on the street"
]

# 示例图像路径（需要根据实际情况调整）
sample_image_paths = [
    "../data/sample_images/cat.jpg",
    "../data/sample_images/dog.jpg",
    "../data/sample_images/sunset.jpg",
    "../data/sample_images/bicycle.jpg",
    "../data/sample_images/car.jpg"
]

print(f"准备了 {len(sample_texts)} 个文本样本")
print(f"准备了 {len(sample_image_paths)} 个图像路径")

## 模型初始化

In [None]:
# 加载配置
with open('../configs/default.yaml', 'r', encoding='utf-8') as f:
    config = yaml.safe_load(f)

# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")

# 初始化各个组件的配置
text_augment_config = TextAugmentConfig(
    num_variants=config['text_augment']['num_variants'],
    similarity_threshold=config['text_augment']['similarity_threshold'],
    device=device
)

retrieval_config = RetrievalConfig(
    clip_model_name=config['models']['clip']['model_name'],
    device=device,
    batch_size=config['retrieval']['batch_size'],
    top_k=config['retrieval']['top_k']
)

sd_config = SDReferenceConfig(
    model_name=config['models']['stable_diffusion']['model_name'],
    device=device,
    num_images_per_text=config['sd_reference']['num_images_per_text']
)

detector_config = DetectorConfig(
    clip_model_name=config['models']['clip']['model_name'],
    device=device
)

print("配置初始化完成！")

## 文本增强演示

In [None]:
# 初始化文本增强器
print("初始化文本增强器...")
text_augmenter = TextAugmenter(text_augment_config)

# 演示文本增强
original_text = sample_texts[0]
print(f"原始文本: {original_text}")

# 生成文本变体
print("\n生成文本变体...")
variants = text_augmenter.generate_variants(original_text)

print(f"\n生成了 {len(variants)} 个文本变体:")
for i, variant in enumerate(variants[:5], 1):
    print(f"{i}. {variant}")

# 计算相似度
similarities = []
for variant in variants[:5]:
    sim = text_augmenter.compute_text_similarity(original_text, variant)
    similarities.append(sim)

print(f"\n与原文本的相似度: {[f'{sim:.3f}' for sim in similarities]}")

## 检索系统演示

In [None]:
# 初始化检索器
print("初始化多模态检索器...")
retriever = MultiModalRetriever(retrieval_config)

# 构建图像索引（使用示例图像）
print("构建图像索引...")
# 注意：这里需要实际的图像文件
# image_features = retriever.build_image_index(sample_image_paths)
# print(f"图像索引构建完成，包含 {len(image_features)} 个图像特征")

# 构建文本索引
print("构建文本索引...")
text_features = retriever.build_text_index(sample_texts)
print(f"文本索引构建完成，包含 {len(text_features)} 个文本特征")

# 演示文本到文本检索
query_text = "A feline animal resting near a window"
print(f"\n查询文本: {query_text}")

results = retriever.search_text_to_text(query_text, top_k=3)
print(f"\n检索结果 (Top-3):")
for i, (idx, score) in enumerate(results, 1):
    print(f"{i}. 索引: {idx}, 相似度: {score:.3f}, 文本: {sample_texts[idx]}")

## SD参考生成演示

In [None]:
# 初始化SD参考生成器
print("初始化Stable Diffusion参考生成器...")
sd_generator = SDReferenceGenerator(sd_config)

# 生成参考图像
test_prompt = sample_texts[0]
print(f"生成提示: {test_prompt}")

print("生成参考图像...")
reference_images = sd_generator.generate_reference_images([test_prompt])

print(f"生成了 {len(reference_images[0])} 张参考图像")

# 显示生成的图像
fig, axes = plt.subplots(1, len(reference_images[0]), figsize=(15, 5))
if len(reference_images[0]) == 1:
    axes = [axes]

for i, img in enumerate(reference_images[0]):
    axes[i].imshow(img)
    axes[i].set_title(f"参考图像 {i+1}")
    axes[i].axis('off')

plt.suptitle(f"生成的参考图像 - 提示: {test_prompt}")
plt.tight_layout()
plt.show()

## 对抗检测演示

In [None]:
# 初始化对抗检测器
print("初始化对抗检测器...")
detector = AdversarialDetector(detector_config)

# 准备测试数据
clean_text = "A cat sitting on a windowsill"
adversarial_text = "A feline creature positioned atop a window ledge"  # 可能的对抗样本

print(f"干净文本: {clean_text}")
print(f"可疑文本: {adversarial_text}")

# 检测干净样本
print("\n检测干净样本...")
clean_result = detector.detect_adversarial_sample(clean_text, None)
print(f"检测结果: {clean_result}")

# 检测可疑样本
print("\n检测可疑样本...")
adv_result = detector.detect_adversarial_sample(adversarial_text, None)
print(f"检测结果: {adv_result}")

# 批量检测
test_texts = [clean_text, adversarial_text] + sample_texts[:3]
print(f"\n批量检测 {len(test_texts)} 个样本...")
batch_results = detector.detect_batch_samples(test_texts, [None] * len(test_texts))

for i, (text, result) in enumerate(zip(test_texts, batch_results)):
    status = "对抗" if result['is_adversarial'] else "干净"
    confidence = result['confidence']
    print(f"{i+1}. [{status}] (置信度: {confidence:.3f}) {text[:50]}...")

## 完整流水线演示

In [None]:
# 初始化完整流水线
pipeline_config = PipelineConfig(
    enable_text_augmentation=True,
    enable_retrieval=True,
    enable_sd_reference=True,
    enable_detection=True,
    enable_profiling=True
)

print("初始化防御流水线...")
pipeline = DefensePipeline(pipeline_config)

# 处理单个样本
test_text = sample_texts[0]
test_image_path = None  # 可以提供图像路径

print(f"处理样本: {test_text}")
result = pipeline.process_sample(test_text, test_image_path)

print(f"\n流水线处理结果:")
print(f"- 原始输入: {result.original_input['text']}")
print(f"- 生成变体数量: {len(result.text_variants) if result.text_variants else 0}")
print(f"- 检索结果数量: {len(result.retrieved_items) if result.retrieved_items else 0}")
print(f"- 生成参考数量: {len(result.generated_references) if result.generated_references else 0}")
print(f"- 检测结果: {result.detection_result}")
print(f"- 处理时间: {result.get_total_time():.3f}秒")

# 批量处理
print(f"\n批量处理 {len(sample_texts)} 个样本...")
batch_results = pipeline.process_batch_samples(sample_texts, [None] * len(sample_texts))

print(f"批量处理完成，处理了 {len(batch_results)} 个样本")

# 显示批量结果摘要
adversarial_count = sum(1 for r in batch_results if r.detection_result and r.detection_result.get('is_adversarial', False))
total_time = sum(r.get_total_time() for r in batch_results)
avg_time = total_time / len(batch_results)

print(f"\n批量处理摘要:")
print(f"- 检测到对抗样本: {adversarial_count}/{len(batch_results)}")
print(f"- 总处理时间: {total_time:.3f}秒")
print(f"- 平均处理时间: {avg_time:.3f}秒/样本")

## 结果可视化

In [None]:
# 初始化可视化器
viz_config = VisualizationConfig(
    save_plots=False,
    show_plots=True,
    figure_size=(12, 8)
)

visualizer = ResultVisualizer(viz_config)

# 准备可视化数据
detection_results = []
confidence_scores = []
processing_times = []

for result in batch_results:
    if result.detection_result:
        detection_results.append(1 if result.detection_result.get('is_adversarial', False) else 0)
        confidence_scores.append(result.detection_result.get('confidence', 0.5))
    else:
        detection_results.append(0)
        confidence_scores.append(0.5)
    
    processing_times.append(result.get_total_time())

# 1. 检测结果分布
plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
labels = ['干净样本', '对抗样本']
counts = [detection_results.count(0), detection_results.count(1)]
plt.pie(counts, labels=labels, autopct='%1.1f%%', startangle=90)
plt.title('检测结果分布')

# 2. 置信度分布
plt.subplot(1, 3, 2)
plt.hist(confidence_scores, bins=10, alpha=0.7, edgecolor='black')
plt.xlabel('置信度')
plt.ylabel('频次')
plt.title('检测置信度分布')
plt.grid(True, alpha=0.3)

# 3. 处理时间分布
plt.subplot(1, 3, 3)
plt.bar(range(len(processing_times)), processing_times, alpha=0.7)
plt.xlabel('样本索引')
plt.ylabel('处理时间 (秒)')
plt.title('样本处理时间')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# 4. 流水线性能统计
pipeline_stats = pipeline.get_pipeline_statistics()

print(f"\n流水线性能统计:")
for key, value in pipeline_stats.items():
    if isinstance(value, float):
        print(f"- {key}: {value:.3f}")
    else:
        print(f"- {key}: {value}")

## 总结

本notebook演示了多模态检测一致性实验代码的主要功能：

1. **文本增强**: 生成语义相似但表达不同的文本变体
2. **多模态检索**: 在文本和图像之间进行相似性检索
3. **SD参考生成**: 使用Stable Diffusion生成参考图像
4. **对抗检测**: 基于一致性检测对抗样本
5. **完整流水线**: 集成所有组件的端到端处理
6. **结果可视化**: 直观展示检测结果和性能指标

### 下一步

- 尝试使用真实的对抗样本进行测试
- 调整各组件的参数以优化性能
- 在更大的数据集上进行评估
- 探索不同的检测策略和阈值设置