In [2]:
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import sys
import warnings


BASE_PATH = "/kaggle/input/csiro-teacher-model/pytorch/default/1/CSIRO---Image2Biomass-Prediction"
if BASE_PATH not in sys.path:
    sys.path.append(BASE_PATH)

from KnowledgeDistillation.teacher_model import TeacherModel

warnings.filterwarnings("ignore")

In [6]:
# --- 1. 配置 (Configuration) ---
class Config:
    
    # 包含 "best_teacher_model_fold_k.pth" 文件的目录
    MODEL_DIR = "/kaggle/input/csiro-teacher-model/pytorch/default/1/CSIRO---Image2Biomass-Prediction/KnowledgeDistillation/teacher_model_output"
    
    # 包含 "train_processed.csv" 的目录
    DATA_DIR = "/kaggle/input/csiro-teacher-model/pytorch/default/1/CSIRO---Image2Biomass-Prediction/csiro-biomass/preprocessing_output"
    
    # 包含 "train" 图像文件夹的目录
    IMG_DIR_BASE = "/kaggle/input/csiro-teacher-model/pytorch/default/1/CSIRO---Image2Biomass-Prediction/csiro-biomass/"

    # ---
    
    TRAIN_CSV = os.path.join(DATA_DIR, "train_processed.csv")
    IMG_DIR = os.path.join(IMG_DIR_BASE, "train")
    
    # [输出]
    OUTPUT_CSV = "/kaggle/working/train_with_soft_targets.csv"
    
    # 模型 (必须与训练时一致)
    IMG_MODEL = 'efficientnet_b2'
    IMG_SIZE = 260
    
    # 训练 (必须与训练时一致)
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    SEED = 42 # [!!] 必须与训练教师时 KFold 的 random_state 一致
    N_SPLITS = 5
    BATCH_SIZE = 16 # 推理时 BATCH_SIZE 可以稍大
    NUM_WORKERS = 2

In [4]:
# --- 2. 数据集 (PastureDataset) ---
# [!!] 必须与训练教师时使用的 Dataset *完全*一致

class PastureDataset(Dataset):
    def __init__(self, df, img_dir, transforms, img_size):
        self.df = df
        self.img_dir = img_dir
        self.transforms = transforms
        self.img_size = img_size
        self.numeric_cols = ['Pre_GSHH_NDVI', 'Height_Ave_cm', 'month_sin', 'month_cos']
        self.categorical_cols = ['State_encoded', 'Species_encoded']

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row.name # row.name 是 image_path 索引
        full_img_path = os.path.join(self.img_dir, img_path.split('/')[-1])
        
        try:
            image = Image.open(full_img_path).convert('RGB')
            image = self.transforms(image)
        except Exception as e:
            print('wrong')
            image = torch.zeros((3, self.img_size, self.img_size))



        numeric = torch.tensor(row[self.numeric_cols].values.astype(np.float32), dtype=torch.float32)
        categorical = torch.tensor(row[self.categorical_cols].values.astype(np.int64), dtype=torch.long)

        return {
            'image': image,
            'numeric': numeric,
            'categorical': categorical,
            'orig_index': row.name # 我们传递索引以便稍后排序
        }

In [7]:
# --- 3. 主函数 (Main OOF Generation) ---

