In [67]:
"""
Agri-Foundational Hurdle Network (AFHN)
==================================================================
Production-ready PyTorch implementation for pasture biomass prediction
using Zero-Inflated LogNormal loss, DINOv2 backbone, and MIL aggregation.
"""

import os
import math
import warnings
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
from PIL import Image
from sklearn.model_selection import GroupKFold
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

from transformers import AutoModel, AutoConfig
import albumentations as A
from albumentations.pytorch import ToTensorV2

warnings.filterwarnings('ignore')

In [68]:
# ============================================================================
# CONSTANTS & CONFIGURATION
# ============================================================================

# DINOv2 native resolution
IMG_SIZE = 518
PATCH_STRIDE = 400  # ~23% overlap for smoother features
ORIG_H, ORIG_W = 1000, 2000
DINO_PATH = '/kaggle/input/dinov2/pytorch/large/1'

# Hyperparameters
BATCH_SIZE = 4
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
NUM_EPOCHS = 2
PATIENCE = 10

# Target names and weights (as per competition)
TARGET_NAMES = ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g', 'GDM_g', 'Dry_Total_g']
TARGET_WEIGHTS = {
    'Dry_Green_g': 0.1,
    'Dry_Dead_g': 0.1,
    'Dry_Clover_g': 0.1,
    'GDM_g': 0.2,
    'Dry_Total_g': 0.5
}

# State encoding (one-hot)
STATES = ['NSW', 'Tas', 'Vic', 'WA']
STATE_TO_IDX = {s: i for i, s in enumerate(STATES)}

# Dataset paths
TRAIN_CSV = '/kaggle/input/csiro-biomass/train.csv'
TRAIN_IMG_DIR = '/kaggle/input/csiro-biomass/'

TEST_CSV = '/kaggle/input/csiro-biomass/test.csv'
TEST_IMG_DIR = 'test_images/'
MODEL_PATH = 'afhn_fold0_best.pth' 

# Device configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

Using device: cuda


In [None]:
# ============================================================================
# LOSS FUNCTIONS
# ============================================================================

