# CSIRO Image2Biomass - DINOv2 Hydra (Multi-GPU + Mixup + Tiling)

**Ultimate Performance Strategy:**
1. **Backbone**: DINOv2-Base (`vit_base_patch14_dinov2`).
2. **Tiling**: Original 1000x2000 crops for maximum leaf-level detail.
3. **Mixup**: Blends both Global and Tiled views to improve regression stability.
4. **Multi-GPU**: Doubled batch size for 4-hour limit safety.

In [1]:
import os, sys, functools
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import GroupKFold

print = functools.partial(print, flush=True)

DATA_DIR = '/kaggle/input/csiro-biomass'
CHECKPOINT_DIR = './models_checkpoints'
RESUME_PATH = '/kaggle/input/dinov2-tiled/pytorch/default/1/models_checkpoints/best_dino_tiled_fold_1.pth'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

TARGET_COLUMNS = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'GDM_g', 'Dry_Total_g']
TARGET_WEIGHTS = [0.1, 0.1, 0.1, 0.2, 0.5]

CONFIG = {
    'model_name': 'vit_base_patch14_dinov2.lvd142m', 
    'img_h': 392, 
    'img_w': 784,
    'tile_size': 392,
    'batch_size': 16,     # Multi-GPU Ready
    'lr': 1e-4,
    'epochs': 20,        # Slightly more epochs for Mixup convergence
    'mixup_prob': 0.5,
    'alpha': 1.0,
    'n_splits': 5,
    'device': 'cuda',
    'resume_path': '/kaggle/input/dinov2-tiled/pytorch/default/1/models_checkpoints/best_dino_tiled_fold_1.pth'
}



In [2]:
class GlobalLocalDinoHydra(nn.Module):
    def __init__(self, model_name=CONFIG['model_name'], num_species=15):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=True, num_classes=0, dynamic_img_size=True)
        embed_dim = self.backbone.num_features
        self.meta_reg = nn.Linear(embed_dim * 2, 2) 
        self.meta_cls = nn.Linear(embed_dim * 2, num_species)
        self.species_emb = nn.Embedding(num_species, 32)
        fusion_dim = (embed_dim * 2) + 2 + 32
        self.heads = nn.ModuleList([
            nn.Sequential(nn.Linear(fusion_dim, 512), nn.GELU(), nn.Dropout(0.1), nn.Linear(512, 1))
            for _ in range(5)
        ])
        
    def forward(self, x_global, x_tiles):
        feat_global = self.backbone(x_global) 
        B, N, C, H, W = x_tiles.shape
        feat_tiles = self.backbone(x_tiles.view(B*N, C, H, W))
        feat_tiles = feat_tiles.view(B, N, -1).mean(dim=1) 
        fused_vis = torch.cat([feat_global, feat_tiles], dim=1)
        p_reg = self.meta_reg(fused_vis)
        p_cls = self.meta_cls(fused_vis)
        # Use argmax for inference even during training (Mixup handles soft labels in loss)
        s_emb = self.species_emb(torch.argmax(p_cls, dim=1))
        f_all = torch.cat([fused_vis, p_reg, s_emb], dim=1) 
        out = torch.cat([h(f_all) for h in self.heads], dim=1)
        return out, p_reg, p_cls

In [3]:
def apply_mixup_tiled(g_imgs, t_imgs, bio_gt, reg_gt, cls_gt, num_species):
    if np.random.rand() > CONFIG['mixup_prob']:
        return g_imgs, t_imgs, bio_gt, reg_gt, torch.nn.functional.one_hot(cls_gt, num_species).float(), 1.0
    
    idx = torch.randperm(g_imgs.size(0)).to(g_imgs.device)
    lam = np.random.beta(CONFIG['alpha'], CONFIG['alpha'])
    
    # Mix both global and tiled streams
    g_imgs = lam * g_imgs + (1 - lam) * g_imgs[idx]
    t_imgs = lam * t_imgs + (1 - lam) * t_imgs[idx]
    
    bio_gt = lam * bio_gt + (1 - lam) * bio_gt[idx]
    reg_gt = lam * reg_gt + (1 - lam) * reg_gt[idx]
    
    cls_onehot = torch.nn.functional.one_hot(cls_gt, num_species).float()
    cls_mixed = lam * cls_onehot + (1 - lam) * cls_onehot[idx]
    
    return g_imgs, t_imgs, bio_gt, reg_gt, cls_mixed, lam

