# CSIRO Biomass Estimation

This notebook implements a complete pipeline for biomass estimation using:
- **Step 1**: Benchmark multiple CNN architectures
- **Step 2**: K-Fold training with the best model
- **Step 3**: Ensemble inference on test set
- **Step 4**: Create submission file

## 1. Imports and Setup

In [None]:
# Standard library
import os
import random
from pathlib import Path

# Data science / imaging / display
import numpy as np
import pandas as pd
from PIL import Image
import joblib

# Scikit / utilities
from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import StandardScaler

# PyTorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import CSVLogger

# Third-party model libs
import timm

# Warnings
import warnings
warnings.filterwarnings("ignore")

## 2. Configuration

In [None]:
# ========== CONFIG ==========
SEED = 42
INPUT = Path("/kaggle/input/csiro-biomass")
WORK = Path("/kaggle/working")

# ========== Image and Training Hyperparameters ==========
IMG = 224                  
BATCH = 16                   
LEARNING_RATE = 1e-4        
WEIGHT_DECAY = 1e-2         

# ========== LAYER FREEZING CONFIGURATION ==========
# BENCHMARK:
BENCHMARK_FREEZE_MODE = "percentage"      # Options: "only_head" | "percentage" | "none"
BENCHMARK_TRAINABLE_PCT = 0.20           # Ignored if mode="only_head"

# K-FOLD: 
KFOLD_FREEZE_MODE = "percentage"         # Options: "only_head" | "percentage" | "none"
KFOLD_TRAINABLE_PCT = 0.30               # Trainable % when mode="percentage"

# ========== Benchmark Hyperparameters ==========
QUICK_EPOCHS = 10
QUICK_PATIENCE = 1

# ========== K-Fold Hyperparameters ==========
KFOLD = 5
EPOCHS = 50
MAX_PATIENCE = 5
LR_PATIENCE = 2             
LR_FACTOR = 0.5             

# ========== Target and Weights ==========
TARGET_COLS = ["Dry_Clover_g","Dry_Dead_g","Dry_Green_g","Dry_Total_g","GDM_g"]
WEIGHTS = {"Dry_Clover_g":0.1,"Dry_Dead_g":0.1,"Dry_Green_g":0.1,"Dry_Total_g":0.5,"GDM_g":0.2}

# ========== Models to Test ==========
# Specify exact timm name and weights path
MODELS = [
    #{"name": "convnext_base.fb_in22k_ft_in1k", "weights": "/kaggle/input/convnext-base-fb-in22k-ft-in1k/pytorch/default/1/model.safetensors", "num_classes": 5},
    #{"name": "convnext_tiny.fb_in22k_ft_in1k", "weights": "/kaggle/input/convnext-tiny-fb-in22k-ft-in1k/pytorch/default/1/model.safetensors", "num_classes": 5},
    #{"name": "convnextv2_base.fcmae_ft_in22k_in1k", "weights": "/kaggle/input/convnextv2-base-fcmae-ft-in22k-in1k/pytorch/default/1/model.safetensors", "num_classes": 5},
    {"name": "maxvit_tiny_tf_224.in1k", "weights": "/kaggle/input/maxvit-tiny-tf-224-in1k/pytorch/default/1/model.safetensors", "num_classes": 5},
    #{"name": "swin_base_patch4_window7_224.ms_in22k_ft_in1k", "weights": "/kaggle/input/swin-base-patch4-window7-224-ms-in22k-ft-in1k/pytorch/default/1/model.safetensors", "num_classes": 5},
    #{"name": "swin_tiny_patch4_window7_224.ms_in1k", "weights": "/kaggle/input/swin-tiny-patch4-window7-224-ms-in1k/pytorch/default/1/model.safetensors", "num_classes": 5},
    #{"name": "vit_base_patch16_224.augreg_in21k_ft_in1k", "weights": "/kaggle/input/vit-base-patch16-224-augreg-in21k-ft-in1k/pytorch/default/1/model.safetensors", "num_classes": 5}
]

# Set seed for reproducibility
pl.seed_everything(SEED, workers=True)

## 3. Dataset Class and Transforms

This class loads images and their target values (5 biomass types):
- **Train mode**: Reads images from CSV and applies augmentation transforms
- **Test mode**: Only reads images for prediction
- Normalizes targets using StandardScaler to facilitate training

