In [None]:
# ========= SwinV2 (only) â€” load folds + infer + make submission =========
import os, glob, math, cv2
import numpy as np
import pandas as pd
import torch, torch.nn as nn
import torch.nn.functional as F
import timm

# ----------------- CONFIG -----------------
MODEL_NAME  = "swinv2_large_window12to24_192to384.ms_in22k_ft_in1k"
IMG_SIZE    = 384
BATCH_SIZE  = 4
NUM_WORKERS = 2

TEST_CSV    = "/kaggle/input/csiro-biomass/test.csv"
WEIGHTS_DIR = "/kaggle/input/swinv2/pytorch/default/1/swinv2"  # has best_model_fold0.pth ...
OUT_CSV     = "submission.csv"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")

ALL_TARGET_COLS = ["Dry_Green_g","Dry_Dead_g","Dry_Clover_g","GDM_g","Dry_Total_g"]
MEAN = np.array([0.485, 0.456, 0.406], np.float32)
STD  = np.array([0.229, 0.224, 0.225], np.float32)

# ----------------- image utils -----------------
def clean_image(img_rgb):
    h, w = img_rgb.shape[:2]
    img = img_rgb[: int(h * 0.90), :]
    hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
    mask = cv2.inRange(hsv, np.array([5,150,150]), np.array([25,255,255]))
    mask = cv2.dilate(mask, np.ones((3,3), np.uint8), iterations=2)
    if mask.sum() > 0:
        img = cv2.inpaint(img, mask, 3, cv2.INPAINT_TELEA)
    return img

def preprocess_half(img_rgb, size):
    im = cv2.resize(img_rgb, (size, size), interpolation=cv2.INTER_LINEAR).astype(np.float32) / 255.0
    im = (im - MEAN) / STD
    return torch.from_numpy(im).permute(2,0,1)  # CHW

def resolve_path(rel_path, base_dirs):
    for bd in base_dirs:
        p = os.path.join(bd, rel_path)
        if os.path.exists(p): return p
    bn = os.path.basename(rel_path)
    for bd in base_dirs:
        p = os.path.join(bd, bn)
        if os.path.exists(p): return p
    return None

# ----------------- dataset -----------------
from torch.utils.data import Dataset, DataLoader

class TestDS(Dataset):
    def __init__(self, image_paths, base_dirs):
        self.image_paths = list(image_paths)
        self.base_dirs = list(base_dirs)

    def __len__(self): return len(self.image_paths)

    def __getitem__(self, i):
        rel = self.image_paths[i]
        fp = resolve_path(rel, self.base_dirs)
        img = cv2.imread(fp) if fp else None
        if img is None:
            img = np.zeros((1000, 2000, 3), np.uint8)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = clean_image(img)
        mid = img.shape[1] // 2
        left  = preprocess_half(img[:, :mid], IMG_SIZE)
        right = preprocess_half(img[:, mid:], IMG_SIZE)
        return left, right, rel

# ----------------- model blocks (from your nb, feature-path) -----------------
try:
    from timm.layers import DropPath
except Exception:
    from timm.models.layers import DropPath

class Local2DTokenMixerConvNextBlock(nn.Module):
    def __init__(self, dim, kernel_size=7, mlp_ratio=4.0, dropout=0.0, drop_path=0.0, layer_scale_init=1e-6):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size, padding=kernel_size//2, groups=dim, bias=True)
        self.norm = nn.LayerNorm(dim, eps=1e-6)
        hidden = int(dim * mlp_ratio)
        self.pw1 = nn.Linear(dim, hidden)
        self.act = nn.GELU()
        self.pw2 = nn.Linear(hidden, dim)
        self.drop = nn.Dropout(dropout)
        self.drop_path = DropPath(drop_path) if drop_path and drop_path > 0 else nn.Identity()
        self.gamma = nn.Parameter(layer_scale_init * torch.ones(dim)) if layer_scale_init and layer_scale_init > 0 else None

    def forward(self, x_grid):  # (B,H,W,C)
        shortcut = x_grid
        x = x_grid.permute(0,3,1,2).contiguous()
        x = self.dwconv(x)
        x = x.permute(0,2,3,1).contiguous()
        x = self.norm(x)
        x = self.pw1(x); x = self.act(x); x = self.drop(x); x = self.pw2(x)
        if self.gamma is not None:
            x = x * self.gamma
        x = self.drop(x)
        x = self.drop_path(x)
        return shortcut + x

