In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session


## 0. Setup

- Installs (if missing): `timm`, `albumentations`, `pytorch-lightning`
- Imports and global config


In [None]:
import os
if not os.environ.get("KAGGLE_URL_BASE"):
    # only when running locally, not on Kaggle scoring
    !pip install --quiet "transformers==4.37.2"
    !pip install timm albumentations pytorch-lightning 
    !pip install -q "pytorch-lightning==2.3.3" "lightning-utilities==0.11.6"

In [None]:
import os, gc, json, time, random, math, glob, warnings
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm

import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import timm

import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from pytorch_lightning.loggers import CSVLogger

warnings.filterwarnings('ignore')

import os
os.environ["HF_HUB_OFFLINE"] = "1"           # timm will not try to hit the hub
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"

CFG = {
    'seed': 42,
    'img_size': 512,          # Use 512/576 lateron
    'batch_size': 16,
    'num_workers': 2,
    'epochs': 20,             # increase 20-25 if time allows
    'lr': 2e-4,
    'weight_decay': 1e-5,
    'backbone': 'tf_efficientnetv2_s_in21k',  # good speed/accuracy trade-off
    'train_val_split': 0.15,
    'tta': 4,                 # set 0 to disable TTA
    'precision': '16-mixed',  # use AMP on Kaggle GPU
}

COMPETITION = "csiro-biomass"
DATA_DIR = Path("/kaggle/input/csiro-biomass")
OUTPUT_DIR = Path("./outputs"); OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

seed_everything(CFG['seed'], workers=True)


In [None]:
sample_csv= DATA_DIR/"sample_submission.csv"

sample = pd.read_csv(sample_csv)
sample.head()


## 1. Load data

The competition provides:
- **`train.csv`** with image IDs, metadata (species, location, etc.), and 5 target columns
- **`test.csv`** with image IDs and metadata (no targets)
- **`images/`** folder with JPG/PNG files

In [None]:

train_csv = DATA_DIR/'train.csv'
test_csv  = DATA_DIR/'test.csv'

train = pd.read_csv(train_csv)
test  = pd.read_csv(test_csv)
# --- Image roots (robust to different layouts) ---
def existing(*paths):
    for p in paths:
        if p.exists(): return p
    return None

# If CSV already has absolute paths, prefer using them directly
USE_ABS = ("image_path" in train.columns) or ("image_path" in test.columns)

TRAIN_IMG_DIR = existing(
    DATA_DIR/'train/images',
    DATA_DIR/'train',
    DATA_DIR/'images',
    DATA_DIR
)
TEST_IMG_DIR = existing(
    DATA_DIR/'test/images',
    DATA_DIR/'test',
    DATA_DIR/'images',
    DATA_DIR
)

if not USE_ABS:
    assert TRAIN_IMG_DIR is not None, "Could not locate train image directory."
    assert TEST_IMG_DIR  is not None, "Could not locate test image directory."
print("USE_ABS:", USE_ABS)
print("TRAIN_IMG_DIR:", TRAIN_IMG_DIR)
print("TEST_IMG_DIR:", TEST_IMG_DIR)

print(train.shape, test.shape)
display(train.head(3))
display(test.head(3))

IMG_DIR = DATA_DIR/'train'
assert IMG_DIR.exists(), f"Images dir not found: {IMG_DIR}"
# detect schema & normalize to WIDE for training 

# ID column
cands_id = [c for c in train.columns if c.lower() in ["image_id","image","id","sample_id"] or "image" in c.lower()]
IDCOL = cands_id[0] if cands_id else train.columns[0]

def to_image_id(x):
    x = str(x)
    if "/" in x or "\\" in x:
        base = x.split("/")[-1].split("\\")[-1]
        return base.rsplit(".", 1)[0]
    return x.rsplit(".", 1)[0]

if "image_path" in train.columns and IDCOL != "image_id":
    train["image_id"] = train["image_path"].apply(to_image_id)
    IDCOL = "image_id"
if "image_path" in test.columns and "image_id" not in test.columns:
    test["image_id"] = test["image_path"].apply(to_image_id)

# target list
TARGETS = ['Dry_Green_g','Dry_Dead_g','Dry_Clover_g','GDM_g','Dry_Total_g']

