In [None]:
import os
import sys
import math
import warnings
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Literal
import numpy as np
import pandas as pd
from PIL import Image
from sklearn.model_selection import GroupKFold
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from transformers import AutoModel, AutoConfig
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision import transforms
import h5py
warnings.filterwarnings('ignore')

In [None]:
CONFIG = {
    'dino_path': '/kaggle/input/dinov2/pytorch/large/1',
    'train_csv': '/kaggle/input/csiro-biomass/train.csv',
    'train_img_dir': '/kaggle/input/csiro-biomass/',
    'test_csv': '/kaggle/input/csiro-biomass/test.csv',
    'test_img_dir': '/kaggle/input/csiro-biomass/test/',
    'features_cache': 'biomass_features.h5',
    'backbone_dim': 1024,
    'img_size': 518,
    'patch_stride': 400,
    'num_registers': 4,
    'batch_size': 4,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'num_epochs': 50,
    'patience': 10,
    'grad_clip': 1.0,
    'extract_batch_size': 16,
    'num_workers': 4,
    'target_names': ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g', 'GDM_g', 'Dry_Total_g'],
    'target_weights': {
        'Dry_Green_g': 0.1,
        'Dry_Dead_g': 0.1,
        'Dry_Clover_g': 0.1,
        'GDM_g': 0.2,
        'Dry_Total_g': 0.5
    },
    'states': ['NSW', 'Tas', 'Vic', 'WA'],
    'use_cached_features': False, 
    'apply_frofa': True,  
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}
DEVICE = torch.device(CONFIG['device'])
print(f"Using device: {DEVICE}")

In [None]:
class FroFA:
    def __init__(self, features: torch.Tensor, training: bool = True):
        self.features = features
        self.training = training
    def apply_feature_augmentation(self):
        if not self.training:
            return self.features
        if torch.rand(1).item() < 0.5:
            noise = torch.randn_like(self.features) * 0.01
            self.features += noise
        if torch.rand(1).item() < 0.5:
            scale = torch.rand(1).item() * 0.2 + 0.9
            self.features *= scale
        if torch.rand(1).item() < 0.3:
            mask = torch.rand_like(self.features) > 0.1
            self.features *= mask.float()
        return self.features

In [None]:
class ZeroInflatedLogNormalLoss(nn.Module):
    def __init__(self, eps: float = 1e-7):
        super().__init__()
        self.eps = eps
    def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        logit_prob = preds[:, 0]
        mu = torch.clamp(preds[:, 1], min=-10, max=10)
        sigma = F.softplus(preds[:, 2]) + self.eps 
        sigma = torch.clamp(sigma, min=self.eps, max=10.0)   
        is_positive = (target > 0).float()
        class_loss = F.binary_cross_entropy_with_logits(logit_prob, is_positive, reduction='none')
        safe_target = torch.clamp(target, min=self.eps)
        log_target = torch.clamp(torch.log(safe_target), min=-10, max=10)
        reg_loss = (0.5 * math.log(2 * math.pi) + torch.log(sigma) + (log_target - mu).pow(2) / (2 * sigma.pow(2)))
        total_loss = class_loss + (is_positive * reg_loss)
        return total_loss.mean()
class TweedieLoss(nn.Module):
    def __init__(self, p: float = 1.5, eps: float = 1e-8):
        super().__init__()
        self.p = p
        self.eps = eps
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        pred = torch.clamp(pred, min=self.eps)
        target = torch.clamp(target, min=self.eps)
        term1 = -target * torch.pow(pred, 1 - self.p) / (1 - self.p)
        term2 = torch.pow(pred, 2 - self.p) / (2 - self.p)
        loss = term1 + term2
        return loss.mean()

In [None]:
class GatedAttentionMIL(nn.Module):
    def __init__(self, input_dim: int = 1024, hidden_dim: int = 256):
        super().__init__()
        self.attention_V = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Tanh()
        )
        self.attention_U = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Sigmoid()
        )
        self.attention_weights = nn.Linear(hidden_dim, 1)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        A_V = self.attention_V(x) 
        A_U = self.attention_U(x) 
        A = self.attention_weights(A_V * A_U)
        A = torch.softmax(A, dim=1)
        M = torch.bmm(A.transpose(1, 2), x)
        return M.squeeze(1), A