In [4]:
class HighResTiledDataset(Dataset):
    def __init__(self, df, img_dir, tf_global, tf_tile, species_map):
        self.df, self.img_dir, self.tf_global, self.tf_tile, self.species_map = df, img_dir, tf_global, tf_tile, species_map
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(os.path.join(self.img_dir, row['image_path'])).convert('RGB')
        img_np = np.array(img)
        H, W, _ = img_np.shape
        img_global = self.tf_global(image=img_np)['image']
        mid_w = W // 2
        tile_left = img_np[:, :mid_w, :]
        tile_right = img_np[:, mid_w:, :]
        tiles = [self.tf_tile(image=tile_left)['image'], self.tf_tile(image=tile_right)['image']]
        img_tiles = torch.stack(tiles) 
        bio = torch.tensor(row[TARGET_COLUMNS].values.astype(np.float32))
        reg = torch.tensor([row['Pre_GSHH_NDVI'], row['Height_Ave_cm']], dtype=torch.float32)
        spec_idx = torch.tensor(self.species_map[row['Species']], dtype=torch.long)
        return img_global, img_tiles, bio, reg, spec_idx

In [5]:
def competition_metric(y_true, y_pred):
    N = y_true.shape[0]
    w = np.tile(TARGET_WEIGHTS, (N, 1)).flatten()
    y_t, y_p = y_true.flatten(), y_pred.flatten()
    avg = np.sum(w * y_t) / np.sum(w)
    res = np.sum(w * (y_t - y_p)**2)
    tot = np.sum(w * (y_t - avg)**2)
    return 1 - (res/tot) if tot != 0 else 0

def train_mixup_tiled(df_wide, species_map):
    gkf = GroupKFold(n_splits=CONFIG['n_splits'])
    train_idx, val_idx = next(gkf.split(df_wide, groups=df_wide['Sampling_Date']))
    train_df, val_df = df_wide.iloc[train_idx], df_wide.iloc[val_idx]
    
    num_species = len(species_map)
    tf_g = A.Compose([A.Resize(CONFIG['img_h'], CONFIG['img_w']), A.HorizontalFlip(), A.Normalize(), ToTensorV2()])
    tf_t = A.Compose([A.Resize(CONFIG['tile_size'], CONFIG['tile_size']), A.HorizontalFlip(), A.Normalize(), ToTensorV2()])
    
    loader_t = DataLoader(HighResTiledDataset(train_df, DATA_DIR, tf_g, tf_t, species_map), batch_size=CONFIG['batch_size'], shuffle=True, num_workers=4, drop_last=True)
    loader_v = DataLoader(HighResTiledDataset(val_df, DATA_DIR, tf_g, tf_t, species_map), batch_size=CONFIG['batch_size'], shuffle=False, num_workers=4)
    
    model = GlobalLocalDinoHydra(num_species=num_species).to(CONFIG['device'])
    # LOAD PREVIOUS WEIGHTS
    if os.path.exists(CONFIG['resume_path']):
        print(f"Resuming from: {CONFIG['resume_path']}")
        sd = torch.load(CONFIG['resume_path'], map_location=CONFIG['device'])
        model.load_state_dict({k.replace('module.', ''): v for k, v in sd.items()})
        
    if torch.cuda.device_count() > 1:
        print(f'Detected {torch.cuda.device_count()} GPUs. Using DataParallel.')
        model = nn.DataParallel(model)
    
    optimizer = optim.AdamW(model.parameters(), lr=CONFIG['lr'], weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['epochs'])
    
    crit_bio = nn.HuberLoss()
    crit_reg = nn.MSELoss()
    crit_cls = nn.BCEWithLogitsLoss() # Use BCE for Mixed labels
    scaler = torch.amp.GradScaler('cuda')
    
    best_r2 = -float('inf')
    for epoch in range(CONFIG['epochs']):
        model.train(); loss_acc = 0
        for g, t, b, r, c in loader_t:
            g, t, b, r, c = g.to(CONFIG['device']), t.to(CONFIG['device']), b.to(CONFIG['device']), r.to(CONFIG['device']), c.to(CONFIG['device'])
            
            # Apply Mixup to BOTH streams simultaneously
            g_m, t_m, b_m, r_m, c_m, _ = apply_mixup_tiled(g, t, b, r, c, num_species)
            
            optimizer.zero_grad()
            with torch.amp.autocast('cuda'):
                pb, pr, pc = model(g_m, t_m)
                loss = crit_bio(pb, b_m) + 0.1*crit_reg(pr, r_m) + 0.2*crit_cls(pc, c_m)
            
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); scaler.step(optimizer); scaler.update(); loss_acc += loss.item()
        
        model.eval(); all_p, all_t = [] , []
        with torch.no_grad():
            for g, t, b, _, _ in loader_v:
                p, _, _ = model(g.to(CONFIG['device']), t.to(CONFIG['device']))
                all_p.append(p.cpu().numpy()); all_t.append(b.numpy())
        
        r2 = competition_metric(np.vstack(all_t), np.vstack(all_p))
        print(f'Epoch {epoch+1} | R2: {r2:.4f} | Loss: {loss_acc/len(loader_t):.4f}')
        if r2 > best_r2: 
            best_r2 = r2
            save_sd = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
            torch.save(save_sd, f'{CHECKPOINT_DIR}/best_dino_mixup_tiled.pth')
        scheduler.step()
    print(f'Final Best R2: {best_r2:.4f}')

