In [None]:
import torch
import time
import requests
from PIL import Image
# 修改这里：引入 AutoImageProcessor (通用处理器)，它可以自动识别 ViT 和 Swin
from transformers import AutoImageProcessor 
from transformers import ViTForImageClassification, SwinForImageClassification

# ==========================================
# 1. 准备工作
# ==========================================
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"运行设备: {device}")

# 加载测试图片 (大熊猫)
url = "http://images.cocodataset.org/val2017/000000000285.jpg" 
try:
    image = Image.open(requests.get(url, stream=True).raw)
except:
    print("图片下载失败，请检查网络。")

# ==========================================
# 2. 定义对比函数
# ==========================================
def run_inference(model_name, model_class, image):
    print(f"\n--- 正在加载模型: {model_name} ---")
    
    # 核心修改：统一使用 AutoImageProcessor
    processor = AutoImageProcessor.from_pretrained(model_name)
    model = model_class.from_pretrained(model_name).to(device)
    
    # 预处理
    inputs = processor(images=image, return_tensors="pt").to(device)
    
    # 预热模型 (Warmup)
    with torch.no_grad():
        _ = model(**inputs)
    
    # 正式推理与计时
    start_time = time.time()
    with torch.no_grad():
        outputs = model(**inputs)
    end_time = time.time()
    
    # 获取结果
    logits = outputs.logits
    predicted_label_idx = logits.argmax(-1).item()
    label = model.config.id2label[predicted_label_idx]
    confidence = torch.softmax(logits, dim=-1).max().item()
    
    print(f"预测结果: {label} (置信度: {confidence:.2%})")
    print(f"推理耗时: {(end_time - start_time)*1000:.2f} ms")
    return label, (end_time - start_time)*1000

# ==========================================
# 3. 执行对比 (对应报告 ViT vs Swin 章节)
# ==========================================

# --- 实验 A: Vision Transformer (ViT-Base) ---
vit_model = "google/vit-base-patch16-224"
# 注意：这里不再需要传 processor 类，函数内部会自动处理
run_inference(vit_model, ViTForImageClassification, image)

# --- 实验 B: Swin Transformer (Tiny) ---
swin_model = "microsoft/swin-tiny-patch4-window7-224"
run_inference(swin_model, SwinForImageClassification, image)

# ==========================================
# 4. 显示图片
# ==========================================
print("\n测试图片:")
display(image.resize((300, 300)))

运行设备: cpu

--- 正在加载模型: google/vit-base-patch16-224 ---


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]