class FiLM(nn.Module):
    def __init__(self, meta_dim: int, feat_dim: int):
        super().__init__()
        self.scale = nn.Linear(meta_dim, feat_dim)
        self.shift = nn.Linear(meta_dim, feat_dim)
    def forward(self, features: torch.Tensor, metadata: torch.Tensor) -> torch.Tensor:
        gamma = self.scale(metadata) 
        beta = self.shift(metadata)  
        return features * gamma + beta

In [None]:
class DINOv2Extractor(nn.Module):
    def __init__(self, model_path: str = CONFIG['dino_path']):
        super().__init__()
        print(f"Loading local DINOv2 from: {model_path}")
        self.config = AutoConfig.from_pretrained(model_path)
        self.backbone = AutoModel.from_pretrained(model_path, config=self.config)
        self.backbone.eval()
        for param in self.backbone.parameters():
            param.requires_grad = False
        self.num_registers = 4
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            outputs = self.backbone(pixel_values=x)
            if hasattr(outputs, 'last_hidden_state'):
                last_hidden = outputs.last_hidden_state
                cls_token = last_hidden[:, 0, :]
            elif isinstance(outputs, dict):
                cls_token = outputs.get('x_norm_clstoken', outputs['last_hidden_state'][:, 0, :])
            else:
                cls_token = outputs[:, 0, :]
        return cls_token
            

In [None]:
class AFHN(nn.Module):
    def __init__(self, num_components: int = 5, meta_dim: int = 8, feat_dim: int = 1024, use_ziln: bool = True, online_mode: bool = True):
        super().__init__()
        self.num_components = num_components
        self.use_ziln = use_ziln
        self.online_mode = online_mode
        if online_mode:
            config = AutoConfig.from_pretrained(CONFIG['dino_path'])
            self.backbone = AutoModel.from_pretrained(CONFIG['dino_path'], config=config)
            for param in self.backbone.parameters():
                param.requires_grad = False
        else:
            self.backbone = None
        self.vis_dim = feat_dim
        self.mil = GatedAttentionMIL(input_dim=self.vis_dim, hidden_dim=256)
        self.meta_encoder = nn.Sequential(
            nn.Linear(meta_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64)
        )
        self.film = FiLM(meta_dim=64, feat_dim=self.vis_dim)
        if self.use_ziln:
            self.head_total = nn.Sequential(
                nn.Linear(self.vis_dim, 256),
                nn.GELU(),
                nn.Dropout(0.2),
                nn.Linear(256, 3)
            )
        else:
            self.head_total = nn.Sequential(
                nn.Linear(self.vis_dim, 256),
                nn.GELU(),
                nn.Dropout(0.2),
                nn.Linear(256, 1),
                nn.Softplus() 
            )
        self.head_ratios = nn.Sequential(
            nn.Linear(self.vis_dim, 256),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_components)
        )
        self.head_gate = nn.Sequential(
            nn.Linear(self.vis_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, num_components)
        )
    def extract_features(self, x_patches: torch.Tensor) -> torch.Tensor:
        if not self.online_mode:
            raise RuntimeError("extract_features() only available in online mode")
        batch_size, n_patches, c, h, w = x_patches.shape
        x_flat = x_patches.view(batch_size * n_patches, c, h, w)
        with torch.no_grad():
            outputs = self.backbone(pixel_values=x_flat)
            if hasattr(outputs, 'last_hidden_state'):
                feat_flat = outputs.last_hidden_state[:, 0, :]
            else:
                feat_flat = outputs[:, 0, :]
        feat_seq = feat_flat.view(batch_size, n_patches, -1)
        return feat_seq
    def forward(
        self,
        x: torch.Tensor,
        metadata: torch.Tensor,
        apply_frofa: bool = False
    ) -> Dict[str, torch.Tensor]:
        if self.online_mode:
            if x.dim() == 5: 
                features = self.extract_features(x)
            else:  
                features = x
        else:
            features = x 
        if apply_frofa and CONFIG['apply_frofa']:
            frofa = FroFA(features, apply_frofa)
            features = frofa.apply_feature_augmentation()
        global_feat, _ = self.mil(features)
        meta_emb = self.meta_encoder(metadata)
        modulated_feat = self.film(global_feat, meta_emb)
        total = self.head_total(modulated_feat)
        raw_ratios = self.head_ratios(modulated_feat)
        ratios = F.softmax(raw_ratios, dim=1)
        gate_logits = self.head_gate(modulated_feat)
        return {
            'total': total,
            'ratios': ratios,
            'gates': gate_logits
        }
    def predict_components(
        self,
        x: torch.Tensor,
        metadata: torch.Tensor
    ) -> torch.Tensor:
        outputs = self.forward(x, metadata, apply_frofa=False)
        if self.use_ziln:
            total_ziln = outputs['total']
            prob_nonzero = torch.sigmoid(total_ziln[:, 0])
            mu = total_ziln[:, 1]
            sigma = F.softplus(total_ziln[:, 2]) + 1e-6
            expected_total = prob_nonzero * torch.exp(mu + 0.5 * sigma.pow(2))
        else:
            expected_total = outputs['total'].squeeze(-1)
        gates = torch.sigmoid(outputs['gates'])
        gated_ratios = outputs['ratios'] * gates
        sum_gated = gated_ratios.sum(dim=1, keepdim=True) + 1e-6
        final_ratios = gated_ratios / sum_gated
        components = expected_total.unsqueeze(1) * final_ratios
        return components

