In [1]:
import os
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
TOKEN = user_secrets.get_secret("GITHUB_TOKEN")
USERNAME = 'ada-yl2425'
REPO_NAME = 'CSIRO---Image2Biomass-Prediction'
!git clone https://{USERNAME}:{TOKEN}@github.com/{USERNAME}/{REPO_NAME}.git
!git pull origin main
!ls

Cloning into 'CSIRO---Image2Biomass-Prediction'...
remote: Enumerating objects: 527, done.[K
remote: Counting objects: 100% (16/16), done.[K
remote: Compressing objects: 100% (14/14), done.[K
remote: Total 527 (delta 6), reused 0 (delta 0), pack-reused 511 (from 2)[K
Receiving objects: 100% (527/527), 1.08 GiB | 49.54 MiB/s, done.
Resolving deltas: 100% (93/93), done.
Updating files: 100% (380/380), done.
fatal: not a git repository (or any parent up to mount point /kaggle)
Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).
CSIRO---Image2Biomass-Prediction  __notebook__.ipynb


In [2]:
!pip install torch torchvision pandas scikit-learn pillow tqdm timm

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curan

In [3]:
import os
import argparse
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split, KFold
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import warnings

In [4]:
# 忽略 PIL 的一些警告
warnings.filterwarnings("ignore", "(Possibly corrupt EXIF data|Truncated File Read)")

In [5]:
# --- 1. 评估指标 (Weighted R2) ---
def calculate_weighted_r2(y_true, y_pred, device):
    """
    在原始尺度上计算全局加权 R2
    (此版本已修复，符合官方公式)
    y_true, y_pred: 形状为 [N, 5] 的张量 (在原始尺度上)
    """
    weights = torch.tensor([0.1, 0.1, 0.1, 0.2, 0.5], dtype=torch.float32).to(device) # 形状 [5]

    # --- 1. SS_res (Residual Sum of Squares) ---
    # 按照公式: SS_res = Σ w_j * (y_j - ŷ_j)^2
    # weights * (y_true - y_pred) ** 2 -> 广播 [5] 到 [N, 5]
    # torch.sum(...) -> 聚合 N*5 个元素
    ss_res = torch.sum(weights * (y_true - y_pred) ** 2)

    # --- 2. SS_tot (Total Sum of Squares) ---

    # [修复] 2a. 计算全局加权均值 ȳ_w (y_mean_w)
    # ȳ_w = (Σ w_j * y_j) / (Σ w_j)

    # 分子 (Numerator): Σ w_j * y_j
    # (weights * y_true) -> 广播 [5] 到 [N, 5]
    # torch.sum(...) -> 聚合 N*5 个元素
    sum_weighted_values = torch.sum(weights * y_true)

    # 分母 (Denominator): Σ w_j
    # 1. 将 weights [5] 广播到 [N, 5] (N是批量大小)
    weights_broadcasted = weights.expand_as(y_true)
    # 2. 计算总权重和 (这等于 N * 1.0)
    sum_of_all_weights = torch.sum(weights_broadcasted)

    # 计算 ȳ_w
    y_mean_w = sum_weighted_values / (sum_of_all_weights + 1e-6)

    # [修复] 2b. 计算 SS_tot
    # 按照公式: SS_tot = Σ w_j * (y_j - ȳ_w)^2
    # (y_true - y_mean_w) -> 广播标量 ȳ_w 到 [N, 5]
    # weights * (...) -> 广播 [5] 到 [N, 5]
    ss_tot = torch.sum(weights * (y_true - y_mean_w) ** 2)

    # --- 3. R2 ---
    r2 = 1.0 - (ss_res / (ss_tot + 1e-6)) # +1e-6 防止除以零
    return r2.item()

In [6]:
# --- 2. 自定义数据集 ---
class PastureDataset(Dataset):
    """
    加载图像、表格数据和目标
    """
    def __init__(self, df, img_dir, transforms, img_size): # <-- 增加了 img_size
        self.df = df
        self.img_dir = img_dir
        self.transforms = transforms
        self.img_size = img_size  # <-- 存储 img_size

        # 定义列名
        self.numeric_cols = ['Pre_GSHH_NDVI', 'Height_Ave_cm', 'month_sin', 'month_cos']
        self.categorical_cols = ['State_encoded', 'Species_encoded']

        # 训练目标 (log scale)
        self.log_target_cols = ['log_Dry_Green_g', 'log_Dry_Dead_g',
                                'log_Dry_Clover_g', 'log_GDM_g', 'log_Dry_Total_g']

        # 验证目标 (original scale)
        self.orig_target_cols = ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g',
                                 'GDM_g', 'Dry_Total_g']

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # 1. 加载图像
        # 您的路径逻辑是正确的，因为 'image_path' 索引包含 'train/' 前缀
        filename = row.name.split('/')[-1]
        img_path = os.path.join(self.img_dir, filename)

        try:
            image = Image.open(img_path).convert('RGB')
            image = self.transforms(image)
        except Exception as e:
            print(f"Warning: Error loading image {img_path}. Using a dummy image. Error: {e}")
            # *** 修复 ***: 使用传入的 img_size
            image = torch.zeros((3, self.img_size, self.img_size))

        # 2. 提取表格数据

        # ---
        # *** 关键修复 ***:
        # 在 .values 之后立即使用 .astype() 强制转换类型
        # ---
        numeric = torch.tensor(
            row[self.numeric_cols].values.astype(np.float32),
            dtype=torch.float32
        )

        categorical = torch.tensor(
            row[self.categorical_cols].values.astype(np.int64), # 类别用 int64
            dtype=torch.long
        )

        # 3. 提取目标 (同样应用修复)
        log_target = torch.tensor(
            row[self.log_target_cols].values.astype(np.float32),
            dtype=torch.float32
        )

        orig_target = torch.tensor(
            row[self.orig_target_cols].values.astype(np.float32),
            dtype=torch.float32
        )

        return {
            'image': image,
            'numeric': numeric,
            'categorical': categorical,
            'log_target': log_target,
            'orig_target': orig_target
        }

In [7]:
# --- 3. 训练和验证循环 ---

def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0

    for batch in tqdm(loader, desc="Training"):
        # 移动数据到设备
        image = batch['image'].to(device)
        numeric = batch['numeric'].to(device)
        categorical = batch['categorical'].to(device)
        log_target = batch['log_target'].to(device)

        # 梯度清零
        optimizer.zero_grad()

        # 前向传播
        pred = model(image, numeric, categorical)

        # 计算损失 (在 log 尺度上)
        loss = criterion(pred, log_target)

        # 反向传播
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)

def validate(model, loader, criterion, device):
    model.eval()
    total_val_loss = 0.0
    all_preds_orig = []
    all_targets_orig = []

    with torch.no_grad():
        for batch in tqdm(loader, desc="Validating"):
            image = batch['image'].to(device)
            numeric = batch['numeric'].to(device)
            categorical = batch['categorical'].to(device)
            log_target = batch['log_target'].to(device)
            orig_target = batch['orig_target'].to(device)

            # 预测 (log 尺度)
            pred_log = model(image, numeric, categorical)

            # 计算验证损失 (log 尺度)
            loss = criterion(pred_log, log_target)
            total_val_loss += loss.item()

            # 转换回原始尺度
            pred_orig = torch.expm1(pred_log)

            all_preds_orig.append(pred_orig)
            all_targets_orig.append(orig_target)

    # 拼接所有批次的结果
    all_preds_orig = torch.cat(all_preds_orig, dim=0)
    all_targets_orig = torch.cat(all_targets_orig, dim=0)

    # 计算 R2 (原始尺度)
    val_r2 = calculate_weighted_r2(all_targets_orig, all_preds_orig, device)

    avg_val_loss = total_val_loss / len(loader)

    return avg_val_loss, val_r2


In [8]:
# --- 4. 主函数 (已更新为 5-Fold CV + Early Stopping) ---
def main(args):
    # 设置设备
    device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
    print(f"Using device: {device}")

    # 加载数据
    df = pd.read_csv(args.data_csv, index_col='image_path')

    # 获取类别数量 (用于 Embedding)
    num_states = df['State_encoded'].nunique()
    num_species = df['Species_encoded'].nunique()
    print(f"Found {num_states} states and {num_species} species.")

    # 图像预处理
    train_transforms = transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        
        # --- 1. 几何变换 (强制模型去“寻找”目标) ---
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomRotation(90),
        
        # [新] 仿射变换：平移 和 错切
        transforms.RandomAffine(
            degrees=0,
            translate=(0.15, 0.15),  # 随机平移 15%
            shear=15                 # 随机错切 15 度
        ), 

        # --- 2. 颜色变换 (模拟不同光照/季节) ---
        transforms.ColorJitter(
            brightness=0.3,
            contrast=0.3, 
            saturation=0.3, 
            hue=0.1
        ), 
        
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # 验证集不使用增强，只做 Resize 和 Normalize
    val_transforms = transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # --- K-Fold Cross-Validation 设置 ---
    N_SPLITS = 5
    kf = KFold(n_splits=N_SPLITS, shuffle=True, random_state=42)

    all_fold_best_r2 = [] # 存储每一折的 R2 分数

    # --- K-Fold 训练循环 ---
    for fold, (train_indices, val_indices) in enumerate(kf.split(df)):
        print(f"========== FOLD {fold + 1}/{N_SPLITS} ==========")

        # 1. 为当前折创建数据
        train_df = df.iloc[train_indices]
        val_df = df.iloc[val_indices]

        # 2. 创建 Datasets 和 DataLoaders
        train_dataset = PastureDataset(train_df, args.img_dir, train_transforms, args.img_size)
        val_dataset = PastureDataset(val_df, args.img_dir, val_transforms, args.img_size)

        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

        # 3. !! 为当前折重新初始化模型、损失和优化器 !!
        model = TeacherModel(num_states, num_species).to(device)
        criterion = WeightedMSELoss()
        # optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-3)

        # --- [关键修改] 设置差分学习率 (Differential LRs) ---
        
        # 1. 定义哪些模块属于“Head”（从零开始学）
        head_param_names = [
            'tab_mlp',
            'state_embedding',
            'species_embedding',
            'img_kv_projector',  # <-- [新]
            'tab_q_projector',   # <-- [新]
            'cross_attn',        # <-- [新]
            'attn_norm',         # <-- [新]
            'fusion_head'
        
            # [删除] 'img_pool' 和 'img_projector' 已被替换
        ]
        
        head_params = []
        backbone_params = []

        # 2. 将所有可训练参数 (requires_grad=True) 分配到两组
        for name, param in model.named_parameters():
            if not param.requires_grad:
                continue
                
            is_head = False
            for head_name in head_param_names:
                if name.startswith(head_name):
                    head_params.append(param)
                    is_head = True
                    break
            
            if not is_head:
                backbone_params.append(param)
                # print(f"Backbone param: {name}") # (用于调试)

        # 3. 创建参数组
        #    主干 (Backbone) 使用基础 LR (例如 5e-5)
        #    头部 (Head) 使用 10 倍的基础 LR (例如 5e-4)
        param_groups = [
            {'params': backbone_params, 'lr': args.lr}, 
            {'params': head_params, 'lr': args.lr * 10}  
        ]

        # print(f"设置差分学习率：Head LR = {args.lr * 10}, Backbone LR = {args.lr}")

        # 4. !! 为当前折重新初始化模型、损失和优化器 !!
        
        # [修改] 将 param_groups 传入优化器
        optimizer = optim.AdamW(param_groups, 
                              lr=args.lr, # (默认 LR，主要由 group 覆盖)
                              weight_decay=1e-3) 

        # -----------------------------------------------------        

        # 学习率调度器
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=7, factor=0.1)

        # 5. 训练循环 (针对当前折)
        best_val_r2 = -float('inf')

        # --- [新] 早停变量 ---
        patience_counter = 0
        # -------------------------

        for epoch in range(args.epochs):
            print(f"--- Fold {fold+1}, Epoch {epoch+1}/{args.epochs} ---")

            train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
            val_loss, val_r2 = validate(model, val_loader, criterion, device)

            print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val R2: {val_r2:.4f}")

            # 更新学习率
            scheduler.step(val_r2)

            # --- [新] 早停和模型保存逻辑 ---
            if val_r2 > best_val_r2:
                best_val_r2 = val_r2
                patience_counter = 0 # 重置耐心

                # 保存最佳模型 (针对当前折)
                save_path = os.path.join(args.output_dir, f"best_teacher_model_fold_{fold+1}.pth")
                torch.save(model.state_dict(), save_path)
                print(f"New best model for fold {fold+1} saved with R2: {best_val_r2:.4f}")
            else:
                patience_counter += 1 # 增加耐心
                print(f"No improvement. Patience: {patience_counter}/{args.early_stopping_patience}")

            # 检查是否触发早停
            if patience_counter >= args.early_stopping_patience:
                print(f"--- Early stopping triggered at epoch {epoch+1} ---")
                break # 跳出当前 fold 的 epoch 循环
            # -----------------------------------

        print(f"Fold {fold+1} complete. Best Validation R2: {best_val_r2:.4f}")
        all_fold_best_r2.append(best_val_r2)
        print("=============================\n")

    # --- K-Fold 结束后，计算并打印平均 R2 ---
    print("\n--- K-Fold Cross-Validation Complete ---")
    print(f"R2 scores for each fold: {all_fold_best_r2}")
    print(f"Average R2: {np.mean(all_fold_best_r2):.4f}")
    print(f"Std Dev R2: {np.std(all_fold_best_r2):.4f}")