class ZeroInflatedLogNormalLoss(nn.Module):
    """
    Zero-Inflated LogNormal (ZILN) Loss for semi-continuous targets.
    
    Combines:
    1. Binary classification (zero vs non-zero)
    2. LogNormal regression for positive values
    
    Model outputs 3 values per target:
    - logit_prob: Logit for P(y > 0)
    - mu: Mean of log-normal distribution
    - sigma_raw: Raw scale parameter (transformed via softplus)
    """

    def __init__(self, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
    
    def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            preds: (batch, 3) - [logit_prob, mu, sigma_raw]
            target: (batch,) - Ground truth values (>= 0)
        
        Returns:
            Scalar loss value
        """
        logit_prob = preds[:, 0]
        mu = preds[:, 1]
        sigma_raw = preds[:, 2]
        
        # Ensure sigma > 0
        sigma = F.softplus(sigma_raw) + self.eps

        # Binary indicator: 1 if y > 0, else 0
        is_positive = (target > 0).float()
        
        # Classification loss (zero vs non-zero)
        class_loss = F.binary_cross_entropy_with_logits(
            logit_prob, is_positive, reduction='none'
        )
        
        # Regression loss (LogNormal NLL for positives)
        safe_target = torch.clamp(target, min=self.eps)
        log_target = torch.log(safe_target)
        
        reg_loss = (
            log_target +
            torch.log(sigma) +
            0.5 * math.log(2 * math.pi) +
            (log_target - mu).pow(2) / (2 * sigma.pow(2))
        )
        
        # Combined loss: classification always, regression only for positives
        total_loss = class_loss + (is_positive * reg_loss)
        
        return total_loss.mean()

class TweedieLoss(nn.Module):
    """
    Tweedie loss for compound Poisson-Gamma distribution.
    Suitable for strictly positive, heavy-tailed targets like Total Biomass.
    """
    def __init__(self, p: float = 1.5, epsilon: float = 1e-8):
        super().__init__()
        assert 1 < p < 2, "Tweedie power p must be in (1, 2)"
        self.p = p # Power parameter: 1 < p < 2 for Compound Poisson-Gamma
        self.epsilon = epsilon

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            pred: Predicted mean (batch,)
            target: Ground truth (batch,)
        """
        # Ensure positivity
        pred = torch.clamp(pred, min=self.epsilon) 
        target = torch.clamp(target, min=self.epsilon) 
        
        # Tweedie deviance
        term1 = -target * torch.pow(pred, 1 - self.p) / (1 - self.p)
        term2 = torch.pow(pred, 2 - self.p) / (2 - self.p)
        
        loss = term1 + term2
        return loss.mean()

In [70]:
# ============================================================================
# MULTIPLE INSTANCE LEARNING (MIL) MODULE
# ============================================================================

class GatedAttentionMIL(nn.Module):
    """
    Gated Attention mechanism for MIL aggregation.
    
    Based on: Ilse et al. (2018) "Attention-based Deep Multiple Instance Learning"
    Equation: a_k = softmax(w^T * (tanh(V*h_k) ⊙ sigmoid(U*h_k)))
    """

    def __init__(self, input_dim: int = 1024, hidden_dim: int = 256):
        super().__init__()
        self.attention_V = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Tanh()
        )
        self.attention_U = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Sigmoid()
        )
        self.attention_weights = nn.Linear(hidden_dim, 1)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: (batch, n_patches, embed_dim)
        
        Returns:
            aggregated: (batch, embed_dim) - Weighted sum of patches
            attention: (batch, n_patches, 1) - Attention weights per patch
        """

        # Gated attention computation
        A_V = self.attention_V(x) # (batch, n_patches, hidden_dim)
        A_U = self.attention_U(x) # (batch, n_patches, hidden_dim)

        # Element-wise gating
        A = self.attention_weights(A_V * A_U) # (batch, n_patches, 1)

        # Softmax over patches
        A = torch.softmax(A, dim=1) # (batch, n_patches, 1)

        # Weighted aggregation
        M = torch.bmm(A.transpose(1, 2), x) # (batch, 1, embed_dim)

        return M.squeeze(1), A

In [71]:
# ============================================================================
# FEATURE-WISE LINEAR MODULATION (FiLM) LAYER
# ============================================================================

class FiLM(nn.Module):
    """
    Feature-wise Linear Modulation for metadata injection.
    
    Modulates visual features with metadata (state, height, NDVI):
    z_mod = γ(m) ⊙ z + β(m)
    """

    def __init__(self, meta_dim: int, feat_dim: int):
        super().__init__()
        self.scale = nn.Linear(meta_dim, feat_dim)
        self.shift = nn.Linear(meta_dim, feat_dim)
    
    def forward(self, features: torch.Tensor, metadata: torch.Tensor) -> torch.Tensor:
        """
        Args:
            features: (batch, feat_dim)
            metadata: (batch, meta_dim)
        
        Returns:
            modulated: (batch, feat_dim)
        """
        gamma = self.scale(metadata) # (batch, feat_dim)
        beta = self.shift(metadata)  # (batch, feat_dim)

        return features * gamma + beta

In [72]:
# ============================================================================
# AGRI-FOUNDATIONAL HURDLE NETWORK (AFHN) MODEL
# ============================================================================

class AFHN(nn.Module):
    """
    Agri-Foundational Hurdle Network for biomass prediction.
    
    Architecture:
    1. DINOv2-Large (frozen) backbone for patch features
    2. Gated Attention MIL for patch aggregation
    3. FiLM layers for metadata modulation
    4. Hierarchical heads with physical constraints
    """

    def __init__(
        self,
        num_components: int = 5,
        meta_dim: int = 10,
        use_ziln: bool = True 
    ):
        super().__init__()
        self.num_components = num_components
        self.use_ziln = use_ziln

        # 1. Visual backbone: DINOv2-Large (frozen)
        print("Loading DINOv2-Large backbone from local storage...")
        # Load Config and Model offline
        config = AutoConfig.from_pretrained(DINO_PATH)
        self.backbone = AutoModel.from_pretrained(DINO_PATH, config=config)

        # Freeze all backbone parameters
        for param in self.backbone.parameters():
            param.requires_grad = False
        
        self.vis_dim = 1024  # DINOv2-Large embedding dimension

        # 2. MIL Aggregator
        self.mil = GatedAttentionMIL(input_dim=self.vis_dim, hidden_dim=256)

        # 3. Metadata Encoder + FiLM
        self.meta_encoder = nn.Sequential(
            nn.Linear(meta_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64)
        )
        self.film = FiLM(meta_dim=64, feat_dim=self.vis_dim)

        # 4. Prediction Heads

        # A. Total Biomass Head (ZILN or Tweedie)
        if self.use_ziln:
            self.head_total = nn.Sequential(
                nn.Linear(self.vis_dim, 256),
                nn.GELU(),
                nn.Dropout(0.2),
                nn.Linear(256, 3)  # [logit_prob, mu, sigma_raw]
            )
        else:
            self.head_total = nn.Sequential(
                nn.Linear(self.vis_dim, 256),
                nn.GELU(),
                nn.Dropout(0.2),
                nn.Linear(256, 1),
                nn.Softplus()  # Ensure positive prediction
            )
        
        # B. Component Ratios Head (Softmax for sum-to-one constraint)
        self.head_ratios = nn.Sequential(
            nn.Linear(self.vis_dim, 256),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_components)
        )

        # C. Component Gates (Hurdle mechanism)
        self.head_gate = nn.Sequential(
            nn.Linear(self.vis_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, num_components)
        )
    
    def forward_features(
        self,
        x_patches: torch.Tensor,
        metadata: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Extract and aggregate patch features.
        
        Args:
            x_patches: (batch, n_patches, 3, H, W)
            metadata: (batch, meta_dim)
        
        Returns:
            modulated_feat: (batch, vis_dim)
            attn_weights: (batch, n_patches, 1)
        """
        batch_size, n_patches, c, h, w = x_patches.shape

        # Flaten patches for batch processing
        x_flat = x_patches.view(batch_size * n_patches, c, h, w) # (B*n_patches, 3, H, W)

        # Extract patch features (frozen backbone, no grad)
        with torch.no_grad():
            # Transformers expects 'pixel_values' argument
            outputs = self.backbone(pixel_values=x_flat)
            
            # Get last_hidden_state: (B*n_patches, sequence_length, hidden_size)
            last_hidden_state = outputs.last_hidden_state
            
            # The CLS token is usually at index 0
            feat_flat = last_hidden_state[:, 0, :]  # (batch*n_patches, embed_dim)

        # Reshape to (batch, n_patches, embed_dim)
        feat_seq = feat_flat.view(batch_size, n_patches, -1)

        # MIL aggregation
        global_feat, attn_weights = self.mil(feat_seq) # (batch, vis_dim), (batch, n_patches, 1)

        # Meta Modulation via FiLM
        meta_emb = self.meta_encoder(metadata) # (batch, meta_feat_dim)
        modulated_feat = self.film(global_feat, meta_emb) # (batch, vis_dim)

        return modulated_feat, attn_weights

    def forward(
        self,
        x_patches: torch.Tensor,
        metadata: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        """
        Forward pass returning all predictions.
        
        Returns:
            Dict with keys: 'total', 'ratios', 'gates'
        """
        feat, _ = self.forward_features(x_patches, metadata)
        
        total = self.head_total(feat)
        raw_ratios = self.head_ratios(feat)
        ratios = F.softmax(raw_ratios, dim=1)
        gate_logits = self.head_gate(feat)
        
        return {
            'total': total,
            'ratios': ratios,
            'gates': gate_logits
        }
    
    def predict_components(
        self,
        x_patches: torch.Tensor,
        metadata: torch.Tensor
    ) -> torch.Tensor:
        """
        Inference: predict all 5 component biomass values.
        
        Returns:
            components: (batch, 5) - Predicted biomass for each component
        """
        outputs = self.forward(x_patches, metadata)

        # Decode total biomass
        if self.use_ziln:
            total_ziln = outputs['total']
            prob_nonzero = torch.sigmoid(total_ziln[:, 0])
            mu = total_ziln[:, 1]
            sigma = F.softplus(total_ziln[:, 2]) + 1e-6
            # Expected valud of ZILN
            expected_total = prob_nonzero * torch.exp(mu +0.5 * sigma.pow(2))
        else:
            expected_total = outputs['total'].squeeze(-1)
        
        # Decode gates (Bernoulli probability)
        gates = torch.sigmoid(outputs['gates'])

        # Apply gates to ratios and renormalize
        gated_ratios = outputs['ratios'] * gates
        sum_gated = gated_ratios.sum(dim=1, keepdim=True) + 1e-6
        final_ratios = gated_ratios / sum_gated

        # Component predictions (enforces sum constraint)
        components = expected_total.unsqueeze(1) * final_ratios

        return components

In [73]:
# ============================================================================
# DATASET CLASS
# ============================================================================

class PatchBiomassDataset(Dataset):
    """
    Dataset for CSIRO Biomass with on-the-fly patch extraction.
    
    Each sample contains:
    - Image patches (multiple instances)
    - Metadata (height, NDVI, state, species, date)
    - Target biomass values (5 components)
    """

    def __init__(
        self,
        df: pd.DataFrame,
        image_dir: str,
        transform: Optional[A.Compose] = None,
        patch_size: int = IMG_SIZE,
        stride: int = PATCH_STRIDE,
        is_test: bool = False
    ):
        # For test set, group by image_path (5 rows per image)
        if is_test:
            self.df = df.groupby('image_path').first().reset_index()
        else:
            # For train, group by sample_id (each image appears once)
            self.df = df.groupby('sample_id').first().reset_index()
        
        self.image_dir = Path(image_dir)
        self.transform = transform
        self.patch_size = patch_size
        self.stride = stride
        self.is_test = is_test

    def __len__(self) -> int:
        return len(self.df)

    def extract_patches(self, image: np.ndarray) -> torch.Tensor:
        """
        Extract overlapping patches from high-res image.
        
        Args:
            image: (H, W, 3) numpy array
        
        Returns:
            patches: (n_patches, 3, patch_size, patch_size)
        """
        h, w = image.shape[:2]
        patches = []

        for y in range(0, h - self.patch_size + 1, self.stride):
            for x in range(0, w - self.patch_size + 1, self.stride):
                patch = image[y:y+self.patch_size, x:x+self.patch_size]
                
                if self.transform:
                    augmented = self.transform(image=patch)
                    patch_tensor = augmented['image']
                else:
                    patch_tensor = torch.from_numpy(patch).permute(2, 0, 1).float() / 255.0
                
                patches.append(patch_tensor)
        
        return torch.stack(patches)


    def __getitem__(self, idx: int) -> Dict:
        row = self.df.iloc[idx]

        # Load image
        img_path = self.image_dir / row['image_path']
        image = np.array(Image.open(img_path).convert('RGB'))

        # Extract patches
        patches = self.extract_patches(image)

        # Encode metadata
        # State one-hot
        state_onehot = np.zeros(len(STATES))
        if row['State'] in STATE_TO_IDX:
            state_onehot[STATE_TO_IDX[row['State']]] = 1.0

        # Date encoding (cyclical)
        date = pd.to_datetime(row['Sampling_Date'])
        month = date.month
        month_sin = np.sin(2 * np.pi * month / 12)
        month_cos = np.cos(2 * np.pi * month / 12)

        # Combine metadata: [height, ndvi, state_onehot(4), month_sin, month_cos]
        metadata = np.concatenate([
            [row['Height_Ave_cm'], row['Pre_GSHH_NDVI']],
            state_onehot,
            [month_sin, month_cos]
        ]).astype(np.float32)

        metadata = torch.from_numpy(metadata)

        # Targets (if not test)
        if not self.is_test:
            targets = torch.tensor([
                row['Dry_Green_g'],
                row['Dry_Dead_g'],
                row['Dry_Clover_g'],
                row['GDM_g'],
                row['Dry_Total_g']
            ], dtype=torch.float32)
            
            return {
                'patches': patches,
                'metadata': metadata,
                'targets': targets,
                'sample_id': row['sample_id']
            }
        else:
            return {
                'patches': patches,
                'metadata': metadata,
                'sample_id': row.name # Index for test
            }

In [74]:
# ============================================================================
# DATA TRANSFORMS
# ============================================================================

def get_transforms(is_train: bool = True) -> A.Compose:
    """
    Albumentations transforms for image augmentation.
    
    DINOv2 expects ImageNet normalization.
    """
    if is_train:
        return A.Compose([
            A.RandomResizedCrop(
                size=(IMG_SIZE, IMG_SIZE),
                scale=(0.7, 1.0), 
                ratio=(0.9, 1.1), 
                p=1.0
            ),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.Transpose(p=0.5),

            A.OneOf([
                A.MotionBlur(p=0.2),
                A.GaussianBlur(p=0.2),
            ], p=0.2),

            # Photometric (cautious with hue to avoid confusing green/dead)
            A.OneOf([
                A.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05, p=0.3),
                A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
                A.RandomGamma(gamma_limit=(80, 120), p=0.3),
                A.HueSaturationValue(p=0.3),
            ], p=0.3),

            # Noise and regularization
            A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
            A.CoarseDropout(max_holes=8, max_height=32, max_width=32, fill_value=0, p=0.3),

            # Normalization (DINOv2 uses ImageNet stats)
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),
            ToTensorV2()
        ])
    else:
        return A.Compose([
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),
            ToTensorV2()
        ])

