In [None]:
!pip install torch torchvision pandas scikit-learn pillow tqdm timm

In [3]:
import os
import argparse
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split, KFold
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import warnings

In [None]:
import sys
import os 
import argparse 

project_root = 'CSIRO---Image2Biomass-Prediction'
if project_root not in sys.path:
    sys.path.append(project_root)

from KnowledgeDistillation.teacher_model import TeacherModel
from KnowledgeDistillation.loss import WeightedMSELoss, calculate_weighted_r2

In [None]:
warnings.filterwarnings("ignore", "(Possibly corrupt EXIF data|Truncated File Read)")

In [None]:
class PastureDataset(Dataset):

    def __init__(self, df, img_dir, transforms, img_size): 
        self.df = df
        self.img_dir = img_dir
        self.transforms = transforms
        self.img_size = img_size  

        # 定义列名
        self.numeric_cols = ['Pre_GSHH_NDVI', 'Height_Ave_cm', 'month_sin', 'month_cos']
        self.categorical_cols = ['State_encoded', 'Species_encoded']

        # 训练目标 (log scale)
        self.log_target_cols = ['log_Dry_Green_g', 'log_Dry_Dead_g',
                                'log_Dry_Clover_g', 'log_GDM_g', 'log_Dry_Total_g']

        # 验证目标 (original scale)
        self.orig_target_cols = ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g',
                                 'GDM_g', 'Dry_Total_g']

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # 1. 加载图像
        filename = row.name.split('/')[-1]
        img_path = os.path.join(self.img_dir, filename)

        try:
            image = Image.open(img_path).convert('RGB')
            image = self.transforms(image)
        except Exception as e:
            print(f"Warning: Error loading image {img_path}. Using a dummy image. Error: {e}")
            image = torch.zeros((3, self.img_size, self.img_size))

        # 2. 提取表格数据
        numeric = torch.tensor(
            row[self.numeric_cols].values.astype(np.float32),
            dtype=torch.float32
        )

        categorical = torch.tensor(
            row[self.categorical_cols].values.astype(np.int64), 
            dtype=torch.long
        )

        # 3. 提取目标
        log_target = torch.tensor(
            row[self.log_target_cols].values.astype(np.float32),
            dtype=torch.float32
        )

        orig_target = torch.tensor(
            row[self.orig_target_cols].values.astype(np.float32),
            dtype=torch.float32
        )

        return {
            'image': image,
            'numeric': numeric,
            'categorical': categorical,
            'log_target': log_target,
            'orig_target': orig_target
        }

In [None]:
# --- 3. 训练和验证循环 ---

def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0

    for batch in tqdm(loader, desc="Training"):
        # 移动数据到设备
        image = batch['image'].to(device)
        numeric = batch['numeric'].to(device)
        categorical = batch['categorical'].to(device)
        log_target = batch['log_target'].to(device)

        # 梯度清零
        optimizer.zero_grad()

        # 前向传播
        pred = model(image, numeric, categorical)

        # 计算损失 (在 log 尺度上)
        loss = criterion(pred, log_target)

        # 反向传播
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)

def validate(model, loader, criterion, device):
    model.eval()
    total_val_loss = 0.0
    all_preds_orig = []
    all_targets_orig = []

    with torch.no_grad():
        for batch in tqdm(loader, desc="Validating"):
            image = batch['image'].to(device)
            numeric = batch['numeric'].to(device)
            categorical = batch['categorical'].to(device)
            log_target = batch['log_target'].to(device)
            orig_target = batch['orig_target'].to(device)

            # 预测 (log 尺度)
            pred_log = model(image, numeric, categorical)

            # 计算验证损失 (log 尺度)
            loss = criterion(pred_log, log_target)
            total_val_loss += loss.item()

            # 转换回原始尺度
            pred_orig = torch.expm1(pred_log)

            all_preds_orig.append(pred_orig)
            all_targets_orig.append(orig_target)

    # 拼接所有批次的结果
    all_preds_orig = torch.cat(all_preds_orig, dim=0)
    all_targets_orig = torch.cat(all_targets_orig, dim=0)

    # 计算 R2 (原始尺度)
    val_r2 = calculate_weighted_r2(all_targets_orig, all_preds_orig, device)

    avg_val_loss = total_val_loss / len(loader)

    return avg_val_loss, val_r2


