In [1]:
!ls /kaggle/input/timm-1-0-22/timm-1.0.22-py3-none-any.whl
!pip install -q /kaggle/input/timm-1-0-22/timm-1.0.22-py3-none-any.whl --no-deps

/kaggle/input/timm-1-0-22/timm-1.0.22-py3-none-any.whl


# Modules

In [2]:
from torch.utils.data import Dataset, DataLoader
from warnings import filterwarnings
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn.functional as F
import numpy as np
import albumentations as A
from pathlib import Path

from matplotlib import pyplot as plt

from PIL import Image
import seaborn as sns
from sklearn.metrics import r2_score
import random
import numpy as np
import os

import timm
import torch.nn as nn
from transformers import AutoModel
from dataclasses import dataclass
import gc

filterwarnings("ignore")

  data = fetch_version_info()


# Config

In [3]:
class Config:
    test_dir: Path = Path("/kaggle/input/csiro-biomass/test")
    train_dir: Path = Path("/kaggle/input/csiro-biomass/train")
    train_csv_path: Path = Path("/kaggle/input/csiro-biomass/train.csv")
    seed: int = 42
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    print("set seed to ", seed)

set_seed(Config.seed)

set seed to  42


# Model code

### Mamba Block

In [4]:
class LocalMambaBlock(nn.Module):
    def __init__(self, dim: int, kernal_size: int = 5, dropout: float = 0.1):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.dwconv = nn.Conv1d(
            dim, dim,
            kernel_size=kernal_size,
            padding=kernal_size // 2,
            groups=dim  # depthwise conv
        )
        self.gate = nn.Linear(dim, dim)
        self.proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T, D)
        shortcut = x
        x = self.norm(x)
        g = torch.sigmoid(self.gate(x))  # (B,T,D)
        x = x * g

        x = x.transpose(1, 2)            # (B,D,T)
        x = self.dwconv(x)               # (B,D,T)
        x = x.transpose(1, 2)            # (B,T,D)

        x = self.proj(x)                 # (B,T,D)
        x = self.dropout(x)
        return shortcut + x

### Dinov3MultiReg

In [5]:
@dataclass
class Dinov3Config:
    model_id: str = "facebook/dinov3-vith16plus-pretrain-lvd1689m"  # HF (may be gated)
    timm_id: str = "vit_huge_plus_patch16_dinov3.lvd1689m"           # fallback timm
    patch_size: int = 16

    # tiling setup: 2x4 = 8 tiles
    tile_rows: int = 2
    tile_cols: int = 4
    tile_size: int = 512   # S (must be divisible by patch_size)


