# CSIRO DINOv2 Tiling + Mixup Strategy (Multi-GPU Training)

**Architecture:**
1. **Global Stream**: Full image at `392x784` for environmental context.
2. **Local Stream**: Two high-res `1000x1000` crops (original scale) resized to `392x392`.
3. **Regularization**: Mixup applied synchronously to both streams.
4. **Speed**: Optimized for Dual T4 GPUs (`nn.DataParallel`).

In [None]:
!pip install -q -U albumentations timm opencv-python-headless

import os, sys, functools, json
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import GroupKFold

print = functools.partial(print, flush=True)
DATA_DIR = '/kaggle/input/csiro-biomass'
CHECKPOINT_DIR = './models_checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# STRICT ALPHABETICAL ORDER
TARGET_COLUMNS = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'Dry_Total_g', 'GDM_g']
TARGET_WEIGHTS = [0.1, 0.1, 0.1, 0.5, 0.2]

CONFIG = {
    'model_name': 'vit_base_patch14_dinov2.lvd142m', 
    'img_h': 392, 
    'img_w': 784,
    'tile_size': 392,
    'batch_size': 16,
    'lr': 1e-4,
    'epochs': 25,
    'mixup_prob': 0.5,
    'n_splits': 5,
    'resume_path': None, 
    'device': 'cuda'
}

In [None]:
class GlobalLocalDinoHydra(nn.Module):
    def __init__(self, model_name=CONFIG['model_name'], num_species=15):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=True, num_classes=0, dynamic_img_size=True)
        embed_dim = self.backbone.num_features
        self.meta_reg = nn.Linear(embed_dim * 2, 2) 
        self.meta_cls = nn.Linear(embed_dim * 2, num_species)
        self.species_emb = nn.Embedding(num_species, 32)
        fusion_dim = (embed_dim * 2) + 2 + 32
        self.heads = nn.ModuleList([
            nn.Sequential(nn.Linear(fusion_dim, 512), nn.GELU(), nn.Linear(512, 1))
            for _ in range(5)
        ])
    def forward(self, x_global, x_tiles):
        f_g = self.backbone(x_global)
        B, N, C, H, W = x_tiles.shape
        f_t = self.backbone(x_tiles.view(B*N, C, H, W)).view(B, N, -1).mean(dim=1)
        vis = torch.cat([f_g, f_t], dim=1)
        pr, pc = self.meta_reg(vis), self.meta_cls(vis)
        se = self.species_emb(torch.argmax(pc, dim=1))
        f = torch.cat([vis, pr, se], dim=1)
        return torch.cat([h(f) for h in self.heads], dim=1), pr, pc

In [None]:
class TiledDataset(Dataset):
    def __init__(self, df, tf_g, tf_t, sm):
        self.df, self.tf_g, self.tf_t, self.sm = df, tf_g, tf_t, sm
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = np.array(Image.open(os.path.join(DATA_DIR, row['image_path'])).convert('RGB'))
        mid = img.shape[1] // 2
        g = self.tf_g(image=img)['image']
        t = torch.stack([self.tf_t(image=img[:, :mid])['image'], self.tf_t(image=img[:, mid:])['image']])
        b = torch.tensor(row[TARGET_COLUMNS].values.astype(np.float32))
        r = torch.tensor([row['Pre_GSHH_NDVI'], row['Height_Ave_cm']], dtype=torch.float32)
        c = torch.tensor(self.sm[row['Species']], dtype=torch.long)
        return g, t, b, r, c

In [None]:
def apply_mixup(g, t, b, r, c, ns):
    if np.random.rand() > CONFIG['mixup_prob']: 
        return g, t, b, r, nn.functional.one_hot(c, ns).float(), 1.0
    i = torch.randperm(g.size(0)).to(g.device); l = np.random.beta(1.0, 1.0)
    g, t = l*g + (1-l)*g[i], l*t + (1-l)*t[i]
    b, r = l*b + (1-l)*b[i], l*r + (1-l)*r[i]
    c_oh = nn.functional.one_hot(c, ns).float()
    return g, t, b, r, l*c_oh + (1-l)*c_oh[i], l