In [75]:
# ============================================================================
# METRICS
# ============================================================================

def weighted_r2_score(
    y_true: np.ndarray,
    y_pred: np.ndarray,
    weights: np.ndarray
) -> float:
    """
    Compute weighted R2 score as per competition metric.
    
    R2_w = 1 - (SS_res_w / SS_tot_w)
    """
    # Weighted mean
    y_mean = np.average(y_true, weights=weights)

    # Weighted residual sum of squares
    ss_res = np.sum(weights * (y_true - y_pred) ** 2)

    # Weighted total sum of squares
    ss_tot = np.sum(weights * (y_true - y_mean) ** 2)

    r2 = 1 - (ss_res / ss_tot + 1e-8)
    return r2

def compute_metrics(
    targets: torch.Tensor,
    predictions: torch.Tensor
) -> Dict[str, float]:
    """
    Compute all evaluation metrics.
    
    Args:
        targets: (n_samples, 5)
        predictions: (n_samples, 5)
    
    Returns:
        Dict of metric names to values
    """
    targets_np = targets.cpu().numpy()
    preds_np = predictions.cpu().numpy()

    # Per-component MAE
    component_mae = {}
    for i, name in enumerate(TARGET_NAMES):
        mae = np.abs(targets_np[:, i] - preds_np[:, i]).mean()
        component_mae[f'{name}_MAE'] = mae

    # Weighted R2
    # Expand weights to match row-wise structure
    n_samples = len(targets_np)
    weights_expanded = np.array([TARGET_WEIGHTS[name] for name in TARGET_NAMES])
    weights_tiled = np.tile(weights_expanded, n_samples)
    
    targets_flat = targets_np.flatten()
    preds_flat = preds_np.flatten()
    
    r2_weighted = weighted_r2_score(targets_flat, preds_flat, weights_tiled)
    
    return {
        **component_mae,
        'Weighted_R2': r2_weighted
    }

