In [1]:
import os
import argparse
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split, KFold
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import warnings
import sys

In [2]:
!pip install torch torchvision pandas scikit-learn pillow tqdm timm
!conda update torchvision

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [3]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
TOKEN = user_secrets.get_secret("GITHUB_TOKEN")
USERNAME = 'ada-yl2425'
REPO_NAME = 'CSIRO---Image2Biomass-Prediction'
!git clone https://{USERNAME}:{TOKEN}@github.com/{USERNAME}/{REPO_NAME}.git
!git pull origin main
!ls

Cloning into 'CSIRO---Image2Biomass-Prediction'...
remote: Enumerating objects: 710, done.[K
remote: Counting objects: 100% (4/4), done.[K
remote: Compressing objects: 100% (4/4), done.[K
remote: Total 710 (delta 0), reused 0 (delta 0), pack-reused 706 (from 2)[K
Receiving objects: 100% (710/710), 2.69 GiB | 47.25 MiB/s, done.
Resolving deltas: 100% (183/183), done.
Updating files: 100% (386/386), done.
fatal: not a git repository (or any parent up to mount point /kaggle)
Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).
CSIRO---Image2Biomass-Prediction


In [4]:
project_root = 'CSIRO---Image2Biomass-Prediction'
if project_root not in sys.path:
    sys.path.append(project_root)

from KnowledgeDistillation.teacher_model import TeacherModel
from KnowledgeDistillation.student_model import StudentModel
from KnowledgeDistillation.loss import WeightedMSELoss, StudentLoss, calculate_weighted_r2
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from KnowledgeDistillation.data_transform import transForms



In [5]:
# 忽略 PIL 的一些警告
warnings.filterwarnings("ignore", "(Possibly corrupt EXIF data|Truncated File Read)")

In [6]:
# --- 2. 自定义数据集 ---
# (与 teacher_train 相同, Student 训练循环需要所有数据)
class PastureDataset(Dataset):

    def __init__(self, df, img_dir, transforms, img_size): 
        self.df = df
        self.img_dir = img_dir
        self.transforms = transforms
        self.img_size = img_size  

        # 定义列名 (for teacher models) (improvement 3)
        self.numeric_cols = ['Pre_GSHH_NDVI', 'Height_Ave_cm', 'month_sin', 'month_cos']
        self.categorical_cols = ['State_encoded', 'Species_encoded']

        self.log_target_cols = ['log_Dry_Green_g', 'log_Dry_Dead_g',
                                'log_Dry_Clover_g', 'log_GDM_g', 'log_Dry_Total_g']
        self.orig_target_cols = ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g',
                                 'GDM_g', 'Dry_Total_g']
        
        # improvement 3 deleted
        '''self.log_teacher_cols = ['teacher_log_Dry_Green_g', 'teacher_log_Dry_Dead_g',
                                 'teacher_log_Dry_Clover_g', 'teacher_log_GDM_g',
                                 'teacher_log_Dry_Total_g']'''
            
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        filename = row.name.split('/')[-1]
        img_path = os.path.join(self.img_dir, filename)

        try:
            image = Image.open(img_path).convert('RGB')
            image = self.transforms(image)
        except Exception as e:
            print(f"Warning: Error loading image {img_path}. Using a dummy image. Error: {e}")
            image = torch.zeros((3, self.img_size, self.img_size))

        # 2. 提取表格数据 (improvement 3)
        numeric = torch.tensor(
            row[self.numeric_cols].values.astype(np.float32),
            dtype=torch.float32
        )

        categorical = torch.tensor(
            row[self.categorical_cols].values.astype(np.int64), 
            dtype=torch.long
        )

        log_target = torch.tensor(
            row[self.log_target_cols].values.astype(np.float32),
            dtype=torch.float32
        )
        orig_target = torch.tensor(
            row[self.orig_target_cols].values.astype(np.float32),
            dtype=torch.float32
        )

        return {
            'image': image, 
            'numeric': numeric, # improvement 3
            'categorical': categorical, # improvement 3
            'log_target': log_target,
            'orig_target': orig_target
        }

In [7]:
# --- 3. 训练和验证循环 ---

def train_one_epoch_student(student_model, teacher_model, loader, 
                            criterion, optimizer, device):
    student_model.train()
    total_loss = 0.0

    for batch in tqdm(loader, desc="Training"):
        batch = {k: v.to(device) for k, v in batch.items()}

        image = batch['image']
        numeric = batch['numeric']
        categorical = batch['categorical']
        log_target = batch['log_target']

        # 梯度清零 (只为 Student)
        optimizer.zero_grad()

        # 1. 获取教师预测 (特征 + 软标签)
        with torch.no_grad(): # 确保不计算教师的梯度
            teacher_pred, teacher_features = teacher_model(image, numeric, categorical) # [B, 256]
            # [B, 256] -> [B, 1, 256] -> [B, 5, 256]
            teacher_features_expanded = teacher_features.unsqueeze(1).expand(-1, 5, -1)

        # 2. 获取学生预测 (特征 + 预测)
        student_pred, student_features = student_model(image) # Student 只需要图像
        
        # --- 修正 CosineEmbeddingLoss 的调用 ---
        
        # a. 获取批量大小 (B)
        B = student_features.shape[0]
        # b. 将特征从 3D 压平为 2D
        # [B, 5, 256] -> [B*5, 256]
        student_feat_flat = student_features.reshape(-1, student_features.shape[-1])
        teacher_feat_flat = teacher_features_expanded.reshape(-1, teacher_features_expanded.shape[-1])
        # c. 创建目标张量 (target tensor)
        # 形状: [B*5]
        target_tensor = torch.ones(B * 5, device=student_features.device)
        
        # 3. 计算蒸馏损失 (StudentLoss)
        loss = criterion(student_pred, teacher_pred, student_feat_flat, 
                         teacher_feat_flat, target_tensor, log_target)

        # 4. 反向传播 (只更新 Student 的权重)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)

def validate_student(student_model, loader, criterion, device):

    student_model.eval() # Student 进入评估模式
    total_val_loss = 0.0
    all_preds_orig = []
    all_targets_orig = []

    with torch.no_grad():
        for batch in tqdm(loader, desc="Validating"):
            batch = {k: v.to(device) for k, v in batch.items()}

            image = batch['image']
            log_target = batch['log_target']
            orig_target = batch['orig_target']

            pred_log, _ = student_model(image)

            loss = criterion(pred_log, log_target)
            total_val_loss += loss.item()

            # 转换回原始尺度
            pred_orig = torch.expm1(pred_log)
            all_preds_orig.append(pred_orig)
            all_targets_orig.append(orig_target)

    # 拼接所有批次的结果
    all_preds_orig = torch.cat(all_preds_orig, dim=0)
    all_targets_orig = torch.cat(all_targets_orig, dim=0)

    # 计算 R2 (原始尺度)
    val_r2 = calculate_weighted_r2(all_targets_orig, all_preds_orig, device)
    avg_val_loss = total_val_loss / len(loader)

    return avg_val_loss, val_r2

In [8]:
# --- 4. 主函数 (Student K-Fold CV) ---
def main(args):
    # 设置设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # 加载数据
    df = pd.read_csv(args.data_csv, index_col='image_path')

    num_states = df['State_encoded'].nunique()
    num_species = df['Species_encoded'].nunique()
    print(f"Found {num_states} states and {num_species} species.")

    # 图像预处理
    train_transforms = transForms(args, num_bins=31, 
                        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],
                        p=0.3, type='train')

    val_transforms = transForms(args, num_bins=31, 
                        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],
                        p=0.3, type='train')
    # K-Fold CV 设置 （必须和teacher的保持一致！！！）
    N_SPLITS = 5
    kf = KFold(n_splits=N_SPLITS, shuffle=True, random_state=42)
    all_fold_best_r2 = []
    
    # 导入 LR 调度器
    from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR

    # K-Fold 训练循环
    for fold, (train_indices, val_indices) in enumerate(kf.split(df)):
        print(f"========== FOLD {fold + 1}/{N_SPLITS} ==========\n")

        # 1. 创建数据
        train_df = df.iloc[train_indices]
        val_df = df.iloc[val_indices]
        train_dataset = PastureDataset(train_df, args.img_dir, train_transforms, args.img_size)
        val_dataset = PastureDataset(val_df, args.img_dir, val_transforms, args.img_size)
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

        # 2. 重新初始化 Student Model 和优化器
        student_model = StudentModel().to(device) 

        # --- MODIFICATION ---
        # 2.5 加载对应的 Teacher Model (设为评估模式)
        teacher_model = TeacherModel(num_states, num_species).to(device) 
        teacher_model_path = os.path.join(
            args.teacher_model_dir,
            f"best_teacher_model_fold_{fold+1}.pth"
        )
        teacher_model.load_state_dict(torch.load(teacher_model_path))
        teacher_model.eval()
        # --- END MODIFICATION ---


        # 3. 始化损失函数
        criterion_train = StudentLoss(alpha=args.alpha, beta=args.beta, gamma=args.gamma)
        criterion_val = WeightedMSELoss()
        
        # --- MODIFICATION ---
        # 新增一个用于特征匹配的损失
        criterion_feature = nn.MSELoss()
        # --- END MODIFICATION ---


        # 4. 为 Student 设置差分学习率
        head_param_names = [
            'patch_projector',
            'query_tokens',
            'transformer_decoder',
            'prediction_head'
        ]
        head_params = []
        backbone_params = []

        for name, param in student_model.named_parameters():
            if not param.requires_grad:
                continue
            is_head = any(name.startswith(head_name) for head_name in head_param_names)
            if is_head:
                head_params.append(param)
            else:
                backbone_params.append(param)
        
        param_groups = [
            {'params': backbone_params, 'lr': args.lr},      
            {'params': head_params, 'lr': args.lr * 10} 
        ]
        
        optimizer = optim.AdamW(param_groups, lr=args.lr, weight_decay=1e-3)

        # LR 调度器设置
        TOTAL_EPOCHS = args.epochs 
        WARMUP_EPOCHS = 5 # 前 5 轮用于预热
        
        scheduler_warmup = LinearLR(optimizer, start_factor=0.1, total_iters=WARMUP_EPOCHS)
        scheduler_cosine = CosineAnnealingLR(optimizer, T_max=(TOTAL_EPOCHS - WARMUP_EPOCHS), eta_min=1e-7)
        scheduler = SequentialLR(optimizer, schedulers=[scheduler_warmup, scheduler_cosine], milestones=[WARMUP_EPOCHS])
        
        # 5. 训练循环
        best_val_r2 = -float('inf')
        patience_counter = 0

        for epoch in range(args.epochs):
            print(f"--- Fold {fold+1}, Epoch {epoch+1}/{args.epochs} ---")

            train_loss = train_one_epoch_student(
                            student_model,
                            teacher_model,      # <-- 传递 teacher_model
                            train_loader, 
                            criterion_train,    
                            optimizer, 
                            device
                        )

            val_loss, val_r2 = validate_student(
                student_model, val_loader, criterion_val, device
            )

            print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val R2: {val_r2:.4f}")

            scheduler.step()

            if val_r2 > best_val_r2:
                best_val_r2 = val_r2
                patience_counter = 0
                save_path = os.path.join(args.output_dir, f"best_student_model_fold_{fold+1}.pth")
                torch.save(student_model.state_dict(), save_path)
                print(f"New best model for fold {fold+1} saved with R2: {best_val_r2:.4f}")
            else:
                patience_counter += 1
                print(f"No improvement. Patience: {patience_counter}/{args.early_stopping_patience}")

            if patience_counter >= args.early_stopping_patience:
                print(f"--- Early stopping triggered at epoch {epoch+1} ---")
                break

        print(f"Fold {fold+1} complete. Best Validation R2: {best_val_r2:.4f}")
        all_fold_best_r2.append(best_val_r2)
        print("=============================\n")


    print("\n--- Student K-Fold Cross-Validation Complete ---")
    print(f"R2 scores for each fold: {all_fold_best_r2}")
    print(f"Average R2: {np.mean(all_fold_best_r2):.4f}")
    print(f"Std Dev R2: {np.std(all_fold_best_r2):.4f}")