# Decide name/value columns for LONG schema
name_col = 'target_name' if 'target_name' in train.columns else ('target' if train['target'].dtype=='O' else None)
value_col = None
for cand in ['value','target','target_value','biomass','biomass_g']:
    if cand in train.columns and np.issubdtype(train[cand].dtype, np.number):
        value_col = cand
        break

has_wide_targets = all(t in train.columns for t in TARGETS)
is_long_like = (name_col is not None) and (value_col is not None)

if is_long_like and not has_wide_targets:
    
    train[name_col] = train[name_col].astype(str).str.strip()

    # pivot -> wide (may include more than the 5 official names)
    wide = (train
            .pivot_table(index=IDCOL, columns=name_col, values=value_col, aggfunc='mean')
            .reset_index())

    # bring back a few metadata columns 
    meta_cols = ['Sampling_Date','State','Species','Pre_GSHH_NDVI','Height_Ave_cm']
    meta_cols = [c for c in meta_cols if c in train.columns]
    if meta_cols:
        meta_first = train.drop_duplicates(subset=[IDCOL])[ [IDCOL] + meta_cols ]
        wide = wide.merge(meta_first, on=IDCOL, how='left')

    # keep only the 5 official targets if others exist
    present_targets = [t for t in TARGETS if t in wide.columns]
    if len(present_targets) < len(TARGETS):
        print("Some targets missing in train after pivot:",
              [t for t in TARGETS if t not in wide.columns])
        # continue with intersection to let EDA/train run
    train = wide

# Final check
present_targets = [t for t in TARGETS if t in train.columns]
if not present_targets:
    raise ValueError(f"No official targets found after normalization. "
                     f"Columns present: {list(train.columns)}")

TARGETS = present_targets

display(train.head())
print("IDCOL:", IDCOL)
print("Detected schema:", "long→wide (target_name/target)" if is_long_like and not has_wide_targets else "wide")




## 2. EDA (compact, visual)

We'll quickly check:
- Target distributions (raw + log1p)
- Pairwise correlations of targets
- Metadata frequencies (species, season, site if provided)
- Image statistics (shape, brightness, contrast) sampled
- Visual grids by **`Dry_Total_g`** quantile bins




In [None]:
import matplotlib.pyplot as plt

def plot_hist(series, title, bins=50):
    plt.figure(figsize=(6,4))
    plt.hist(series.values, bins=bins)
    plt.title(title); plt.xlabel('value'); plt.ylabel('count')
    plt.show()

# Target distributions
for t in TARGETS:
    plot_hist(train[t], f"{t} (raw)")
    plot_hist(np.log1p(train[t]), f"log1p({t})")

# Correlation
plt.figure(figsize=(5,4))
corr = train[TARGETS].corr()
im = plt.imshow(corr, cmap='viridis')
plt.xticks(range(len(TARGETS)), TARGETS, rotation=45, ha='right')
plt.yticks(range(len(TARGETS)), TARGETS)
plt.colorbar(im, fraction=0.046, pad=0.04)
plt.title("Target Correlation")
plt.tight_layout(); plt.show()

# Metadata quick looks
meta_cols = [c for c in train.columns if c not in TARGETS+[IDCOL]]
for c in meta_cols[:6]:  # show a few
    if train[c].dtype == 'object' or train[c].nunique()<50:
        vc = train[c].value_counts().head(20)
        plt.figure(figsize=(6,3))
        vc.plot(kind='bar')
        plt.title(f"{c} (top 20)")
        plt.tight_layout(); plt.show()

In [None]:

# Image shapes and brightness sample
sample_ids = train[IDCOL].sample(min(200, len(train)), random_state=CFG['seed']).tolist()
shapes, means, stds = [], [], []
for img_id in tqdm(sample_ids, desc="Image stats"):
    # check all extension
    for ext in (".jpg",".jpeg",".png",".bmp"):
        p = IMG_DIR/f"{img_id}{ext}"
        if p.exists():
            img = cv2.imread(str(p), cv2.IMREAD_COLOR)
            if img is None: continue
            shapes.append(img.shape[:2])
            means.append(img.mean())
            stds.append(img.std())
            break

plt.figure(figsize=(6,4))
plt.hist(means, bins=50); plt.title("Brightness (mean pixel)"); plt.show()
plt.figure(figsize=(6,4))
plt.hist(stds, bins=50); plt.title("Contrast proxy (pixel std)"); plt.show()