In [None]:
class PatchBiomassDataset(Dataset):    
    def __init__(
        self,
        df: pd.DataFrame,
        image_dir: str,
        transform: Optional[A.Compose] = None,
        patch_size: int = 518,
        stride: int = 400,
        is_test: bool = False
    ):
        if is_test:
            self.df = df.groupby('image_path').first().reset_index()
        else:
            self.df = df.groupby('sample_id').first().reset_index()
        self.image_dir = Path(image_dir)
        self.transform = transform
        self.patch_size = patch_size
        self.stride = stride
        self.is_test = is_test
        self.state_to_idx = {s: i for i, s in enumerate(CONFIG['states'])}
    def extract_patches(self, image: np.ndarray) -> torch.Tensor:
        h, w = image.shape[:2]
        patches = []
        for y in range(0, max(1, h - self.patch_size + 1), self.stride):
            for x in range(0, max(1, w - self.patch_size + 1), self.stride):
                y_end = min(y + self.patch_size, h)
                x_end = min(x + self.patch_size, w)
                y_start = max(0, y_end - self.patch_size)
                x_start = max(0, x_end - self.patch_size)
                patch = image[y_start:y_end, x_start:x_end]
                if patch.shape[0] < self.patch_size or patch.shape[1] < self.patch_size:
                    pad_h = self.patch_size - patch.shape[0]
                    pad_w = self.patch_size - patch.shape[1]
                    patch = np.pad(patch, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
                if self.transform:
                    augmented = self.transform(image=patch)
                    patch_tensor = augmented['image']
                else:
                    patch_tensor = torch.from_numpy(patch).permute(2, 0, 1).float() / 255.0
                patches.append(patch_tensor)
        return torch.stack(patches)
    def __len__(self) -> int:
        return len(self.df)
    def __getitem__(self, idx: int) -> Dict:
        row = self.df.iloc[idx]
        img_path = self.image_dir / row['image_path']
        image = np.array(Image.open(img_path).convert('RGB'))
        patches = self.extract_patches(image)
        state_onehot = np.zeros(len(CONFIG['states']))
        if row['State'] in self.state_to_idx:
            state_onehot[self.state_to_idx[row['State']]] = 1.0
        date = pd.to_datetime(row['Sampling_Date'])
        month = date.month
        month_sin = np.sin(2 * np.pi * month / 12)
        month_cos = np.cos(2 * np.pi * month / 12)
        metadata = torch.tensor(
            [row['Height_Ave_cm'], row['Pre_GSHH_NDVI']] + 
            state_onehot.tolist() + 
            [month_sin, month_cos],
            dtype=torch.float32
        )
        if not self.is_test:
            targets = torch.tensor([
                row['Dry_Green_g'],
                row['Dry_Dead_g'],
                row['Dry_Clover_g'],
                row['GDM_g'],
                row['Dry_Total_g']
            ], dtype=torch.float32)
            return {
                'patches': patches,
                'metadata': metadata,
                'targets': targets,
                'sample_id': row['sample_id']
            }
        else:
            return {
                'patches': patches,
                'metadata': metadata,
                'sample_id': row.name
            }
        
class CachedFeaturesDataset(Dataset):
    def __init__(
        self,
        hdf5_path: str,
        df: pd.DataFrame,
        is_test: bool = False
    ):
        self.hdf5_path = hdf5_path
        self.is_test = is_test
        if 'sample_id' in df.columns:
            df['sample_id'] = df['sample_id'].str.split('__').str[0]
        if is_test:
            self.df = df.groupby('image_path').first().reset_index()
        else:
            self.df = df.groupby('sample_id').first().reset_index()
        self.h5_file = h5py.File(hdf5_path, 'r', swmr=True)
        self.state_to_idx = {s: i for i, s in enumerate(CONFIG['states'])}
    def __len__(self) -> int:
        return len(self.df)
    def __getitem__(self, idx: int) -> Dict:
        row = self.df.iloc[idx]
        sample_id = row.get('sample_id', str(idx))
        try:
            grp = self.h5_file[str(sample_id)]
            features = torch.from_numpy(grp['features'][:]).float()
            height = float(grp.attrs.get('Height_Ave_cm', row.get('Height_Ave_cm', 0)))
            ndvi = float(grp.attrs.get('Pre_GSHH_NDVI', row.get('Pre_GSHH_NDVI', 0)))
            state = str(grp.attrs.get('State', row.get('State', 'Unknown')))
            date_str = str(grp.attrs.get('Sampling_Date', row.get('Sampling_Date', '2020-01-01')))
        except KeyError:
            features = torch.zeros(1, CONFIG['backbone_dim'])
            height = float(row.get('Height_Ave_cm', 0))
            ndvi = float(row.get('Pre_GSHH_NDVI', 0))
            state = str(row.get('State', 'Unknown'))
            date_str = str(row.get('Sampling_Date', '2020-01-01'))
        state_onehot = np.zeros(len(CONFIG['states']))
        if state in self.state_to_idx:
            state_onehot[self.state_to_idx[state]] = 1.0
        try:
            date = pd.to_datetime(date_str)
            month = date.month
        except:
            month = 1
        month_sin = np.sin(2 * np.pi * month / 12)
        month_cos = np.cos(2 * np.pi * month / 12)
        metadata = torch.tensor(
            [height, ndvi] + state_onehot.tolist() + [month_sin, month_cos],
            dtype=torch.float32
        )
        if not self.is_test:
            targets = torch.tensor([
                row['Dry_Green_g'],
                row['Dry_Dead_g'],
                row['Dry_Clover_g'],
                row['GDM_g'],
                row['Dry_Total_g']
            ], dtype=torch.float32)
            
            return {
                'features': features,
                'metadata': metadata,
                'targets': targets,
                'sample_id': sample_id
            }
        else:
            return {
                'features': features,
                'metadata': metadata,
                'sample_id': sample_id
            }
    def __del__(self):
        if hasattr(self, 'h5_file'):
            self.h5_file.close()

In [None]:
def get_transforms(is_train: bool = True) -> A.Compose:
    if is_train:
        return A.Compose([
            A.RandomResizedCrop(
                size=(CONFIG['img_size'], CONFIG['img_size']),
                scale=(0.7, 1.0),
                ratio=(0.9, 1.1),
                p=1.0
            ),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.Transpose(p=0.5),
            A.OneOf([
                A.MotionBlur(p=0.2),
                A.GaussianBlur(p=0.2),
            ], p=0.2),
            A.OneOf([
                A.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05, p=0.3),
                A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
                A.RandomGamma(gamma_limit=(80, 120), p=0.3),
                A.HueSaturationValue(p=0.3),
            ], p=0.3),
            A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
            A.CoarseDropout(max_holes=8, max_height=32, max_width=32, fill_value=0, p=0.3),
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),
            ToTensorV2()
        ])
    else:
        return A.Compose([
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),
            ToTensorV2()
        ])

