In [None]:
# ========= ConvNeXt/ConvNeXtV2 (only) â€” load folds + infer + make submission =========
import os, glob, math, cv2
import numpy as np
import pandas as pd
import torch, torch.nn as nn
import timm
from torch.utils.data import Dataset, DataLoader

# ----------------- CONFIG -----------------
MODEL_NAME  = "convnextv2_huge.fcmae_ft_in22k_in1k_512"   # change if you trained a different one
IMG_SIZE    = 512
BATCH_SIZE  = 8
NUM_WORKERS = 2

TEST_CSV    = "/kaggle/input/csiro-biomass/test.csv"
WEIGHTS_DIR = "/kaggle/input/convnext2/pytorch/default/1/convnextv2"  # contains best_model_fold0.pth ...
OUT_CSV     = "submission.csv"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("high")

TARGET_COLS = ["Dry_Green_g","Dry_Dead_g","Dry_Clover_g","GDM_g","Dry_Total_g"]
MEAN = np.array([0.485, 0.456, 0.406], np.float32)
STD  = np.array([0.229, 0.224, 0.225], np.float32)

# ----------------- utils -----------------
def clean_image(img_rgb):
    h, w = img_rgb.shape[:2]
    img = img_rgb[: int(h * 0.90), :]
    hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
    mask = cv2.inRange(hsv, np.array([5,150,150]), np.array([25,255,255]))
    mask = cv2.dilate(mask, np.ones((3,3), np.uint8), iterations=2)
    if mask.sum() > 0:
        img = cv2.inpaint(img, mask, 3, cv2.INPAINT_TELEA)
    return img

def preprocess_half(img_rgb, size):
    im = cv2.resize(img_rgb, (size, size), interpolation=cv2.INTER_LINEAR).astype(np.float32) / 255.0
    im = (im - MEAN) / STD
    return torch.from_numpy(im).permute(2, 0, 1)  # CHW

def resolve_path(rel_path, base_dirs):
    for bd in base_dirs:
        p = os.path.join(bd, rel_path)
        if os.path.exists(p): return p
    bn = os.path.basename(rel_path)
    for bd in base_dirs:
        p = os.path.join(bd, bn)
        if os.path.exists(p): return p
    return None

class TestDS(Dataset):
    def __init__(self, image_paths, base_dirs):
        self.image_paths = list(image_paths)
        self.base_dirs = list(base_dirs)

    def __len__(self): return len(self.image_paths)

    def __getitem__(self, i):
        rel = self.image_paths[i]
        fp = resolve_path(rel, self.base_dirs)
        img = cv2.imread(fp) if fp else None
        if img is None:
            img = np.zeros((1000, 2000, 3), np.uint8)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = clean_image(img)
        mid = img.shape[1] // 2
        left  = preprocess_half(img[:, :mid], IMG_SIZE)
        right = preprocess_half(img[:, mid:], IMG_SIZE)
        return left, right, rel