class GeM2d(nn.Module):
    def __init__(self, p=3.0, eps=1e-6):
        super().__init__()
        self.p = nn.Parameter(torch.ones(1) * float(p))
        self.eps = float(eps)
    def forward(self, x):  # (B,C,H,W)->(B,C)
        x = x.clamp(min=self.eps).pow(self.p)
        x = x.mean(dim=(2,3)).pow(1.0 / self.p)
        return x

class SeparableConvGNAct(nn.Module):
    def __init__(self, c, kernel_size=3, gn_groups=32):
        super().__init__()
        g = min(int(gn_groups), int(c))
        while g > 1 and (c % g) != 0: g -= 1
        self.dw = nn.Conv2d(c, c, kernel_size, padding=kernel_size//2, groups=c, bias=False)
        self.pw = nn.Conv2d(c, c, 1, bias=False)
        self.gn = nn.GroupNorm(g, c)
        self.act = nn.GELU()
    def forward(self, x):
        x = self.dw(x); x = self.pw(x); x = self.gn(x); x = self.act(x)
        return x

def _resize_like(x, ref):
    if x.shape[-2:] == ref.shape[-2:]:
        return x
    if x.shape[-2] > ref.shape[-2] or x.shape[-1] > ref.shape[-1]:
        return F.interpolate(x, size=ref.shape[-2:], mode="area")
    return F.interpolate(x, size=ref.shape[-2:], mode="nearest")

class StereoCrossGate2D(nn.Module):
    def __init__(self, c, kernel_size=3, gn_groups=32, layer_scale_init=1e-3):
        super().__init__()
        g = min(int(gn_groups), int(c))
        while g > 1 and (c % g) != 0: g -= 1
        self.normL = nn.GroupNorm(g, c)
        self.normR = nn.GroupNorm(g, c)
        pad = kernel_size // 2
        self.gate_from_L = nn.Sequential(
            nn.Conv2d(c, c, kernel_size, padding=pad, groups=c, bias=False),
            nn.Conv2d(c, c, 1, bias=True),
        )
        self.gate_from_R = nn.Sequential(
            nn.Conv2d(c, c, kernel_size, padding=pad, groups=c, bias=False),
            nn.Conv2d(c, c, 1, bias=True),
        )
        self.ls = nn.Parameter(torch.ones(c) * float(layer_scale_init))

    def forward(self, xL, xR):
        l = self.normL(xL); r = self.normR(xR)
        gL = torch.sigmoid(self.gate_from_L(l))
        gR = torch.sigmoid(self.gate_from_R(r))
        ls = self.ls.view(1,-1,1,1).to(xL.dtype)
        xL = xL + ls * (xL * gR)
        xR = xR + ls * (xR * gL)
        return xL, xR

class FuseToP16(nn.Module):
    def __init__(self, c, gn_groups=32):
        super().__init__()
        self.w8  = nn.Conv2d(c, 1, 1, bias=True)
        self.w16 = nn.Conv2d(c, 1, 1, bias=True)
        self.w32 = nn.Conv2d(c, 1, 1, bias=True)
        self.out = SeparableConvGNAct(c, kernel_size=3, gn_groups=gn_groups)
    def forward(self, p8, p16, p32):
        p8r  = _resize_like(p8,  p16)
        p32r = _resize_like(p32, p16)
        w = torch.cat([self.w8(p8r), self.w16(p16), self.w32(p32r)], dim=1)  # (B,3,H,W)
        w = torch.softmax(w, dim=1)
        fused = w[:,0:1]*p8r + w[:,1:2]*p16 + w[:,2:3]*p32r
        return self.out(fused)

class SeamBlend2D(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.dw = nn.Conv2d(c, c, kernel_size=(3,9), padding=(1,4), groups=c, bias=False)
        self.pw = nn.Conv2d(c, c, 1, bias=True)
        self.act = nn.GELU()
    def forward(self, x):
        return x + self.pw(self.act(self.dw(x)))

class BiomassSwinV2(nn.Module):
    def __init__(self, model_name, pretrained=False, head_dropout=0.1, fuse_dim=None):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=pretrained, features_only=True, out_indices=(0,1,2,3))
        fi = self.backbone.feature_info
        chans = list(fi.channels())
        reds  = list(fi.reduction())

        def pick(target):
            return int(min(range(len(reds)), key=lambda i: abs(int(reds[i]) - int(target))))
        self.idx8, self.idx16, self.idx32 = pick(8), pick(16), pick(32)
        ch8, ch16, ch32 = int(chans[self.idx8]), int(chans[self.idx16]), int(chans[self.idx32])

        self.nf = int(fuse_dim or ch16)

        def proj(in_ch):
            g = min(32, self.nf)
            while g > 1 and (self.nf % g) != 0: g -= 1
            return nn.Sequential(
                nn.Conv2d(in_ch, self.nf, 1, bias=False),
                nn.GroupNorm(g, self.nf),
                nn.GELU(),
            )
        self.proj8, self.proj16, self.proj32 = proj(ch8), proj(ch16), proj(ch32)

        self.stereo8  = StereoCrossGate2D(self.nf, kernel_size=3, gn_groups=32, layer_scale_init=1e-3)
        self.stereo16 = StereoCrossGate2D(self.nf, kernel_size=3, gn_groups=32, layer_scale_init=1e-3)
        self.stereo32 = StereoCrossGate2D(self.nf, kernel_size=3, gn_groups=32, layer_scale_init=1e-3)

        self.fuse_to_p16 = FuseToP16(self.nf, gn_groups=32)
        self.seam16 = SeamBlend2D(self.nf)

        self.fusion2d = nn.Sequential(
            Local2DTokenMixerConvNextBlock(self.nf, kernel_size=7, mlp_ratio=4.0, dropout=0.0, drop_path=0.10, layer_scale_init=1e-6),
            Local2DTokenMixerConvNextBlock(self.nf, kernel_size=7, mlp_ratio=4.0, dropout=0.0, drop_path=0.10, layer_scale_init=1e-6),
        )

        self.gem2d = GeM2d(p=3.0)
        self.feat_ln = nn.LayerNorm(self.nf)

        def pos_head():
            return nn.Sequential(
                nn.Linear(self.nf, self.nf//2),
                nn.GELU(),
                nn.Dropout(head_dropout),
                nn.Linear(self.nf//2, 1),
                nn.Softplus(),
            )
        self.head_1 = pos_head()  # green
        self.head_2 = pos_head()  # clover
        self.head_3 = pos_head()  # dead
 
    @staticmethod
    def _ensure_nchw(x: torch.Tensor, in_ch: int) -> torch.Tensor:
        # If x is NHWC (B,H,W,C) convert to NCHW (B,C,H,W)
        if x.ndim == 4 and x.shape[1] != in_ch and x.shape[-1] == in_ch:
            return x.permute(0, 3, 1, 2).contiguous()
        return x

    def forward(self, left, right):
        fl = self.backbone(left)
        fr = self.backbone(right)

        f8_l  = self._ensure_nchw(fl[self.idx8],  self.proj8[0].in_channels)
        f16_l = self._ensure_nchw(fl[self.idx16], self.proj16[0].in_channels)
        f32_l = self._ensure_nchw(fl[self.idx32], self.proj32[0].in_channels)
    
        f8_r  = self._ensure_nchw(fr[self.idx8],  self.proj8[0].in_channels)
        f16_r = self._ensure_nchw(fr[self.idx16], self.proj16[0].in_channels)
        f32_r = self._ensure_nchw(fr[self.idx32], self.proj32[0].in_channels)
    
        p8_l  = self.proj8(f8_l);   p8_r  = self.proj8(f8_r)
        p16_l = self.proj16(f16_l); p16_r = self.proj16(f16_r)
        p32_l = self.proj32(f32_l); p32_r = self.proj32(f32_r)


        p8_l,  p8_r  = self.stereo8(p8_l,  p8_r)
        p16_l, p16_r = self.stereo16(p16_l, p16_r)
        p32_l, p32_r = self.stereo32(p32_l, p32_r)

        p8  = torch.cat([p8_l,  p8_r],  dim=3)
        p16 = torch.cat([p16_l, p16_r], dim=3)
        p32 = torch.cat([p32_l, p32_r], dim=3)

        p16_fused = self.fuse_to_p16(p8, p16, p32)
        p16_fused = self.seam16(p16_fused)

        grid = p16_fused.permute(0,2,3,1).contiguous()
        grid = self.fusion2d(grid)
        p16m = grid.permute(0,3,1,2).contiguous()

        feat = self.gem2d(p16m)
        feat = self.feat_ln(feat)

        green  = self.head_1(feat)
        clover = self.head_2(feat)
        dead   = self.head_3(feat)
        gdm    = green + clover
        total  = gdm + dead
        return total, gdm, green, clover, dead

# ----------------- inference helpers -----------------
def load_sd(path):
    s = torch.load(path, map_location="cpu")
    if isinstance(s, dict) and ("model_state_dict" in s or "state_dict" in s):
        s = s.get("model_state_dict", s.get("state_dict"))
    return s

def postprocess_3_to_5(pred3):
    total = pred3[:,0]; gdm = pred3[:,1]; green = pred3[:,2]
    clover = np.maximum(0.0, gdm - green)
    dead   = np.maximum(0.0, total - gdm)
    return np.stack([green, dead, clover, gdm, total], axis=1).astype(np.float32)

@torch.inference_mode()
def predict_ckpt(model, loader, ckpt):
    model.load_state_dict(load_sd(ckpt), strict=False)
    model.to(DEVICE).eval()
    out = np.zeros((len(loader.dataset), 3), np.float32)  # [total,gdm,green]
    off = 0
    for l, r, _ in loader:
        l = l.to(DEVICE, non_blocking=True)
        r = r.to(DEVICE, non_blocking=True)
        if DEVICE.type == "cuda":
            with torch.autocast("cuda", dtype=torch.bfloat16):
                outs = model(l, r)
        else:
            outs = model(l, r)
        total, gdm, green = outs[0].view(-1), outs[1].view(-1), outs[2].view(-1)
        pred3 = torch.stack([total, gdm, green], dim=1).float().cpu().numpy()
        b = l.size(0)
        out[off:off+b] = pred3
        off += b
    return out

# ----------------- run -----------------
test_df = pd.read_csv(TEST_CSV)
uniq_imgs = test_df["image_path"].drop_duplicates().values

base_dirs = [
    "/kaggle/input/csiro-biomass",
    "/kaggle/input/csiro-biomass/test",
    "/kaggle/input/csiro-biomass/test_images",
]

ds = TestDS(uniq_imgs, base_dirs)
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

ckpts = sorted(glob.glob(os.path.join(WEIGHTS_DIR, "best_model_fold*.pth")))
if not ckpts:
    raise FileNotFoundError(f"No checkpoints like best_model_fold*.pth in: {WEIGHTS_DIR}")

model = BiomassSwinV2(MODEL_NAME, pretrained=False, head_dropout=0.1)
pred3_sum = np.zeros((len(ds), 3), np.float32)
for p in ckpts:
    pred3_sum += predict_ckpt(model, dl, p)
pred3 = pred3_sum / float(len(ckpts))
pred5 = postprocess_3_to_5(pred3)

preds_wide = pd.DataFrame(pred5, columns=ALL_TARGET_COLS)
preds_wide.insert(0, "image_path", uniq_imgs)

preds_long = preds_wide.melt(
    id_vars=["image_path"],
    value_vars=ALL_TARGET_COLS,
    var_name="target_name",
    value_name="target",
)

sub = (
    test_df[["sample_id","image_path","target_name"]]
    .merge(preds_long, on=["image_path","target_name"], how="left")[["sample_id","target"]]
    .fillna(0.0)
    .sort_values("sample_id")
    .reset_index(drop=True)
)

sub.to_csv(OUT_CSV, index=False)
print("saved:", OUT_CSV)
sub.head()