In [None]:
# ============================================================================
# TRAINING LOOP
# ============================================================================

def train_epoch(
    model: AFHN,
    dataloader: DataLoader,
    optimizer: optim.Optimizer,
    device: torch.device,
    use_ziln: bool = True
) -> float:
    """Train for one epoch."""
    model.train()
    total_loss = 0
    
    pbar = tqdm(dataloader, desc='Training')
    for batch in pbar:
        patches = batch['patches'].to(device)
        metadata = batch['metadata'].to(device)
        targets = batch['targets'].to(device)
        
        optimizer.zero_grad()
        
        # Get predicted components (final output)
        pred_components = model.predict_components(patches, metadata)
        
        # Component-wise losses
        if use_ziln:
            # Use ZILN loss for each component
            criterion = ZeroInflatedLogNormalLoss()
            # Need to get ZILN params for each component
            # Since model outputs total×ratios×gates, use MSE instead
            loss_components = F.mse_loss(pred_components, targets)
        else:
            loss_components = F.mse_loss(pred_components, targets)
        
        # Gate loss (encourage correct zero prediction)
        outputs = model(patches, metadata)
        gate_targets = (targets > 0).float()
        loss_gate = F.binary_cross_entropy_with_logits(
            outputs['gates'], gate_targets
        )
        
        # Total biomass consistency loss
        pred_total = pred_components.sum(dim=1)
        target_total = targets[:, -1]  # Dry_Total_g is last column
        loss_total = F.mse_loss(pred_total, target_total)
        
        # Combined loss
        loss = loss_components + 0.1 * loss_gate + 0.2 * loss_total
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / len(dataloader)

