In [83]:
from __future__ import annotations

import random
from pathlib import Path
from typing import List, Dict, Tuple

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import timm
from tqdm import tqdm
from argparse import ArgumentParser


In [84]:
def get_training_args(argv=None):
    p = ArgumentParser()

    # paths
    p.add_argument("--train_csv", type=str, default="hackdata/sentinel-beetles/public_release/train.csv")
    p.add_argument("--val_csv", type=str, default="hackdata/sentinel-beetles/public_release/val.csv")
    p.add_argument("--train_img_dir", type=str, default="training_images")
    p.add_argument("--val_img_dir", type=str, default="validation_images")
    p.add_argument("--save_dir", type=str, default="ckpts")

    # columns
    p.add_argument("--event_col", type=str, default="eventID")
    p.add_argument("--img_col", type=str, default="relative_img_loc")

    # targets
    p.add_argument("--spei30_col", type=str, default="SPEI_30d")
    p.add_argument("--spei1y_col", type=str, default="SPEI_1y")
    p.add_argument("--spei2y_col", type=str, default="SPEI_2y")

    # model / data
    p.add_argument("--img_size", type=int, default=512)
    p.add_argument("--k_max", type=int, default=8)          # max images per event used
    p.add_argument("--batch_size", type=int, default=1)     # events per batch (keep small!)
    p.add_argument("--num_workers", type=int, default=1)
    p.add_argument("--seed", type=int, default=0)

    # optimization
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--lr", type=float, default=3e-5)        # good default for finetuning convnext_small
    p.add_argument("--weight_decay", type=float, default=0.05)
    p.add_argument("--grad_accum", type=int, default=1)     # increase if OOM
    p.add_argument("--freeze_backbone_epochs", type=int, default=1)  # stabilize early training

    return p.parse_args(argv)

