# Plant Disease Recognition Model: Inference Notebook

Name: Zihan

### Step 1 - (One-Time Utility) Create Class Map File

In [None]:
# import json
# import numpy as np

# # 1. 将你从Colab复制的字典粘贴在这里
# # 我们需要 import numpy as np 才能让Python理解 np.int64() 是什么
# original_map = {0: np.int64(0), 1: np.int64(5), 2: np.int64(6), 3: np.int64(7), 4: np.int64(9), 5: np.int64(10), 6: np.int64(13), 7: np.int64(18), 8: np.int64(22), 9: np.int64(31), 10: np.int64(34), 11: np.int64(38), 12: np.int64(49), 13: np.int64(52), 14: np.int64(56), 15: np.int64(61), 16: np.int64(67), 17: np.int64(75), 18: np.int64(82), 19: np.int64(83), 20: np.int64(92), 21: np.int64(98)}

# # 2. 清理数据，将其转换为纯粹的Python类型
# #   - 将key转换为字符串 (JSON标准)
# #   - 将value从np.int64转换为普通的int
# cleaned_map = {str(k): int(v) for k, v in original_map.items()}


# # 3. 定义输出文件名
# output_filename = "class_map.json"

# # 4. 使用json库将清理后的字典保存为格式正确的文件
# with open(output_filename, 'w') as f:
#     json.dump(cleaned_map, f, indent=4)

# print(f"✅ 成功创建了格式完全正确的 '{output_filename}' 文件！")

### Step 2 - Environment Configuration & Path Setup

In [32]:
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from pathlib import Path
import json
import timm
import warnings
import os

# 忽略一些不影响结果的警告信息
warnings.filterwarnings('ignore')

# ====================================================================
# 1. 配置信息 (Configuration - 使用你提供的最佳实践)
# ====================================================================

# --- 动态获取当前Notebook或脚本所在的目录 (你提供的代码) ---
# 这能确保我们总能从文件所在位置开始寻找，无论当前工作目录是什么
NOTEBOOK_DIR = Path(__file__).parent if "__file__" in globals() else Path().resolve()


class CFG:
    """
    存放所有配置信息，方便统一修改。
    """
    # --- 路径配置 (现在基于NOTEBOOK_DIR，非常稳健) ---
    # 训练好的模型权重文件路径
    MODEL_PATH = NOTEBOOK_DIR / "model.pth"
    # 索引 -> ID 的映射文件路径
    CLASS_MAP_PATH = NOTEBOOK_DIR / "class_map.json"
    
    # --- 模型配置 (必须与训练时完全一致) ---
    MODEL_NAME = 'swin_base_patch4_window7_224.ms_in1k'
    IMAGE_SIZE = 224
    
    # --- 推理设备 ---
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- 打印出最终的绝对路径以供检查 ---
print(f"脚本/Notebook所在目录: {NOTEBOOK_DIR}")
print(f"模型文件预期路径: {CFG.MODEL_PATH}")
print(f"映射文件预期路径: {CFG.CLASS_MAP_PATH}")

脚本/Notebook所在目录: E:\05_YZH_DS\02_Monash_DS\2025_S2_FIT5120_Industry_Experience_Studio_Project\06_main_project\03_github_submission\03_github_submission\2025-08-SDG13-Plant-X-Website\03_machine_learning_models\02_plant_disease_recognition
模型文件预期路径: E:\05_YZH_DS\02_Monash_DS\2025_S2_FIT5120_Industry_Experience_Studio_Project\06_main_project\03_github_submission\03_github_submission\2025-08-SDG13-Plant-X-Website\03_machine_learning_models\02_plant_disease_recognition\model.pth
映射文件预期路径: E:\05_YZH_DS\02_Monash_DS\2025_S2_FIT5120_Industry_Experience_Studio_Project\06_main_project\03_github_submission\03_github_submission\2025-08-SDG13-Plant-X-Website\03_machine_learning_models\02_plant_disease_recognition\class_map.json


### Step 3 - Load Class Mapping Dictionary

In [33]:
# ====================================================================
# 2. 加载“翻译词典”
# ====================================================================
def load_class_map(json_path):
    """从JSON文件加载 索引 -> ID 的映射字典"""
    try:
        with open(json_path, 'r') as f:
            idx_to_label = json.load(f)
            # JSON加载的key默认是字符串，我们需要将其转为整数以匹配PyTorch的输出
            idx_to_label = {int(k): v for k, v in idx_to_label.items()}
        print("✅ 翻译词典加载成功。")
        return idx_to_label
    except FileNotFoundError:
        print(f"🛑 错误: 映射文件未找到于 '{json_path}'")
        return None
    except json.JSONDecodeError:
        print(f"🛑 错误: '{json_path}' 不是一个有效的JSON文件。")
        return None

### Step 4 - Load Trained Model