In [None]:
class BiomassDataset(Dataset):
    def __init__(self, data, transform=None, scaler=None, mode='train'):
        self.transform = transform
        self.scaler = scaler
        self.mode = mode
        
        if mode == 'train':
            self.df = data.reset_index(drop=True)
            self.targets = self.scaler.transform(data[TARGET_COLS].values.astype(np.float32))
        else:
            self.paths = data
    
    def __len__(self):
        return len(self.df) if self.mode == 'train' else len(self.paths)
    
    def __getitem__(self, i):
        if self.mode == 'train':
            img = Image.open(self.df.iloc[i]["image_path"]).convert("RGB")
            if self.transform is not None:
                img = self.transform(img)
            return img, torch.tensor(self.targets[i], dtype=torch.float32)
        else:
            img = Image.open(self.paths[i]).convert("RGB")
            if self.transform is not None:
                img = self.transform(img)
            return img, Path(self.paths[i]).stem

### Image Transform Functions

**get_train_transform()**: Augmentation transforms for training
- Resize to 224x224
- Random horizontal/vertical flips
- Random rotation
- Color, brightness, saturation jitter
- ImageNet mean/std normalization

**get_val_transform()**: No augmentation for validation/test
- Only resize and normalize
- Ensures consistent evaluation

In [None]:
# Define transforms as functions to ensure IMG is defined
def get_train_transform():
    """Training transforms with augmentation"""
    return transforms.Compose([
        transforms.Resize((IMG, IMG)),
        transforms.RandomRotation(degrees=5),
        #transforms.RandomAffine(degrees=5, translate=(0.01, 0.01), scale=(0.99, 1.01)),
        #transforms.RandomHorizontalFlip(p=0.5),
        #transforms.RandomVerticalFlip(p=0.5),
        #transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

def get_val_transform():
    """Validation transforms without augmentation"""
    return transforms.Compose([
        transforms.Resize((IMG, IMG)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

## 4. Lightning Module with Flexible Layer Freezing

**Three Freezing Modes**:

| Mode | Frozen | Trainable | Use Case |
|------|--------|-----------|----------|
| **only_head** | Everything except last layer | Head only (1 layer) | Fast benchmark |
| **percentage** | All but X% | Last N% of layers | Fine-tuning balance |
| **none** | Nothing | All layers (100%) | Full training |

**`__init__`**: Initializes model with timm architecture and optional layer freezing
- **lr parameter**: Uses `LEARNING_RATE` from CONFIG if not provided
- **weight_decay parameter**: Uses `WEIGHT_DECAY` from CONFIG if not provided
- Loads pretrained weights from safetensors
- Applies freezing according to `freeze_mode` parameter
- Filters classifier layers (incompatible with 5 classes)

**`_apply_freezing(freeze_mode, trainable_pct)`**: Flexible layer freezing
- **"only_head"**: Keeps only the last classification layer trainable (backbone frozen)
- **"percentage"**: Freezes all but trainable_pct% of layers (e.g., 0.30 = last 30% trainable)
- **"none"**: All layers trainable (no freezing)
- Prints statistics showing frozen vs trainable layer count

**`training_step`**: Computes MSE loss for each training batch

**`validation_step`**: Stores predictions to compute R² at epoch end

**`on_validation_epoch_end`**: Computes weighted R²
- Inverse transforms predictions and targets
- Applies weights: Dry_Total_g=50%, GDM_g=20%, others=10%

**`configure_optimizers`**: AdamW + ReduceLROnPlateau
- Uses learning rate and weight decay from __init__
- Reduces learning rate if R² does not improve (LR_FACTOR, LR_PATIENCE from CONFIG)


In [None]:
class BiomassModel(pl.LightningModule):
    def __init__(self, model_info, scaler, lr=None, weight_decay=None, load_pretrained=True, freeze_mode="none", trainable_pct=1.0):
        super().__init__()
        # Use CONFIG values if not provided
        if lr is None:
            lr = LEARNING_RATE
        if weight_decay is None:
            weight_decay = WEIGHT_DECAY
            
        self.save_hyperparameters(ignore=['scaler', 'model_info'])
        self.model_name = model_info['name']
        self.scaler = scaler
        self.loss_fn = nn.MSELoss()
        self.validation_step_outputs = []
        self.lr = lr
        self.weight_decay = weight_decay
        self.num_targets = model_info['num_classes']
        self.weights_path = model_info['weights']
        self.load_pretrained = load_pretrained
        self.freeze_mode = freeze_mode          # "only_head" | "percentage" | "none"
        self.trainable_pct = trainable_pct      # 0.0 to 1.0
        self.model = self._load_model()
        self.eval_mode = False  # Initialize eval_mode flag
        
        # Apply layer freezing after model creation
        if self.freeze_mode != "none":
            self._apply_freezing(self.freeze_mode, self.trainable_pct)

    def _load_model(self):
        # Create model architecture with safe defaults
        model = timm.create_model(
            self.model_name, 
            pretrained=False, 
            num_classes=self.num_targets,
            scriptable=True,  # Disable dynamic features
            exportable=True   # Use more stable implementations
        )
        
        # Skip weight loading if load_pretrained=False (K-Fold will load from checkpoint)
        if not self.load_pretrained:
            return model
        
        # Load safetensors weights
        from safetensors.torch import load_file
        state_dict = load_file(self.weights_path)
        
        # Remove classifier weights (different num_classes: 1000 vs 5)
        state_dict = {k: v for k, v in state_dict.items() if not any(x in k for x in ['classifier', 'head', 'fc'])}
        
        # Load backbone weights only
        model.load_state_dict(state_dict, strict=False)
        
        return model
    
    def _apply_freezing(self, freeze_mode, trainable_pct):
        """
        Apply layer freezing based on mode
        - "only_head": Freeze everything except the last classification layer
        - "percentage": Freeze all but trainable_pct% of layers
        - "none": Train all layers
        """
        if freeze_mode == "none":
            # Train all layers
            for param in self.model.parameters():
                param.requires_grad = True
            print(f"All layers trainable (100%)")
            return
        
        params = list(self.model.parameters())
        total_params = len(params)
        
        if freeze_mode == "only_head":
            # Identify and freeze all but the head (last layer)
            num_trainable = 1  # Only last layer (head)
        elif freeze_mode == "percentage":
            # Calculate how many parameters to keep trainable
            num_trainable = max(1, int(total_params * trainable_pct))
        else:
            raise ValueError(f"Unknown freeze_mode: {freeze_mode}")
        
        # Freeze all but the last num_trainable layers
        for i, param in enumerate(params):
            if i < total_params - num_trainable:
                param.requires_grad = False  # Freeze
            else:
                param.requires_grad = True   # Keep trainable
        
        # Only calculate and print freezing info during training, not during inference
        if not self.eval_mode:
            frozen_count = sum(1 for p in self.model.parameters() if not p.requires_grad)
            trainable_count = sum(1 for p in self.model.parameters() if p.requires_grad)
            print(f"Frozen {frozen_count}/{total_params} layers | Trainable {trainable_count}/{total_params} ({trainable_count/total_params*100:.1f}%)")

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log('train_loss', loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log('val_loss', loss, prog_bar=True)
        output = {'y_true': y, 'y_pred': y_hat}
        self.validation_step_outputs.append(output)
        return output
    
    def on_validation_epoch_end(self):
        outputs = self.validation_step_outputs
        y_true = torch.cat([x['y_true'] for x in outputs]).cpu().numpy()
        y_pred = torch.cat([x['y_pred'] for x in outputs]).cpu().numpy()
        
        # Inverse transform
        y_true = self.scaler.inverse_transform(y_true)
        y_pred = self.scaler.inverse_transform(y_pred)
        
        # Calculate weighted R2
        all_y, all_p, all_w = [], [], []
        for i, t in enumerate(TARGET_COLS):
            w = WEIGHTS[t]
            all_y.append(y_true[:, i])
            all_p.append(y_pred[:, i])
            all_w.append(np.full(len(y_true), w))
        
        y_all = np.concatenate(all_y)
        p_all = np.concatenate(all_p)
        w_all = np.concatenate(all_w)
        y_mean = np.average(y_all, weights=w_all)
        ss_res = np.sum(w_all * (y_all - p_all)**2)
        ss_tot = np.sum(w_all * (y_all - y_mean)**2)
        r2 = 1 - ss_res / ss_tot
        
        self.log('val_r2', r2, prog_bar=True)
        self.validation_step_outputs.clear()  # Free memory
        return r2
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='max',
            factor=LR_FACTOR,
            patience=LR_PATIENCE,
            verbose=True
        )
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_r2'
            }
        }

## 5. Data Module (PyTorch Lightning)

Encapsulates dataloader preparation with configurable batch size:
- **setup()**: Creates train/val datasets with appropriate transforms
- **train_dataloader()**: Returns training loader with shuffle=True
- **val_dataloader()**: Returns validation loader with shuffle=False
- **batch_size** parameter: Uses `BATCH` from CONFIG if not provided
- **use_augmentation** parameter: Controls if augmentation is applied to training data


In [None]:
class BiomassDataModule(pl.LightningDataModule):
    def __init__(self, train_df, val_df, scaler, batch_size=None, use_augmentation=True):
        super().__init__()
        # Use CONFIG value if not provided
        if batch_size is None:
            batch_size = BATCH
            
        self.train_df = train_df
        self.val_df = val_df
        self.scaler = scaler
        self.batch_size = batch_size
        self.use_augmentation = use_augmentation
        
    def setup(self, stage=None):
        # Get transforms - call functions to ensure IMG is available
        train_transform = get_train_transform() if self.use_augmentation else get_val_transform()
        val_transform = get_val_transform()
        
        self.train_ds = BiomassDataset(self.train_df, train_transform, self.scaler, mode='train')
        self.val_ds = BiomassDataset(self.val_df, val_transform, self.scaler, mode='train')
    
    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.batch_size, shuffle=False)

## 6. Data Loading Function

**load_and_prepare_data()**: Processes the training CSV
1. Reads train.csv and pivots: one row per image, 5 target columns
2. Fills missing targets with 0.0
3. Fixes paths: looks for images in /train or /test
4. Filters images that actually exist on disk
5. Returns DataFrame with columns: image_path, Dry_Clover_g, Dry_Dead_g, Dry_Green_g, Dry_Total_g, GDM_g

In [None]:
def load_and_prepare_data():
    df = pd.read_csv(INPUT / "train.csv")
    wide = df.pivot_table(index="image_path", columns="target_name", values="target", aggfunc="first").reset_index()
    for t in TARGET_COLS:
        if t not in wide.columns: wide[t] = 0.0
    wide = wide[["image_path"] + TARGET_COLS]
    
    def fix_path(p):
        for folder in ["train", "test"]:
            cand = INPUT / folder / Path(p).name
            if cand.exists(): return str(cand)
        return str(p)
    
    wide["image_path"] = wide["image_path"].apply(fix_path)
    wide = wide[wide["image_path"].apply(os.path.exists)].reset_index(drop=True)
    print(f"Total images: {len(wide)}")
    return wide

## 7. Load Data and Prepare Scaler

Creates and saves the **StandardScaler** that normalizes the 5 targets:
- Fits on ALL training data
- Converts biomass values (grams) to ~N(0,1) distribution
- Facilitates model learning
- Saved to disk for reuse in K-Fold and inference

In [None]:
# Load data
wide = load_and_prepare_data()

# Fit and save scaler
target_scaler = StandardScaler()
target_scaler.fit(wide[TARGET_COLS].values.astype(np.float32))
joblib.dump(target_scaler, WORK / "target_scaler.pkl")
print(f"Scaler saved")

---

# STEP 1: BENCHMARK - Testing Models

**Goal**: Quickly test multiple CNN architectures to pick the best one

**Strategy - Train Only Head**:
- **Freeze Mode**: `only_head` - everything except the last classification layer frozen
- **Trainable Layers**: Only the task-specific head
- **Why?** Fast convergence, ideal for comparing architectures without overfitting

**Process**:
1. Split train/validation data
2. For each model:
   - Load pretrained weights from timm
   - **Freeze backbone** - all feature extraction layers locked
   - Train **only the classification head**
   - Train WITHOUT augmentation for fast comparison
   - Compute R² on validation set
   - Track best performing architecture
3. Result: Saves `best_init.pth` (weights of the best model)

**Configuration from CONFIG cell**:
- See BENCHMARK_FREEZE_MODE and BENCHMARK_TRAINABLE_PCT for freezing settings
- See QUICK_EPOCHS and QUICK_PATIENCE for early stopping


In [None]:
print("\n" + "="*60)
print("STEP 1: BENCHMARK - Testing models (NO augmentation)...")
print("="*60)

print(f"\nDataset preview (first 3 rows):")
print(wide.head(3).to_string(index=False))
print()

train_df, val_df = train_test_split(wide, test_size=0.15, random_state=SEED)

best_r2 = -float("inf")
best_model_info = None

for model_info in MODELS:
    print(f"\nTesting {model_info['name']}...")
    print(f"  Freeze Mode: {BENCHMARK_FREEZE_MODE}")
    
    model = BiomassModel(
        model_info, 
        target_scaler, 
        load_pretrained=True,
        freeze_mode=BENCHMARK_FREEZE_MODE,
        trainable_pct=BENCHMARK_TRAINABLE_PCT
    )
    dm_benchmark = BiomassDataModule(train_df, val_df, target_scaler, use_augmentation=False)
    
    # Use checkpoint to save best model during training
    checkpoint_callback = ModelCheckpoint(
        dirpath=WORK / "temp_benchmark",
        filename=f'{model_info["name"]}_best',
        monitor='val_r2',
        mode='max',
        save_top_k=1
    )
    
    early_stop_benchmark = EarlyStopping(
        monitor='val_r2',
        patience=QUICK_PATIENCE,
        mode='max'
    )
    
    trainer = pl.Trainer(
        max_epochs=QUICK_EPOCHS,
        accelerator='auto',
        devices=1,
        logger=False,
        enable_checkpointing=True,
        enable_progress_bar=True,
        deterministic=True,
        callbacks=[checkpoint_callback, early_stop_benchmark]
    )
    
    trainer.fit(model, dm_benchmark)
    
    # Get best R2 from checkpoint callback
    r2 = checkpoint_callback.best_model_score.item() if checkpoint_callback.best_model_score is not None else -float("inf")
    
    print(f"{model_info['name']}: R2={r2:.4f}")
    
    if r2 > best_r2:
        best_r2 = r2
        best_model_info = model_info
        # Load best checkpoint and save its state_dict
        best_model = BiomassModel.load_from_checkpoint(
            checkpoint_callback.best_model_path,
            model_info=model_info,
            scaler=target_scaler
        )
        torch.save(best_model.state_dict(), WORK / "best_init.pth")

print(f"\nBEST MODEL: {best_model_info['name']} (R2={best_r2:.4f})")
import json
with open(WORK / "best_model_info.json", 'w') as f:
    json.dump(best_model_info, f)

---

# STEP 2: K-FOLD TRAINING - Fine-tuning with Selective Layer Training

**Goal**: Train the best model with cross-validation for robustness and better accuracy

**K-Fold Process**:
1. Load best model weights from Step 1 (`best_init.pth`)
2. Split data into 5 folds using KFold
3. For each fold:
   - **Unfreeze last % of layers** - fine-tune backbone + head
   - Train WITH augmentation (improves generalization)
   - Use early stopping on validation R²
   - Save best checkpoint in `fold{N}/best.ckpt`
4. Result: 5 models trained with different data splits, ready for ensemble


In [None]:
print("\n" + "="*60)
print(f"STEP 2: K-FOLD TRAINING with {best_model_info['name']} (WITH augmentation)")
print("="*60)

kf = KFold(n_splits=KFOLD, shuffle=True, random_state=SEED)

fold_ckpts = []
for fold, (train_idx, val_idx) in enumerate(kf.split(wide), 1):
    print(f"\n--- Fold {fold}/{KFOLD} ---")
    
    train_part = wide.iloc[train_idx]
    val_part = wide.iloc[val_idx]
    
    # Create model without loading pretrained weights initially
    model = BiomassModel(
        best_model_info, 
        target_scaler, 
        load_pretrained=False,
        freeze_mode=KFOLD_FREEZE_MODE,
        trainable_pct=KFOLD_TRAINABLE_PCT
    )
    
    # Load weights from STEP 1 benchmark (best model so far)
    print(f"Loading best benchmark weights from STEP 1...")
    model.load_state_dict(torch.load(WORK / "best_init.pth", map_location='cpu'))
    print(f"Loaded best_init.pth")
    print(f"  Freeze Mode: {KFOLD_FREEZE_MODE} | Trainable: {KFOLD_TRAINABLE_PCT*100:.0f}%")
    
    # Use augmentation for K-Fold training (improves generalization)
    dm = BiomassDataModule(train_part, val_part, target_scaler, use_augmentation=True)
    
    checkpoint_callback = ModelCheckpoint(
        dirpath=WORK / f"fold{fold}",
        filename='best',
        monitor='val_r2',
        mode='max',
        save_top_k=1
    )
    
    early_stop_callback = EarlyStopping(
        monitor='val_r2',
        patience=MAX_PATIENCE,
        mode='max'
    )
    
    logger = CSVLogger(WORK / "logs", name=f"fold{fold}")
    
    trainer = pl.Trainer(
        max_epochs=EPOCHS,
        accelerator='auto',
        devices=1,
        callbacks=[checkpoint_callback, early_stop_callback],
        logger=logger,
        enable_progress_bar=True,
        deterministic=True,
    )
    
    trainer.fit(model, dm)
    
    fold_ckpts.append(checkpoint_callback.best_model_path)
    print(f"Fold {fold} best R2: {checkpoint_callback.best_model_score:.4f}")

print(f"\nCompleted {len(fold_ckpts)}/{KFOLD} folds - ready for ensemble voting")


---

# STEP 3: ENSEMBLE INFERENCE

**Goal**: Predict biomass on test images using multiple models

**Process**:
1. Load test set images
2. For each K-Fold checkpoint:
   - Load the trained model
   - Predict all test images
   - Accumulate predictions per image
3. **Average** all predictions for each image
4. Result: Final predictions are more robust than using a single model

In [None]:
print("\n" + "="*60)
print("STEP 3: ENSEMBLE INFERENCE")
print("="*60)

test_df = pd.read_csv(INPUT / "test.csv")
test_df["image_path"] = test_df["image_path"].apply(lambda p: str(INPUT / "test" / Path(p).name))
test_df = test_df[test_df["image_path"].apply(os.path.exists)].reset_index(drop=True)
test_paths = sorted(test_df["image_path"].unique())
print(f"Test images found: {len(test_paths)}")
print(f"Ensembling {len(fold_ckpts)} folds...\n")

# Get validation transform for test images
test_transform = get_val_transform()
test_ds = BiomassDataset(test_paths, test_transform, mode='test')
test_dl = DataLoader(test_ds, batch_size=BATCH, shuffle=False)

accum = {Path(p).stem: [] for p in test_paths}

# Detect device once (GPU if available, else CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for fold_idx, ckpt_path in enumerate(fold_ckpts, 1):
    print(f"[{fold_idx}/{len(fold_ckpts)}] Predicting with {Path(ckpt_path).parent.name}...", end=" ")
    model = BiomassModel.load_from_checkpoint(
        ckpt_path, 
        model_info=best_model_info,
        scaler=target_scaler
    )
    model.eval_mode = True  # Set before any other operations to suppress freezing info
    model.eval()
    model.to(device)
    
    batch_count = 0
    with torch.no_grad():
        for xb, stems in test_dl:
            xb = xb.to(model.device)
            out = target_scaler.inverse_transform(model(xb).cpu().numpy())
            for stem, pred in zip(stems, out):
                accum[stem].append(pred)
            batch_count += len(stems)
    
    print(f"{batch_count} predictions")

final_preds = {stem: np.mean(preds, axis=0) if preds else np.zeros(len(TARGET_COLS))
               for stem, preds in accum.items()}

print(f"\nEnsemble complete: {len(final_preds)} images processed")
print(f"Predictions per image: {len(fold_ckpts)} folds averaged\n")


### Prediction Preview

Shows the predictions in a readable format:
- Image name
- Predicted values for each biomass type (in grams)
- Useful to check that predictions make sense before submitting

In [None]:
print("Preview of first predictions (grams):")
print("-" * 60)
for i, (stem, vals) in enumerate(list(final_preds.items())[:3]):
    print(f"Image: {stem}")
    for target, value in zip(TARGET_COLS, vals):
        print(f"  • {target:15s}: {value:7.2f} g")
    if i < 2:
        print()
if len(final_preds) > 3:
    print(f"  ... and {len(final_preds) - 3} more images")
print("-" * 60)

---

# STEP 4: CREATE SUBMISSION

**Kaggle format**: Converts predictions to required format

**CSV structure**:
- **sample_id**: `{image}__{target}` (e.g. `ID1234__Dry_Total_g`)
- **target**: Predicted value in grams (minimum 0.0)

**Result**:
- Creates `submission.csv` with N_images × 5 targets = total rows
- Ready to upload to Kaggle and get your public/private score

In [None]:
print("\n" + "="*60)
print("STEP 4: CREATE SUBMISSION")
print("="*60)

rows = []
for stem, vals in final_preds.items():
    for target, value in zip(TARGET_COLS, vals):
        sample_id = f"{stem}__{target}"
        rows.append([sample_id, max(0.0, float(value))])

submission = pd.DataFrame(rows, columns=["sample_id", "target"])
submission_path = WORK / "submission.csv"
submission.to_csv(submission_path, index=False)

print(f"\nFull submission:")
print(submission.to_string(index=False))
print(f"\nSubmission saved: {submission_path}")
print("\nPIPELINE COMPLETE!")