In [None]:
def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
    print(f"Using device: {device}")

    df = pd.read_csv(args.data_csv, index_col='image_path')

    num_states = df['State_encoded'].nunique()
    num_species = df['Species_encoded'].nunique()
    print(f"Found {num_states} states and {num_species} species.")

    train_transforms = transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        
        # --- 1. 几何变换 ---
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomRotation(90),
        
        # 仿射变换：平移 和 错切
        transforms.RandomAffine(
            degrees=0,
            translate=(0.15, 0.15),  # 随机平移 15%
            shear=15                 # 随机错切 15 度
        ), 

        # --- 2. 颜色变换 (模拟不同光照/季节) ---
        transforms.ColorJitter(
            brightness=0.3,
            contrast=0.3, 
            saturation=0.3, 
            hue=0.1
        ), 
        
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # 验证集不使用增强，只做 Resize 和 Normalize
    val_transforms = transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # --- K-Fold Cross-Validation 设置 ---
    N_SPLITS = 5
    kf = KFold(n_splits=N_SPLITS, shuffle=True, random_state=42)

    all_fold_best_r2 = [] # 存储每一折的 R2 分数

    # --- K-Fold 训练循环 ---
    for fold, (train_indices, val_indices) in enumerate(kf.split(df)):
        print(f"========== FOLD {fold + 1}/{N_SPLITS} ==========")

        # 1. 为当前折创建数据
        train_df = df.iloc[train_indices]
        val_df = df.iloc[val_indices]

        # 2. 创建 Datasets 和 DataLoaders
        train_dataset = PastureDataset(train_df, args.img_dir, train_transforms, args.img_size)
        val_dataset = PastureDataset(val_df, args.img_dir, val_transforms, args.img_size)

        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

        # 3. 为当前折重新初始化模型、损失和优化器
        model = TeacherModel(num_states, num_species).to(device)
        criterion = WeightedMSELoss()

        # 4. 设置差分学习率 (Differential LRs)
        
        # 1. 定义哪些模块属于“Head”（从零开始学）
        head_param_names = [
            'tab_mlp',
            'state_embedding',
            'species_embedding',
            'img_kv_projector',  
            'tab_q_projector',   
            'cross_attn',        
            'attn_norm',         
            'fusion_head'
        ]
        
        head_params = []
        backbone_params = []

        # 2. 将所有可训练参数 (requires_grad=True) 分配到两组
        for name, param in model.named_parameters():
            if not param.requires_grad:
                continue
                
            is_head = False
            for head_name in head_param_names:
                if name.startswith(head_name):
                    head_params.append(param)
                    is_head = True
                    break
            
            if not is_head:
                backbone_params.append(param)

        # 3. 创建参数组
        # 主干 (Backbone) 使用基础 LR
        # 头部 (Head) 使用 10 倍的基础 LR
        param_groups = [
            {'params': backbone_params, 'lr': args.lr}, 
            {'params': head_params, 'lr': args.lr * 10}  
        ]

        # 4. 为当前折重新初始化模型、损失和优化器
        optimizer = optim.AdamW(param_groups, 
                              lr=args.lr, 
                              weight_decay=1e-3) 

        # 学习率调度器
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=7, factor=0.1)

        # 5. 训练循环 (针对当前折)
        best_val_r2 = -float('inf')

        # --- 早停变量 ---
        patience_counter = 0
        # -------------------------

        for epoch in range(args.epochs):
            print(f"--- Fold {fold+1}, Epoch {epoch+1}/{args.epochs} ---")

            train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
            val_loss, val_r2 = validate(model, val_loader, criterion, device)

            print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val R2: {val_r2:.4f}")

            # 更新学习率
            scheduler.step(val_r2)

            # --- 早停和模型保存逻辑 ---
            if val_r2 > best_val_r2:
                best_val_r2 = val_r2
                patience_counter = 0 # 重置耐心

                # 保存最佳模型 (针对当前折)
                save_path = os.path.join(args.output_dir, f"best_teacher_model_fold_{fold+1}.pth")
                torch.save(model.state_dict(), save_path)
                print(f"New best model for fold {fold+1} saved with R2: {best_val_r2:.4f}")
            else:
                patience_counter += 1 # 增加耐心
                print(f"No improvement. Patience: {patience_counter}/{args.early_stopping_patience}")

            # 检查是否触发早停
            if patience_counter >= args.early_stopping_patience:
                print(f"--- Early stopping triggered at epoch {epoch+1} ---")
                break # 跳出当前 fold 的 epoch 循环
            # -----------------------------------

        print(f"Fold {fold+1} complete. Best Validation R2: {best_val_r2:.4f}")
        all_fold_best_r2.append(best_val_r2)
        print("=============================\n")

    # --- K-Fold 结束后，计算并打印平均 R2 ---
    print("\n--- K-Fold Cross-Validation Complete ---")
    print(f"R2 scores for each fold: {all_fold_best_r2}")
    print(f"Average R2: {np.mean(all_fold_best_r2):.4f}")
    print(f"Std Dev R2: {np.std(all_fold_best_r2):.4f}")

In [None]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train Teacher Model")

    parser.add_argument('--data_csv', type=str,
                        default=os.path.join(project_root, 'outputs/datasets/train_processed.csv'),
                        help='Path to the processed training CSV file')

    parser.add_argument('--img_dir', type=str,
                        default=os.path.join(project_root, 'csiro-biomass/train'),
                        help='Path to the directory containing training images')

    # 指定一个明确的输出目录
    output_path = os.path.join(project_root, 'outputs/models/teacher_model_output')
    parser.add_argument('--output_dir', type=str,
                        default=output_path,
                        help='Directory to save the best model')

    # 训练超参数
    parser.add_argument('--img_size', type=int, default=260, 
                        help='Image size for the model (B2 uses 260)')
    parser.add_argument('--lr', type=float, default=5e-5,
                        help='Initial learning rate (1e-4 is good for fine-tuning)')
    parser.add_argument('--batch_size', type=int, default=16,
                        help='Batch size (use 8 or 16 for small datasets)')
    parser.add_argument('--epochs', type=int, default=150,
                        help='Number of training epochs')
    parser.add_argument('--val_split', type=float, default=0.2,
                        help='Validation split fraction')
    parser.add_argument('--num_workers', type=int, default=2,
                        help='Number of workers for DataLoader')

    # --- 早停参数 ---
    parser.add_argument('--early_stopping_patience', type=int, default=15,
                        help='Patience for early stopping (e.g., 15 epochs)')
    # -------------------------

    args = parser.parse_args(args=[])

    # 确保输出目录存在
    os.makedirs(args.output_dir, exist_ok=True)
    print(f"Model output will be saved to: {args.output_dir}")
    print(f"Reading data from: {args.data_csv}")

    main(args)