In [None]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train Student Model via Distillation") 

    # --- 路径 ---
    parser.add_argument('--data_csv', type=str,
                        default=os.path.join(project_root, 'outputs/datasets/train_processed.csv'))
    parser.add_argument('--img_dir', type=str,
                        default=os.path.join(project_root, 'csiro-biomass/train'))
    parser.add_argument('--teacher_model_dir', type=str,
                        default=os.path.join(project_root, 'outputs/models/teacher_model_output'))
    
    # Student 输出目录
    output_path = os.path.join('kaggle/', 'working')
    parser.add_argument('--output_dir', type=str,
                        default=output_path,
                        help='Directory to save the best student model')

    # --- 训练超参数 ---
    parser.add_argument('--img_size', type=int, default=260)
    
    # 学习率与 Teacher fine-tuning 时相同
    parser.add_argument('--lr', type=float, default=5e-5, 
                        help='Base learning rate (Backbone)')
    
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--epochs', type=int, default=150) 
    parser.add_argument('--num_workers', type=int, default=2)
    parser.add_argument('--early_stopping_patience', type=int, default=15)

    # --- 蒸馏超参数 ---
    parser.add_argument('--alpha', type=float, default=0.5)
    parser.add_argument('--beta', type=float, default=0.1)
    parser.add_argument('--gamma', type=float, default=0.4)
    
    # ------------------------
    args = parser.parse_args(args=[])
    os.makedirs(args.output_dir, exist_ok=True)
    print(f"Student models will be saved to: {args.output_dir}")
    print(f"Reading data from: {args.data_csv}")

    main(args)

Student models will be saved to: kaggle/working
Reading data from: CSIRO---Image2Biomass-Prediction/outputs/datasets/train_processed.csv
Using device: cuda
Found 4 states and 15 species.



model.safetensors:   0%|          | 0.00/100M [00:00<?, ?B/s]

--- Fold 1, Epoch 1/150 ---


Training:   0%|          | 0/18 [00:00<?, ?it/s]