# Image grid by Dry_Total_g quantiles
qt = pd.qcut(train['Dry_Total_g'], 4, labels=False, duplicates='drop')
train['_qt'] = qt

def show_grid(df, title, rows=2, cols=4):
    plt.figure(figsize=(cols*3, rows*3))
    subset = df.sample(min(rows*cols, len(df)), random_state=CFG['seed'])
    for i,(idx,row) in enumerate(subset.iterrows()):
        img_id = row[IDCOL]
        path = None
        for ext in ('.jpg','.jpeg','.png','.bmp'):
            p = IMG_DIR/f"{img_id}{ext}"
            if p.exists():
                path = p; break
        if path is None: continue
        img = cv2.cvtColor(cv2.imread(str(path)), cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (256,256))
        plt.subplot(rows,cols,i+1)
        plt.imshow(img); plt.axis('off')
        tt = f"{row['Dry_Total_g']:.0f}g"
        plt.title(tt, fontsize=9)
    plt.suptitle(title)
    plt.tight_layout(); plt.show()

for q in sorted(train['_qt'].dropna().unique()):
    show_grid(train[train['_qt']==q], f"Samples – Dry_Total_g quantile {int(q)}")

train.drop(columns=['_qt'], errors='ignore', inplace=True)


### Correlation

In [None]:
import numpy as np, pandas as pd, matplotlib.pyplot as plt, seaborn as sns

# columns that may exist
TARGETS = ['Dry_Green_g','Dry_Dead_g','Dry_Clover_g','GDM_g','Dry_Total_g']
META    = [c for c in ['Pre_GSHH_NDVI','Height_Ave_cm'] if c in train.columns]
NUMCOLS = [c for c in TARGETS+META if c in train.columns]

assert len(NUMCOLS) >= 2, f"Need at least 2 numeric columns; found {NUMCOLS}"

# Spearman is robust to skew
corr_s = train[NUMCOLS].corr(method='spearman')
corr_p = train[NUMCOLS].corr(method='pearson')

fig, axes = plt.subplots(1, 2, figsize=(12, 4))
sns.heatmap(corr_s, vmin=-1, vmax=1, cmap="RdBu_r", annot=True, fmt=".2f", ax=axes[0])
axes[0].set_title("Spearman correlation")

sns.heatmap(corr_p, vmin=-1, vmax=1, cmap="RdBu_r", annot=True, fmt=".2f", ax=axes[1])
axes[1].set_title("Pearson correlation")
plt.tight_layout(); plt.show()


#### Histogram

In [None]:
cols = [c for c in TARGETS if c in train.columns]
assert len(cols) > 0, "No biomass target columns found."

rows = 1
cols_per_row = len(cols)
fig, axes = plt.subplots(rows, cols_per_row, figsize=(4*cols_per_row, 3), squeeze=False)

for j, tgt in enumerate(cols):
    ax = axes[0, j]
    x = train[tgt].dropna()
    ax.hist(x, bins=40, alpha=0.65, label="raw")
    ax2 = ax.twinx()
    ax2.hist(np.log1p(x), bins=40, alpha=0.35, color='tab:orange', label="log1p")
    ax.set_title(tgt)
    ax.set_xlabel("value"); ax.set_ylabel("count")
    ax2.set_ylabel("count (log1p)")
fig.suptitle("Biomass distributions: raw vs log1p (orange)", y=1.03)
plt.tight_layout(); plt.show()


#### Scatter plot

In [None]:
if 'Pre_GSHH_NDVI' in train.columns and 'Dry_Total_g' in train.columns:
    df = train[['Pre_GSHH_NDVI','Dry_Total_g','State']].dropna()
    # Density view to handle many points
    plt.figure(figsize=(6,4))
    plt.hexbin(df['Pre_GSHH_NDVI'], df['Dry_Total_g'], gridsize=40, cmap='viridis')
    plt.colorbar(label='count')
    sns.regplot(x='Pre_GSHH_NDVI', y='Dry_Total_g', data=df, scatter=False, color='red', ci=None)
    plt.title('NDVI vs Dry_Total_g (with linear trend)')
    plt.xlabel('Pre_GSHH_NDVI'); plt.ylabel('Dry_Total_g (g)')
    plt.tight_layout(); plt.show()

    # Optional: color by State (if many states, this can be busy)
    if 'State' in train.columns and train['State'].nunique() <= 8:
        plt.figure(figsize=(7,4))
        sns.scatterplot(data=df, x='Pre_GSHH_NDVI', y='Dry_Total_g', hue='State', s=18, alpha=0.7)
        plt.title('NDVI vs Dry_Total_g by State')
        plt.tight_layout(); plt.show()
