# CSIRO Image2Biomass Prediction - Complete Model Training Strategy
# üéØ Project Overview
This pipeline predicts 5 biomass components from pasture images using a multi-modal deep learning approach:

Dry_Green_g - Dry green biomass

Dry_Dead_g - Dry dead biomass

Dry_Clover_g - Dry clover biomass

GDM_g - Green dry matter

Dry_Total_g - Total dry biomass

# üèóÔ∏è Architecture Strategy
Multi-Modal Fusion Model
The model combines visual features from images with tabular metadata for robust predictions:

1. Image Encoder (EfficientNetV2-M)
Backbone: tf_efficientnetv2_m (pre-trained on ImageNet)

Input: 512√ó512 RGB images

Features: Global average pooling ‚Üí 1,280-dimensional feature vector

Advantage: Pre-trained weights enable effective feature extraction from pasture images

# 2. Tabular Encoder (MLP)
Inputs:

Pre_GSHH_NDVI - Normalized Difference Vegetation Index

Height_Ave_cm - Average vegetation height

Architecture:

2 ‚Üí 64 ‚Üí 128 dimensions

BatchNorm + ReLU + Dropout (0.3)

Purpose: Capture domain-specific environmental context

# 3. Fusion Layer
Input: Image features (1,280D) + Tabular features (128D) = 1,408D

Architecture:

1,408 ‚Üí 512 ‚Üí 256 dimensions

BatchNorm + ReLU + Dropout (0.4‚Üí0.3)

Function: Learn complex interactions between visual and environmental features

# 4. Multi-Head Output
5 Separate Heads: Each biomass component has dedicated output layer

Architecture: 256 ‚Üí 1 (linear layer for each target)

Benefit: Prevents interference between different biomass types

In [None]:
# ============================================================================
# CSIRO Image2Biomass Prediction - Complete End-to-End Pipeline
# ============================================================================
# This pipeline predicts 5 biomass components from pasture images:
# - Dry_Green_g, Dry_Dead_g, Dry_Clover_g, GDM_g, Dry_Total_g
# ============================================================================

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
import cv2
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# CONFIGURATION
# ============================================================================
class CFG:
    # Paths
    train_csv = '/kaggle/input/csiro-biomass/train.csv'
    test_csv = '/kaggle/input/csiro-biomass/test.csv'
    train_dir = '/kaggle/input/csiro-biomass/train'
    test_dir = '/kaggle/input/csiro-biomass/test/'
    
    # Model
    model_name = 'tf_efficientnetv2_m'  # EfficientNetV2-M for better performance
    img_size = 512  # Higher resolution for detail
    pretrained = True
    
    # Training
    n_folds = 5
    seed = 42
    epochs = 50
    batch_size = 16
    num_workers = 4
    lr = 3e-4  # Increased learning rate
    weight_decay = 1e-5
    warmup_epochs = 2  # Add warmup
    
    # Augmentation
    use_tta = True
    tta_steps = 5
    
    # Targets
    targets = ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g', 'GDM_g', 'Dry_Total_g']
    target_weights = [0.1, 0.1, 0.1, 0.2, 0.5]  # From evaluation criteria
    
    # Target scaling (CRITICAL FIX)
    use_target_scaling = True  # Scale targets to reasonable range
    
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set random seeds for reproducibility
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(CFG.seed)

# ============================================================================
# DATA PREPROCESSING
# ============================================================================
def prepare_data(train_csv_path):
    """
    Prepare training data by pivoting from long to wide format.
    Each image has 5 rows (one per target), we combine them into 1 row.
    """
    df = pd.read_csv(train_csv_path)
    
    # The CSV is already in long format with one row per (image, target) pair
    # We need to pivot so each image becomes one row with all 5 targets as columns
    
    # First, get the unique identifier for each image (excluding target columns)
    # Extract just the image ID from sample_id
    df['image_id'] = df['sample_id'].str.split('__').str[0] if '__' in df['sample_id'].iloc[0] else df['sample_id']
    
    # Group by image and get metadata (should be same for all targets of same image)
    metadata_cols = ['image_path', 'Sampling_Date', 'State', 'Species', 'Pre_GSHH_NDVI', 'Height_Ave_cm']
    
    # Pivot to wide format
    df_pivot = df.pivot_table(
        index=['image_id'] + metadata_cols,
        columns='target_name',
        values='target',
        aggfunc='first'  # Use first value if duplicates exist
    ).reset_index()
    
    # Ensure all 5 target columns exist and fill any NaN with 0
    for target in CFG.targets:
        if target not in df_pivot.columns:
            df_pivot[target] = 0.0
        else:
            df_pivot[target] = df_pivot[target].fillna(0.0)
    
    # Create stratification bins based on total biomass
    # This ensures balanced folds across biomass ranges
    # Use robust binning to handle edge cases
    try:
        df_pivot['biomass_bin'] = pd.qcut(
            df_pivot['Dry_Total_g'], 
            q=10, 
            labels=False, 
            duplicates='drop'
        )
    except ValueError:
        # If qcut fails, use cut with equal-width bins
        df_pivot['biomass_bin'] = pd.cut(
            df_pivot['Dry_Total_g'], 
            bins=10, 
            labels=False
        )
    
    # Fill any remaining NaN in biomass_bin with a default value
    df_pivot['biomass_bin'] = df_pivot['biomass_bin'].fillna(0).astype(int)
    
    print(f"Prepared {len(df_pivot)} unique images")
    print(f"Target columns: {CFG.targets}")
    print(f"Sample biomass statistics:")
    for target in CFG.targets:
        print(f"  {target}: mean={df_pivot[target].mean():.2f}, std={df_pivot[target].std():.2f}")
    
    return df_pivot