In [34]:
# ====================================================================
# 3. 加载训练好的模型
# ====================================================================
def load_model(model_name, num_classes, model_path, device):
    """加载模型架构并载入训练好的权重"""
    try:
        # pretrained=False, 因为我们要加载自己的本地权重
        model = timm.create_model(model_name, pretrained=False, num_classes=num_classes)
        # 使用 map_location=device 确保无论模型在GPU还是CPU上训练，都能在当前设备正确加载
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.to(device)
        model.eval() # 切换到评估模式，这非常重要！
        print("✅ 模型加载成功。")
        return model
    except FileNotFoundError:
        print(f"🛑 错误: 模型文件未找到于 '{model_path}'")
        return None
    except Exception as e:
        print(f"🛑 加载模型时发生未知错误: {e}")
        return None

In [35]:
# ====================================================================
# 4. 定义图像预处理流程
# ====================================================================
# 推理时使用的图像转换，必须与训练时的验证/测试集转换完全一致
inference_transforms = transforms.Compose([
    transforms.Resize((CFG.IMAGE_SIZE, CFG.IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

### Step 5 - Define Image Preprocessing Pipeline

In [36]:
# ====================================================================
# 5. 核心预测函数
# ====================================================================
def predict_top3(model, image_path, transforms, idx_to_label_map, device):
    """对单张图片进行Top-3预测，只输出ID和概率"""
    try:
        # 加载并预处理图片
        image = Image.open(image_path).convert("RGB")
        # .unsqueeze(0) 是为了增加一个batch维度，因为模型需要 [B, C, H, W] 形状的输入
        image_tensor = transforms(image).unsqueeze(0).to(device)
    except FileNotFoundError:
        return f"🛑 错误: 图片文件未找到于 '{image_path}'"
    except Exception as e:
        return f"🛑 处理图片时发生错误: {e}"

    # 关闭梯度计算，能加速推理并节省显存
    with torch.no_grad():
        logits = model(image_tensor)

    # 应用Softmax函数将logits转换为概率
    probabilities = F.softmax(logits, dim=1)
    
    # 获取概率最高的3个值和它们对应的索引
    top3_probs, top3_indices = torch.topk(probabilities, 3, dim=1)

    # 将Tensor从GPU移动到CPU，并转换为Numpy数组以便处理
    top3_probs = top3_probs.squeeze().cpu().numpy()
    top3_indices = top3_indices.squeeze().cpu().numpy()
    
    results = []
    for i in range(3):
        class_idx = top3_indices[i]
        # **关键翻译步骤**：将模型的内部索引(class_idx)转换为真实ID(class_id)
        class_id = idx_to_label_map[class_idx]
        prob = top3_probs[i]
        
        results.append({
            "predicted_id": int(class_id), # 确保ID是整数
            "probability": f"{prob:.2%}"   # 格式化为百分比字符串
        })
        
    return results

### Step 6 - Execute Inference and Display Results

In [37]:
# ====================================================================
# 6. 主执行入口
# ====================================================================
if __name__ == '__main__':
    print("--- 开始植物病害识别推理 ---")
    print(f"使用设备: {CFG.DEVICE}")
    
    # 步骤 1: 加载翻译词典
    print("\n[1/3] 正在加载标签映射表...")
    idx_to_label = load_class_map(CFG.CLASS_MAP_PATH)
    
    # 步骤 2: 加载模型
    if idx_to_label is not None:
        print("\n[2/3] 正在加载模型...")
        # 模型的类别数必须与翻译词典的大小一致
        num_classes = len(idx_to_label)
        model = load_model(CFG.MODEL_NAME, num_classes, CFG.MODEL_PATH, CFG.DEVICE)
    else:
        model = None

    # 步骤 3: 执行预测
    if model is not None:
        print("\n[3/3] 正在执行预测...")
        # !! 需要你指定一张本地图片的路径 !!
        TEST_IMAGE_PATH = "test_images/011d0.jfif" # <--- 在这里修改为你本地的测试图片路径

        if Path(TEST_IMAGE_PATH).exists():
            top3_predictions = predict_top3(model, TEST_IMAGE_PATH, inference_transforms, idx_to_label, CFG.DEVICE)
            
            print("\n" + "="*30)
            print("--- 预测结果 (Top 3) ---")
            print(f"图片: {TEST_IMAGE_PATH}")
            print("="*30)
            for pred in top3_predictions:
                print(f"病害ID: {pred['predicted_id']:<5} | 概率: {pred['probability']}")
        else:
            print(f"🛑 错误: 测试图片 '{TEST_IMAGE_PATH}' 不存在。请修改路径后重试。")

--- 开始植物病害识别推理 ---
使用设备: cuda

[1/3] 正在加载标签映射表...
✅ 翻译词典加载成功。

[2/3] 正在加载模型...


✅ 模型加载成功。

[3/3] 正在执行预测...

--- 预测结果 (Top 3) ---
图片: test_images/011d0.jfif
病害ID: 0     | 概率: 89.83%
病害ID: 31    | 概率: 0.72%
病害ID: 22    | 概率: 0.68%