else:
    print("Need both 'Pre_GSHH_NDVI' and 'Dry_Total_g' to plot the scatter.")


#### Setup & helpers for deeper EDA

In [None]:
# --- Paths & columns ---
from pathlib import Path
import pandas as pd, numpy as np, cv2, matplotlib.pyplot as plt

COMP = "csiro-biomass"
DATA_ROOT = Path(f"/kaggle/input/{COMP}")
TRAIN_DIR = (DATA_ROOT / "train")  # images
assert TRAIN_DIR.exists(), TRAIN_DIR

IDCOL   = "image_id"
TARGET = ['Dry_Clover_g','Dry_Dead_g','Dry_Green_g','Dry_Total_g','GDM_g']

# Make sure Sampling_Date is datetime
if not np.issubdtype(train['Sampling_Date'].dtype, np.datetime64):
    train['Sampling_Date'] = pd.to_datetime(train['Sampling_Date'])

# Image resolver (handles common extensions) 
def find_image_path(image_id, train_dir=TRAIN_DIR):
    for ext in (".jpg",".jpeg",".png",".bmp",".JPG",".JPEG",".PNG",".BMP"):
        p = train_dir / f"{image_id}{ext}"
        if p.exists(): return p
    return None

# Thumbnail grid 
def show_grid(df, title, rows=2, cols=4, seed=42, label_fn=None):
    subset = df.sample(min(rows*cols, len(df)), random_state=seed)
    plt.figure(figsize=(cols*3, rows*3))
    k = 0
    for _, r in subset.iterrows():
        p = find_image_path(r[IDCOL])
        if p is None: continue
        img = cv2.cvtColor(cv2.imread(str(p), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (256,256))
        k += 1
        plt.subplot(rows, cols, k)
        plt.imshow(img); plt.axis("off")
        if label_fn:
            plt.title(label_fn(r), fontsize=9)
    plt.suptitle(title, y=1.02)
    plt.tight_layout(); plt.show()

#### “By State → which Species?” (counts + example images)

In [None]:
# Counts table
cnt = (train.groupby(["State","Species"])
             .size().reset_index(name="n")
             .sort_values(["State","n"], ascending=[True, False]))
display(cnt.head(20))

# Montage per State with the most common Species labels
for state, g in cnt.groupby("State"):
    species_order = g.sort_values("n", ascending=False)["Species"].tolist()
    df = train[train["State"].eq(state) & train["Species"].isin(species_order[:4])]
    show_grid(df, f"State={state} — sample Species", rows=2, cols=4,
              label_fn=lambda r: f"{r['Species']}")


#### “By Sampling Date” (timeline sampling)

In [None]:
# Month buckets to reduce sparsity
train["_month"] = train["Sampling_Date"].dt.to_period("M").astype(str)

# counts per month
display(train["_month"].value_counts().sort_index())

# montage: sample per month (latest 6 months present)
for m in sorted(train["_month"].unique())[-6:]:
    df = train[train["_month"].eq(m)]
    if len(df) == 0: continue
    show_grid(df, f"Sampling month {m}", rows=2, cols=4,
              label_fn=lambda r: f"{r['Species']} • {r['State']}")


#### NDVI & Height: bins → thumbnails

In [None]:
# NDVI bins
if "Pre_GSHH_NDVI" in train.columns:
    train["_ndvi_bin"] = pd.qcut(train["Pre_GSHH_NDVI"], q=4, labels=False, duplicates="drop")
    for b in sorted(train["_ndvi_bin"].dropna().unique()):
        df = train[train["_ndvi_bin"]==b]
        show_grid(df, f"NDVI quartile {int(b)} • NDVI~[{df['Pre_GSHH_NDVI'].min():.2f},{df['Pre_GSHH_NDVI'].max():.2f}]",
                  rows=2, cols=4,
                  label_fn=lambda r: f"NDVI {r['Pre_GSHH_NDVI']:.2f} • Total {r['Dry_Total_g']:.0f}g")

# Height bins
if "Height_Ave_cm" in train.columns:
    train["_h_bin"] = pd.qcut(train["Height_Ave_cm"], q=4, labels=False, duplicates="drop")
    for b in sorted(train["_h_bin"].dropna().unique()):
        df = train[train["_h_bin"]==b]
        show_grid(df, f"Height quartile {int(b)} • H~[{df['Height_Ave_cm'].min():.1f},{df['Height_Ave_cm'].max():.1f}]",
                  rows=2, cols=4,
                  label_fn=lambda r: f"H {r['Height_Ave_cm']:.1f}cm • Total {r['Dry_Total_g']:.0f}g")

# cleanup  cols
train.drop(columns=["_month","_ndvi_bin","_h_bin"], errors="ignore", inplace=True)


#### Quick numeric EDA alongside thumbnails

In [None]:
# Distributions
ax = train[TARGETS].plot(kind="kde", figsize=(8,4), title="Target distributions (raw)")
ax.set_xlim(left=0); plt.show()

# Correlations among targets and with NDVI/Height (if present)
num_cols = TARGETS + [c for c in ["Pre_GSHH_NDVI","Height_Ave_cm"] if c in train.columns]
corr = train[num_cols].corr()
display(corr.style.background_gradient(cmap="RdBu_r", vmin=-1, vmax=1).format("{:.2f}"))



## 3. Augmentations

We use a light but effective set suitable for pasture canopies:
- RandomResizedCrop / Resize
- Horizontal/vertical flip (pasture has no canonical orientation)
- Color jitter & CLAHE (handle lighting/contrast variability)
- Small rotation / shift / scale
- Cutout (optional)



In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

def _RRC(sz, **kwargs):
    """RandomResizedCrop that works with Albumentations v1 or v2."""
    try:
        # v2: expects size=(H, W)
        return A.RandomResizedCrop(size=(sz, sz), **kwargs)
    except Exception:
        # v1: expects height, width
        return A.RandomResizedCrop(height=sz, width=sz, **kwargs)

def _Resize(sz):
    """Resize that works with Albumentations v1 or v2."""
    try:
        return A.Resize(size=(sz, sz))
    except Exception:
        return A.Resize(height=sz, width=sz)

def get_transforms(img_size, is_train=True):
    if is_train:
        return A.Compose([
            _RRC(img_size, scale=(0.8, 1.0)),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.3),
            A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=10, p=0.5),
            A.RandomBrightnessContrast(p=0.4),
            A.CLAHE(p=0.2),
            A.CoarseDropout(max_holes=4,
                            max_height=int(img_size*0.08),
                            max_width=int(img_size*0.08),
                            p=0.3),
            A.Normalize(),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
            _Resize(img_size),
            A.Normalize(),
            ToTensorV2(),
        ])