In [None]:
def train_epoch(
    model: AFHN,
    dataloader: DataLoader,
    optimizer: optim.Optimizer,
    criterion_ziln: ZeroInflatedLogNormalLoss,
    device: torch.device,
    use_cached: bool = False
) -> float:
    model.train()
    total_loss = 0.0
    
    pbar = tqdm(dataloader, desc='Training')
    for batch in pbar:
        if use_cached:
            x = batch['features'].to(device)
        else:
            x = batch['patches'].to(device)
        metadata = batch['metadata'].to(device)
        targets = batch['targets'].to(device)
        optimizer.zero_grad()
        outputs = model(x, metadata, apply_frofa=True)
        pred_components = model.predict_components(x, metadata)
        loss_components = F.mse_loss(pred_components, targets)
        target_total = targets[:, -1]
        loss_total = criterion_ziln(outputs['total'], target_total)
        target_gates = (targets > 0).float()
        loss_gate = F.binary_cross_entropy_with_logits(
            outputs['gates'], target_gates
        )
        gt_total = target_total.unsqueeze(1) + 1e-6
        gt_ratios = targets / gt_total
        pred_gates = torch.sigmoid(outputs['gates'])
        gated_ratios = outputs['ratios'] * pred_gates
        sum_gated = gated_ratios.sum(dim=1, keepdim=True) + 1e-6
        pred_effective_ratios = gated_ratios / sum_gated
        loss_ratios = F.mse_loss(pred_effective_ratios, gt_ratios)
        loss = loss_total + loss_ratios + 0.5 * loss_gate + 0.1 * loss_components
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['grad_clip'])
        optimizer.step()
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    return total_loss / len(dataloader)

