# CSIRO SigLIP Biomass Predictor (Log-Target + High-Res Tiling)

**Key Strategies:**
1. **Backbone**: `vit_siglip_base_patch16_384` (Superior semantic understanding).
2. **Target Transform**: `log1p(y)` training to handle heavy-tailed biomass distribution.
3. **Tiling Strategy**: `RandomCrop(384)` to feed high-res patches without detail-destroying resizing.
4. **Loss**: `WeightedHuberLoss` to match competition metric and be robust to outliers.
5. **Hydra Architecture**: Multi-task learning for Species, NDVI/Height, and 5 Biomass targets.

In [7]:
import os, sys, json, torch, timm
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import GroupShuffleSplit

DATA_DIR = '/kaggle/input/csiro-biomass'
CHECKPOINT_DIR = './models_checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

TARGET_COLS = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'GDM_g', 'Dry_Total_g']
TARGET_WEIGHTS = torch.tensor([0.1, 0.1, 0.1, 0.2, 0.5])

CONFIG = {
    'model_name': 'vit_base_patch16_siglip_512.v2_webli', #timm/ViT-B-16-SigLIP-512
    'img_h': 512,
    'img_w': 512,
    'batch_size': 16, # Large resolution benefits from smaller batches
    'lr': 2e-5, 
    'epochs': 100, 
    'device': 'cuda'
}

In [9]:
class SigLIPBiomassModel(nn.Module):
    def __init__(self, model_name=CONFIG['model_name'], num_species=15, num_region = 4):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=True, num_classes=0)
        d = self.backbone.num_features
        
        # Auxiliary Tasks
        self.meta_reg = nn.Linear(d, 2)     # Predicted NDVI/Height
        self.meta_cls = nn.Linear(d, num_species) # Predicted Species
        self.species_emb = nn.Embedding(num_species, 32)
        self.meta_plc = nn.Linear(d, num_region) # Predicted Species
        self.region_emb = nn.Embedding(num_region, 8)
        
        # 5 Hydra Heads for Biomass (Log-Scale Prediction)
        # Fusion: Vis Feat + Meta Reg + Species Emb
        fusion_dim = d + 2 + 32 + 8
        self.heads = nn.ModuleList([
            nn.Sequential(nn.Linear(fusion_dim, 256), nn.GELU(), nn.Linear(256, 1))
            for _ in range(5)
        ])
        
    def forward(self, x, s_lbls=None, r_lbls=None):
        feat = self.backbone(x)
        pr = self.meta_reg(feat)
        pc = self.meta_cls(feat)
        se = self.species_emb(s_lbls if s_lbls is not None else torch.argmax(pc, dim=1))
        pplc = self.meta_plc(feat)
        plce = self.region_emb(r_lbls if r_lbls is not None else torch.argmax(pplc, dim=1))
        
        fus = torch.cat([feat, pr, se, plce], dim=1)
        out = torch.cat([h(fus) for h in self.heads], dim=1)
        
        return out, pr, pc, plce

In [10]:
class BiomassDataset(Dataset):
    def __init__(self, df, species_map, region_map, transform=None, is_train=True):
        self.df = df
        self.transform = transform
        self.is_train = is_train
        self.species_map = species_map
        self.region_map = region_map
        
    def __len__(self): return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = np.array(Image.open(os.path.join(DATA_DIR, row['image_path'])).convert('RGB'))
        # print(type(img))
        
        if self.transform: 
            img = self.transform(image=img)['image']
        
        # Targeted log1p transform
        bio = torch.tensor(np.log1p(row[TARGET_COLS].values.astype(np.float32)))
        reg = torch.tensor([row['Pre_GSHH_NDVI'], row['Height_Ave_cm']], dtype=torch.float32)
        cls = torch.tensor(self.species_map[row['Species']], dtype=torch.long)
        plc = torch.tensor(self.region_map[row['State']], dtype=torch.long)
        
        return img, bio, reg, cls, plc

In [11]:
def weighted_huber_loss(pred, target, weights):
    huber = nn.HuberLoss(reduction='none')
    loss = huber(pred, target)
    return (loss * weights.to(pred.device)).mean()