## 4. Dataset / DataLoader


In [None]:
class BiomassDataset(Dataset):
    def __init__(self, df, img_root=None, transform=None, targets=TARGETS, use_abs=False):
        self.df = df.reset_index(drop=True)
        self.img_root = Path(img_root) if img_root else None
        self.transform = transform
        self.targets = targets
        self.idcol = IDCOL
        self.use_abs = use_abs

    def _find_path(self, row):
        if self.use_abs and 'image_path' in row:
            p = Path(row['image_path'])
            if p.exists(): return p
        img_id = row[self.idcol]
        if self.img_root is None: return None
        for ext in ('.jpg','.jpeg','.png','.bmp'):
            p = self.img_root/f"{img_id}{ext}"
            if p.exists(): return p
        return None

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        p = self._find_path(row)
        if p is None:
            raise FileNotFoundError(row.get(self.idcol, row.get('image_path', 'UNKNOWN_ID')))
        img = cv2.cvtColor(cv2.imread(str(p), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
        if self.transform: img = self.transform(image=img)['image']
        y = torch.tensor(row[self.targets].values.astype('float32'))
        return img, y, row[self.idcol]





## 5. Train / Val split


In [None]:
from sklearn.model_selection import StratifiedKFold

# Stratify by Dry_Total_g quantiles to balance difficulty
bins = pd.qcut(train['Dry_Total_g'], q=min(10, train['Dry_Total_g'].nunique()), labels=False, duplicates='drop')
train['strat'] = bins
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=CFG['seed'])
train['fold'] = -1
for i,(tr,va) in enumerate(skf.split(train, train['strat'])):
    train.loc[va,'fold'] = i
display(train['fold'].value_counts())

FOLD = 0 
trn_df = train[train['fold']!=FOLD].copy()
val_df = train[train['fold']==FOLD].copy()

trn_ds = BiomassDataset(trn_df, TRAIN_IMG_DIR, get_transforms(CFG['img_size'], True), use_abs=USE_ABS)
val_ds = BiomassDataset(val_df, TRAIN_IMG_DIR, get_transforms(CFG['img_size'], False), use_abs=USE_ABS)

trn_loader = DataLoader(trn_ds, batch_size=CFG['batch_size'], shuffle=True,  num_workers=CFG['num_workers'], pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=CFG['batch_size']*2, shuffle=False, num_workers=CFG['num_workers'], pin_memory=True)



## 6. Model (timm EfficientNet, multi-head regression)

- Backbone: `tf_efficientnetv2_s_in21k`
- Head: Linear to 5 outputs
- Loss: `SmoothL1Loss` on log1p targets 
- Metric: Weighted $R^2$ 


In [None]:

def r2_score_torch(y_true, y_pred, eps=1e-9):
    y_true_mean = torch.mean(y_true, dim=0)
    ss_tot = torch.sum((y_true - y_true_mean)**2, dim=0)
    ss_res = torch.sum((y_true - y_pred)**2, dim=0)
    r2 = 1 - ss_res / (ss_tot + eps)
    return r2

WEIGHTS = torch.tensor([0.1,0.1,0.1,0.2,0.5], dtype=torch.float32)

class BiomassModel(pl.LightningModule):
    def __init__(self, backbone=CFG['backbone'], lr=CFG['lr'], wd=CFG['weight_decay']):
        super().__init__()
        self.save_hyperparameters()
        self.backbone = timm.create_model(backbone, pretrained=False, num_classes=0, global_pool="avg")
        in_features = self.backbone.num_features
        self.head = nn.Linear(in_features, len(TARGETS))
        self.loss_fn = nn.SmoothL1Loss()
        self.lr = lr; self.wd = wd

    def forward(self, x):
        f = self.backbone(x)
        out = self.head(f)
        return out

    def common_step(self, batch, stage='train'):
        x, y, _ = batch
        y_t = torch.log1p(y)
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y_t)

        with torch.no_grad():
            r2s = r2_score_torch(y_t, y_hat)
            weighted = (r2s * WEIGHTS.to(self.device)).sum().item()
        self.log(f"{stage}_loss", loss, prog_bar=True)
        self.log(f"{stage}_r2w", weighted, prog_bar=True)
        return loss

    def training_step(self, batch, batch_idx):
        return self.common_step(batch, 'train')

    def validation_step(self, batch, batch_idx):
        self.common_step(batch, 'val')

    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.wd)
        # cosine with warmup
        warmup = 1  # epochs
        total = CFG['epochs']
        def lr_lambda(epoch):
            if epoch < warmup:  # linear warmup
                return float(epoch + 1) / float(max(1, warmup))
            # cosine decay  [warmup..total)
            progress = (epoch - warmup) / float(max(1, total - warmup))
            return 0.5 * (1.0 + math.cos(math.pi * progress))
        sch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lr_lambda)
        return {'optimizer': opt, 'lr_scheduler': sch}