def validate(
    model: AFHN,
    dataloader: DataLoader,
    device: torch.device,
    use_cached: bool = False
) -> Tuple[float, Dict[str, float]]:
    model.eval()
    all_targets = []
    all_predictions = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Validating'):
            if use_cached:
                x = batch['features'].to(device)
            else:
                x = batch['patches'].to(device)
            metadata = batch['metadata'].to(device)
            targets = batch['targets']
            predictions = model.predict_components(x, metadata)
            all_targets.append(targets)
            all_predictions.append(predictions.cpu())
    all_targets = torch.cat(all_targets)
    all_predictions = torch.cat(all_predictions)
    targets_np = all_targets.numpy()
    preds_np = all_predictions.numpy()
    n_samples = len(targets_np)
    weights = np.array([CONFIG['target_weights'][name] for name in CONFIG['target_names']])
    weights_tiled = np.tile(weights, n_samples)
    targets_flat = targets_np.flatten()
    preds_flat = preds_np.flatten()
    y_mean = np.average(targets_flat, weights=weights_tiled)
    ss_res = np.sum(weights_tiled * (targets_flat - preds_flat) ** 2)
    ss_tot = np.sum(weights_tiled * (targets_flat - y_mean) ** 2)
    r2 = 1 - (ss_res / (ss_tot + 1e-8))
    metrics = {'Weighted_R2': r2}
    for i, name in enumerate(CONFIG['target_names']):
        mae = np.abs(targets_np[:, i] - preds_np[:, i]).mean()
        metrics[f'{name}_MAE'] = mae
    return r2, metrics

In [None]:
class ImageTilingDataset(Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        image_dir: str,
        patch_size: int = 518,
        stride: int = 400
    ):
        self.df = df
        self.image_dir = Path(image_dir)
        self.patch_size = patch_size
        self.stride = stride
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
    def __len__(self) -> int:
        return len(self.df)
    def extract_patches(self, image: np.ndarray) -> torch.Tensor:
        h, w = image.shape[:2]
        patches = []
        for y in range(0, max(1, h - self.patch_size + 1), self.stride):
            for x in range(0, max(1, w - self.patch_size + 1), self.stride):
                y_end = min(y + self.patch_size, h)
                x_end = min(x + self.patch_size, w)
                y_start = max(0, y_end - self.patch_size)
                x_start = max(0, x_end - self.patch_size)
                patch = image[y_start:y_end, x_start:x_end]
                if patch.shape[0] < self.patch_size or patch.shape[1] < self.patch_size:
                    pad_h = self.patch_size - patch.shape[0]
                    pad_w = self.patch_size - patch.shape[1]
                    patch = np.pad(patch, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
                patch_pil = Image.fromarray(patch)
                patch_tensor = self.transform(patch_pil)
                patches.append(patch_tensor)
        return torch.stack(patches)
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str, Dict]:
        row = self.df.iloc[idx]
        img_path = self.image_dir / row['image_path']
        try:
            image = np.array(Image.open(img_path).convert('RGB'))
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            return torch.zeros(1, 3, self.patch_size, self.patch_size), str(idx), {}
        patches = self.extract_patches(image)
        metadata = {
            'Height_Ave_cm': float(row.get('Height_Ave_cm', 0)),
            'Pre_GSHH_NDVI': float(row.get('Pre_GSHH_NDVI', 0)),
            'State': str(row.get('State', 'Unknown')),
            'Sampling_Date': str(row.get('Sampling_Date', '2020-01-01'))
        }
        sample_id = row.get('sample_id', str(idx))
        return patches, sample_id, metadata