# ============================================================================
# DATASET CLASS
# ============================================================================
class BiomassDataset(Dataset):
    """
    Custom dataset for loading pasture images and metadata.
    Returns: image tensor, tabular features, and target values
    """
    def __init__(self, df, img_dir, transform=None, is_test=False, scaler=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.is_test = is_test
        
        # Prepare tabular features (NDVI and Height)
        tabular_data = df[['Pre_GSHH_NDVI', 'Height_Ave_cm']].fillna(0).values
        
        if not is_test:
            if scaler is None:
                self.scaler = StandardScaler()
                self.tabular_features = self.scaler.fit_transform(tabular_data)
            else:
                self.scaler = scaler
                self.tabular_features = self.scaler.transform(tabular_data)
        else:
            if scaler is not None:
                self.scaler = scaler
                self.tabular_features = self.scaler.transform(tabular_data)
            else:
                self.tabular_features = tabular_data
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Load image
        img_path = f"{self.img_dir}/{row['image_path'].split('/')[-1]}"
        image = cv2.imread(img_path)
        
        if image is None:
            raise ValueError(f"Failed to load image: {img_path}")
            
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Apply augmentations
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        
        # Get tabular features
        tabular = torch.tensor(self.tabular_features[idx], dtype=torch.float32)
        
        if self.is_test:
            return image, tabular
        else:
            # Get all 5 target values
            targets = torch.tensor([
                row['Dry_Green_g'],
                row['Dry_Dead_g'],
                row['Dry_Clover_g'],
                row['GDM_g'],
                row['Dry_Total_g']
            ], dtype=torch.float32)
            
            return image, tabular, targets

# ============================================================================
# AUGMENTATION STRATEGIES
# ============================================================================
def get_train_transforms():
    """
    Strong augmentation for training to improve generalization.
    Includes geometric, color, and quality transforms.
    """
    return A.Compose([
        A.Resize(CFG.img_size, CFG.img_size),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=15, p=0.5),
        
        # Color augmentations (important for varying lighting conditions)
        A.OneOf([
            A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=1),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1),
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=1),
        ], p=0.7),
        
        # Quality degradation (simulate camera variations)
        A.OneOf([
            A.GaussNoise(var_limit=(10.0, 50.0), p=1),
            A.GaussianBlur(blur_limit=(3, 7), p=1),
            A.MotionBlur(blur_limit=5, p=1),
        ], p=0.3),
        
        A.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.3),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

