In [162]:
from __future__ import annotations

import random
from pathlib import Path
from typing import List, Dict, Tuple

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import functional as TF
import torch
import torchvision.io as io
import math

import timm
from tqdm import tqdm
from argparse import ArgumentParser
import cv2


In [163]:
def get_training_args(argv=None):
    p = ArgumentParser()

    # paths
    p.add_argument("--train_csv", type=str, default="hackdata/sentinel-beetles/public_release/train.csv")
    p.add_argument("--val_csv", type=str, default="hackdata/sentinel-beetles/public_release/val.csv")
    p.add_argument("--train_img_dir", type=str, default="training_images")
    p.add_argument("--mask_train_img_dir", type=str, default="masked_training_images")
    p.add_argument("--train_color_dir", type=str, default="hackdata/sentinel-beetles/color_and_scale_images")



    p.add_argument("--val_img_dir", type=str, default="validation_images")
    p.add_argument("--mask_val_img_dir", type=str, default="masked_validation_images")
    p.add_argument("--val_color_dir", type=str, default="hackdata/sentinel-beetles/color_and_scale_images")
    p.add_argument("--save_dir", type=str, default="ckpts")

    # columns
    p.add_argument("--event_col", type=str, default="eventID")
    p.add_argument("--img_col", type=str, default="relative_img_loc")
    p.add_argument("--color_col", type=str, default="colorpicker_path")

    # targets
    p.add_argument("--spei30_col", type=str, default="SPEI_30d")
    p.add_argument("--spei1y_col", type=str, default="SPEI_1y")
    p.add_argument("--spei2y_col", type=str, default="SPEI_2y")

    # model / data
    p.add_argument("--img_size", type=int, default=224)
    p.add_argument("--k_max", type=int, default=8)          # max images per event used
    p.add_argument("--batch_size", type=int, default=1)     # events per batch (keep small!)
    p.add_argument("--num_workers", type=int, default=1)
    p.add_argument("--seed", type=int, default=0)

    # optimization
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--lr", type=float, default=3e-5)        # good default for finetuning convnext_small
    p.add_argument("--weight_decay", type=float, default=0.05)
    p.add_argument("--grad_accum", type=int, default=1)     # increase if OOM
    p.add_argument("--freeze_backbone_epochs", type=int, default=1)  # stabilize early training

    return p.parse_args(argv)