def validate(
    model: AFHN,
    dataloader: DataLoader,
    device: torch.device
) -> Tuple[float, Dict[str, float]]:
    """Validate model."""
    model.eval()

    all_targets = []
    all_predictions = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Validation'):
            patches = batch['patches'].to(device)
            metadata = batch['metadata'].to(device)
            targets = batch['targets']

            predictions = model.predict_components(patches, metadata)

            all_targets.append(targets)
            all_predictions.append(predictions.cpu())
    
    all_targets = torch.cat(all_targets)
    all_predictions = torch.cat(all_predictions)

    metrics = compute_metrics(all_targets, all_predictions)

    return metrics['Weighted_R2'], metrics

In [77]:
# ============================================================================
# MAIN TRAINING SCRIPT
# ============================================================================

def main():
    """Main training function."""
    # Load data
    print("Loading training data...")
    train_df = pd.read_csv(TRAIN_CSV)

    # Pivot from long to wide format
    train_df_wide = train_df.pivot_table(
        index=['sample_id', 'image_path', 'Sampling_Date', 'State',
               'Species', 'Pre_GSHH_NDVI', 'Height_Ave_cm'],
        columns='target_name',
        values='target'
    ).reset_index()

    # Group K-Fold split (by State to ensure generalization)
    groups = train_df_wide['State'].values
    n_groups = len(np.unique(groups))
    print(f"Found {n_groups} unique states: {np.unique(groups)}")
    
    # Adjust n_splits to not exceed the number of groups
    n_splits = 5
    if n_groups < n_splits:
        print(f"Warning: Requested {n_splits} splits but only found {n_groups} groups.")
        print(f"Adjusting n_splits to {n_groups} (Leave-One-Group-Out CV).")
        n_splits = n_groups

    # Initialize GroupKFold with the safe number of splits
    gkf = GroupKFold(n_splits=n_splits)
    
    fold = 0
    for train_idx, val_idx in gkf.split(train_df_wide, groups=groups):
        print(f"\n{'='*60}")
        print(f"Fold {fold + 1}/5")
        print(f"{'='*60}")
        
        train_fold = train_df_wide.iloc[train_idx]
        val_fold = train_df_wide.iloc[val_idx]
        
        # Datasets
        train_dataset = PatchBiomassDataset(
            train_fold, TRAIN_IMG_DIR,
            transform=get_transforms(is_train=True)
        )
        val_dataset = PatchBiomassDataset(
            val_fold, TRAIN_IMG_DIR,
            transform=get_transforms(is_train=False)
        )
        
        train_loader = DataLoader(
            train_dataset, batch_size=BATCH_SIZE,
            shuffle=True, num_workers=4, pin_memory=True
        )
        val_loader = DataLoader(
            val_dataset, batch_size=BATCH_SIZE,
            shuffle=False, num_workers=4, pin_memory=True
        )
        
        # Model
        model = AFHN(num_components=5, meta_dim=8, use_ziln=True).to(DEVICE)
        
        # Optimizer & Scheduler
        optimizer = optim.AdamW(
            model.parameters(),
            lr=LEARNING_RATE,
            weight_decay=WEIGHT_DECAY
        )
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=10, T_mult=2
        )
        
        # Loss functions
        criterion_components = ZeroInflatedLogNormalLoss()
        criterion_total = TweedieLoss(p=1.5)
        
        # Training loop
        best_r2 = -float('inf')
        patience_counter = 0
        
        for epoch in range(NUM_EPOCHS):
            print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")
            
            train_loss = train_epoch(
                model, train_loader, optimizer,
                criterion_components, criterion_total, DEVICE
            )
            
            val_r2, val_metrics = validate(model, val_loader, DEVICE)
            
            scheduler.step()
            
            print(f"Train Loss: {train_loss:.4f}")
            print(f"Val Weighted R2: {val_r2:.4f}")
            for k, v in val_metrics.items():
                if k != 'Weighted_R2':
                    print(f"  {k}: {v:.4f}")
            
            # Early stopping
            if val_r2 > best_r2:
                best_r2 = val_r2
                patience_counter = 0
                torch.save(model.state_dict(), f'afhn_fold{fold}_best.pth')
                print(f"Saved best model (R2={best_r2:.4f})")
            else:
                patience_counter += 1
                if patience_counter >= PATIENCE:
                    print(f"Early stopping at epoch {epoch + 1}")
                    break
        
        print(f"\nFold {fold + 1} Best R2: {best_r2:.4f}")
        fold += 1
        
        # Only train one fold for demo (remove break for full CV)
        break