def get_valid_transforms():
    """Simple transforms for validation (no augmentation)"""
    return A.Compose([
        A.Resize(CFG.img_size, CFG.img_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

# ============================================================================
# MODEL ARCHITECTURE
# ============================================================================
class BiomassModel(nn.Module):
    """
    Multi-modal model combining:
    1. EfficientNet for image features
    2. MLP for tabular features (NDVI, Height)
    3. Fusion layer combining both modalities
    4. 5 output heads (one per target)
    """
    def __init__(self, model_name, pretrained=True):
        super(BiomassModel, self).__init__()
        
        # Image encoder (EfficientNet)
        self.backbone = timm.create_model(
            model_name, 
            pretrained=pretrained,
            num_classes=0,  # Remove classification head
            global_pool='avg'
        )
        
        # Get feature dimension from backbone
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, CFG.img_size, CFG.img_size)
            img_features = self.backbone(dummy_input).shape[1]
        
        # Tabular feature encoder (for NDVI and Height)
        self.tabular_encoder = nn.Sequential(
            nn.Linear(2, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
        )
        
        # Fusion layer
        fusion_dim = img_features + 128
        self.fusion = nn.Sequential(
            nn.Linear(fusion_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
        )
        
        # Output heads (5 separate heads for better learning)
        self.head_green = nn.Linear(256, 1)
        self.head_dead = nn.Linear(256, 1)
        self.head_clover = nn.Linear(256, 1)
        self.head_gdm = nn.Linear(256, 1)
        self.head_total = nn.Linear(256, 1)
    
    def forward(self, image, tabular):
        # Extract image features
        img_features = self.backbone(image)
        
        # Extract tabular features
        tab_features = self.tabular_encoder(tabular)
        
        # Concatenate features
        combined = torch.cat([img_features, tab_features], dim=1)
        
        # Fusion
        fused = self.fusion(combined)
        
        # Predict all 5 targets
        out_green = self.head_green(fused)
        out_dead = self.head_dead(fused)
        out_clover = self.head_clover(fused)
        out_gdm = self.head_gdm(fused)
        out_total = self.head_total(fused)
        
        # Stack outputs [batch_size, 5]
        outputs = torch.cat([out_green, out_dead, out_clover, out_gdm, out_total], dim=1)
        
        return outputs

# ============================================================================
# LOSS FUNCTION
# ============================================================================
class WeightedMSELoss(nn.Module):
    """
    Weighted MSE loss matching the competition metric.
    Each target has a different weight in final score.
    """
    def __init__(self, weights):
        super(WeightedMSELoss, self).__init__()
        self.weights = torch.tensor(weights, dtype=torch.float32)
    
    def forward(self, predictions, targets):
        self.weights = self.weights.to(predictions.device)
        
        # MSE for each target
        mse_per_target = (predictions - targets) ** 2
        
        # Apply weights
        weighted_mse = mse_per_target * self.weights.unsqueeze(0)
        
        # Return mean loss
        return weighted_mse.mean()

# ============================================================================
# METRIC CALCULATION (R¬≤ Score)
# ============================================================================
def calculate_r2_score(y_true, y_pred):
    """
    Calculate R¬≤ (coefficient of determination) for model evaluation.
    R¬≤ = 1 - (SS_res / SS_tot)
    """
    ss_res = np.sum((y_true - y_pred) ** 2)
    ss_tot = np.sum((y_true - y_true.mean()) ** 2)
    
    if ss_tot == 0:
        return 0.0
    
    r2 = 1 - (ss_res / ss_tot)
    return r2

def calculate_weighted_r2(y_true, y_pred, weights):
    """
    Calculate weighted R¬≤ score across all 5 targets.
    This matches the competition evaluation metric.
    """
    scores = []
    for i in range(5):
        r2 = calculate_r2_score(y_true[:, i], y_pred[:, i])
        scores.append(r2)
    
    weighted_score = sum(s * w for s, w in zip(scores, weights))
    return weighted_score, scores

# ============================================================================
# TRAINING FUNCTION
# ============================================================================
def train_epoch(model, loader, optimizer, criterion, device, scaler):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    
    pbar = tqdm(loader, desc='Training')
    for batch_idx, (images, tabular, targets) in enumerate(pbar):
        images = images.to(device)
        tabular = tabular.to(device)
        targets = targets.to(device)
        
        optimizer.zero_grad()
        
        # Mixed precision training for speed
        with torch.cuda.amp.autocast(enabled=True):
            outputs = model(images, tabular)
            loss = criterion(outputs, targets)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item()
        pbar.set_postfix({'loss': running_loss / (pbar.n + 1)})
        
        # Debug: Print first batch predictions
        if batch_idx == 0:
            print(f"\n  Sample predictions: {outputs[0].detach().cpu().numpy()}")
            print(f"  Sample targets:     {targets[0].cpu().numpy()}")
    
    return running_loss / len(loader)

# ============================================================================
# VALIDATION FUNCTION
# ============================================================================
def validate_epoch(model, loader, criterion, device):
    """Validate and calculate R¬≤ score"""
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for images, tabular, targets in tqdm(loader, desc='Validation'):
            images = images.to(device)
            tabular = tabular.to(device)
            targets = targets.to(device)
            
            outputs = model(images, tabular)
            loss = criterion(outputs, targets)
            
            running_loss += loss.item()
            all_preds.append(outputs.cpu().numpy())
            all_targets.append(targets.cpu().numpy())
    
    all_preds = np.vstack(all_preds)
    all_targets = np.vstack(all_targets)
    
    # Calculate R¬≤ scores
    weighted_r2, individual_r2 = calculate_weighted_r2(
        all_targets, all_preds, CFG.target_weights
    )
    
    return running_loss / len(loader), weighted_r2, individual_r2, all_preds, all_targets

# ============================================================================
# TRAINING LOOP (K-FOLD CROSS-VALIDATION)
# ============================================================================
def train_kfold(df, fold):
    """Train a single fold"""
    print(f"\n{'='*50}")
    print(f"Training Fold {fold + 1}/{CFG.n_folds}")
    print(f"{'='*50}")
    
    # Split data
    train_df = df[df['fold'] != fold].copy()
    valid_df = df[df['fold'] == fold].copy()
    
    print(f"Train size: {len(train_df)}, Valid size: {len(valid_df)}")
    
    # CRITICAL: Check target distribution
    print(f"\nTarget statistics (training set):")
    for target in CFG.targets:
        print(f"  {target}: mean={train_df[target].mean():.2f}, std={train_df[target].std():.2f}, "
              f"min={train_df[target].min():.2f}, max={train_df[target].max():.2f}")
    
    # Create target scaler if enabled
    target_scaler = None
    if CFG.use_target_scaling:
        target_scaler = StandardScaler()
        target_values = train_df[CFG.targets].values
        target_scaler.fit(target_values)
        
        # Scale targets in dataframes
        train_df[CFG.targets] = target_scaler.transform(train_df[CFG.targets].values)
        valid_df[CFG.targets] = target_scaler.transform(valid_df[CFG.targets].values)
        print("\n‚úì Targets scaled to zero mean and unit variance")
    
    # Create datasets with shared scaler
    train_dataset = BiomassDataset(
        train_df, CFG.train_dir, transform=get_train_transforms()
    )
    valid_dataset = BiomassDataset(
        valid_df, CFG.train_dir, transform=get_valid_transforms(),
        scaler=train_dataset.scaler  # Use same scaler for validation
    )
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, batch_size=CFG.batch_size, 
        shuffle=True, num_workers=CFG.num_workers, pin_memory=True
    )
    valid_loader = DataLoader(
        valid_dataset, batch_size=CFG.batch_size * 2,
        shuffle=False, num_workers=CFG.num_workers, pin_memory=True
    )
    
    # Initialize model, loss, optimizer
    model = BiomassModel(CFG.model_name, CFG.pretrained).to(CFG.device)
    criterion = WeightedMSELoss(CFG.target_weights)
    optimizer = optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    
    # Learning rate scheduler with warmup
    def lr_lambda(epoch):
        if epoch < CFG.warmup_epochs:
            return (epoch + 1) / CFG.warmup_epochs
        return 1.0
    
    warmup_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    main_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=10, T_mult=2, eta_min=1e-6
    )
    
    scaler = torch.cuda.amp.GradScaler()
    
    best_score = -np.inf
    patience_counter = 0
    patience = 13
    
    for epoch in range(CFG.epochs):
        print(f"\nEpoch {epoch + 1}/{CFG.epochs}")
        print(f"Learning rate: {optimizer.param_groups[0]['lr']:.6f}")
        
        # Train
        train_loss = train_epoch(model, train_loader, optimizer, criterion, CFG.device, scaler)
        
        # Validate
        valid_loss, weighted_r2, individual_r2, all_preds, all_targets = validate_epoch(
            model, valid_loader, criterion, CFG.device
        )
        
        # Scale predictions back if needed
        if target_scaler is not None:
            all_preds_original = target_scaler.inverse_transform(all_preds)
            all_targets_original = target_scaler.inverse_transform(all_targets)
            
            # Recalculate R¬≤ on original scale
            weighted_r2_original, individual_r2_original = calculate_weighted_r2(
                all_targets_original, all_preds_original, CFG.target_weights
            )
            
            print(f"Train Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f}")
            print(f"Weighted R¬≤ (scaled): {weighted_r2:.4f}")
            print(f"Weighted R¬≤ (original): {weighted_r2_original:.4f}")
            print(f"Individual R¬≤ (original): {individual_r2_original}")
            
            # Use original scale for model selection
            score_to_use = weighted_r2_original
        else:
            print(f"Train Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f}")
            print(f"Weighted R¬≤: {weighted_r2:.4f}")
            print(f"Individual R¬≤: {individual_r2}")
            score_to_use = weighted_r2
        
        # Update scheduler
        if epoch < CFG.warmup_epochs:
            warmup_scheduler.step()
        else:
            main_scheduler.step()
        
        # Save best model
        if score_to_use > best_score:
            best_score = score_to_use
            # Save model and scalers
            checkpoint = {
                'model_state_dict': model.state_dict(),
                'tabular_scaler': train_dataset.scaler,
                'target_scaler': target_scaler
            }
            torch.save(checkpoint, f'best_model_fold{fold}.pth')
            print(f"‚úì Saved best model (R¬≤: {best_score:.4f})")
            patience_counter = 0
        else:
            patience_counter += 1
        
        # Early stopping
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch + 1}")
            break
    
    return best_score


# ============================================================================
# MAIN EXECUTION
# ============================================================================
def main():
    # 1. Prepare data
    print("Loading and preparing data...")
    print(f"Reading from: {CFG.train_csv}")
    
    # First, let's check the format of the CSV
    df_raw = pd.read_csv(CFG.train_csv)
    print(f"\nRaw data shape: {df_raw.shape}")
    print(f"Columns: {df_raw.columns.tolist()}")
    print(f"\nFirst few rows:")
    print(df_raw.head(10))
    
    # Check if data needs pivoting
    if 'target_name' in df_raw.columns and 'target' in df_raw.columns:
        print("\n‚úì Data is in long format, will pivot...")
        df = prepare_data(CFG.train_csv)
    else:
        print("\n‚úì Data appears to be in wide format already")
        df = df_raw.copy()
        # Ensure biomass_bin exists
        try:
            df['biomass_bin'] = pd.qcut(df['Dry_Total_g'], q=10, labels=False, duplicates='drop')
        except:
            df['biomass_bin'] = pd.cut(df['Dry_Total_g'], bins=10, labels=False)
        df['biomass_bin'] = df['biomass_bin'].fillna(0).astype(int)
    
    print(f"\n‚úì Final processed data shape: {df.shape}")
    print(f"‚úì Checking for NaN values in targets:")
    for target in CFG.targets:
        nan_count = df[target].isna().sum()
        print(f"  {target}: {nan_count} NaN values")
    
    # 2. Create folds
    print("\nCreating cross-validation folds...")
    skf = StratifiedKFold(n_splits=CFG.n_folds, shuffle=True, random_state=CFG.seed)
    df['fold'] = -1
    
    # Ensure no NaN in biomass_bin before splitting
    assert df['biomass_bin'].isna().sum() == 0, "NaN found in biomass_bin!"
    
    for fold, (_, val_idx) in enumerate(skf.split(df, df['biomass_bin'])):
        df.loc[val_idx, 'fold'] = fold
    
    print("‚úì Fold distribution:")
    print(df['fold'].value_counts().sort_index())
    
    # 3. Train all folds
    fold_scores = []
    for fold in range(CFG.n_folds):
        score = train_kfold(df, fold)
        fold_scores.append(score)
    
    print(f"\n{'='*50}")
    print(f"Cross-Validation Results:")
    print(f"{'='*50}")
    for i, score in enumerate(fold_scores):
        print(f"Fold {i+1}: {score:.4f}")
    print(f"Mean CV Score: {np.mean(fold_scores):.4f} ¬± {np.std(fold_scores):.4f}")


if __name__ == '__main__':
    main()

# CSIRO Image2Biomass Prediction - Inference Pipeline Documentation

## üéØ Inference Pipeline Overview

This pipeline performs **ensemble inference** using multiple trained models to predict 5 biomass components from pasture images. The system combines predictions from cross-validation folds with optional Test-Time Augmentation (TTA) for robust performance.

## üèóÔ∏è Inference Architecture

### Multi-Model Ensemble Strategy
```python
# Ensemble Configuration
n_folds = 5                    # Number of trained models
use_tta = True                 # Test-Time Augmentation
tta_steps = 5                  # Number of augmentation variations
batch_size = 32               # Larger batches for inference speed
```

### Prediction Flow
1. **Load Test Data** ‚Üí 2. **Load Model Checkpoints** ‚Üí 3. **Generate Fold Predictions** ‚Üí 4. **Ensemble Averaging** ‚Üí 5. **Create Submission**

## üìä Data Processing for Inference

### Test Data Structure Handling
```python
# Test data format: 5 rows per image (one per target)
# Required transformation: Extract unique images for prediction

test_df_unique = test_df.drop_duplicates(subset=['image_path'])
# Input: 5 rows per image ‚Üí Output: 1 row per image
```

### Metadata Handling Strategy
```python
# Smart metadata detection
has_metadata = 'Pre_GSHH_NDVI' in test_df.columns and 'Height_Ave_cm' in test_df.columns

if has_metadata:
    # Use actual environmental data
    tabular_data = df[['Pre_GSHH_NDVI', 'Height_Ave_cm']].fillna(0)
else:
    # Fallback: Use zeros (scaled mean from training)
    tabular_data = np.zeros((len(df), 2))
    print("‚ö† No metadata - using scaled mean values")
```

## üé® Test-Time Augmentation (TTA) Strategy

### TTA Transform Variations
```python
tta_transforms = [
    # 1. Original (no augmentation) - Baseline
    A.Compose([Resize, Normalize, ToTensor]),
    
    # 2. Horizontal Flip - Mirror image
    A.Compose([Resize, HorizontalFlip(p=1.0), Normalize, ToTensor]),
    
    # 3. Vertical Flip - Upside-down
    A.Compose([Resize, VerticalFlip(p=1.0), Normalize, ToTensor]),
    
    # 4. 90¬∞ Rotation - Different perspective
    A.Compose([Resize, Rotate(limit=(90, 90), p=1.0), Normalize, ToTensor]),
    
    # 5. Brightness/Contrast - Lighting variations
    A.Compose([Resize, RandomBrightnessContrast(brightness_limit=0.1), Normalize, ToTensor]),
]
```

### TTA Prediction Averaging
```python
# For each TTA variation:
tta_predictions = []
for tta_transform in tta_transforms:
    preds = model_predict_with_transform(tta_transform)
    tta_predictions.append(preds)

# Ensemble average
final_predictions = np.mean(tta_predictions, axis=0)
```

**Benefits of TTA:**
- ‚úÖ **Improved robustness** to image variations
- ‚úÖ **Better generalization** without retraining
- ‚úÖ **Reduced overfitting** to specific image orientations
- ‚úÖ **No additional training cost**

## üîß Model Loading & Compatibility

### Checkpoint Loading System
```python
def load_model_checkpoint(fold):
    checkpoint = torch.load(f'best_model_fold{fold}.pth', 
                          map_location=CFG.device, 
                          weights_only=False)
    
    # Handle different checkpoint formats
    if 'model_state_dict' in checkpoint:
        # New format with scalers
        model_state = checkpoint['model_state_dict']
        tabular_scaler = checkpoint['tabular_scaler']
        target_scaler = checkpoint['target_scaler']
    else:
        # Legacy format support
        model_state = checkpoint
        tabular_scaler, target_scaler = None, None
    
    return model_state, tabular_scaler, target_scaler
```

### Critical Compatibility Checks
- ‚úÖ **Model architecture** matches training exactly
- ‚úÖ **Image size** (512√ó512) consistent
- ‚úÖ **Feature dimensions** aligned
- ‚úÖ **Scaler objects** preserved for consistent preprocessing

## ‚öôÔ∏è Inference Optimization

### Memory Management
```python
# Batch size optimization
batch_size = 32  # Larger than training (no gradient computation)

# GPU memory cleanup
del model
torch.cuda.empty_cache()  # After each fold prediction
```

### Parallel Processing
```python
DataLoader(
    dataset, 
    batch_size=CFG.batch_size,
    shuffle=False,           # No need to shuffle for inference
    num_workers=CFG.num_workers,  # Parallel data loading
    pin_memory=True          # Faster GPU transfer
)
```

## üìà Ensemble Strategy

### Fold-Level Ensemble
```python
# Collect predictions from all trained folds
all_fold_predictions = []
for fold in range(n_folds):
    fold_preds = predict_with_tta(model_fold, dataset)
    all_fold_predictions.append(fold_preds)

# Simple averaging ensemble
final_predictions = np.mean(all_fold_predictions, axis=0)
```

**Ensemble Benefits:**
- ‚úÖ **Reduces variance** from individual model randomness
- ‚úÖ **Improves generalization** across different data splits
- ‚úÖ **More stable predictions** for competition evaluation
- ‚úÖ **Leverages full training data** via cross-validation

## üìã Submission Format Generation

### Output Transformation
```python
# Convert from wide to long format
# Input: 1 row per image with 5 predictions
# Output: 5 rows per image (competition format)

for image_idx, image_row in test_df_unique.iterrows():
    image_id = extract_image_id(image_row['image_path'])
    
    for target_idx, target_name in enumerate(CFG.targets):
        sample_id = f"{image_id}__{target_name}"
        prediction = final_predictions[image_idx, target_idx]
        
        submission_rows.append({
            'sample_id': sample_id,
            'target': max(0.0, prediction)  # Ensure non-negative
        })
```

### Submission Validation
```python
# Critical checks before submission
expected_rows = n_images * 5
actual_rows = len(submission_df)

assert actual_rows == expected_rows, "Row count mismatch"
assert submission_df['target'].isna().sum() == 0, "NaN values detected"
assert (submission_df['target'] >= 0).all(), "Negative values found"
```

## üöÄ Performance Optimizations

### 1. **Efficient Image Loading**
```python
# OpenCV for fast image reading
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Model expects RGB
```

### 2. **Vectorized Operations**
```python
# Batch prediction instead of single image
# Reduces GPU overhead and improves throughput
outputs = model(batch_images, batch_tabular)  # [32, 5] predictions
```

### 3. **Memory-Efficient TTA**
```python
# Reuse tabular features across TTA variations
# Avoids redundant scaling operations
dataset_tta.tabular_features = original_dataset.tabular_features
```

## üîç Error Handling & Robustness

### Graceful Fallbacks
```python
# Metadata availability check
if not has_metadata:
    print("‚ö† Using zero values for missing metadata")
    # System continues with reasonable defaults

# Checkpoint availability
available_folds = [f for f in range(n_folds) if checkpoint_exists(f)]
if not available_folds:
    raise FileNotFoundError("No models available for inference")
```

### Prediction Sanitization
```python
# Ensure physically plausible predictions
prediction = max(0.0, raw_prediction)  # Biomass cannot be negative

# Handle extreme outliers (optional)
if prediction > reasonable_threshold:
    prediction = reasonable_threshold
```

## üìä Output Quality Assurance

### Prediction Statistics
```python
# Monitor prediction distribution
for target_idx, target_name in enumerate(CFG.targets):
    preds = final_predictions[:, target_idx]
    print(f"{target_name}: min={preds.min():.2f}, max={preds.max():.2f}, "
          f"mean={preds.mean():.2f}, std={preds.std():.2f}")
```

### Validation Against Training
- ‚úÖ **Prediction ranges** similar to training data
- ‚úÖ **No extreme outliers** in biomass estimates
- ‚úÖ **Consistent relationships** between target variables
- ‚úÖ **Physically plausible** values (non-negative, reasonable magnitudes)

## üéØ Key Advantages of This Inference Pipeline

### 1. **Robustness**
- Handles missing metadata gracefully
- TTA reduces orientation/lighting sensitivity
- Ensemble averaging stabilizes predictions

### 2. **Efficiency**
- Batch processing for speed
- Memory management for large datasets
- Parallel data loading

### 3. **Reproducibility**
- Exact model architecture matching
- Consistent preprocessing pipelines
- Deterministic operations (where possible)

### 4. **Competition-Ready**
- Correct submission format generation
- Comprehensive validation checks
- Error handling for production deployment

This inference pipeline represents a production-grade system that transforms trained models into reliable predictions, incorporating best practices for robustness, efficiency, and competition success.

In [None]:
# ============================================================================
# CSIRO Image2Biomass Prediction - INFERENCE ONLY
# ============================================================================

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
import cv2
from tqdm import tqdm
import warnings
import os
warnings.filterwarnings('ignore')

# ============================================================================
# CONFIGURATION
# ============================================================================
class CFG:
    # Paths
    test_csv = '/kaggle/input/csiro-biomass/test.csv'
    test_dir = '/kaggle/input/csiro-biomass/test'
    model_dir = '/kaggle/input/csiro-models'  # Directory containing model checkpoints
    output_file = 'submission.csv'
    
    # Model settings (MUST match training configuration)
    model_name = 'tf_efficientnetv2_m'
    img_size = 512
    n_folds = 5  # Number of folds to ensemble
    
    # Inference settings
    batch_size = 32  # Can be larger than training since no gradients
    num_workers = 4
    use_tta = True  # Test-Time Augmentation
    tta_steps = 5   # Number of TTA augmentations
    
    # Target names (order matters!)
    targets = ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g', 'GDM_g', 'Dry_Total_g']
    
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
print(f"Using device: {CFG.device}")
print(f"Will ensemble {CFG.n_folds} models")
print(f"TTA enabled: {CFG.use_tta} ({CFG.tta_steps} augmentations)" if CFG.use_tta else "TTA disabled")

# ============================================================================
# DATASET CLASS
# ============================================================================
class BiomassTestDataset(Dataset):
    """
    Dataset for test/inference data.
    Loads images and applies tabular feature scaling.
    Note: Test data may not have metadata features (NDVI, Height).
    """
    def __init__(self, df, img_dir, transform=None, tabular_scaler=None, has_metadata=True):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.has_metadata = has_metadata
        
        # Prepare tabular features
        if has_metadata:
            # Use actual metadata from CSV
            tabular_data = df[['Pre_GSHH_NDVI', 'Height_Ave_cm']].fillna(0).values
            
            if tabular_scaler is not None:
                self.tabular_features = tabular_scaler.transform(tabular_data)
            else:
                self.tabular_features = tabular_data
        else:
            # No metadata available - use zeros (scaled mean)
            # This is reasonable since StandardScaler centers data around 0
            print("  ‚ö† No metadata in test set - using zero values (scaled mean)")
            self.tabular_features = np.zeros((len(df), 2), dtype=np.float32)
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Load image
        img_path = f"{self.img_dir}/{row['image_path'].split('/')[-1]}"
        image = cv2.imread(img_path)
        
        if image is None:
            raise ValueError(f"Failed to load image: {img_path}")
        
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Apply transforms
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        
        # Get tabular features
        tabular = torch.tensor(self.tabular_features[idx], dtype=torch.float32)
        
        return image, tabular

# ============================================================================
# AUGMENTATION TRANSFORMS
# ============================================================================
def get_inference_transforms():
    """Standard transforms for inference (no augmentation)"""
    return A.Compose([
        A.Resize(CFG.img_size, CFG.img_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

def get_tta_transforms():
    """
    Test-Time Augmentation transforms.
    Returns list of different augmentation pipelines.
    """
    return [
        # 1. Original (no augmentation)
        A.Compose([
            A.Resize(CFG.img_size, CFG.img_size),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ]),
        # 2. Horizontal flip
        A.Compose([
            A.Resize(CFG.img_size, CFG.img_size),
            A.HorizontalFlip(p=1.0),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ]),
        # 3. Vertical flip
        A.Compose([
            A.Resize(CFG.img_size, CFG.img_size),
            A.VerticalFlip(p=1.0),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ]),
        # 4. Rotate 90 degrees
        A.Compose([
            A.Resize(CFG.img_size, CFG.img_size),
            A.Rotate(limit=(90, 90), p=1.0),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ]),
        # 5. Slight brightness adjustment
        A.Compose([
            A.Resize(CFG.img_size, CFG.img_size),
            A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=1.0),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ]),
    ]

# ============================================================================
# MODEL ARCHITECTURE
# ============================================================================
class BiomassModel(nn.Module):
    """
    Multi-modal model combining image and tabular features.
    MUST match the architecture used during training!
    """
    def __init__(self, model_name, pretrained=False):
        super(BiomassModel, self).__init__()
        
        # Image encoder (EfficientNet backbone)
        self.backbone = timm.create_model(
            model_name, 
            pretrained=pretrained,
            num_classes=0,
            global_pool='avg'
        )
        
        # Get feature dimension
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, CFG.img_size, CFG.img_size)
            img_features = self.backbone(dummy_input).shape[1]
        
        # Tabular feature encoder
        self.tabular_encoder = nn.Sequential(
            nn.Linear(2, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
        )
        
        # Fusion layer
        fusion_dim = img_features + 128
        self.fusion = nn.Sequential(
            nn.Linear(fusion_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
        )
        
        # Output heads (5 targets)
        self.head_green = nn.Linear(256, 1)
        self.head_dead = nn.Linear(256, 1)
        self.head_clover = nn.Linear(256, 1)
        self.head_gdm = nn.Linear(256, 1)
        self.head_total = nn.Linear(256, 1)
    
    def forward(self, image, tabular):
        # Extract features
        img_features = self.backbone(image)
        tab_features = self.tabular_encoder(tabular)
        
        # Fuse
        combined = torch.cat([img_features, tab_features], dim=1)
        fused = self.fusion(combined)
        
        # Predict
        out_green = self.head_green(fused)
        out_dead = self.head_dead(fused)
        out_clover = self.head_clover(fused)
        out_gdm = self.head_gdm(fused)
        out_total = self.head_total(fused)
        
        outputs = torch.cat([out_green, out_dead, out_clover, out_gdm, out_total], dim=1)
        return outputs

# ============================================================================
# INFERENCE FUNCTIONS
# ============================================================================
def predict_single_tta(model, dataset, device, tta_transform, has_metadata):
    """
    Make predictions with a single TTA transform.
    """
    # Create dataset with specific TTA transform
    dataset_tta = BiomassTestDataset(
        dataset.df, 
        dataset.img_dir,
        transform=tta_transform,
        tabular_scaler=None,  # Already scaled in original dataset
        has_metadata=has_metadata
    )
    dataset_tta.tabular_features = dataset.tabular_features  # Use same tabular features
    
    loader = DataLoader(
        dataset_tta, 
        batch_size=CFG.batch_size,
        shuffle=False, 
        num_workers=CFG.num_workers,
        pin_memory=True
    )
    
    model.eval()
    predictions = []
    
    with torch.no_grad():
        for images, tabular in loader:
            images = images.to(device)
            tabular = tabular.to(device)
            
            outputs = model(images, tabular)
            predictions.append(outputs.cpu().numpy())
    
    return np.vstack(predictions)

def predict_with_tta(model, dataset, device, has_metadata):
    """
    Make predictions with Test-Time Augmentation.
    Averages predictions across multiple augmentations.
    """
    if not CFG.use_tta:
        # No TTA - single prediction
        tta_transforms = [get_inference_transforms()]
    else:
        # Multiple TTA transforms
        tta_transforms = get_tta_transforms()[:CFG.tta_steps]
    
    print(f"  Making predictions with {len(tta_transforms)} TTA variations...")
    
    all_tta_preds = []
    for tta_idx, tta_transform in enumerate(tta_transforms):
        preds = predict_single_tta(model, dataset, device, tta_transform, has_metadata)
        all_tta_preds.append(preds)
        print(f"    TTA {tta_idx + 1}/{len(tta_transforms)} complete")
    
    # Average across TTA augmentations
    avg_preds = np.mean(all_tta_preds, axis=0)
    return avg_preds

def load_model_checkpoint(fold):
    """Load model checkpoint with all necessary components"""
    checkpoint_path = os.path.join(CFG.model_dir, f'best_model_fold{fold}.pth')
    
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Model checkpoint not found: {checkpoint_path}")
    
    # Load with weights_only=False to allow loading sklearn scalers
    # This is safe since we trust our own checkpoints
    checkpoint = torch.load(checkpoint_path, map_location=CFG.device, weights_only=False)
    
    # Handle different checkpoint formats
    if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
        # New format with scalers
        model_state = checkpoint['model_state_dict']
        tabular_scaler = checkpoint.get('tabular_scaler', None)
        target_scaler = checkpoint.get('target_scaler', None)
    else:
        # Old format (just model weights)
        model_state = checkpoint
        tabular_scaler = None
        target_scaler = None
        print(f"  Warning: Fold {fold} checkpoint doesn't contain scalers")
    
    return model_state, tabular_scaler, target_scaler

# ============================================================================
# MAIN INFERENCE PIPELINE
# ============================================================================
def main():
    print("="*70)
    print("CSIRO BIOMASS PREDICTION - INFERENCE")
    print("="*70)
    
    # 1. Load test data
    print("\n[1/5] Loading test data...")
    test_df = pd.read_csv(CFG.test_csv)
    print(f"‚úì Loaded {len(test_df)} test samples")
    print(f"‚úì Columns in test data: {test_df.columns.tolist()}")
    
    # Get unique images (test.csv has 5 rows per image, one per target)
    test_df_unique = test_df.drop_duplicates(subset=['image_path']).reset_index(drop=True)
    print(f"‚úì Found {len(test_df_unique)} unique images")
    
    # Check if metadata features are available
    has_metadata = 'Pre_GSHH_NDVI' in test_df_unique.columns and 'Height_Ave_cm' in test_df_unique.columns
    
    if has_metadata:
        print(f"‚úì Test data has metadata features (NDVI, Height)")
    else:
        print(f"‚ö† Test data does NOT have metadata features")
        print(f"  Will use zero values (scaled mean) for tabular features")
        # Add dummy columns so dataset creation doesn't fail
        test_df_unique['Pre_GSHH_NDVI'] = 0.0
        test_df_unique['Height_Ave_cm'] = 0.0
    
    # 2. Verify model checkpoints exist
    print("\n[2/5] Checking model checkpoints...")
    available_folds = []
    for fold in range(CFG.n_folds):
        checkpoint_path = os.path.join(CFG.model_dir, f'best_model_fold{fold}.pth')
        if os.path.exists(checkpoint_path):
            available_folds.append(fold)
            print(f"‚úì Found checkpoint for fold {fold}")
        else:
            print(f"‚úó Missing checkpoint for fold {fold}")
    
    if len(available_folds) == 0:
        raise FileNotFoundError("No model checkpoints found! Train models first.")
    
    print(f"\n‚úì Will use {len(available_folds)} models for ensemble")
    
    # 3. Generate predictions from each fold
    print("\n[3/5] Generating predictions...")
    all_fold_predictions = []
    
    for fold in available_folds:
        print(f"\n--- Fold {fold + 1}/{CFG.n_folds} ---")
        
        # Load checkpoint
        model_state, tabular_scaler, target_scaler = load_model_checkpoint(fold)
        
        # Create model and load weights
        model = BiomassModel(CFG.model_name, pretrained=False).to(CFG.device)
        model.load_state_dict(model_state)
        model.eval()
        print(f"‚úì Model loaded successfully")
        
        # Create dataset
        test_dataset = BiomassTestDataset(
            test_df_unique,
            CFG.test_dir,
            transform=get_inference_transforms(),
            tabular_scaler=tabular_scaler,
            has_metadata=has_metadata
        )
        
        # Make predictions with TTA
        fold_predictions = predict_with_tta(model, test_dataset, CFG.device, has_metadata)
        
        # Inverse transform if target scaler exists
        if target_scaler is not None:
            fold_predictions = target_scaler.inverse_transform(fold_predictions)
            print(f"‚úì Predictions scaled back to original range")
        
        all_fold_predictions.append(fold_predictions)
        print(f"‚úì Fold {fold} predictions: shape {fold_predictions.shape}")
        
        # Print sample predictions
        print(f"  Sample prediction: {fold_predictions[0]}")
        
        # Clean up
        del model
        torch.cuda.empty_cache()
    
    # 4. Ensemble predictions
    print("\n[4/5] Ensembling predictions from all folds...")
    final_predictions = np.mean(all_fold_predictions, axis=0)
    print(f"‚úì Final predictions shape: {final_predictions.shape}")
    print(f"‚úì Prediction statistics:")
    for idx, target_name in enumerate(CFG.targets):
        preds = final_predictions[:, idx]
        print(f"  {target_name}: min={preds.min():.2f}, max={preds.max():.2f}, "
              f"mean={preds.mean():.2f}, std={preds.std():.2f}")
    
    # 5. Create submission file
    print("\n[5/5] Creating submission file...")
    submission_rows = []
    
    for idx, row in test_df_unique.iterrows():
        # Extract image ID from path
        image_id = row['image_path'].split('/')[-1].replace('.jpg', '')
        
        # Create one row per target
        for target_idx, target_name in enumerate(CFG.targets):
            sample_id = f"{image_id}__{target_name}"
            
            # Get prediction and ensure non-negative
            prediction = max(0.0, final_predictions[idx, target_idx])
            
            submission_rows.append({
                'sample_id': sample_id,
                'target': prediction
            })
    
    # Create DataFrame and save
    submission_df = pd.DataFrame(submission_rows)
    submission_df.to_csv(CFG.output_file, index=False)
    
    print(f"\n{'='*70}")
    print(f"‚úì INFERENCE COMPLETE!")
    print(f"{'='*70}")
    print(f"‚úì Submission file saved: {CFG.output_file}")
    print(f"‚úì Total predictions: {len(submission_df)}")
    print(f"‚úì Expected format: sample_id, target")
    print(f"\nFirst few rows:")
    print(submission_df.head(10))
    print(f"\nLast few rows:")
    print(submission_df.tail(5))
    
    # Validation checks
    print(f"\n--- Submission Validation ---")
    expected_rows = len(test_df_unique) * 5  # 5 targets per image
    if len(submission_df) == expected_rows:
        print(f"‚úì Row count correct: {len(submission_df)} rows")
    else:
        print(f"‚ö† Warning: Expected {expected_rows} rows, got {len(submission_df)}")
    
    # Check for any NaN or negative values
    if submission_df['target'].isna().sum() > 0:
        print(f"‚ö† Warning: {submission_df['target'].isna().sum()} NaN values found")
    else:
        print(f"‚úì No NaN values")
    
    if (submission_df['target'] < 0).sum() > 0:
        print(f"‚ö† Warning: {(submission_df['target'] < 0).sum()} negative values found")
    else:
        print(f"‚úì No negative values")
    
    print(f"\n‚úì Ready for submission!")

if __name__ == '__main__':
    main()