def train_fold(fold, t_df, v_df, sm):
    ns = len(sm)
    tf_g = A.Compose([A.Resize(392, 784), A.HorizontalFlip(), A.Normalize(), ToTensorV2()])
    tf_t = A.Compose([A.Resize(392, 392), A.HorizontalFlip(), A.Normalize(), ToTensorV2()])
    ld_t = DataLoader(TiledDataset(t_df, tf_g, tf_t, sm), batch_size=CONFIG['batch_size'], shuffle=True, num_workers=4, drop_last=True)
    ld_v = DataLoader(TiledDataset(v_df, tf_g, tf_t, sm), batch_size=CONFIG['batch_size'], shuffle=False, num_workers=4)
    
    m = GlobalLocalDinoHydra(num_species=ns).cuda()
    if CONFIG['resume_path']:
        sd = torch.load(CONFIG['resume_path'], map_location='cuda')
        m.load_state_dict({k.replace('module.',''): v for k,v in sd.items()})
    if torch.cuda.device_count() > 1: m = nn.DataParallel(m)
    
    # STABILITY: Lower LR + Weight Decay
    opt = optim.AdamW(m.parameters(), lr=5e-5, weight_decay=0.01)
    scaler = torch.amp.GradScaler('cuda') # Instantiate OUTSIDE loop
    crits = [nn.HuberLoss(), nn.MSELoss(), nn.BCEWithLogitsLoss()]
    
    best_r2 = -float('inf')
    for e in range(CONFIG['epochs']):
        m.train(); loss_acc = 0
        for g, t, b, r, c in ld_t:
            g, t, b, r, c = g.cuda(), t.cuda(), b.cuda(), r.cuda(), c.cuda()
            gm, tm, bm, rm, cm, _ = apply_mixup(g, t, b, r, c, ns)
            opt.zero_grad()
            with torch.amp.autocast('cuda'):
                pb, pr, pc = m(gm, tm)
                loss = crits[0](pb, bm) + 0.1*crits[1](pr, rm) + 0.2*crits[2](pc, cm)
            
            # STABILITY: Correct AMP scaling
            scaler.scale(loss).backward()
            scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(m.parameters(), 1.0) # Prevents NaN
            scaler.step(opt)
            scaler.update()
            loss_acc += loss.item()
            
        m.eval(); ap, at = [], []
        with torch.no_grad():
            for g, t, b, _, _ in ld_v:
                pb, _, _ = m(g.cuda(), t.cuda())
                ap.append(pb.cpu().numpy()); at.append(b.numpy())
        
        y_t, y_p = np.vstack(at), np.vstack(ap)
        # DENOM Guard: prevent division by zero or NaN
        denom = np.sum((y_t - y_t.mean())**2)
        r2 = 1 - np.sum((y_t-y_p)**2)/denom if denom > 0 else 0.0
        print(f'Fold {fold} | Ep {e+1} | R2: {r2:.4f} | Loss: {loss_acc/len(ld_t):.4f}')
        if r2 > best_r2:
            best_r2 = r2
            save_sd = m.module.state_dict() if hasattr(m, "module") else m.state_dict()
            torch.save(save_sd, f"dino_fold{fold}.pth")
    return best_r2

In [None]:
df = pd.read_csv('/kaggle/input/csiro-biomass/train.csv')
w = df.pivot_table(index=['image_path','Sampling_Date','Species','Pre_GSHH_NDVI','Height_Ave_cm'], columns='target_name', values='target').reset_index()
sm = {s: i for i, s in enumerate(sorted(w['Species'].unique()))}
gkf = GroupKFold(n_splits=5)
for f, (ti, vi) in enumerate(gkf.split(w, groups=w['Sampling_Date'])):
    if f > 1: break # Train 2 folds initially
    train_fold(f, w.iloc[ti], w.iloc[vi], sm)