def inference_example():
    """Example inference on test set."""
    # Load test data
    test_df = pd.read_csv(TEST_CSV)

    test_dataset = PatchBiomassDataset(
        test_df, TEST_IMG_DIR,
        transform=get_transforms(is_train=False),
        is_test=True
    )

    test_loader = DataLoader(
        test_dataset, batch_size=4,
        shuffle=False, num_workers=4
    )

    # Load model
    model = AFHN(num_components=5, meta_dim=8, use_ziln=True).to(DEVICE)
    model.load_state_dict(torch.load(MODEL_PATH))

    # Predict
    predictions = []
    sample_ids = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc='Inference'):
            patches = batch['patches'].to(DEVICE)
            metadata = batch['metadata'].to(DEVICE)
            
            preds = model.predict_components(patches, metadata)
            predictions.append(preds.cpu().numpy())
            sample_ids.extend(batch['sample_id'])
    
    predictions = np.vstack(predictions)

    # Create submission
    submission_rows = []

    for i, img_id in enumerate(sample_ids):
        for j, target_name in enumerate(TARGET_NAMES):
            submission_rows.append({
                'sample_id': f"{img_id}__{target_name}",
                'target': predictions[i, j]
            })
    
    submission_df = pd.DataFrame(submission_rows)
    submission_df.to_csv('submission.csv', index=False)
    print("Submission saved to submission.csv")

In [78]:
if __name__ == '__main__':
    # Run training
    main()
    # inference_example()

Loading training data...
Found 4 unique states: ['NSW' 'Tas' 'Vic' 'WA']
Adjusting n_splits to 4 (Leave-One-Group-Out CV).

Fold 1/5
Loading DINOv2-Large backbone from local storage...

Epoch 1/2


Training: 100%|██████████| 274/274 [48:17<00:00, 10.58s/it, loss=nan]
Validation: 100%|██████████| 173/173 [30:30<00:00, 10.58s/it]


Train Loss: nan
Val Weighted R2: nan
  Dry_Green_g_MAE: nan
  Dry_Dead_g_MAE: nan
  Dry_Clover_g_MAE: nan
  GDM_g_MAE: nan
  Dry_Total_g_MAE: nan

Epoch 2/2


Training:   1%|          | 2/274 [00:33<1:16:25, 16.86s/it, loss=nan]


KeyboardInterrupt: 