def generate_oof():
    args = Config()
    print(f"--- 生成 OOF 软目标 ---")
    print(f"加载教师模型来源: {args.MODEL_DIR}")
    print(f"加载数据来源: {args.TRAIN_CSV}")
    
    # --- 1. 加载数据 ---
    df = pd.read_csv(args.TRAIN_CSV, index_col='image_path')
    num_states = df['State_encoded'].nunique()
    num_species = df['Species_encoded'].nunique()
    print(f"Found {num_states} states and {num_species} species.")

    # --- 2. 定义验证集变换 (必须与训练时一致) ---
    val_transforms = transforms.Compose([
        transforms.Resize((args.IMG_SIZE, args.IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # --- 3. K-Fold 拆分 (必须与训练时一致) ---
    kf = KFold(n_splits=args.N_SPLITS, shuffle=True, random_state=args.SEED)

    # 用于存储 OOF 预测的列表
    oof_preds_list = []
    oof_indices_list = []

    for fold, (train_indices, val_indices) in enumerate(kf.split(df)):
        fold_num = fold + 1
        print(f"\n========== FOLD {fold_num}/{args.N_SPLITS} - OOF 预测 ========== ")
        
        # --- 4. 加载该 Fold 对应的模型 ---
        model_path = os.path.join(args.MODEL_DIR, f"best_teacher_model_fold_{fold_num}.pth")
        if not os.path.exists(model_path):
            print(f"[!!] 错误: 未找到模型 {model_path}。请检查路径。")
            continue
            
        print(f"加载模型: {model_path}")
        model = TeacherModel(num_states, num_species, img_model_name=args.IMG_MODEL).to(args.DEVICE)
        model.load_state_dict(torch.load(model_path, map_location=args.DEVICE))
        model.eval()

        # --- 5. 仅在验证集上创建数据加载器 ---
        # [关键] 我们 *只* 预测验证集 (val_indices)
        val_df = df.iloc[val_indices]
        val_dataset = PastureDataset(val_df, args.IMG_DIR, val_transforms, args.IMG_SIZE)
        val_loader = DataLoader(val_dataset, batch_size=args.BATCH_SIZE, shuffle=False, num_workers=args.NUM_WORKERS)
        
        print(f"为 Fold {fold_num} 的 {len(val_df)} 个验证样本生成预测...")
        
        fold_preds = []
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Fold {fold_num} 推理"):
                image = batch['image'].to(args.DEVICE)
                numeric = batch['numeric'].to(args.DEVICE)
                categorical = batch['categorical'].to(args.DEVICE)
                
                pred_log = model(image, numeric, categorical)
                fold_preds.append(pred_log.cpu())

        # 存储这个 fold 的预测和对应的原始索引
        oof_preds_list.append(torch.cat(fold_preds, dim=0))
        oof_indices_list.append(val_indices) # 存储索引以便稍后排序

    print("\n--- 所有 Folds 推理完毕 ---")

    # --- 6. 重新组合 OOF 预测 ---
    if not oof_preds_list:
        print("没有生成任何预测。退出。")
        return

    print("正在重新组合 OOF 预测...")
    all_preds_tensor = torch.cat(oof_preds_list, dim=0).numpy()
    all_indices_array = np.concatenate(oof_indices_list)

    # [关键] 创建一个 DataFrame 并按索引排序，以确保与原始 df 顺序一致
    df_oof = pd.DataFrame(all_preds_tensor, index=all_indices_array)
    df_oof = df_oof.sort_index()

    # --- 7. 合并并保存 ---
    log_target_cols = ['log_Dry_Green_g', 'log_Dry_Dead_g', 'log_Dry_Clover_g', 'log_GDM_g', 'log_Dry_Total_g']
    teacher_pred_cols = [f"teacher_{c}" for c in log_target_cols]
    
    df_oof.columns = teacher_pred_cols
    df_oof.index.name = 'temp_index' # 临时命名

    # 将索引转换为与 df 匹配
    df_oof_reset = df_oof.reset_index()
    df_reset = df.reset_index() # df 的索引是 'image_path'

    # 将 df_oof 的索引替换为 'image_path'
    df_oof_final = df_oof.set_index(df_reset.loc[df_oof_reset['temp_index'], 'image_path'].values)

    # 将 OOF 预测列连接回原始 DataFrame
    df_final = df.join(df_oof_final[teacher_pred_cols])

    # 检查
    if df_final[teacher_pred_cols[0]].isnull().any():
        print("[!!] 警告: 合并后存在 NaN 值。K-Fold 拆分可能不匹配。")
    else:
        print("合并成功。")

    # 保存
    df_final.to_csv(args.OUTPUT_CSV)
    print(f"\n[成功] 带有 OOF 软目标的新文件已保存到:\n{args.OUTPUT_CSV}")
    print("\n文件头部内容:")
    print(df_final.head())


if __name__ == "__main__":
    generate_oof()

--- 生成 OOF 软目标 ---
加载教师模型来源: /kaggle/input/csiro-teacher-model/pytorch/default/1/CSIRO---Image2Biomass-Prediction/KnowledgeDistillation/teacher_model_output
加载数据来源: /kaggle/input/csiro-teacher-model/pytorch/default/1/CSIRO---Image2Biomass-Prediction/csiro-biomass/preprocessing_output/train_processed.csv
Found 4 states and 15 species.

加载模型: /kaggle/input/csiro-teacher-model/pytorch/default/1/CSIRO---Image2Biomass-Prediction/KnowledgeDistillation/teacher_model_output/best_teacher_model_fold_1.pth
为 Fold 1 的 72 个验证样本生成预测...


Fold 1 推理: 100%|██████████| 5/5 [00:02<00:00,  1.94it/s]



加载模型: /kaggle/input/csiro-teacher-model/pytorch/default/1/CSIRO---Image2Biomass-Prediction/KnowledgeDistillation/teacher_model_output/best_teacher_model_fold_2.pth
为 Fold 2 的 72 个验证样本生成预测...


Fold 2 推理: 100%|██████████| 5/5 [00:02<00:00,  1.83it/s]



加载模型: /kaggle/input/csiro-teacher-model/pytorch/default/1/CSIRO---Image2Biomass-Prediction/KnowledgeDistillation/teacher_model_output/best_teacher_model_fold_3.pth
为 Fold 3 的 71 个验证样本生成预测...


Fold 3 推理: 100%|██████████| 5/5 [00:02<00:00,  1.97it/s]



加载模型: /kaggle/input/csiro-teacher-model/pytorch/default/1/CSIRO---Image2Biomass-Prediction/KnowledgeDistillation/teacher_model_output/best_teacher_model_fold_4.pth
为 Fold 4 的 71 个验证样本生成预测...


Fold 4 推理: 100%|██████████| 5/5 [00:02<00:00,  2.19it/s]



加载模型: /kaggle/input/csiro-teacher-model/pytorch/default/1/CSIRO---Image2Biomass-Prediction/KnowledgeDistillation/teacher_model_output/best_teacher_model_fold_5.pth
为 Fold 5 的 71 个验证样本生成预测...


Fold 5 推理: 100%|██████████| 5/5 [00:02<00:00,  2.23it/s]


--- 所有 Folds 推理完毕 ---
正在重新组合 OOF 预测...
合并成功。

[成功] 带有 OOF 软目标的新文件已保存到:
/kaggle/working/train_with_soft_targets.csv

文件头部内容:
                       Sampling_Date State            Species  Pre_GSHH_NDVI  \
image_path                                                                     
train/ID1011485656.jpg    2015-09-04   Tas    Ryegrass_Clover      -0.246319   
train/ID1012260530.jpg    2015-04-01   NSW            Lucerne      -0.707060   
train/ID1025234388.jpg    2015-09-01    WA  SubcloverDalkeith      -1.826004   
train/ID1028611175.jpg    2015-05-18   Tas           Ryegrass       0.016962   
train/ID1035947949.jpg    2015-09-11   Tas           Ryegrass      -0.772880   

                        Height_Ave_cm  Dry_Clover_g  Dry_Dead_g  Dry_Green_g  \
image_path                                                                     
train/ID1011485656.jpg      -0.285204        0.0000     31.9984      16.2751   
train/ID1012260530.jpg       0.818240        0.0000      0.0000       7.60