def extract_and_cache_features(
    csv_path: str,
    image_dir: str,
    output_hdf5: str,
    device: torch.device = DEVICE
):
    df = pd.read_csv(csv_path)
    if 'sample_id' in df.columns:
        df['sample_id'] = df['sample_id'].str.split('__').str[0]
    df_unique = df.groupby('sample_id').first().reset_index()
    print(f"Found {len(df_unique)} unique images")
    dataset = ImageTilingDataset(df_unique, image_dir, patch_size=CONFIG['img_size'], stride=CONFIG['patch_stride'])
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=CONFIG['num_workers'], pin_memory=True)
    model = DINOv2Extractor(CONFIG['dino_path']).to(device)
    print(f"Creating HDF5 cache: {output_hdf5}")
    with h5py.File(output_hdf5, 'w') as f:
        for patches, sample_id, metadata in tqdm(dataloader, desc="Extracting features"):
            sample_id = sample_id[0]
            patches = patches.squeeze(0).to(device)
            if patches.dim() != 4:
                print(f"Skipping {sample_id}: invalid shape")
                continue
            features_list = []
            batch_size = CONFIG['extract_batch_size']
            for i in range(0, len(patches), batch_size):
                batch = patches[i:i+batch_size]
                feat = model(batch)
                features_list.append(feat.cpu())
            features = torch.cat(features_list, dim=0).numpy()
            grp = f.create_group(sample_id)
            grp.create_dataset('features', data=features, compression='gzip', compression_opts=4)
            for key, value in metadata.items():
                if isinstance(value, list):
                    value = value[0]
                grp.attrs[key] = value
    print(f"\nFeature extraction complete!")
    print(f"Saved to: {output_hdf5}")
    print(f"File size: {os.path.getsize(output_hdf5) / 1024**2:.1f} MB")

