In [None]:
# ========= EVA02 (only) â€” load folds + infer + make submission =========
import os, math, glob, cv2
import numpy as np
import pandas as pd
import torch, torch.nn as nn
import timm
from torch.utils.data import Dataset, DataLoader

# ----------------- CONFIG -----------------
MODEL_NAME  = "eva02_large_patch14_448.mim_in22k_ft_in22k_in1k"
IMG_SIZE    = 448
BATCH_SIZE  = 4
NUM_WORKERS = 2

TEST_CSV  = "/kaggle/input/csiro-biomass/test.csv"
# folder that contains your trained weights: best_model_fold0.pth, best_model_fold1.pth, ...
WEIGHTS_DIR = "/kaggle/input/csiro-eva02/pytorch/default/1/eva02"
OUT_CSV  = "submission.csv"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("high")

ALL_TARGET_COLS = ["Dry_Green_g","Dry_Dead_g","Dry_Clover_g","GDM_g","Dry_Total_g"]
MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
STD  = np.array([0.229, 0.224, 0.225], dtype=np.float32)

# ----------------- image utils -----------------
def clean_image(img_rgb: np.ndarray) -> np.ndarray:
    h, w = img_rgb.shape[:2]
    img = img_rgb[: int(h * 0.90), :]  # drop bottom 10%
    hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
    lower = np.array([5, 150, 150])
    upper = np.array([25, 255, 255])
    mask = cv2.inRange(hsv, lower, upper)
    mask = cv2.dilate(mask, np.ones((3, 3), np.uint8), iterations=2)
    if mask.sum() > 0:
        img = cv2.inpaint(img, mask, 3, cv2.INPAINT_TELEA)
    return img

def preprocess_half(img_rgb: np.ndarray, size: int) -> torch.Tensor:
    # resize -> normalize -> CHW float32
    im = cv2.resize(img_rgb, (size, size), interpolation=cv2.INTER_LINEAR).astype(np.float32) / 255.0
    im = (im - MEAN) / STD
    return torch.from_numpy(im).permute(2, 0, 1)  # (3,H,W)

# ----------------- dataset -----------------
def resolve_path(rel_path: str, base_dirs):
    # try exact rel path under each base; fallback to basename under each base
    for bd in base_dirs:
        p = os.path.join(bd, rel_path)
        if os.path.exists(p): return p
    bn = os.path.basename(rel_path)
    for bd in base_dirs:
        p = os.path.join(bd, bn)
        if os.path.exists(p): return p
    return None

class TestDS(Dataset):
    def __init__(self, image_paths, base_dirs):
        self.image_paths = list(image_paths)
        self.base_dirs = list(base_dirs)

    def __len__(self): return len(self.image_paths)

    def __getitem__(self, i):
        rel = self.image_paths[i]
        fp = resolve_path(rel, self.base_dirs)
        img = cv2.imread(fp) if fp else None
        if img is None:
            img = np.zeros((1000, 2000, 3), np.uint8)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = clean_image(img)
        mid = img.shape[1] // 2
        left  = preprocess_half(img[:, :mid], IMG_SIZE)
        right = preprocess_half(img[:, mid:], IMG_SIZE)
        return left, right, rel

# ----------------- model (EVA02 token mode only) -----------------
class Local2DTokenMixerBlock(nn.Module):
    def __init__(self, dim, kernel_size=5, dropout=0.1):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.gate = nn.Linear(dim, dim)
        self.dwconv = nn.Conv2d(dim, dim, kernel_size, padding=kernel_size // 2, groups=dim, bias=True)
        self.proj = nn.Linear(dim, dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x_grid):  # (B,H,W,D)
        shortcut = x_grid
        x = self.norm(x_grid)
        x = x * torch.sigmoid(self.gate(x))
        x = x.permute(0, 3, 1, 2)   # (B,D,H,W)
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1)   # (B,H,W,D)
        x = self.proj(x)
        x = self.drop(x)
        return shortcut + x