# ----------------- model -----------------
class Local2DTokenMixerBlock(nn.Module):
    def __init__(self, dim, k=5, drop=0.1):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.gate = nn.Linear(dim, dim)
        self.dw   = nn.Conv2d(dim, dim, k, padding=k//2, groups=dim, bias=True)
        self.proj = nn.Linear(dim, dim)
        self.drop = nn.Dropout(drop)

    def forward(self, x):  # (B,H,W,C)
        s = x
        x = self.norm(x)
        x = x * torch.sigmoid(self.gate(x))
        x = x.permute(0,3,1,2)
        x = self.dw(x)
        x = x.permute(0,2,3,1)
        x = self.proj(x)
        x = self.drop(x)
        return s + x

class BiomassConvNeXt(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.backbone = timm.create_model(name, pretrained=False, features_only=True)
        in_ch = int(self.backbone.feature_info.channels()[-1])
        self.nf = in_ch

        self.fuse = nn.Sequential(
            Local2DTokenMixerBlock(self.nf, k=5, drop=0.1),
            Local2DTokenMixerBlock(self.nf, k=5, drop=0.1),
        )
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.ln   = nn.LayerNorm(self.nf)

        def head():
            return nn.Sequential(
                nn.Linear(self.nf, self.nf//2),
                nn.GELU(),
                nn.Dropout(0.2),
                nn.Linear(self.nf//2, 1),
                nn.Softplus(),
            )
        self.h_green  = head()
        self.h_clover = head()
        self.h_dead   = head()

    def _featmap(self, x):
        feats = self.backbone(x)
        return feats[-1] if isinstance(feats, (list, tuple)) else feats  # (B,C,H,W)

    def forward(self, left, right):
        fl = self._featmap(left)
        fr = self._featmap(right)
        f  = torch.cat([fl, fr], dim=3)          # (B,C,H,2W)
        f  = f.permute(0,2,3,1).contiguous()     # (B,H,2W,C)
        with torch.amp.autocast("cuda", enabled=False):
            f  = self.fuse(f.float())
            B,H,W,C = f.shape
            seq = f.view(B, H*W, C)
            feat = self.pool(seq.transpose(1,2)).squeeze(-1)
            feat = self.ln(feat)

        green  = self.h_green(feat)
        clover = self.h_clover(feat)
        dead   = self.h_dead(feat)
        gdm    = green + clover
        total  = gdm + dead
        return green, dead, clover, gdm, total  # match TARGET_COLS order (except names)

# ----------------- load + infer -----------------
def load_sd(path):
    s = torch.load(path, map_location="cpu")
    if isinstance(s, dict) and ("model_state_dict" in s or "state_dict" in s):
        s = s.get("model_state_dict", s.get("state_dict"))
    return s

@torch.inference_mode()
def predict_ckpt(model, loader, ckpt):
    model.load_state_dict(load_sd(ckpt), strict=False)
    model.to(DEVICE).eval()

    out = np.zeros((len(loader.dataset), 5), np.float32)
    off = 0
    for l, r, _ in loader:
        l = l.to(DEVICE, non_blocking=True)
        r = r.to(DEVICE, non_blocking=True)
        if DEVICE.type == "cuda":
            with torch.autocast("cuda", dtype=torch.bfloat16):
                y = model(l, r)
        else:
            y = model(l, r)
        y = torch.stack([t.view(-1) for t in y], dim=1).float().cpu().numpy()
        b = l.size(0)
        out[off:off+b] = y
        off += b
    return out

# ----------------- run -----------------
test_df = pd.read_csv(TEST_CSV)
uniq_imgs = test_df["image_path"].drop_duplicates().values

base_dirs = [
    "/kaggle/input/csiro-biomass",
    "/kaggle/input/csiro-biomass/test_images",
    "/kaggle/input/csiro-biomass/test",
]

ds = TestDS(uniq_imgs, base_dirs)
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

ckpts = sorted(glob.glob(os.path.join(WEIGHTS_DIR, "best_model_fold*.pth")))
if not ckpts:
    raise FileNotFoundError(f"No checkpoints like best_model_fold*.pth in: {WEIGHTS_DIR}")

model = BiomassConvNeXt(MODEL_NAME)
pred = np.zeros((len(ds), 5), np.float32)
for p in ckpts:
    pred += predict_ckpt(model, dl, p)
pred /= float(len(ckpts))

preds_wide = pd.DataFrame(pred, columns=TARGET_COLS)
preds_wide.insert(0, "image_path", uniq_imgs)

preds_long = preds_wide.melt(
    id_vars=["image_path"],
    value_vars=TARGET_COLS,
    var_name="target_name",
    value_name="target",
)

sub = (
    test_df[["sample_id", "image_path", "target_name"]]
    .merge(preds_long, on=["image_path", "target_name"], how="left")[["sample_id", "target"]]
    .fillna(0.0)
    .sort_values("sample_id")
    .reset_index(drop=True)
)

sub.to_csv(OUT_CSV, index=False)
print("saved:", OUT_CSV)
sub.head()