In [85]:
def r2_score_np(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    ss_res = float(((y_true - y_pred) ** 2).sum())
    ss_tot = float(((y_true - y_true.mean()) ** 2).sum())
    if ss_tot == 0.0:
        return 0.0
    return 1.0 - ss_res / ss_tot

In [86]:
def evaluate_spei_r2(gts: np.ndarray, preds: np.ndarray) -> Tuple[float, float, float]:
    return (
        r2_score_np(gts[:, 0], preds[:, 0]),
        r2_score_np(gts[:, 1], preds[:, 1]),
        r2_score_np(gts[:, 2], preds[:, 2]),
    )

In [87]:
class EventDataset(Dataset):
    """
    Each __getitem__ returns:
      x: (K, 3, H, W) tensor of beetle images for this event (padded if <K)
      mask: (K,) float mask where 1 = real image, 0 = padding
      y: (3,) float targets for the event
    """

    def __init__(self, csv_path: str, img_dir: str, event_col: str, img_col: str, target_cols: Tuple[str, str, str], tfm, k_max: int, train_mode: bool, seed: int = 0,):
        self.df = pd.read_csv(csv_path).reset_index(drop=True)
        self.img_root = Path(img_dir)
        self.event_col = event_col
        self.img_col = img_col
        self.target_cols = target_cols
        self.tfm = tfm
        self.k_max = k_max
        self.train_mode = train_mode
        self.rng = random.Random(seed)

        # sanity columns
        for col in [event_col, img_col, *target_cols]:
            if col not in self.df.columns:
                raise KeyError(f"Missing column '{col}' in {csv_path}. Columns: {list(self.df.columns)}")

        # drop missing
        self.df = self.df.dropna(subset=[event_col, img_col, *target_cols]).reset_index(drop=True)

        
        # Build event -> row indices
        self.event_to_rows: Dict[str, List[int]] = {}
        for i in range(len(self.df)):
            ev = str(self.df.loc[i, self.event_col])
            self.event_to_rows.setdefault(ev, []).append(i)
        self.events = list(self.event_to_rows.keys())

    def __len__(self) -> int:
        return len(self.events)

    def _open_image(self, rel: str) -> Image.Image:
        rel = str(rel).lstrip("/").replace("\\", "/")
        p = self.img_root / rel
        if not p.exists():
            # fallback: basename only
            p2 = self.img_root / Path(rel).name
            if p2.exists():
                p = p2
            else:
                raise FileNotFoundError(f"Image not found: {p} (also tried {p2})")
        return Image.open(p).convert("RGB")
    
    def __getitem__(self, idx: int):
        ev = self.events[idx]
        rows = self.event_to_rows[ev]

        # targets from first row
        row0 = self.df.loc[rows[0]]
        y = torch.tensor([row0[c] for c in self.target_cols], dtype=torch.float32)

        
        # sample up to K specimens for this event
        if self.train_mode and len(rows) > self.k_max:
            chosen = self.rng.sample(rows, self.k_max)
        else:
            # deterministic for validation
            chosen = rows[: self.k_max]

        # load + transform images
        xs: List[torch.Tensor] = []
        for r in chosen:
            rel_path = self.df.loc[r, self.img_col]
            xs.append(self.tfm(self._open_image(rel_path)))

        n = len(xs)
        H = W = self.tfm._img_size
        
        # pad to K
        if n < self.k_max:
            pad = torch.zeros((self.k_max - n, 3, H, W), dtype=torch.float32)
            x = torch.cat([torch.stack(xs, dim=0), pad], dim=0) if n > 0 else pad
            mask = torch.cat([torch.ones(n), torch.zeros(self.k_max - n)]).float()
        else:
            x = torch.stack(xs[: self.k_max], dim=0)
            mask = torch.ones(self.k_max).float()

        return x, mask, y

In [88]:
class EventConvNeXtRegressor(nn.Module):
    def __init__(self, backbone_name: str = "convnext_small"):
        super().__init__()
        self.backbone = timm.create_model("convnext_small", pretrained=True, num_classes=0, cache_dir=str(Path("timm_cache").resolve()),
)
        d = self.backbone.num_features
        self.head = nn.Linear(d, 6)  # 3 mu + 3 log_sigma

    def forward(self, x: torch.Tensor, mask: torch.Tensor):
        """
        x: (B, K, 3, H, W)
        mask: (B, K) 1 for real, 0 for padded
        """
        B, K, C, H, W = x.shape
        feats = self.backbone(x.view(B * K, C, H, W))  # (B*K, d)
        d = feats.shape[-1]
        feats = feats.view(B, K, d)             # (B, K, d)

        # masked mean pool
        m = mask.unsqueeze(-1)                  # (B, K, 1)
        denom = m.sum(dim=1).clamp_min(1.0)     # (B, 1)
        event_feat = (feats * m).sum(dim=1) / denom  # (B, d)

        out = self.head(event_feat)             # (B, 6)
        mu = out[:, :3]                         # (B, 3)
        sigma = F.softplus(out[:, 3:]) + 1e-6
        return mu, sigma

In [89]:
def gaussian_nll(y: torch.Tensor, mu: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
    """
    y, mu, sigma: (B, 3)
    Negative log-likelihood for independent Gaussian dims.
    """
    # 0.5*((y-mu)/sigma)^2 + log(sigma)
    return (0.5 * ((y - mu) / sigma).pow(2) + torch.log(sigma)).mean()

In [90]:
def main(argv=None):
    args = get_training_args(argv)

    tfm = transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225)),
    ])
    tfm._img_size = args.img_size

    target_cols = (args.spei30_col, args.spei1y_col, args.spei2y_col)

    train_ds = EventDataset(
        csv_path=args.train_csv,
        img_dir=args.train_img_dir,
        event_col=args.event_col,
        img_col=args.img_col,
        target_cols=target_cols,
        tfm=tfm,
        k_max=args.k_max,
        train_mode=True,
        seed=args.seed,
    )
    val_ds = EventDataset(
        csv_path=args.val_csv,
        img_dir=args.val_img_dir,
        event_col=args.event_col,
        img_col=args.img_col,
        target_cols=target_cols,
        tfm=tfm,
        k_max=args.k_max,
        train_mode=False,
        seed=args.seed,
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = EventConvNeXtRegressor("convnext_small").to(device)

    def set_backbone_trainable(trainable: bool):
        for p in model.backbone.parameters():
            p.requires_grad = trainable
            
    # freeze backbone initially (optional but recommended for stability)
    if args.freeze_backbone_epochs > 0:
        set_backbone_trainable(False)

    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=(1e-3 if args.freeze_backbone_epochs > 0 else args.lr),
        weight_decay=args.weight_decay
    )

    use_cuda = (device.type == "cuda")
    scaler = torch.amp.GradScaler("cuda", enabled=use_cuda)

    train_loader = DataLoader(
        train_ds, batch_size=args.batch_size, shuffle=True,
        num_workers=args.num_workers, pin_memory=(device.type == "cuda")
    )
    val_loader = DataLoader(
        val_ds, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=(device.type == "cuda")
    )
    
    save_dir = Path(args.save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    best_path = save_dir / "event_convnext_small_best.pth"

    best_avg = -1.0
    best_epoch = -1

    for epoch in range(args.epochs):
        if epoch == args.freeze_backbone_epochs and args.freeze_backbone_epochs > 0:
            set_backbone_trainable(True)
            optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

        # ---- train ----
        model.train()
        tr_loss = 0.0
        tr_preds, tr_gts = [], []
        optimizer.zero_grad(set_to_none=True)

        pbar = tqdm(train_loader, desc=f"train {epoch}", leave=False)
        for step, (x, mask, y) in enumerate(pbar):
            x = x.to(device, non_blocking=True)
            mask = mask.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            with torch.amp.autocast("cuda", enabled=use_cuda):

                mu, sigma = model(x, mask)
                loss = gaussian_nll(y, mu, sigma) / args.grad_accum

            scaler.scale(loss).backward()

            if (step + 1) % args.grad_accum == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)

            tr_loss += float(loss.item()) * args.grad_accum
            tr_preds.append(mu.detach().cpu().numpy())
            tr_gts.append(y.detach().cpu().numpy())
            pbar.set_postfix({"loss": tr_loss / max(1, step + 1)})

        tr_preds = np.concatenate(tr_preds, axis=0)
        tr_gts = np.concatenate(tr_gts, axis=0)
        tr30, tr1y, tr2y = evaluate_spei_r2(tr_gts, tr_preds)

        # ---- val ----
        model.eval()
        va_loss = 0.0
        va_preds, va_gts = [], []
        pbar = tqdm(val_loader, desc=f"val {epoch}", leave=False)
        with torch.no_grad():
            for step, (x, mask, y) in enumerate(pbar):
                x = x.to(device, non_blocking=True)
                mask = mask.to(device, non_blocking=True)
                y = y.to(device, non_blocking=True)

                with torch.amp.autocast("cuda", enabled=use_cuda):

                    mu, sigma = model(x, mask)
                    loss = gaussian_nll(y, mu, sigma)

                va_loss += float(loss.item())
                va_preds.append(mu.detach().cpu().numpy())
                va_gts.append(y.detach().cpu().numpy())
                pbar.set_postfix({"loss": va_loss / max(1, step + 1)})

        va_preds = np.concatenate(va_preds, axis=0)
        va_gts = np.concatenate(va_gts, axis=0)
        v30, v1y, v2y = evaluate_spei_r2(va_gts, va_preds)
        avg = (v30 + v1y + v2y) / 3.0
        print(
            f"epoch {epoch:03d} | "
            f"train_loss={tr_loss/len(train_loader):.4f} val_loss={va_loss/len(val_loader):.4f} | "
            f"train_r2=({tr30:.3f},{tr1y:.3f},{tr2y:.3f}) "
            f"val_r2=({v30:.3f},{v1y:.3f},{v2y:.3f}) avg={avg:.3f} | "
            f"best_avg={best_avg:.3f} @ {best_epoch}"
        )

        if avg > best_avg:
            best_avg = avg
            best_epoch = epoch
            torch.save(model.state_dict(), best_path)

    print("Saved best weights to:", best_path)

In [None]:
main([
  "--train_img_dir","training_images",
  "--val_img_dir","validation_images",
  "--event_col","eventID",
  "--img_col","relative_img_loc",
  "--spei30_col","SPEI_30d",
  "--spei1y_col","SPEI_1y",
  "--spei2y_col","SPEI_2y",
  "--batch_size","1",
  "--k_max","8",
  "--epochs","5",
  "--lr","3e-5",
])

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/201M [00:00<?, ?B/s]

train 0:   0%|          | 4/1088 [01:07<4:56:09, 16.39s/it, loss=1.64]