In [164]:
def r2_score_np(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    ss_res = float(((y_true - y_pred) ** 2).sum())
    ss_tot = float(((y_true - y_true.mean()) ** 2).sum())
    if ss_tot == 0.0:
        return 0.0
    return 1.0 - ss_res / ss_tot

In [165]:
def evaluate_spei_r2(gts: np.ndarray, preds: np.ndarray) -> Tuple[float, float, float]:
    return (
        r2_score_np(gts[:, 0], preds[:, 0]),
        r2_score_np(gts[:, 1], preds[:, 1]),
        r2_score_np(gts[:, 2], preds[:, 2]),
    )

In [166]:
class RandomRotate90:
    def __init__(self, angles=(0, 90, 180, 270), p=1.0):
        self.angles = angles
        self.p = p

    def __call__(self, img):
        if random.random() > self.p:
            return img
        angle = random.choice(self.angles)
        return TF.rotate(img, angle, expand=True)


In [167]:
def calculate_normalization_factors(color_img_rgb):
    """
    Calculates normalization factors based on the extreme pixels in an image.
    Input: img_rgb (Tensor) of shape [3, H, W] normalized to [0, 1]
    Output: (brightness_factor, gamma, hue_factor, saturation_factor)
    """
    # 1. Convert to HSV
    rgb = np.array(color_img_rgb.convert("RGB")) # (3,H,W) in [0,1]
    hsv = cv2.cvtColor(rgb, cv2.COLOR_RGB2HSV).astype(np.float32)               # true HSV in [0,1]
    h = hsv[..., 0] / 179.0
    s = hsv[..., 1] / 255.0
    v = hsv[..., 2] / 255.0

    h_f = torch.from_numpy(h).flatten()
    s_f = torch.from_numpy(s).flatten()
    v_f = torch.from_numpy(v).flatten()

    # --- IDENTIFY PIXELS ---
    v_black = v_f.min().item()

    low_sat = s_f < 0.2
    v_white = v_f[low_sat].max().item() if low_sat.any() else v_f.max().item()

    if low_sat.any():
        vv = v_f[low_sat]
        gray_scores = (vv - 0.5).abs()
        v_gray = vv[gray_scores.argmin()].item()
    else:
        v_gray = v_f[(v_f - 0.5).abs().argmin()].item()

    sat_mask = s_f > 0.2
    if sat_mask.any():
        hh = h_f[sat_mask]
        ss = s_f[sat_mask]
        vv = v_f[sat_mask]
        dist_to_red = torch.minimum((hh - 0.0).abs(), (hh - 1.0).abs())
        score = (ss + vv) - dist_to_red
        idx = score.argmax()
        h_red = hh[idx].item()
        s_red = ss[idx].item()
    else:
        h_red, s_red = 0.0, 1.0

    # --- FACTORS ---
    brightness = 1.0 / v_white if v_white > 1e-6 else 1.0
    contrast = 1.0 / ((v_white - v_black) + 1e-6)

    vg = max(1e-6, min(0.999999, v_gray * brightness))
    gamma = math.log(0.5) / math.log(vg) if vg not in (0.0, 1.0) else 1.0

    hue_shift = -h_red
    if hue_shift < -0.5: hue_shift += 1.0
    if hue_shift >  0.5: hue_shift -= 1.0

    sat_scale = 1.0 / s_red if s_red > 1e-6 else 1.0

    return hue_shift, sat_scale, contrast, gamma, brightness


In [168]:
class TrainingParamTfm:
    def __init__(self, img_size: int = 224):
        self._img_size = img_size
        self.post = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),      # mirror (left-right)
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225)),
        ])

    def __call__(self, img, color, color_dir):
        image = Image.open(color_dir/Path(color)).convert("RGB")
        hue_shift, sat_scale, contrast, gamma, brightness = calculate_normalization_factors(image)
        hue_shift = float(max(-0.5, min(0.5, hue_shift)))
        sat_scale = float(max(0.1, min(3.0, sat_scale)))
        contrast  = float(max(0.1, min(3.0, contrast)))
        brightness= float(max(0.1, min(3.0, brightness)))
        gamma     = float(max(0.1, min(3.0, gamma)))
        img = TF.adjust_hue(img, hue_shift)
        img = TF.adjust_saturation(img, sat_scale)
        img = TF.adjust_contrast(img, contrast)
        img = TF.adjust_gamma(img, gamma=gamma, gain=1.0)
        img = TF.adjust_brightness(img, brightness)
        return self.post(img)

In [169]:
class ValidationParamTfm:
    def __init__(self, img_size: int = 224):
        self._img_size = img_size
        self.post = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225)),
        ])

    def __call__(self, img, color, color_dir):
        return self.post(img)

In [170]:
class EventDataset(Dataset):
    """
    Each __getitem__ returns:
      x: (K, 3, H, W) tensor of beetle images for this event (padded if <K)
      mask: (K,) float mask where 1 = real image, 0 = padding
      y: (3,) float targets for the event
    """

    def __init__(self, csv_path: str, masked_img_dir: str, img_dir: str, color_dir: str, event_col: str, img_col: str, color_col: str, target_cols: Tuple[str, str, str], tfm, k_max: int, train_mode: bool, seed: int = 0,):
        self.df = pd.read_csv(csv_path).reset_index(drop=True)
        self.img_root = Path(masked_img_dir)
        self.raw_img_root = Path(img_dir)
        self.color_dir = Path(color_dir)
        self.event_col = event_col
        self.img_col = img_col
        self.color_col = color_col
        self.target_cols = target_cols
        self.tfm = tfm
        self.k_max = k_max
        self.train_mode = train_mode
        self.rng = random.Random(seed)

        # sanity columns
        for col in [event_col, img_col, *target_cols]:
            if col not in self.df.columns:
                raise KeyError(f"Missing column '{col}' in {csv_path}. Columns: {list(self.df.columns)}")

        # drop missing
        self.df = self.df.dropna(subset=[event_col, img_col, *target_cols]).reset_index(drop=True)

        
        # Build event -> row indices
        self.event_to_rows: Dict[str, List[int]] = {}
        for i in range(len(self.df)):
            ev = str(self.df.loc[i, self.event_col])
            self.event_to_rows.setdefault(ev, []).append(i)
        self.events = list(self.event_to_rows.keys())

    def __len__(self) -> int:
        return len(self.events)

    def _open_image(self, rel: str) -> Image.Image:
        rel = str(rel).lstrip("/").replace("\\", "/")
        p = self.img_root / rel
        if not p.exists():
            # fallback: basename only
            p = self.img_root / Path(rel).name
            if not p.exists():
                p = self.raw_img_root / rel
                if not p.exists():
                    p = self.raw_img_root / Path(rel).name
                    if not p.exists():
                        raise FileNotFoundError(f"Image not found: {p}")
        return Image.open(p).convert("RGB")
    
    def __getitem__(self, idx: int):
        ev = self.events[idx]
        rows = self.event_to_rows[ev]

        # targets from first row
        row0 = self.df.loc[rows[0]]
        y = torch.tensor([row0[c] for c in self.target_cols], dtype=torch.float32)

        
        # sample up to K specimens for this event
        if self.train_mode and len(rows) > self.k_max:
            chosen = self.rng.sample(rows, self.k_max)
        else:
            # deterministic for validation
            chosen = rows[: self.k_max]

        # load + transform images
        xs: List[torch.Tensor] = []
        for r in chosen:
            rel_path = self.df.loc[r, self.img_col]
            color_path = self.df.loc[r, self.color_col]
            xs.append(self.tfm(self._open_image(rel_path), color_path, self.color_dir))

        n = len(xs)
        H = W = self.tfm._img_size
        
        # pad to K
        if n < self.k_max:
            pad = torch.zeros((self.k_max - n, 3, H, W), dtype=torch.float32)
            x = torch.cat([torch.stack(xs, dim=0), pad], dim=0) if n > 0 else pad
            mask = torch.cat([torch.ones(n), torch.zeros(self.k_max - n)]).float()
        else:
            x = torch.stack(xs[: self.k_max], dim=0)
            mask = torch.ones(self.k_max).float()

        return x, mask, y