class BiomassModel(nn.Module):
    def __init__(self, model_name: str):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=False, num_classes=0)
        self.nf = int(self.backbone.num_features)
        self.num_prefix = int(getattr(self.backbone, "num_prefix_tokens", 1))

        gh = gw = None
        if hasattr(self.backbone, "patch_embed") and hasattr(self.backbone.patch_embed, "grid_size"):
            gs = self.backbone.patch_embed.grid_size
            if isinstance(gs, (tuple, list)) and len(gs) == 2:
                gh, gw = int(gs[0]), int(gs[1])
        self.gh, self.gw = gh, gw

        self.fusion2d = nn.Sequential(
            Local2DTokenMixerBlock(self.nf, kernel_size=5, dropout=0.1),
            Local2DTokenMixerBlock(self.nf, kernel_size=5, dropout=0.1),
        )
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.feat_ln = nn.LayerNorm(self.nf)

        def make_pos_head():
            return nn.Sequential(
                nn.Linear(self.nf, self.nf // 2),
                nn.GELU(),
                nn.Dropout(0.2),
                nn.Linear(self.nf // 2, 1),
                nn.Softplus(),
            )
        self.head_green  = make_pos_head()
        self.head_clover = make_pos_head()
        self.head_dead   = make_pos_head()

    @staticmethod
    def _infer_grid_hw(num_patches: int):
        gh = int(math.isqrt(num_patches))
        while gh > 1 and (num_patches % gh) != 0:
            gh -= 1
        gw = num_patches // gh
        return gh, gw

    def _grid_from_input(self, x: torch.Tensor, num_patches: int):
        if hasattr(self.backbone, "patch_embed") and hasattr(self.backbone.patch_embed, "patch_size"):
            ps = self.backbone.patch_embed.patch_size
            ph, pw = (ps if isinstance(ps, (tuple, list)) else (ps, ps))
            gh = int(x.shape[-2] // ph)
            gw = int(x.shape[-1] // pw)
            if gh * gw == num_patches:
                return gh, gw
        return self._infer_grid_hw(num_patches)

    @staticmethod
    def _concat_halves(p_l, p_r, gh, gw):
        B, Np, D = p_l.shape
        pl2 = p_l.view(B, gh, gw, D)
        pr2 = p_r.view(B, gh, gw, D)
        full2 = torch.cat([pl2, pr2], dim=2)      # (B,gh,2gw,D)
        return full2.reshape(B, gh * (2 * gw), D)

    def forward(self, left, right):
        x_l = self.backbone.forward_features(left)   # (B,N,D)
        x_r = self.backbone.forward_features(right)  # (B,N,D)

        p_l = x_l[:, self.num_prefix:, :]
        p_r = x_r[:, self.num_prefix:, :]
        Np = int(p_l.size(1))

        gh, gw = self.gh, self.gw
        if (gh is None) or (gw is None) or (gh * gw != Np):
            gh, gw = self._grid_from_input(left, Np)

        p_seq = self._concat_halves(p_l, p_r, gh, gw)  # (B, gh*(2gw), D)
        B, N, D = p_seq.shape
        H, W = gh, 2 * gw

        with torch.amp.autocast("cuda", enabled=False):
            p_grid = p_seq.float().view(B, H, W, D).contiguous()
            p_grid = self.fusion2d(p_grid)
            p_fused = p_grid.view(B, H * W, D)

            feat = self.pool(p_fused.transpose(1, 2)).squeeze(-1)
            feat = self.feat_ln(feat)

        green  = self.head_green(feat)
        clover = self.head_clover(feat)
        dead   = self.head_dead(feat)
        gdm    = green + clover
        total  = gdm + dead
        return total, gdm, green  # keep it lean (derive clover/dead later)

# ----------------- load folds + infer -----------------
def load_sd(path):
    state = torch.load(path, map_location="cpu")
    if isinstance(state, dict) and ("model_state_dict" in state or "state_dict" in state):
        state = state.get("model_state_dict", state.get("state_dict"))
    return state

@torch.inference_mode()
def predict_one_checkpoint(model, loader, ckpt_path):
    model.load_state_dict(load_sd(ckpt_path), strict=False)
    model.to(DEVICE).eval()

    preds = np.zeros((len(loader.dataset), 3), np.float32)  # [total, gdm, green]
    off = 0
    for l, r, _ in loader:
        l = l.to(DEVICE, non_blocking=True)
        r = r.to(DEVICE, non_blocking=True)
        if DEVICE.type == "cuda":
            with torch.autocast("cuda", dtype=torch.bfloat16):
                total, gdm, green = model(l, r)
        else:
            total, gdm, green = model(l, r)

        b = l.size(0)
        pred3 = torch.stack([total.view(-1), gdm.view(-1), green.view(-1)], dim=1).float().cpu().numpy()
        preds[off:off+b] = pred3
        off += b
    return preds

def postprocess_3_to_5(pred3):
    total = pred3[:, 0]
    gdm   = pred3[:, 1]
    green = pred3[:, 2]
    clover = np.maximum(0.0, gdm - green)
    dead   = np.maximum(0.0, total - gdm)
    # order required:
    return np.stack([green, dead, clover, gdm, total], axis=1).astype(np.float32)

# ----------------- run -----------------
test_df = pd.read_csv(TEST_CSV)
uniq_imgs = test_df["image_path"].drop_duplicates().values

# base dirs to try (add more if you keep images elsewhere)
base_dirs = [
    "/kaggle/input/csiro-biomass",              # sometimes image_path is relative under this
    "/kaggle/input/csiro-biomass/test_images",  # common layout
    "/kaggle/input/csiro-biomass/test",         # alt layout
]

ds = TestDS(uniq_imgs, base_dirs)
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

ckpts = sorted(glob.glob(os.path.join(WEIGHTS_DIR, "best_model_fold*.pth")))
if not ckpts:
    raise FileNotFoundError(f"No fold checkpoints found in: {WEIGHTS_DIR}")

model = BiomassModel(MODEL_NAME)
preds_sum = np.zeros((len(ds), 3), np.float32)

for p in ckpts:
    preds_sum += predict_one_checkpoint(model, dl, p)
preds_3 = preds_sum / float(len(ckpts))
preds_5 = postprocess_3_to_5(preds_3)

preds_wide = pd.DataFrame(preds_5, columns=ALL_TARGET_COLS)
preds_wide.insert(0, "image_path", uniq_imgs)

preds_long = preds_wide.melt(
    id_vars=["image_path"],
    value_vars=ALL_TARGET_COLS,
    var_name="target_name",
    value_name="target",
)

sub = (
    test_df[["sample_id", "image_path", "target_name"]]
    .merge(preds_long, on=["image_path", "target_name"], how="left")[["sample_id", "target"]]
    .fillna(0.0)
    .sort_values("sample_id")
    .reset_index(drop=True)
)

sub.to_csv(OUT_CSV, index=False)
print("saved:", OUT_CSV)
sub.head()