## 7. Train


In [None]:

model = BiomassModel()

ckpt_cb = ModelCheckpoint(
    dirpath=str(OUTPUT_DIR),
    filename="model-{epoch:02d}-{val_r2w:.4f}",
    monitor="val_r2w",
    mode="max",
    save_top_k=1
)
lr_cb = LearningRateMonitor(logging_interval='epoch')
es_cb = EarlyStopping(monitor="val_r2w", mode="max", patience=3)

logger = CSVLogger(save_dir=str(OUTPUT_DIR), name="logs")

trainer = pl.Trainer(
    max_epochs=CFG['epochs'],
    precision=CFG['precision'],
    callbacks=[ckpt_cb, lr_cb, es_cb],
    logger=logger,
    default_root_dir=str(OUTPUT_DIR),
    gradient_clip_val=1.0,
    deterministic=False,
    accumulate_grad_batches=1,
)

trainer.fit(model, trn_loader, val_loader)

best_ckpt = ckpt_cb.best_model_path
print("Best ckpt:", best_ckpt)


In [None]:
#Plot train vs val metric over epochs + their averages
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

#Locate the metrics.csv that CSVLogger just wrote
try:
    log_dir = Path(logger.log_dir)  # preferred (available after trainer.fit)
except Exception:
    # fallback: find the newest logs/version_*
    versions = sorted((Path(OUTPUT_DIR) / "logs").glob("version_*"))
    assert versions, "No CSVLogger versions found under OUTPUT_DIR/logs"
    log_dir = versions[-1]