In [6]:
df = pd.read_csv(os.path.join(DATA_DIR, 'train.csv'))
df_wide = df.pivot_table(index=['image_path', 'Sampling_Date', 'Species', 'Pre_GSHH_NDVI', 'Height_Ave_cm'], columns='target_name', values='target').reset_index()
species_map = {s: i for i, s in enumerate(sorted(df_wide['Species'].unique()))}
print('Starting Mixup + Tiling + Multi-GPU Training...')
train_mixup_tiled(df_wide, species_map)

Starting Mixup + Tiling + Multi-GPU Training...


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Resuming from: /kaggle/input/dinov2-tiled/pytorch/default/1/models_checkpoints/best_dino_tiled_fold_1.pth
Detected 2 GPUs. Using DataParallel.
Epoch 1 | R2: 0.6261 | Loss: 14.0349
Epoch 2 | R2: 0.4035 | Loss: 12.7722
Epoch 3 | R2: 0.3498 | Loss: 15.8199
Epoch 4 | R2: 0.3494 | Loss: 13.9585
Epoch 5 | R2: 0.4627 | Loss: 11.7235
Epoch 6 | R2: 0.5446 | Loss: 10.7164
Epoch 7 | R2: 0.5318 | Loss: 10.3651
Epoch 8 | R2: 0.3576 | Loss: 9.9597
Epoch 9 | R2: 0.5305 | Loss: 9.5868
Epoch 10 | R2: 0.6136 | Loss: 8.2950
Epoch 11 | R2: 0.5992 | Loss: 8.0576
Epoch 12 | R2: 0.6982 | Loss: 7.4100
Epoch 13 | R2: 0.6633 | Loss: 6.3249
Epoch 14 | R2: 0.6951 | Loss: 6.9932
Epoch 15 | R2: 0.6719 | Loss: 6.3986
Epoch 16 | R2: 0.7350 | Loss: 5.4784
Epoch 17 | R2: 0.7348 | Loss: 5.6322
Epoch 18 | R2: 0.7159 | Loss: 6.0133
Epoch 19 | R2: 0.7232 | Loss: 4.6080
Epoch 20 | R2: 0.7258 | Loss: 5.4030
Final Best R2: 0.7350