class Dinov3MultiReg(nn.Module):
    """
    Input:
      img_full: (B, 3, H, W) where:
        H = tile_rows * tile_size
        W = tile_cols * tile_size
      Example for 2x4 tiles: H=2S, W=4S

    Internally:
      tiles: (B, N, 3, S, S) where N=tile_rows*tile_cols
      encode tiles in one shot: (B*N, T, D)
      reshape + concat tokens: (B, N*T, D)
      fusion -> pool -> heads -> (B,5)
    """
    def __init__(self, cfg: Dinov3Config):
        super().__init__()
        self.cfg = cfg

        self.backend = None
        self.encoder = None

        self.encoder = timm.create_model(cfg.timm_id, pretrained=False)
        self.backend = "timm"

        # Freeze backbone
        for p in self.encoder.parameters():
            p.requires_grad = False
        self.encoder.eval()

        # Figure out D
        D = getattr(self.encoder, "num_features", None)
        if D is None:
            if self.backend == "hf":
                D = self.encoder.config.hidden_size
            else:
                # timm usually has num_features
                D = getattr(self.encoder, "num_features", None)
                if D is None:
                    raise RuntimeError("Could not infer hidden size D from timm model.")
        self.num_features = D

        print(
            f"Encoder trainable params: {sum(p.numel() for p in self.encoder.parameters() if p.requires_grad)}"
            f" | backend: {self.backend} | D: {self.num_features}"
        )

        self.fusion = nn.Sequential(
            LocalMambaBlock(self.num_features, kernal_size=5, dropout=0.1),
            LocalMambaBlock(self.num_features, kernal_size=5, dropout=0.1),
        )
        self.pool = nn.AdaptiveAvgPool1d(1)

        self.head_green = nn.Sequential(
            nn.Linear(self.num_features, self.num_features // 2),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(self.num_features // 2, 1),
            nn.Softplus(),
        )
        self.head_dead = nn.Sequential(
            nn.Linear(self.num_features, self.num_features // 2),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(self.num_features // 2, 1),
            nn.Softplus(),
        )
        self.head_clover = nn.Sequential(
            nn.Linear(self.num_features, self.num_features // 2),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(self.num_features // 2, 1),
            nn.Softplus(),
        )

    @torch.no_grad()
    def _encode_tokens(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B,3,S,S)
        returns tokens: (B,T,D)
        """
        if self.backend == "hf":
            out = self.encoder(pixel_values=x)
            return out.last_hidden_state  # (B,T,D)

        # timm path: prefer forward_features
        if hasattr(self.encoder, "forward_features"):
            feats = self.encoder.forward_features(x)
        else:
            feats = self.encoder(x)

        # Many timm ViTs return (B,T,D). Some return (B,D).
        if feats.ndim == 3:
            return feats
        if feats.ndim == 2:
            # fallback: treat as pooled embedding, make it look like one token
            return feats.unsqueeze(1)  # (B,1,D)
        raise RuntimeError(f"Unexpected timm features shape: {feats.shape}")

    def _tile(self, img_full: torch.Tensor) -> torch.Tensor:
        """
        img_full: (B,3,H,W) where H=R*S, W=C*S
        returns tiles: (B,N,3,S,S)
        """
        B, C, H, W = img_full.shape
        R, Cc, S = self.cfg.tile_rows, self.cfg.tile_cols, self.cfg.tile_size

        expected_h = R * S
        expected_w = Cc * S
        if H != expected_h or W != expected_w:
            raise AssertionError(
                f"Expected img_full H,W=({expected_h},{expected_w}) "
                f"for tile_rows={R}, tile_cols={Cc}, tile_size={S}, "
                f"but got ({H},{W}). Ensure your dataset resize matches."
            )

        # reshape to grid then flatten tiles
        # (B,3,R,S,Cc,S) -> (B,R,Cc,3,S,S) -> (B,N,3,S,S)
        x = img_full.view(B, 3, R, S, Cc, S)
        x = x.permute(0, 2, 4, 1, 3, 5).contiguous()
        tiles = x.view(B, R * Cc, 3, S, S)
        return tiles

    def forward(self, img_full: torch.Tensor) -> torch.Tensor:
        """
        img_full: (B,3,2S,4S) for 2x4 tiles
        """
        # ensure backbone stays deterministic even when model.train()
        self.encoder.eval()

        tiles = self._tile(img_full)  # (B,N,3,S,S)
        B, N, C, S, S2 = tiles.shape  # S2 == S

        # Encode all tiles together
        tiles_flat = tiles.view(B * N, 3, S, S)         # (B*N,3,S,S)
        tok = self._encode_tokens(tiles_flat)           # (B*N,T,D)

        # reshape back + concat tokens across tiles
        T = tok.shape[1]
        tok = tok.view(B, N * T, self.num_features)     # (B, N*T, D)

        # fuse + pool
        tok = self.fusion(tok)                          # (B, N*T, D)
        pooled = self.pool(tok.transpose(1, 2)).flatten(1)  # (B,D)

        green = self.head_green(pooled)   # (B,1)
        dead = self.head_dead(pooled)     # (B,1)
        clover = self.head_clover(pooled) # (B,1)

        gdm = green + clover
        total = gdm + dead
        return torch.cat([green, dead, clover, gdm, total], dim=1)  # (B,5)


### DinoV3Structured

In [6]:
class DinoV3StructuredConfig:
    def __init__(
        self,
        model_id: str = "vit_huge_plus_patch16_dinov3.lvd1689m",
        img_size: int = 512,
        tiles: int = 2,          # 2 = left/right, 8 = your tile idea
        freeze_backbone: bool = True,
        dropout: float = 0.1,
        use_mamba: bool = True,
        mamba_depth: int = 2,
    ):
        self.timm_id = model_id
        self.img_size = img_size
        self.tiles = tiles
        self.freeze_backbone = freeze_backbone
        self.dropout = dropout
        self.use_mamba = use_mamba
        self.mamba_depth = mamba_depth


class DinoV3Structured(nn.Module):
    """
    Input: full image tensor (B,3,H,W) already resized consistently (e.g. 2S x 4S).
    We tile inside the model into N tiles, encode each tile with frozen timm ViT,
    fuse tile embeddings, then output structured predictions.

    Outputs:
      dict with:
        green_pred, total_pred, dead_ratio_pred, clover_ratio_pred
        dead_pred, gdm_pred, clover_pred
        pred5 in order [Green, Dead, Clover, GDM, Total]
    """
    def __init__(self, cfg: DinoV3StructuredConfig):
        super().__init__()
        self.cfg = cfg

        # timm backbone (offline ok if weights are available; else set pretrained=False)
        self.encoder = timm.create_model(
            cfg.timm_id,
            pretrained=False,
            num_classes=0,            # removes classification head
            global_pool="avg",        # gives (B, D)
        )

        self.D = self.encoder.num_features

        if cfg.freeze_backbone:
            for p in self.encoder.parameters():
                p.requires_grad = False
            self.encoder.eval()

        # Fusion (tile tokens -> one vector)
        # You can run Mamba blocks on (B, T, D) then pool.
        if cfg.use_mamba:
            blocks = []
            for _ in range(cfg.mamba_depth):
                blocks.append(LocalMambaBlock(self.D, kernal_size=5, dropout=cfg.dropout))
            self.fusion = nn.Sequential(*blocks)
        else:
            self.fusion = nn.Identity()

        self.tile_pool = nn.AdaptiveAvgPool1d(1)  # (B, D, T) -> (B, D, 1)

        # Heads
        # 1) regression heads (green, total) — force non-negative via softplus
        self.head_reg = nn.Sequential(
            nn.LayerNorm(self.D),
            nn.Dropout(cfg.dropout),
            nn.Linear(self.D, self.D // 2),
            nn.GELU(),
            nn.Linear(self.D // 2, 2),  # [green_raw, total_raw]
        )

        # 2) ratio heads (dead_ratio, clover_ratio) — sigmoid
        self.head_ratio = nn.Sequential(
            nn.LayerNorm(self.D),
            nn.Dropout(cfg.dropout),
            nn.Linear(self.D, self.D // 2),
            nn.GELU(),
            nn.Linear(self.D // 2, 2),  # [dead_ratio_logit, clover_ratio_logit]
        )

    # -------------------------
    # Tiling utilities
    # -------------------------
    def _make_tiles(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B,3,H,W)
        returns tiles: (B*T, 3, tileH, tileW)
        Supported:
          tiles=2  -> split width into left/right
          tiles=8  -> 2 rows x 4 cols grid
        """
        B, C, H, W = x.shape
        if self.cfg.tiles == 2:
            mid = W // 2
            left = x[:, :, :, :mid]
            right = x[:, :, :, mid:]
            tiles = torch.cat([left, right], dim=0)  # (2B,3,H,W/2)
            return tiles

        if self.cfg.tiles == 8:
            # 2x4 grid
            rows, cols = 2, 4
            th, tw = H // rows, W // cols
            tile_list = []
            for r in range(rows):
                for c in range(cols):
                    tile = x[:, :, r*th:(r+1)*th, c*tw:(c+1)*tw]
                    tile_list.append(tile)
            tiles = torch.cat(tile_list, dim=0)  # (B*8, 3, th, tw)
            return tiles

        raise ValueError(f"Unsupported tiles={self.cfg.tiles}. Use 2 or 8.")

    def _encode_tiles(self, tiles: torch.Tensor) -> torch.Tensor:
        """
        tiles: (B*T,3,tH,tW) -> embeddings (B*T,D)
        """
        if self.cfg.freeze_backbone:
            with torch.no_grad():
                z = self.encoder(tiles)  # (B*T, D)
        else:
            z = self.encoder(tiles)
        return z

    # -------------------------
    # Forward
    # -------------------------
    def forward(self, x_full: torch.Tensor) -> dict:
        B = x_full.size(0)
        T = self.cfg.tiles

        tiles = self._make_tiles(x_full)          # (B*T,3,th,tw)
        z = self._encode_tiles(tiles)             # (B*T, D)

        # reshape to (B, T, D)
        z = z.view(T, B, self.D).transpose(0, 1)  # (B, T, D)

        # fusion expects (B, T, D) for mamba blocks (you already used that pattern)
        z = self.fusion(z)                        # (B, T, D)

        # pool over T -> (B, D)
        z = z.transpose(1, 2)                     # (B, D, T)
        z = self.tile_pool(z).squeeze(-1)         # (B, D)

        # heads
        reg_raw = self.head_reg(z)                # (B,2)
        ratio_logit = self.head_ratio(z)          # (B,2)

        green = F.softplus(reg_raw[:, 0])         # >=0
        total = F.softplus(reg_raw[:, 1])         # >=0

        # optional: enforce total >= green softly
        # total = green + F.softplus(reg_raw[:, 1])

        dead_ratio = torch.sigmoid(ratio_logit[:, 0])    # (0,1)
        clover_ratio = torch.sigmoid(ratio_logit[:, 1])  # (0,1)

        dead = dead_ratio * total
        gdm = total - dead
        clover = clover_ratio * gdm

        pred5 = torch.stack([green, dead, clover, gdm, total], dim=1)

        return {
            "green_pred": green,
            "total_pred": total,
            "dead_ratio_pred": dead_ratio,
            "clover_ratio_pred": clover_ratio,
            "dead_pred": dead,
            "gdm_pred": gdm,
            "clover_pred": clover,
            "pred5": pred5,
        }


#### DinoV3Hybrid

In [7]:

class DinoV3HybridConfig:
    def __init__(
        self,
        model_id="vit_huge_plus_patch16_dinov3.lvd1689m",
        pretrained_backbone=False,
        dropout=0.2,
        mamba_depth=2,
        use_grad_checkpointing=False,
        freeze_backbone=True,
        # blending
        init_mix_logits=(-1.0, -1.0),  # (dead_mix, clover_mix) negative => prefer derived early
    ):
        self.model_id = model_id
        self.pretrained_backbone = pretrained_backbone
        self.dropout = dropout
        self.mamba_depth = mamba_depth
        self.use_grad_checkpointing = use_grad_checkpointing
        self.freeze_backbone = freeze_backbone
        self.init_mix_logits = init_mix_logits


class DinoV3Hybrid(nn.Module):
    """
    Token-level fusion (left/right), separate heads + ratio heads.

    Direct heads:
      - green_raw, gdm_raw, total_raw  (Softplus)
      - optionally dead_raw, clover_raw (Softplus)  [kept here]

    Ratio heads:
      - dead_ratio in (0,1)
      - clover_ratio in (0,1)

    Derived:
      - dead_from_core   = relu(total - gdm)
      - clover_from_core = relu(gdm - green)
      - dead_from_ratio  = dead_ratio * total
      - clover_from_ratio= clover_ratio * gdm

    Final:
      dead   = mix_dead * dead_direct + (1-mix_dead) * dead_from_core   (or ratio)
      clover = mix_clover * clover_direct + (1-mix_clover) * clover_from_core (or ratio)

    Outputs pred5: [Green, Dead, Clover, GDM, Total]
    """
    def __init__(self, cfg: DinoV3HybridConfig):
        super().__init__()
        self.cfg = cfg

        # keep patch tokens: (B, N, D)
        self.backbone = timm.create_model(
            cfg.model_id,
            pretrained=cfg.pretrained_backbone,
            num_classes=0,
            global_pool="",  # IMPORTANT: keep tokens
        )

        if hasattr(self.backbone, "set_grad_checkpointing") and cfg.use_grad_checkpointing:
            self.backbone.set_grad_checkpointing(True)

        self.D = self.backbone.num_features

        if cfg.freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False
            self.backbone.eval()

        # Mamba fusion over token sequence
        blocks = []
        for _ in range(cfg.mamba_depth):
            blocks.append(LocalMambaBlock(self.D, kernal_size=5, dropout=cfg.dropout))
        self.fusion = nn.Sequential(*blocks)

        self.pool = nn.AdaptiveAvgPool1d(1)

        # ------- Heads -------
        def reg_head():
            return nn.Sequential(
                nn.LayerNorm(self.D),
                nn.Linear(self.D, self.D // 2),
                nn.GELU(),
                nn.Dropout(cfg.dropout),
                nn.Linear(self.D // 2, 1),
            )

        def ratio_head():
            return nn.Sequential(
                nn.LayerNorm(self.D),
                nn.Linear(self.D, self.D // 2),
                nn.GELU(),
                nn.Dropout(cfg.dropout),
                nn.Linear(self.D // 2, 1),
            )

        # direct regressors (raw -> softplus)
        self.head_green = reg_head()
        self.head_gdm   = reg_head()
        self.head_total = reg_head()
        self.head_dead  = reg_head()    # direct dead (optional but useful)
        self.head_clover= reg_head()    # direct clover (optional)

        # ratio heads
        self.head_dead_ratio   = ratio_head()
        self.head_clover_ratio = ratio_head()

        # learnable mixing (logits)
        # sigmoid(mix_logit) = weight on direct prediction
        dead_mix0, clover_mix0 = cfg.init_mix_logits
        self.dead_mix_logit   = nn.Parameter(torch.tensor(float(dead_mix0)))
        self.clover_mix_logit = nn.Parameter(torch.tensor(float(clover_mix0)))

    def forward(self, left: torch.Tensor, right: torch.Tensor) -> dict:
        # tokens: (B, N, D)
        if self.cfg.freeze_backbone:
            with torch.no_grad():
                x_l = self.backbone(left)
                x_r = self.backbone(right)
        else:
            x_l = self.backbone(left)
            x_r = self.backbone(right)

        # concat token sequences: (B, 2N, D)
        x = torch.cat([x_l, x_r], dim=1)

        # fuse
        x = self.fusion(x)  # (B, 2N, D)

        # pool: (B, D)
        x_pool = self.pool(x.transpose(1, 2)).squeeze(-1)

        # --- direct preds ---
        green = F.softplus(self.head_green(x_pool).squeeze(1))
        gdm_d = F.softplus(self.head_gdm(x_pool).squeeze(1))
        total_d= F.softplus(self.head_total(x_pool).squeeze(1))
        dead_d = F.softplus(self.head_dead(x_pool).squeeze(1))
        clover_d=F.softplus(self.head_clover(x_pool).squeeze(1))

        # --- ratio preds ---
        dead_ratio = torch.sigmoid(self.head_dead_ratio(x_pool).squeeze(1))
        clover_ratio = torch.sigmoid(self.head_clover_ratio(x_pool).squeeze(1))

        # --- derived candidates (two options) ---
        dead_from_core   = F.relu(total_d - gdm_d)     # ensures >=0
        clover_from_core = F.relu(gdm_d - green)       # ensures >=0

        dead_from_ratio  = dead_ratio * total_d
        clover_from_ratio= clover_ratio * gdm_d

        # pick which derived style you want:
        # - core-derived enforces equations and avoids "invisible dead"
        # - ratio-derived keeps your ratio heads meaningful
        # You can also blend core-derived and ratio-derived; keeping it simple:
        dead_derived = 0.5 * dead_from_core + 0.5 * dead_from_ratio
        clover_derived = 0.5 * clover_from_core + 0.5 * clover_from_ratio

        # --- mix direct vs derived ---
        mix_dead = torch.sigmoid(self.dead_mix_logit)       # scalar
        mix_clover = torch.sigmoid(self.clover_mix_logit)   # scalar

        dead = mix_dead * dead_d + (1.0 - mix_dead) * dead_derived
        clover = mix_clover * clover_d + (1.0 - mix_clover) * clover_derived

        # recompose for consistency (optional)
        # If you want strict: force gdm = green + clover, total = gdm + dead
        gdm = green + clover
        total = gdm + dead

        pred5 = torch.stack([green, dead, clover, gdm, total], dim=1)

        return {
            "green_pred": green,
            "gdm_pred": gdm,
            "total_pred": total,
            "dead_pred": dead,
            "clover_pred": clover,
            "dead_ratio_pred": dead_ratio,
            "clover_ratio_pred": clover_ratio,
            "mix_dead": mix_dead.detach(),
            "mix_clover": mix_clover.detach(),
            "pred5": pred5,
        }


# Inference

In [8]:
# -----------------------------------------------------------------------------------
# Your Dataset stays same (no changes)
# -----------------------------------------------------------------------------------
class InferenceDataset(Dataset):
    def __init__(self, df: pd.DataFrame, img_size: int):
        self.df = df.reset_index(drop=True)
        self.S = int(img_size)

        self.H = self.S * 2
        self.W = self.S * 4

        self.aug = A.Compose([
            A.Resize(self.H, self.W),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            A.pytorch.ToTensorV2(),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = str(row["image_path"])

        with open(img_path, "rb") as f:
            img = np.array(Image.open(f).convert("RGB"))

        img_t = self.aug(image=img)["image"]  # (3,H,W)
        image_id = row["image_id"]
        return image_id, img_t


def load_wide_train(train_csv: str) -> pd.DataFrame:
    df = pd.read_csv(train_csv)
    df["image_id"] = df["image_path"].apply(lambda p: os.path.splitext(os.path.basename(p))[0])

    wide = (
        df.pivot_table(
            index=["image_id", "image_path", "Sampling_Date", "State", "Species", "Pre_GSHH_NDVI", "Height_Ave_cm"],
            columns="target_name",
            values="target",
            aggfunc="first",
        )
        .reset_index()
    )

    needed = ["Dry_Green_g", "Dry_Dead_g", "Dry_Clover_g", "GDM_g", "Dry_Total_g"]
    for c in needed:
        if c not in wide.columns:
            raise ValueError(f"Missing target column after pivot: {c}")

    return wide



# -----------------------------------------------------------------------------------
# Corrected inference with proper ensemble + plotting
# -----------------------------------------------------------------------------------

def _nonneg_np(x: np.ndarray) -> np.ndarray:
    return np.maximum(x, 0.0)

def recompute_from_parts_np(green: np.ndarray, dead: np.ndarray, clover: np.ndarray) -> np.ndarray:
    """
    Enforce exact identities:
      GDM = green + clover
      Total = GDM + dead
    Returns pred5 (N,5) in TARGETS order.
    """
    green = _nonneg_np(green)
    dead = _nonneg_np(dead)
    clover = _nonneg_np(clover)
    gdm = green + clover
    total = gdm + dead
    return np.stack([green, dead, clover, gdm, total], axis=1).astype(np.float32)

def apply_postprocess_from_calib(preds_df: pd.DataFrame, calib: dict) -> pd.DataFrame:
    """
    preds_df must have columns: image_id + TARGETS
    calib is loaded from best_calibration.json (or dict with same keys)
    """
    TARGETS = ["Dry_Green_g", "Dry_Dead_g", "Dry_Clover_g", "GDM_g", "Dry_Total_g"]
    df = preds_df.copy()

    a = np.asarray(calib["calib_a"], dtype=np.float32)
    b = np.asarray(calib["calib_b"], dtype=np.float32)
    lo = np.asarray(calib["clip_lo"], dtype=np.float32)
    hi = np.asarray(calib["clip_hi"], dtype=np.float32)

    pred5 = df[TARGETS].to_numpy(np.float32)

    # 1) non-neg
    pred5 = _nonneg_np(pred5)

    # 2) linear calib
    pred5 = pred5 * a.reshape(1, 5) + b.reshape(1, 5)

    # 3) non-neg again
    pred5 = _nonneg_np(pred5)

    # 4) clip
    pred5 = np.clip(pred5, lo.reshape(1, 5), hi.reshape(1, 5))

    # 5) recompute exact sums from parts (green, dead, clover)
    pred5 = recompute_from_parts_np(pred5[:, 0], pred5[:, 1], pred5[:, 2])

    df[TARGETS] = pred5
    return df


# -----------------------------------------------------------------------------------
# Corrected inference with proper ensemble + plotting
# -----------------------------------------------------------------------------------
def inference(
    model_id: str,
    patch_size: int,
    model_name: str,
    image_paths: list[Path],
    weight_paths: list[Path],
    model_weights: list[float] | None,
    submission_mode: bool,
    device: torch.device,
    img_size: int,
    use_tta: bool,
    apply_post_proc: bool,
) -> pd.DataFrame:

    # -----------------------------
    # Validate / default weights
    # -----------------------------
    if model_weights is None:
        model_weights = [1.0 / len(weight_paths)] * len(weight_paths)

    if len(model_weights) != len(weight_paths):
        raise ValueError(f"model_weights len {len(model_weights)} != weight_paths len {len(weight_paths)}")

    w_sum = float(sum(model_weights))
    if w_sum <= 0:
        raise ValueError("Sum of model_weights must be > 0")

    # Normalize weights
    model_weights = [float(w) / w_sum for w in model_weights]

    # -----------------------------
    # Build inference dataframe
    # -----------------------------
    rows = [{"image_id": p.stem, "image_path": str(p)} for p in image_paths]
    df = pd.DataFrame(rows)

    ds = InferenceDataset(df=df, img_size=img_size)
    data_loader = DataLoader(
        dataset=ds,
        batch_size=2,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        drop_last=False
    )

    print(f"Loaded inference data loader: {len(ds)} images")

    labels = ["Dry_Green_g", "Dry_Dead_g", "Dry_Clover_g", "GDM_g", "Dry_Total_g"]

    # -----------------------------
    # Accumulator for weighted ensemble
    # image_id -> np.array(5,)
    # -----------------------------
    accum: dict[str, np.ndarray] = {}

    with torch.no_grad():
        for fold_idx, (weight_path, w) in enumerate(zip(weight_paths, model_weights)):
            print(f"\n[Fold {fold_idx}] Loading weights: {weight_path} (weight={w:.4f})")

            # -----------------------------
            # Instantiate model ONCE
            # -----------------------------
            if model_name == "Dinov3MultiReg":
                model = Dinov3MultiReg(Dinov3Config())
            elif model_name == "DinoV3Structured":
                model = DinoV3Structured(DinoV3StructuredConfig())
            elif model_name == "DinoV3Hybrid":
                model = DinoV3Hybrid(DinoV3HybridConfig())
            else:
                raise ValueError(f"model name not found: {model_name}")
        
            model.to(device)
            model.eval()
            
            ckpt = torch.load(weight_path, map_location=device)
            state = ckpt["model_state"] if isinstance(ckpt, dict) and "model_state" in ckpt else ckpt
            model.load_state_dict(state, strict=True)

            def forward_pred5(img_full_t: torch.Tensor) -> torch.Tensor:
                if model_name == "DinoV3Structured":
                    out = model(img_full_t)
                    return out["pred5"]
                elif model_name == "DinoV3Hybrid":
                    _, _, _, W = img_full_t.shape
                    mid = W // 2
                    left = img_full_t[:, :, :, :mid]
                    right = img_full_t[:, :, :, mid:]
                    out = model(left, right)
                    return out["pred5"]
                else:
                    return model(img_full_t)
            
            pbar = tqdm(data_loader, desc=f"Inference fold {fold_idx}", dynamic_ncols=True)
            for img_ids, img_full_t in pbar:
                img_full_t = img_full_t.to(device, non_blocking=True)

                preds = forward_pred5(img_full_t)

                if use_tta:
                    preds_h = forward_pred5(torch.flip(img_full_t, dims=[3]))
                    preds_v = forward_pred5(torch.flip(img_full_t, dims=[2]))
                    preds = 0.5 * preds + 0.25 * preds_h + 0.25 * preds_v

                preds_np = preds.detach().cpu().numpy()  # (B,5)

                for img_id, pred in zip(img_ids, preds_np):
                    pred = np.maximum(pred.astype(np.float32), 0.0)  # clamp non-negative

                    # weighted accumulate
                    if img_id not in accum:
                        accum[img_id] = w * pred
                    else:
                        accum[img_id] += w * pred

            del model
            gc.collect()
            torch.cuda.empty_cache()

    # Optional warning if something went missing
    if len(accum) != len(ds):
        print(f"⚠️ Warning: accum has {len(accum)} ids but dataset has {len(ds)}. Check duplicate/missing image_id.")

    # -----------------------------
    # Build final ensembled preds_df
    # -----------------------------
    out_rows = []
    for img_id, vec in accum.items():
        out_rows.append({
            "image_id": img_id,
            "Dry_Green_g":  float(vec[0]),
            "Dry_Dead_g":   float(vec[1]),
            "Dry_Clover_g": float(vec[2]),
            "GDM_g":        float(vec[3]),
            "Dry_Total_g":  float(vec[4]),
        })

    preds_df = pd.DataFrame(out_rows)

    # Ensure same order as input list (robust)
    order_map = {p.stem: i for i, p in enumerate(image_paths)}
    preds_df["__order"] = preds_df["image_id"].map(order_map).fillna(10**12).astype(np.int64)
    preds_df = preds_df.sort_values("__order").drop(columns="__order").reset_index(drop=True)

    # -----------------------------
    # Plotting on train/val (when not submission)
    # -----------------------------
    if not submission_mode:
        print("Plotting results (GT vs Ensemble Preds)")
        train_df = load_wide_train(Config.train_csv_path)
        merged_df = train_df.merge(
            preds_df,
            on="image_id",
            how="inner",
            suffixes=("_gt", "_pred")
        )

        if len(merged_df) == 0:
            print("⚠️ merged_df is empty. Check that image_ids match between train_csv and image_paths.")
        else:
            sns.set_theme(style="whitegrid")
            targets = labels

            fig, axes = plt.subplots(nrows=5, ncols=2, figsize=(14, 18), constrained_layout=True)

            for i, t in enumerate(targets):
                gt = merged_df[f"{t}_gt"].astype(float).to_numpy()
                pred = merged_df[f"{t}_pred"].astype(float).to_numpy()  # ✅ FIXED

                # Left: Histogram
                ax_hist = axes[i, 0]
                sns.histplot(gt, bins=40, stat="density", kde=True, alpha=0.45, label="GT", ax=ax_hist)
                sns.histplot(pred, bins=40, stat="density", kde=True, alpha=0.45, label="Pred", ax=ax_hist)
                ax_hist.set_title(f"{t} — Distribution")
                ax_hist.set_xlabel(t)
                ax_hist.set_ylabel("Density")
                ax_hist.legend()

                # Right: Scatter + R²
                ax_scatter = axes[i, 1]
                r2 = r2_score(gt, pred)

                sns.scatterplot(x=gt, y=pred, s=25, alpha=0.6, ax=ax_scatter)

                mn = min(gt.min(), pred.min())
                mx = max(gt.max(), pred.max())
                ax_scatter.plot([mn, mx], [mn, mx], "r--", linewidth=1)

                ax_scatter.set_title(f"{t} — Pred vs GT (R² = {r2:.4f})")
                ax_scatter.set_xlabel("GT")
                ax_scatter.set_ylabel("Pred")

            plt.show()

    # -----------------------------
    # Create submission
    # -----------------------------
    order = ["Dry_Clover_g", "Dry_Dead_g", "Dry_Green_g", "Dry_Total_g", "GDM_g"]


    # Applying postprocessing
    def post_proc_preds_df(preds_df: pd.DataFrame):
        params = dict(
            DinoV3Hybrid = {
              "best_score": 0.7583875060081482,
              "cfg": {
                "use_ratio_recompose": False,
                "alpha_dead_derive": 0.0,
                "use_linear_calib": True,
                "clip_mode": "quantile",
                "q_low": 0.1,
                "q_high": 99.9
              },
              "calib_a": [
                1.161050796508789,
                0.8045432567596436,
                1.1122260093688965,
                1.0957939624786377,
                1.0523895025253296
              ],
              "calib_b": [
                -0.6085014343261719,
                1.8143501281738281,
                -0.6774587631225586,
                0.352935791015625,
                0.31894683837890625
              ],
              "clip_lo": [
                0.0,
                0.0,
                0.0,
                1.1681599617004395,
                1.5526399612426758
              ],
              "clip_hi": [
                147.78758239746094,
                74.90797424316406,
                71.10720825195312,
                147.78758239746094,
                178.72264099121094
              ]
            },
            DinoV3Structured = {
              "best_score": 0.7661679983139038,
              "cfg": {
                "use_ratio_recompose": True,
                "alpha_dead_derive": 0.7,
                "use_linear_calib": False,
                "clip_mode": "quantile",
                "q_low": 0.1,
                "q_high": 99.9
              },
              "calib_a": [
                1.058223843574524,
                0.9063701629638672,
                0.9845092296600342,
                1.0237661600112915,
                1.0476360321044922
              ],
              "calib_b": [
                0.6819171905517578,
                1.508835792541504,
                -0.19695329666137695,
                1.05670166015625,
                0.17140579223632812
              ],
              "clip_lo": [
                0.0,
                0.0,
                0.0,
                1.1681599617004395,
                1.5526399612426758
              ],
              "clip_hi": [
                147.78758239746094,
                74.90797424316406,
                71.10720825195312,
                147.78758239746094,
                178.72264099121094
              ]
            },
        )
        calib = params[model_name]
        final_preds_df = apply_postprocess_from_calib(preds_df, calib)

        if not submission_mode:
            assert (preds_df["GDM_g"] - (preds_df["Dry_Green_g"] + preds_df["Dry_Clover_g"])).abs().max() < 1e-3
            assert (preds_df["Dry_Total_g"] - (preds_df["GDM_g"] + preds_df["Dry_Dead_g"])).abs().max() < 1e-3

        return final_preds_df


    if apply_post_proc:
        preds_df = post_proc_preds_df(preds_df)
    
    submission_rows = []
    for _, row in preds_df.iterrows():
        sample_id = row["image_id"]
        for target_name in order:
            submission_rows.append({
                "sample_id": f"{sample_id}__{target_name}",
                "target": float(row[target_name]),
            })

    submission_df = pd.DataFrame(submission_rows, columns=["sample_id", "target"])
    submission_df.to_csv("submission.csv", index=False)
    print("Saved results to submission.csv")

    return submission_df


In [9]:
# Removing cache and reduce RAM
gc.collect()
torch.cuda.empty_cache()

test_image_paths = [img_path for img_path in Config.test_dir.glob("*")]
submission_mode = len(test_image_paths) >= 1

if submission_mode:
    image_paths = [Path(img_path) for img_path in Config.test_dir.glob("*")]
else:
    image_paths = [Path(img_path) for img_path in Config.train_dir.glob("*")][:10]

# # Dinov3 mamba weights 8 tiles
# ----------------------------------
# model_local_cv = [
#     0.6726178856038335,
#     0.6315996830280011,
#     0.7813904246726593,
#     0.7476579813247032,
#     0.717584684713563,
# ]

# model_lb_score = [
#     0.66,
#     0.68,
#     0.66,
#     0.65,
#     0.62
# ]

# Dinvo3 mamba weights 2 tiles
# ----------------------------------
model_local_cv = [
    0.681189797707458,
    0.6198682876733633,
    0.751599745316939,
    0.7379660302020133,
    0.7346212232290809,
]

# model_weights = [score/sum(model_local_cv) for score in model_local_cv]
model_weights = [1/len(model_local_cv) for score in model_local_cv]
print(model_weights)

submission_df_dinov3 = inference(
    model_id = "vit_huge_plus_patch16_dinov3.lvd1689m",
    patch_size = 16,
    model_name = "DinoV3Hybrid",
    image_paths = image_paths,
    weight_paths = [
        # Dinov3-mamba-8-tiles-approach
        # Path("/kaggle/input/csiro-dinov3-mamba-5-fold-ensemble/30bd90d9235742d098ef18efced22c98/artifacts/fold_0_best/DinoV3Structured_fold0.pt"),
        # Path("/kaggle/input/csiro-dinov3-mamba-5-fold-ensemble/dee2f45859bb41c88e29727b07c96726/artifacts/fold_1_best/DinoV3Structured_fold1.pt"),
        # Path("/kaggle/input/csiro-dinov3-mamba-5-fold-ensemble/9daf136ace1f41029e4b0be4e8ee6229/artifacts/fold_2_best/DinoV3Structured_fold2.pt"),
        # Path("/kaggle/input/csiro-dinov3-mamba-5-fold-ensemble/536f0aad17dc4e9d93770f5b30d9d72d/artifacts/fold_3_best/DinoV3Structured_fold3.pt"),
        # Path("/kaggle/input/csiro-dinov3-mamba-5-fold-ensemble/artifacts/fold_4_best/DinoV3Structured_fold4.pt"),


        # Dinov3 mamba 2 tiles approach
        "/kaggle/input/dinov3-mamba-gdm-green-total-ratio-2-tile-approach/2776e2c07fd544d79afcbfff8db8f429/artifacts/fold_0_best/DinoV3Hybrid_fold0.pt",
        "/kaggle/input/dinov3-mamba-gdm-green-total-ratio-2-tile-approach/6a52b6a027b74fc0b65234a84fe50cc8/artifacts/fold_1_best/DinoV3Hybrid_fold1.pt",
        "/kaggle/input/dinov3-mamba-gdm-green-total-ratio-2-tile-approach/b3a849cab7954bf5a98e62ab3d60964f/artifacts/fold_2_best/DinoV3Hybrid_fold2.pt",
        "/kaggle/input/dinov3-mamba-gdm-green-total-ratio-2-tile-approach/9567cd0f28af4b8d9d407df1c30ddb1d/artifacts/fold_3_best/DinoV3Hybrid_fold3.pt",
        "/kaggle/input/dinov3-mamba-gdm-green-total-ratio-2-tile-approach/7085174f4dc9463f9ee7c349453368e1/artifacts/fold_4_best/DinoV3Hybrid_fold4.pt",
        
    ],
    model_weights = model_weights,
    submission_mode = submission_mode,
    device = 'cuda' if torch.cuda.is_available() else 'cpu',
    img_size = 512,
    use_tta = False,
    apply_post_proc = False,
)

submission_df_dinov3

[0.2, 0.2, 0.2, 0.2, 0.2]
Loaded inference data loader: 1 images

[Fold 0] Loading weights: /kaggle/input/dinov3-mamba-gdm-green-total-ratio-2-tile-approach/2776e2c07fd544d79afcbfff8db8f429/artifacts/fold_0_best/DinoV3Hybrid_fold0.pt (weight=0.2000)


Inference fold 0: 100%|██████████| 1/1 [00:04<00:00,  4.56s/it]



[Fold 1] Loading weights: /kaggle/input/dinov3-mamba-gdm-green-total-ratio-2-tile-approach/6a52b6a027b74fc0b65234a84fe50cc8/artifacts/fold_1_best/DinoV3Hybrid_fold1.pt (weight=0.2000)


Inference fold 1: 100%|██████████| 1/1 [00:03<00:00,  3.70s/it]



[Fold 2] Loading weights: /kaggle/input/dinov3-mamba-gdm-green-total-ratio-2-tile-approach/b3a849cab7954bf5a98e62ab3d60964f/artifacts/fold_2_best/DinoV3Hybrid_fold2.pt (weight=0.2000)


Inference fold 2: 100%|██████████| 1/1 [00:03<00:00,  3.73s/it]



[Fold 3] Loading weights: /kaggle/input/dinov3-mamba-gdm-green-total-ratio-2-tile-approach/9567cd0f28af4b8d9d407df1c30ddb1d/artifacts/fold_3_best/DinoV3Hybrid_fold3.pt (weight=0.2000)


Inference fold 3: 100%|██████████| 1/1 [00:03<00:00,  3.74s/it]



[Fold 4] Loading weights: /kaggle/input/dinov3-mamba-gdm-green-total-ratio-2-tile-approach/7085174f4dc9463f9ee7c349453368e1/artifacts/fold_4_best/DinoV3Hybrid_fold4.pt (weight=0.2000)


Inference fold 4: 100%|██████████| 1/1 [00:03<00:00,  3.74s/it]


Saved results to submission.csv


Unnamed: 0,sample_id,target
0,ID1001187975__Dry_Clover_g,0.624027
1,ID1001187975__Dry_Dead_g,25.299662
2,ID1001187975__Dry_Green_g,23.91939
3,ID1001187975__Dry_Total_g,49.843082
4,ID1001187975__GDM_g,24.543417


# Classical preds

In [10]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
from PIL import Image
import torch
import joblib
import json
import timm
import random
import gc
from matplotlib import pyplot as plt
import seaborn as sns

from sklearn.linear_model import Ridge, ElasticNet, LinearRegression
from sklearn.decomposition import PCA
from sklearn.metrics import r2_score

from warnings import filterwarnings
filterwarnings("ignore")


# ---------------------------
# Mass balance projection
# ---------------------------
def apply_mass_balance_projection(pred: np.ndarray) -> np.ndarray:
    """
    Enforce:
      1) Green + Clover - GDM = 0
      2) GDM + Dead - Total = 0

    pred: (N, 5) in order [Green, Clover, Dead, GDM, Total]
    Returns corrected pred (N, 5) (closest in L2 sense).
    """
    x = pred.astype(np.float64)

    # A x = 0
    A = np.array([
        [1.0, 1.0, 0.0, -1.0, 0.0],   # Green + Clover - GDM = 0
        [0.0, 0.0, 1.0,  1.0, -1.0],  # Dead + GDM - Total = 0
    ], dtype=np.float64)  # (2,5)

    # Projection: x' = x - A^T (A A^T)^-1 (A x)
    AA_T = A @ A.T  # (2,2)
    inv = np.linalg.inv(AA_T)

    # (N,2) residual per sample
    r = (A @ x.T).T
    correction = (A.T @ (inv @ r.T)).T  # (N,5)
    x_corr = x - correction

    return x_corr.astype(np.float32)


# ---------------------------
# DINOv3 offline loader
# ---------------------------
def build_dinov3_backbone_offline(model_id: str, img_size: int, device: str, weights_dir: Path):
    weights_path = weights_dir / f"{model_id}.pt"
    meta_path = weights_dir / f"{model_id}.json"

    assert weights_path.exists(), f"Missing weights: {weights_path}"
    assert meta_path.exists(), f"Missing meta: {meta_path}"

    meta = json.loads(meta_path.read_text())

    model = timm.create_model(meta["model_id"], pretrained=False, num_classes=meta["num_classes"])
    sd = torch.load(weights_path, map_location="cpu")
    model.load_state_dict(sd, strict=True)
    model.eval().to(device)

    cfg = timm.data.resolve_model_data_config(model)
    cfg["input_size"] = (3, img_size, img_size)
    transform = timm.data.create_transform(**cfg, is_training=False)

    return model, transform


# ---------------------------
# Embedding extraction
# ---------------------------
@torch.inference_mode()
def extract_embeddings_from_paths(
    image_paths: list[Path],
    model,
    transform,
    device: str,
    batch_size: int = 8,
) -> np.ndarray:
    embs = []
    batch = []

    for i, p in enumerate(tqdm(image_paths, desc="Extracting embeddings")):
        img = Image.open(p).convert("RGB")
        x = transform(img)  # (C,H,W)
        batch.append(x)

        if len(batch) == batch_size or (i == len(image_paths) - 1):
            bx = torch.stack(batch, dim=0).to(device, non_blocking=True)
            feats = model(bx)

            # timm may return (B,D) or (B,T,D)
            if feats.ndim == 3:
                feats = feats[:, 0, :]  # CLS

            embs.append(feats.detach().float().cpu().numpy())
            batch = []

    return np.concatenate(embs, axis=0).astype(np.float32)


# ---------------------------
# Ensemble helpers
# ---------------------------
def ensemble_equal(preds_list: list[np.ndarray]) -> np.ndarray:
    return np.mean(np.stack(preds_list, axis=0), axis=0)

def ensemble_weighted(preds_list: list[np.ndarray], weights: list[float]) -> np.ndarray:
    w = np.asarray(weights, dtype=np.float64)
    w = w / (w.sum() + 1e-12)
    out = np.zeros_like(preds_list[0], dtype=np.float64)
    for wi, pi in zip(w, preds_list):
        out += wi * pi.astype(np.float64)
    return out.astype(np.float32)


def load_wide_train(train_csv: str) -> pd.DataFrame:
    df = pd.read_csv(train_csv)
    df["image_id"] = df["image_path"].apply(lambda p: os.path.splitext(os.path.basename(p))[0])

    wide = (
        df.pivot_table(
            index=["image_id", "image_path", "Sampling_Date", "State", "Species", "Pre_GSHH_NDVI", "Height_Ave_cm"],
            columns="target_name",
            values="target",
            aggfunc="first",
        )
        .reset_index()
    )

    needed = ["Dry_Green_g", "Dry_Dead_g", "Dry_Clover_g", "GDM_g", "Dry_Total_g"]
    for c in needed:
        if c not in wide.columns:
            raise ValueError(f"Missing target column after pivot: {c}")

    return wide


def plot_gt_vs_pred(merged_df, target_names, title=""):
    sns.set_theme(style="whitegrid")

    n = len(target_names)
    fig, axes = plt.subplots(nrows=n, ncols=2, figsize=(14, 3.5 * n), constrained_layout=True)
    if n == 1:
        axes = np.expand_dims(axes, axis=0)

    if title:
        fig.suptitle(title)

    for i, t in enumerate(target_names):
        gt = merged_df[f"{t}_gt"].to_numpy(dtype=np.float64)
        pred = merged_df[f"{t}_pred"].to_numpy(dtype=np.float64)

        mask = np.isfinite(gt) & np.isfinite(pred)
        gt = gt[mask]
        pred = pred[mask]

        ax_hist = axes[i, 0]
        ax_scatter = axes[i, 1]

        if len(gt) == 0:
            ax_hist.set_title(f"{t} — EMPTY (check merge)")
            ax_scatter.set_title(f"{t} — EMPTY")
            continue

        # Histogram
        sns.histplot(gt, bins=40, stat="density", kde=True, alpha=0.45, label="GT", ax=ax_hist)
        sns.histplot(pred, bins=40, stat="density", kde=True, alpha=0.45, label="Pred", ax=ax_hist)
        ax_hist.legend()
        ax_hist.set_title(f"{t} — Distribution")

        # Scatter
        r2 = r2_score(gt, pred)
        sns.scatterplot(x=gt, y=pred, s=25, alpha=0.6, ax=ax_scatter)

        mn = min(gt.min(), pred.min())
        mx = max(gt.max(), pred.max())
        ax_scatter.plot([mn, mx], [mn, mx], "r--", linewidth=1)

        ax_scatter.set_title(f"{t} — Pred vs GT (R² = {r2:.4f})")
        ax_scatter.set_xlabel("GT")
        ax_scatter.set_ylabel("Pred")

    plt.show()

# ---------------------------
# Main inference
# ---------------------------
def inference_classical(
    image_paths: list[Path],
    submission_mode: bool,
    apply_mass_balance: bool,
    model_id: str,  # timm id (NOT .json)
    embedding_model_weights_dir: str,
    classical_models: list,          # list of sklearn pipelines loaded via joblib
    model_weights: list[float] | None,
    device: str,
    batch_size: int = 8,
    img_size: int = 512,
    out_csv: str = "submission.csv",
) -> pd.DataFrame:
    """
    End-to-end inference:
      1) Load DINOv3 embedding model offline (.pt + .json meta)
      2) Extract embeddings for image_paths
      3) Predict with multiple classical models
      4) Build equal + weighted ensemble
      5) (Optional) non-negativity + mass-balance projection
      6) Write competition submission.csv in long format: sample_id__target_name, target
    """

    # ---------------------------
    # small helper: postprocess
    # ---------------------------
    def postprocess_preds(pred: np.ndarray, do_mb: bool) -> np.ndarray:
        pred = np.asarray(pred, dtype=np.float32)

        # masses can't be negative (models can output negatives)
        pred = np.maximum(pred, 0.0)

        # enforce constraints if requested
        if do_mb:
            pred = apply_mass_balance_projection(pred).astype(np.float32)

        # projection can introduce tiny negatives again
        pred = np.maximum(pred, 0.0)
        return pred

    # ---------------------------
    # 1) load embedding model offline
    # ---------------------------
    weights_dir = Path(embedding_model_weights_dir)
    emb_model, transform = build_dinov3_backbone_offline(
        model_id=model_id,
        img_size=img_size,
        device=device,
        weights_dir=weights_dir,
    )

    # IMPORTANT: this must match your mass-balance A-matrix ordering:
    # [Green, Clover, Dead, GDM, Total]
    target_names = ["Dry_Green_g", "Dry_Clover_g", "Dry_Dead_g", "GDM_g", "Dry_Total_g"]

    # ---------------------------
    # 2) embeddings
    # ---------------------------
    X = extract_embeddings_from_paths(
        image_paths=image_paths,
        model=emb_model,
        transform=transform,
        device=device,
        batch_size=batch_size,
    )

    # ---------------------------
    # 3) predict per classical model
    # ---------------------------
    preds_list = []
    for m in classical_models:
        pred = m.predict(X).astype(np.float32)  # (N, 5)
        if pred.ndim == 1:
            raise ValueError("Model returned 1D pred; expected (N,5). Did you forget MultiOutput?")
        if pred.shape[1] != len(target_names):
            raise ValueError(f"Pred shape {pred.shape} does not match targets={len(target_names)}")
        preds_list.append(pred)

    # ---------------------------
    # 4) ensembles
    # ---------------------------
    pred_equal = ensemble_equal(preds_list)

    if model_weights is None:
        pred_weighted = pred_equal.copy()
    else:
        pred_weighted = ensemble_weighted(preds_list, model_weights)

    # ---------------------------
    # 5) postprocess (clip negatives + optional mass balance)
    # ---------------------------
    pred_equal = postprocess_preds(pred_equal, do_mb=apply_mass_balance)
    pred_weighted = postprocess_preds(pred_weighted, do_mb=apply_mass_balance)

    # choose final predictions for submission (weighted usually best)
    final_pred = pred_weighted

    # ---------------------------
    # 6) build wide preds df (optional debug)
    # ---------------------------
    ids = [p.stem for p in image_paths]
    preds_wide = pd.DataFrame({"sample_id": ids})
    for i, t in enumerate(target_names):
        preds_wide[t] = final_pred[:, i]

    # Optional local validation plots (only when not in submission mode)
    if not submission_mode:
        try:
            train_df = load_wide_train(Config.train_csv_path)
            merged_df = train_df.merge(
                preds_wide,
                left_on="image_id",
                right_on="sample_id",
                how="inner",
                suffixes=("_gt", "_pred")
            )
            
            print("Merged rows:", len(merged_df))
            
            plot_gt_vs_pred(
                merged_df,
                target_names,
                title="Weighted Ensemble (after Mass Balance)"
            )
        except Exception as e:
            print(f"[WARN] Skipping debug plots due to: {e}")

    # -----------------------------
    # 7) Create competition submission (long format)
    # -----------------------------
    # Competition expects this order:
    order = ["Dry_Clover_g", "Dry_Dead_g", "Dry_Green_g", "Dry_Total_g", "GDM_g"]

    submission_rows = []
    for _, row in preds_wide.iterrows():
        sid = row["sample_id"]
        for target_name in order:
            submission_rows.append({
                "sample_id": f"{sid}__{target_name}",
                "target": float(row[target_name]),
            })

    submission_df = pd.DataFrame(submission_rows, columns=["sample_id", "target"])
    submission_df.to_csv(out_csv, index=False)
    print(f"Saved results to {out_csv}")

    # quick sanity
    print("Final preds min per target:", preds_wide[target_names].min().to_dict())
    print("Final preds max per target:", preds_wide[target_names].max().to_dict())

    return submission_df


# Removing cache and reduce RAM
gc.collect()
torch.cuda.empty_cache()

test_image_paths = [img_path for img_path in Config.test_dir.glob("*")]
submission_mode = len(test_image_paths) >= 1

if submission_mode:
    image_paths = [Path(img_path) for img_path in Config.test_dir.glob("*")]
else:
    image_paths = [Path(img_path) for img_path in Config.train_dir.glob("*")]

# Loading classical models
model_paths = [
    "/kaggle/input/dinov3-embedding-classical-models/classical_model_results/Ridge_pca64.joblib",
    "/kaggle/input/dinov3-embedding-classical-models/classical_model_results/ElasticNet_pca64.joblib",
    "/kaggle/input/dinov3-embedding-classical-models/classical_model_results/LinearRegression_pca64.joblib",
]

# Loading best pca param
ensemble_metadata = json.load(
    open(
        "/kaggle/input/dinov3-embedding-classical-models/classical_model_results/ensemble_metadata.json", 
        "r"
    )
)

local_cv_scores = [
    0.77473443, 
    0.76917793, 
    0.76910136
]
model_weights = [
    score/sum(local_cv_scores)
    for score in local_cv_scores
]

# model_weights = [
#     1/len(local_cv_scores)
#     for score in local_cv_scores
# ]

print(model_weights)
classical_models = [
    joblib.load(path)
    for path in model_paths
]

from pprint import pprint
pprint(classical_models)

submission_df_classical = inference_classical(
    image_paths=image_paths,
    submission_mode=submission_mode,
    apply_mass_balance=True,

    model_id="vit_huge_plus_patch16_dinov3.lvd1689m",
    embedding_model_weights_dir="/kaggle/input/dinov3-embedding-classical-models/classical_model_results/embedding_model",

    classical_models=classical_models,
    model_weights=model_weights,

    device=str(Config.device),
    batch_size=8,
    img_size=512,
)

submission_df_classical

[0.33494588609703535, 0.33254360895014495, 0.33251050495281975]
[Pipeline(steps=[('scaler', StandardScaler()),
                ('pca', PCA(n_components=64, random_state=42)),
                ('model', Ridge(alpha=np.float64(316.2277660168379)))]),
 Pipeline(steps=[('scaler', StandardScaler()),
                ('pca', PCA(n_components=64, random_state=42)),
                ('model',
                 ElasticNet(alpha=0.003, l1_ratio=0.8, max_iter=50000))]),
 Pipeline(steps=[('scaler', StandardScaler()),
                ('pca', PCA(n_components=64, random_state=42)),
                ('model', LinearRegression())])]


Extracting embeddings:   0%|          | 0/1 [00:00<?, ?it/s]

Saved results to submission.csv
Final preds min per target: {'Dry_Green_g': 28.218172073364258, 'Dry_Clover_g': 1.704144835472107, 'Dry_Dead_g': 27.984580993652344, 'GDM_g': 29.922317504882812, 'Dry_Total_g': 57.906898498535156}
Final preds max per target: {'Dry_Green_g': 28.218172073364258, 'Dry_Clover_g': 1.704144835472107, 'Dry_Dead_g': 27.984580993652344, 'GDM_g': 29.922317504882812, 'Dry_Total_g': 57.906898498535156}


Unnamed: 0,sample_id,target
0,ID1001187975__Dry_Clover_g,1.704145
1,ID1001187975__Dry_Dead_g,27.984581
2,ID1001187975__Dry_Green_g,28.218172
3,ID1001187975__Dry_Total_g,57.906898
4,ID1001187975__GDM_g,29.922318


# Ensemble

In [11]:
import pandas as pd
import numpy as np

lb_scores = [0.71, 0.66]
submissions_df = [submission_df_dinov3, submission_df_classical]

def ensemble_submissions(sub_dfs, weights=None):
    # keep only required cols + ensure float
    dfs = []
    for df in sub_dfs:
        d = df[["sample_id", "target"]].copy()
        d["target"] = d["target"].astype(np.float64)
        d = d.set_index("sample_id")
        dfs.append(d)

    # join on sample_id (inner ensures same ordering / keys)
    joined = pd.concat(dfs, axis=1, join="inner")
    joined.columns = [f"m{i}" for i in range(len(dfs))]

    # sanity checks
    if not submission_mode:
        assert joined.shape[0] == sub_dfs[0].shape[0], "Mismatch in sample_id rows across submissions!"
        assert joined.isna().sum().sum() == 0, "Found NaNs after joining—some sample_ids missing."

    if weights is None:
        w = np.ones(len(dfs), dtype=np.float64) / len(dfs)
    else:
        w = np.asarray(weights, dtype=np.float64)
        w = w / (w.sum() + 1e-12)

    # weighted average across columns
    ens = joined.values @ w
    out = pd.DataFrame({"sample_id": joined.index, "target": ens.astype(np.float32)})
    return out, w

# # 1) weighted by LB scores
# ens_weighted, w_used = ensemble_submissions(submissions_df, weights=lb_scores)
# print("Weights used:", w_used)

#  2) Manual weights
weights = [0.65, 0.35]
ens_weighted, w_used = ensemble_submissions(submissions_df, weights=weights)
print("Weights used:", w_used)

ens_weighted.to_csv("submission.csv", index=False)
print("Saved: submission.csv")

# # 2) equal ensemble (optional)
# ens_equal, _ = ensemble_submissions(submissions_df, weights=None)
# ens_equal.to_csv("submission.csv", index=False)
# print("Saved: submission.csv")


Weights used: [0.65 0.35]
Saved: submission.csv