In [171]:
class EventConvNeXtRegressor(nn.Module):
    def __init__(self, backbone_name: str = "convnext_small"):
        super().__init__()
        self.backbone = timm.create_model("convnext_small", pretrained=True, num_classes=0, cache_dir=str(Path("timm_cache").resolve()),
)
        d = self.backbone.num_features
        self.head = nn.Linear(d, 6)  # 3 mu + 3 log_sigma

    def forward(self, x: torch.Tensor, mask: torch.Tensor):
        """
        x: (B, K, 3, H, W)
        mask: (B, K) 1 for real, 0 for padded
        """
        B, K, C, H, W = x.shape
        feats = self.backbone(x.view(B * K, C, H, W))  # (B*K, d)
        d = feats.shape[-1]
        feats = feats.view(B, K, d)             # (B, K, d)

        # masked mean pool
        m = mask.unsqueeze(-1)                  # (B, K, 1)
        denom = m.sum(dim=1).clamp_min(1.0)     # (B, 1)
        event_feat = (feats * m).sum(dim=1) / denom  # (B, d)

        out = self.head(event_feat)             # (B, 6)
        mu = out[:, :3]                         # (B, 3)
        sigma = F.softplus(out[:, 3:]) + 1e-6
        return mu, sigma

In [172]:
def gaussian_nll(y: torch.Tensor, mu: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
    """
    y, mu, sigma: (B, 3)
    Negative log-likelihood for independent Gaussian dims.
    """
    # 0.5*((y-mu)/sigma)^2 + log(sigma)
    return (0.5 * ((y - mu) / sigma).pow(2) + torch.log(sigma)).mean()

In [173]:
def main(argv=None):
    args = get_training_args(argv)
    img_size = args.img_size

    train_tfm = TrainingParamTfm(img_size)
    train_tfm._img_size = img_size  # keep your padding logic working

    val_tfm = ValidationParamTfm(img_size)
    val_tfm._img_size = img_size


    target_cols = (args.spei30_col, args.spei1y_col, args.spei2y_col)

    train_ds = EventDataset(
        csv_path=args.train_csv,
        masked_img_dir = args.mask_train_img_dir,
        img_dir = args.train_img_dir,
        color_dir = args.train_color_dir,
        event_col=args.event_col,
        img_col=args.img_col,
        color_col=args.color_col,
        target_cols=target_cols,
        tfm=train_tfm,
        k_max=args.k_max,
        train_mode=True,
        seed=args.seed,
    )
    val_ds = EventDataset(
        csv_path=args.val_csv,
        masked_img_dir = args.mask_val_img_dir,
        img_dir = args.val_img_dir,
        color_dir = args.val_color_dir,
        event_col=args.event_col,
        img_col=args.img_col,
        color_col=args.color_col,
        target_cols=target_cols,
        tfm=val_tfm,
        k_max=args.k_max,
        train_mode=False,
        seed=args.seed,
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = EventConvNeXtRegressor("convnext_small").to(device)

    def set_backbone_trainable(trainable: bool):
        for p in model.backbone.parameters():
            p.requires_grad = trainable
            
    # freeze backbone initially (optional but recommended for stability)
    if args.freeze_backbone_epochs > 0:
        set_backbone_trainable(False)

    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=(1e-3 if args.freeze_backbone_epochs > 0 else args.lr),
        weight_decay=args.weight_decay
    )

    use_cuda = (device.type == "cuda")
    scaler = torch.amp.GradScaler("cuda", enabled=use_cuda)

    train_loader = DataLoader(
        train_ds, batch_size=args.batch_size, shuffle=True,
        num_workers=args.num_workers, pin_memory=(device.type == "cuda")
    )
    val_loader = DataLoader(
        val_ds, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=(device.type == "cuda")
    )
    
    save_dir = Path(args.save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    best_path = save_dir / "event_convnext_small_mask_img_aug_best.pth"

    best_avg = -1.0
    best_epoch = -1

    for epoch in range(args.epochs):
        if epoch == args.freeze_backbone_epochs and args.freeze_backbone_epochs > 0:
            set_backbone_trainable(True)
            optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

        # ---- train ----
        model.train()
        tr_loss = 0.0
        tr_preds, tr_gts = [], []
        optimizer.zero_grad(set_to_none=True)

        pbar = tqdm(train_loader, desc=f"train {epoch}", leave=False)
        for step, (x, mask, y) in enumerate(pbar):
            x = x.to(device, non_blocking=True)
            mask = mask.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            with torch.amp.autocast("cuda", enabled=use_cuda):

                mu, sigma = model(x, mask)
                loss = gaussian_nll(y, mu, sigma) / args.grad_accum

            scaler.scale(loss).backward()

            if (step + 1) % args.grad_accum == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)

            tr_loss += float(loss.item()) * args.grad_accum
            tr_preds.append(mu.detach().cpu().numpy())
            tr_gts.append(y.detach().cpu().numpy())
            pbar.set_postfix({"loss": tr_loss / max(1, step + 1)})

        tr_preds = np.concatenate(tr_preds, axis=0)
        tr_gts = np.concatenate(tr_gts, axis=0)
        tr30, tr1y, tr2y = evaluate_spei_r2(tr_gts, tr_preds)

        # ---- val ----
        model.eval()
        va_loss = 0.0
        va_preds, va_gts = [], []
        pbar = tqdm(val_loader, desc=f"val {epoch}", leave=False)
        with torch.no_grad():
            for step, (x, mask, y) in enumerate(pbar):
                x = x.to(device, non_blocking=True)
                mask = mask.to(device, non_blocking=True)
                y = y.to(device, non_blocking=True)

                with torch.amp.autocast("cuda", enabled=use_cuda):

                    mu, sigma = model(x, mask)
                    loss = gaussian_nll(y, mu, sigma)

                va_loss += float(loss.item())
                va_preds.append(mu.detach().cpu().numpy())
                va_gts.append(y.detach().cpu().numpy())
                pbar.set_postfix({"loss": va_loss / max(1, step + 1)})

        va_preds = np.concatenate(va_preds, axis=0)
        va_gts = np.concatenate(va_gts, axis=0)
        v30, v1y, v2y = evaluate_spei_r2(va_gts, va_preds)
        avg = (v30 + v1y + v2y) / 3.0
        print(
            f"epoch {epoch:03d} | "
            f"train_loss={tr_loss/len(train_loader):.4f} val_loss={va_loss/len(val_loader):.4f} | "
            f"train_r2=({tr30:.3f},{tr1y:.3f},{tr2y:.3f}) "
            f"val_r2=({v30:.3f},{v1y:.3f},{v2y:.3f}) avg={avg:.3f} | "
            f"best_avg={best_avg:.3f} @ {best_epoch}"
        )

        if avg > best_avg:
            best_avg = avg
            best_epoch = epoch
            torch.save(model.state_dict(), best_path)

    print("Saved best weights to:", best_path)

In [None]:
main([
  "--train_img_dir","training_images",
  "--val_img_dir","validation_images",
  "--event_col","eventID",
  "--img_col","relative_img_loc",
  "--spei30_col","SPEI_30d",
  "--spei1y_col","SPEI_1y",
  "--spei2y_col","SPEI_2y",
  "--batch_size","2",
  "--k_max","4",
  "--epochs","20",
  "--lr","3e-5",
])

train 0:  10%|â–ˆ         | 57/544 [01:51<13:06,  1.61s/it, loss=1.37]