In [None]:
def main():
    train_df = pd.read_csv(CONFIG['train_csv'])
    train_df['sample_id'] = train_df['sample_id'].str.split('__').str[0]
    train_df_wide = train_df.pivot_table(
        index=['sample_id', 'image_path', 'Sampling_Date', 'State',
               'Species', 'Pre_GSHH_NDVI', 'Height_Ave_cm'],
        columns='target_name',
        values='target'
    ).reset_index()
    train_df_wide = train_df_wide.fillna(0.0)
    groups = train_df_wide['State'].values
    n_groups = len(np.unique(groups))
    n_splits = min(5, n_groups)
    gkf = GroupKFold(n_splits=n_splits)
    fold = 0
    for train_idx, val_idx in gkf.split(train_df_wide, groups=groups):
        print(f"\n{'='*60}")
        print(f"Fold {fold + 1}/{n_splits}")
        print(f"{'='*60}")
        train_fold = train_df_wide.iloc[train_idx]
        val_fold = train_df_wide.iloc[val_idx]
        if CONFIG['use_cached_features']:
            print("Using cached features from HDF5")
            train_dataset = CachedFeaturesDataset(
                CONFIG['features_cache'], train_fold, is_test=False
            )
            val_dataset = CachedFeaturesDataset(
                CONFIG['features_cache'], val_fold, is_test=False
            )
            online_mode = False
        else:
            print("Using online training (loading images)")
            train_dataset = PatchBiomassDataset(
                train_fold, CONFIG['train_img_dir'],
                transform=get_transforms(is_train=True)
            )
            val_dataset = PatchBiomassDataset(
                val_fold, CONFIG['train_img_dir'],
                transform=get_transforms(is_train=False)
            )
            online_mode = True
        train_loader = DataLoader(
            train_dataset, batch_size=CONFIG['batch_size'],
            shuffle=True, num_workers=CONFIG['num_workers'], pin_memory=True
        )
        val_loader = DataLoader(
            val_dataset, batch_size=CONFIG['batch_size'],
            shuffle=False, num_workers=CONFIG['num_workers'], pin_memory=True
        )
        model = AFHN(
            num_components=5,
            meta_dim=8,
            feat_dim=CONFIG['backbone_dim'],
            online_mode=online_mode
        ).to(DEVICE)
        optimizer = optim.AdamW(
            model.parameters(),
            lr=CONFIG['learning_rate'],
            weight_decay=CONFIG['weight_decay']
        )
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=10, T_mult=2
        )
        criterion_ziln = ZeroInflatedLogNormalLoss()
        best_r2 = -float('inf')
        patience_counter = 0
        for epoch in range(CONFIG['num_epochs']):
            train_loss = train_epoch(
                model, train_loader, optimizer, criterion_ziln, 
                DEVICE, use_cached=CONFIG['use_cached_features']
            )
            val_r2, val_metrics = validate(
                model, val_loader, DEVICE,
                use_cached=CONFIG['use_cached_features']
            )
            scheduler.step()
            print(f"Train Loss: {train_loss:.4f}")
            print(f"Val Weighted R²: {val_r2:.4f}")
            for k, v in val_metrics.items():
                if k != 'Weighted_R2':
                    print(f"  {k}: {v:.4f}")
            if val_r2 > best_r2:
                best_r2 = val_r2
                patience_counter = 0
                model_name = f'afhn_{"cached" if CONFIG["use_cached_features"] else "online"}_fold{fold}_best.pth'
                torch.save(model.state_dict(), model_name)
                print(f"Saved best model (R²={best_r2:.4f})")
            else:
                patience_counter += 1
                if patience_counter >= CONFIG['patience']:
                    print(f"Early stopping at epoch {epoch + 1}")
                    break
        print(f"\nFold {fold + 1} Best R²: {best_r2:.4f}")
        fold += 1
        # break

In [None]:
if __name__ == '__main__':
    # MODE 1: Extract features (run once)
    # extract_and_cache_features(
    #     csv_path=CONFIG['train_csv'],
    #     image_dir=CONFIG['train_img_dir'],
    #     output_hdf5=CONFIG['features_cache']
    # )
    
    # MODE 2: Train on cached features (fast)
    CONFIG['use_cached_features'] = True
    main()
    
    # MODE 3: Train online (traditional, slower but self-contained)
    # CONFIG['use_cached_features'] = False
    # main()

In [None]:
def inference():
    """Inference on test set."""
    # Load test data
    test_df = pd.read_csv(CONFIG['test_csv'])

    test_dataset = PatchBiomassDataset(
        test_df, CONFIG['test_img_dir'],
        transform=get_transforms(is_train=False),
        is_test=True
    )

    test_loader = DataLoader(
        test_dataset, batch_size=4,
        shuffle=False, num_workers=4
    )

    # Load model
    model = AFHN(num_components=5, meta_dim=8, use_ziln=True).to(DEVICE)
    model.load_state_dict(torch.load(MODEL_PATH))

    # Predict
    predictions = []
    sample_ids = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc='Inference'):
            patches = batch['patches'].to(DEVICE)
            metadata = batch['metadata'].to(DEVICE)
            
            preds = model.predict_components(patches, metadata)
            predictions.append(preds.cpu().numpy())
            sample_ids.extend(batch['sample_id'])
    
    predictions = np.vstack(predictions)

    # Create submission
    submission_rows = []

    for i, img_id in enumerate(sample_ids):
        for j, target_name in enumerate(TARGET_NAMES):
            submission_rows.append({
                'sample_id': f"{img_id}__{target_name}",
                'target': predictions[i, j]
            })
    
    submission_df = pd.DataFrame(submission_rows)
    submission_df.to_csv('submission.csv', index=False)
    print("Submission saved to submission.csv")

In [None]:
inference()