def train_fold(fold, t_df, v_df, species_map, region_map):
    # Transforms: No Resizing, just Random Crop to 384 (High-Res Details)
    t_trans = A.Compose([
        A.Resize(CONFIG['img_h'], CONFIG['img_w']),
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(brightness=0.1, contrast=0.1, p=0.3),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ToTensorV2()
    ])
    # Validation: Center Crop to be consistent
    v_trans = A.Compose([
        A.Resize(CONFIG['img_h'], CONFIG['img_w']),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ToTensorV2()
    ])
    
    t_ds = BiomassDataset(t_df, species_map, region_map, t_trans)
    v_ds = BiomassDataset(v_df, species_map, region_map, v_trans)
    
    ld_t = DataLoader(t_ds, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=4, pin_memory=True)
    ld_v = DataLoader(v_ds, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=4)
    
    model = SigLIPBiomassModel(num_species=len(species_map), num_region = len(region_map)).to(CONFIG['device'])

    # LOAD PREVIOUS WEIGHTS
    # CONFIG['resume_path'] = f"/kaggle/input/siglip-512/pytorch/default/6/models_checkpoints/siglip_best_fold0.pth"
    # if CONFIG['resume_path'] != None and os.path.exists(CONFIG['resume_path']):
    #     print(f"Resuming from: {CONFIG['resume_path']}")
    #     sd = torch.load(CONFIG['resume_path'], map_location=CONFIG['device'])
    #     model.load_state_dict({k.replace('module.', ''): v for k, v in sd.items()})
    if torch.cuda.device_count() > 1: model = nn.DataParallel(model)
    # Use AdamW with smaller LR for the backbone
    opt = optim.AdamW(model.parameters(), lr=CONFIG['lr'], weight_decay=0.01)
    # ADD THIS: OneCycleLR Scheduler (Steps per batch)
    # Smoother Scheduler for Small Data
    scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=CONFIG['epochs'], eta_min=1e-7)
    scal = torch.amp.GradScaler('cuda')
    
    best_r2 = -float('inf')

    # if fold ==2: CONFIG['epochs']=100
    
    for epoch in range(CONFIG['epochs']):
        model.train(); l_acc = 0
        for imgs, bios, regs, clss, plcs in ld_t:
            imgs, bios, regs, clss, plcs = imgs.to(CONFIG['device']), bios.to(CONFIG['device']), regs.to(CONFIG['device']), clss.to(CONFIG['device']), plcs.to(CONFIG['device'])
            opt.zero_grad()
            with torch.amp.autocast('cuda'):
                p_bio, p_reg, p_cls, p_plc = model(imgs, s_lbls=clss, r_lbls=plcs)
                l_bio = weighted_huber_loss(p_bio, bios, TARGET_WEIGHTS)
                l_reg = nn.MSELoss()(p_reg, regs)
                l_cls = nn.CrossEntropyLoss()(p_cls, clss)
                l_plc = nn.CrossEntropyLoss()(p_plc, plcs)
                loss = l_bio + 0.3*l_reg + 0.1*l_cls + 0.1*l_plc
            scal.scale(loss).backward()
            scal.step(opt)
            scal.update()
            l_acc += loss.item()
            
        model.eval(); all_p, all_t = [], []
        with torch.no_grad():
            for imgs, bios, _, _, _ in ld_v:
                pb, _, _, _ = model(imgs.to(CONFIG['device']))
                # Convert back from log scale for metric calculation
                # Clip log-predictions between 0 and 10 to prevent infinite linear values
                pb = torch.clamp(pb, 0, 10)
                all_p.append(torch.expm1(pb).cpu().numpy())
                all_t.append(torch.expm1(bios).numpy())
        
        y_p, y_t = np.vstack(all_p), np.vstack(all_t)
        # Weighted R2 Metric
        w = TARGET_WEIGHTS.numpy()
        ss_res = np.sum(w * (y_t - y_p)**2)
        ss_tot = np.sum(w * (y_t - y_t.mean(axis=0))**2)
        r2 = 1 - (ss_res / ss_tot)
        
        scheduler.step() # Step per epoch for Cosine
        print(f'Fold {fold} | Ep {epoch+1} | Loss: {l_acc/len(ld_t):.4f} | Val R2: {r2:.4f}')
        if r2 > best_r2:
            best_r2 = r2
            torch.save(model.state_dict(), f'{CHECKPOINT_DIR}/siglip_best_fold{fold}.pth')
    return best_r2

In [None]:
# df = pd.read_csv(f'{DATA_DIR}/train.csv')
# df_w = df.pivot_table(index=['image_path','Sampling_Date','Species','State','Pre_GSHH_NDVI','Height_Ave_cm'], columns='target_name', values='target').reset_index()
# print(df_w.head())
# species_map = {s: i for i, s in enumerate(sorted(df['Species'].unique()))}
# region_map = {s: i for i, s in enumerate(sorted(df['State'].unique()))}


# gkf = GroupKFold(n_splits=5)
# for fold, (t_idx, v_idx) in enumerate(gkf.split(df_w, groups=df_w['Sampling_Date'])):
#     # if fold == 2:
#     print(f"----Training fold {fold}----")
#     score = train_fold(fold, df_w.iloc[t_idx], df_w.iloc[v_idx], species_map, region_map)
#     print(f"Fold {fold} Best R2 Score: {score:.4f}")
#     # else:
#     #     continue
    
    

In [None]:
df = pd.read_csv(f'{DATA_DIR}/train.csv')
df_w = df.pivot_table(index=['image_path','Sampling_Date','Species','State','Pre_GSHH_NDVI','Height_Ave_cm'], columns='target_name', values='target').reset_index()
print(df_w.head())

species_map = {s: i for i, s in enumerate(sorted(df['Species'].unique()))}
region_map = {s: i for i, s in enumerate(sorted(df['State'].unique()))}

# Simple Train-Val Split (grouped by Sampling_Date)
gss = GroupShuffleSplit(n_splits=1, test_size=0.1, random_state=42)
t_idx, v_idx = next(gss.split(df_w, groups=df_w['Sampling_Date']))

t_df, v_df = df_w.iloc[t_idx], df_w.iloc[v_idx]

print(f"Training on {len(t_df)} samples, Validating on {len(v_df)} samples")
score = train_fold(0, t_df, v_df, species_map, region_map)
print(f"Best Val R2 Score: {score:.4f}")