metrics_csv = log_dir / "metrics.csv"
assert metrics_csv.exists(), f"metrics.csv not found at {metrics_csv}"

metrics = pd.read_csv(metrics_csv)

# Pick the metric columns (robust to name variants)
# Change METRIC_KEY if your model logs a different name.
METRIC_KEY = "r2w"  # e.g., 'r2w' or 'loss' or 'rmse'
train_col_candidates = [f"train_{METRIC_KEY}", "train_r2w", "train_r2", "train_score", "train_loss"]
val_col_candidates   = [f"val_{METRIC_KEY}",   "val_r2w",   "val_r2",   "val_score",   "val_loss"]

def first_present(cols):
    for c in cols:
        if c in metrics.columns:
            return c
    return None

TRAIN_COL = first_present(train_col_candidates)
VAL_COL   = first_present(val_col_candidates)
assert TRAIN_COL is not None and VAL_COL is not None, f"No train/val metric columns found in {metrics.columns.tolist()}"

# Reduce to one value per epoch (last logged row for that epoch)
def per_epoch_last(df, col):
    d = df[["epoch", col]].dropna()
    # keep the last record per epoch
    d = d.groupby("epoch", as_index=False).last()
    return d.rename(columns={col: "value"})

train_e = per_epoch_last(metrics, TRAIN_COL)
val_e   = per_epoch_last(metrics, VAL_COL)

# Align epochs (outer-join, in case early epochs log train but not val or vice versa)
plot_df = pd.merge(train_e.rename(columns={"value": "train"}),
                   val_e.rename(columns={"value": "val"}),
                   on="epoch", how="outer").sort_values("epoch")
plot_df.reset_index(drop=True, inplace=True)

# Compute averages across all available epochs
train_avg = np.nanmean(plot_df["train"].values)
val_avg   = np.nanmean(plot_df["val"].values)

# Plot
plt.figure(figsize=(8.5, 5.0))
plt.plot(plot_df["epoch"], plot_df["train"], marker="o", label=f"Train ({TRAIN_COL})")
plt.plot(plot_df["epoch"], plot_df["val"],   marker="s", label=f"Val ({VAL_COL})")
plt.axhline(train_avg, linestyle="--", linewidth=1.2, label=f"Train avg = {train_avg:.4f}")
plt.axhline(val_avg,   linestyle="--", linewidth=1.2, label=f"Val avg = {val_avg:.4f}")
plt.xlabel("Epoch")
plt.ylabel(METRIC_KEY.upper())
plt.title("Training vs Validation metric per epoch")
plt.grid(True, alpha=0.25)
plt.legend()
plt.tight_layout()
plt.show()

# 6) (Optional) print a small table
display(plot_df.assign(train_avg=train_avg, val_avg=val_avg))



## 8. Inference (+ optional TTA) and Submission

- Loads the best checkpoint
- Predicts on **test.csv**
- Applies `expm1` to invert the training transform
- Writes `submission.csv` with 5 target columns


In [None]:
# ===========================================
# FINAL SUBMISSION (uses test.csv for mapping,
# sample_submission.csv for order)
# ===========================================
import os, json, hashlib
from pathlib import Path
import numpy as np, pandas as pd, torch, cv2
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

COMP = os.environ.get("KAGGLE_COMPETITION", "csiro-biomass")
ROOT = Path(f"/kaggle/input/{COMP}")
assert ROOT.exists(), f"Competition input not mounted: {ROOT}"

#Load runtime files
test_rt = pd.read_csv(ROOT / "test.csv")
sub_tmpl = pd.read_csv(ROOT / "sample_submission.csv")

assert {"sample_id","image_path","target_name"}.issubset(test_rt.columns), test_rt.columns.tolist()
assert list(sub_tmpl.columns) == ["sample_id","target"], sub_tmpl.columns.tolist()

