In [None]:
# ---------------------------------
# 阶段 1: 生成教师的软目标 (Soft Targets)
# (在 K-Fold 循环之前运行)
# ---------------------------------
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from KnowledgeDistillation.teacher_model import TeacherModel
from KnowledgeDistillation.student_model import StudentModel # <-- 导入
from KnowledgeDistillation.student_loss import DistillationLoss # <-- 导入
from KnowledgeDistillation.loss import WeightedMSELoss
# (导入所有其他必要的库: os, transforms, PastureDataset, etc.)

def generate_soft_targets(teacher_model, full_df, args, device):
    """
    使用教师模型为所有数据生成预测。
    """
    print("--- 阶段 1: 正在生成教师的软目标 ---")
    
    # 1. 加载用于 *推理* 的数据集 (使用验证集变换)
    val_transforms = transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    # *** 注意：这里必须使用 Teacher 用的 PastureDataset ***
    # 它需要加载 (Image, Table)
    from KnowledgeDistillation.teacher_train import PastureDataset # 假设您将它们放在一起
    
    inference_dataset = PastureDataset(full_df, args.img_dir, val_transforms, args.img_size)
    
    # num_workers=0 可以避免多进程的顺序问题
    inference_loader = DataLoader(inference_dataset, batch_size=args.batch_size * 2, 
                                  shuffle=False, num_workers=0) 
    
    teacher_model.to(device)
    teacher_model.eval()
    
    all_teacher_preds = []

    with torch.no_grad():
        for batch in tqdm(inference_loader, desc="生成软目标"):
            image = batch['image'].to(device)
            numeric = batch['numeric'].to(device)
            categorical = batch['categorical'].to(device)
            
            # 教师模型预测
            pred_log = teacher_model(image, numeric, categorical)
            all_teacher_preds.append(pred_log.cpu())

    all_teacher_preds = torch.cat(all_teacher_preds, dim=0).numpy()
    
    # 5. 将预测添加回原始 DataFrame
    log_target_cols = ['log_Dry_Green_g', 'log_Dry_Dead_g', 'log_Dry_Clover_g', 
                       'log_GDM_g', 'log_Dry_Total_g']
    
    teacher_pred_cols = [f"teacher_{c}" for c in log_target_cols]
    
    pred_df = pd.DataFrame(all_teacher_preds, columns=teacher_pred_cols, index=full_df.index)
    
    # 合并
    df_with_soft_targets = full_df.join(pred_df)
    
    print(f"软目标生成完毕。总共 {len(df_with_soft_targets)} 条记录。")
    return df_with_soft_targets