# Canonical targets for indexing model outputs
TARGETS = ['Dry_Green_g','Dry_Dead_g','Dry_Clover_g','GDM_g','Dry_Total_g']
tgt2idx = {t:i for i,t in enumerate(TARGETS)}

# Resolve absolute paths for unique images
def resolve_path(image_path):
    p = ROOT / image_path
    if p.exists():
        return p
    stem = Path(image_path).stem
    for ext in (".jpg",".jpeg",".png",".bmp",".JPG",".JPEG",".PNG",".BMP"):
        q = ROOT / "test" / f"{stem}{ext}"
        if q.exists():
            return q
    return p  # may not exist locally; will exist at scoring

uniq = test_rt[["image_path"]].drop_duplicates().copy()
uniq["image_id"] = uniq["image_path"].apply(lambda s: Path(str(s)).stem)
uniq["abs_path"] = uniq["image_path"].apply(resolve_path)

# Minimal dataset for unique images
class _UDS(Dataset):
    def __init__(self, df): self.df = df.reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        row = self.df.iloc[i]
        p = row["abs_path"]
        img = cv2.cvtColor(cv2.imread(str(p), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
        return img, row["image_id"]

uds = _UDS(uniq)
ldr = DataLoader(
    uds, batch_size=CFG['batch_size']*2, shuffle=False, num_workers=CFG['num_workers'],
    collate_fn=lambda b: tuple(zip(*b))
)

# Model & transforms
device = "cuda" if torch.cuda.is_available() else "cpu"
model = BiomassModel.load_from_checkpoint(best_ckpt, backbone=CFG['backbone'], map_location='cpu').to(device).eval()
base_tf = get_transforms(CFG['img_size'], False)
train_tf = get_transforms(CFG['img_size'], True)

def prep(imgs, tfm):
    xs = [tfm(image=im)['image'] for im in imgs]
    return torch.stack(xs).to(device, non_blocking=True)

# Predict once per unique image (TTA optional)
id2vec = {}
with torch.no_grad():
    for imgs, ids_batch in tqdm(ldr, desc="Predict unique images"):
        xb = prep(list(imgs), base_tf)
        ys = [model(xb)]
        for _ in range(max(0, CFG.get('tta',1)-1)):
            xa = prep(list(imgs), train_tf)
            ys.append(model(xa))
        y = torch.stack(ys).mean(0).cpu().numpy()  # [B,5] log-space
        y = np.expm1(y)                             # back to grams
        for k, iid in enumerate(ids_batch):
            id2vec[iid] = y[k]

# Build long predictions in test.csv order, then map to sample_id
vals = []
for _, r in test_rt.iterrows():
    iid = Path(str(r["image_path"])).stem
    tname = str(r["target_name"]).strip()
    vec = id2vec.get(iid)
    j = tgt2idx.get(tname)
    v = 0.0
    if (vec is not None) and (j is not None):
        v = float(vec[j])
        if not np.isfinite(v) or v < 0: v = 0.0
    vals.append(v)

pred_long = pd.DataFrame({
    "sample_id": test_rt["sample_id"],
    "target": vals
})

# Left-join onto the runtime template to guarantee exact row count & order
final_sub = sub_tmpl[["sample_id"]].merge(pred_long, on="sample_id", how="left")
# Fill any missing with 0.0 (shouldn't happen if everything resolved)
final_sub["target"] = final_sub["target"].fillna(0.0).astype("float64")

# Strict validation & save
assert list(final_sub.columns) == ["sample_id","target"]
assert len(final_sub) == len(sub_tmpl), f"Row count mismatch {len(final_sub)} vs {len(sub_tmpl)}"
assert set(final_sub["sample_id"]) == set(sub_tmpl["sample_id"])
assert not final_sub["sample_id"].duplicated().any()
assert np.isfinite(final_sub["target"]).all()

Path("./outputs").mkdir(parents=True, exist_ok=True)
final_sub.to_csv("./outputs/submission.csv", index=False)
final_sub.to_csv("submission.csv", index=False)

def _md5(p):
    h = hashlib.md5()
    with open(p,"rb") as f:
        for ch in iter(lambda: f.read(8192), b""): h.update(ch)
    return h.hexdigest()

print(json.dumps({
    "rows": int(len(final_sub)),
    "md5": _md5("submission.csv"),
    "head": final_sub.head(3).to_dict(orient="records")
}, indent=2))
