In [None]:
# ===============================================================
# CSIRO Image2Biomass – Training (Weighted R² Validation)
# ===============================================================
import os, gc, cv2, numpy as np, pandas as pd
from tqdm import tqdm
import torch, torch.nn as nn, torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
from sklearn.model_selection import KFold

# ---------------------------------------------------------------
# 1. CONFIG (memory-safe + R² metric)
# ---------------------------------------------------------------
class CFG:
    BASE_PATH       = '/kaggle/input/csiro-biomass'
    TRAIN_CSV       = os.path.join(BASE_PATH, 'train.csv')
    TRAIN_IMAGE_DIR = os.path.join(BASE_PATH, 'train')
    MODEL_DIR       = '/kaggle/working/'
    N_FOLDS         = 5

    MODEL_NAME = 'convnext_tiny'      # safe & matches inference
    IMG_SIZE   = 512                  # fits easily
    PRETRAINED = True

    BATCH_SIZE   = 2
    GRAD_ACC     = 4                  # effective batch = 8
    NUM_WORKERS  = 1
    EPOCHS       = 25
    LR           = 1e-4
    WD           = 1e-2
    PATIENCE     = 5

    TARGET_COLS    = ['Dry_Total_g', 'GDM_g', 'Dry_Green_g']
    DERIVED_COLS   = ['Dry_Clover_g', 'Dry_Dead_g']
    ALL_TARGET_COLS = ['Dry_Green_g','Dry_Dead_g','Dry_Clover_g','GDM_g','Dry_Total_g']
    R2_WEIGHTS     = np.array([0.1, 0.1, 0.1, 0.2, 0.5])  # matches metric

    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Device : {CFG.DEVICE}")
print(f"Backbone: {CFG.MODEL_NAME} | Size: {CFG.IMG_SIZE}")

# ---------------------------------------------------------------
# 2. WEIGHTED R² METRIC (your function)
# ---------------------------------------------------------------
def weighted_r2_score(y_true: np.ndarray, y_pred: np.ndarray):
    """
    y_true, y_pred: shape (N, 5)
    weights: [0.1, 0.1, 0.1, 0.2, 0.5]
    """
    weights = CFG.R2_WEIGHTS
    r2_scores = []
    for i in range(5):
        y_t = y_true[:, i]
        y_p = y_pred[:, i]
        ss_res = np.sum((y_t - y_p) ** 2)
        ss_tot = np.sum((y_t - np.mean(y_t)) ** 2)
        r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0.0
        r2_scores.append(r2)
    r2_scores = np.array(r2_scores)
    weighted_r2 = np.sum(r2_scores * weights) / np.sum(weights)
    return weighted_r2, r2_scores

# ---------------------------------------------------------------
# 3. AUGMENTATIONS
# ---------------------------------------------------------------
def get_train_transforms():
    return A.Compose([
        A.Resize(CFG.IMG_SIZE, CFG.IMG_SIZE),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.3),
        A.Rotate(limit=(-10, 10), p=0.3,
                 interpolation=cv2.INTER_LINEAR,
                 border_mode=cv2.BORDER_REFLECT_101),
        A.ColorJitter(brightness=0.1, contrast=0.1, p=0.3),
        A.Normalize(mean=[0.485, 0.456, 0.406],
                    std =[0.229, 0.224, 0.225]),
        ToTensorV2()
    ], p=1.0)

def get_val_transforms():
    return A.Compose([
        A.Resize(CFG.IMG_SIZE, CFG.IMG_SIZE),
        A.Normalize(mean=[0.485, 0.456, 0.406],
                    std =[0.229, 0.224, 0.225]),
        ToTensorV2()
    ], p=1.0)

# ---------------------------------------------------------------
# 4. DATASET
# ---------------------------------------------------------------
class BiomassDataset(Dataset):
    def __init__(self, df, transform, img_dir):
        self.df        = df
        self.transform = transform
        self.img_dir   = img_dir
        self.paths     = df['image_path'].values
        self.labels    = df[CFG.ALL_TARGET_COLS].values.astype(np.float32)

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        path = os.path.join(self.img_dir, os.path.basename(self.paths[idx]))
        img  = cv2.imread(path)
        if img is None:
            img = np.zeros((1000, 2000, 3), dtype=np.uint8)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        h, w, _ = img.shape
        mid = w // 2
        left  = img[:, :mid]
        right = img[:, mid:]

        left  = self.transform(image=left)['image']
        right = self.transform(image=right)['image']

        label = torch.from_numpy(self.labels[idx])
        return left, right, label

# ---------------------------------------------------------------
# 5. MODEL (safe pretrained loading)
# ---------------------------------------------------------------
class BiomassModel(nn.Module):
    def __init__(self, model_name, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(
            model_name, pretrained=False, num_classes=0, global_pool='avg')
        nf = self.backbone.num_features
        comb = nf * 2

        def head():
            return nn.Sequential(
                nn.Linear(comb, comb // 2),
                nn.ReLU(inplace=True),
                nn.Dropout(0.3),
                nn.Linear(comb // 2, 1)
            )
        self.head_total = head()
        self.head_gdm   = head()
        self.head_green = head()

        if pretrained:
            self.load_pretrained()

    def load_pretrained(self):
        try:
            state_dict = timm.create_model(CFG.MODEL_NAME, pretrained=True, num_classes=0).state_dict()
            self.backbone.load_state_dict(state_dict, strict=False)
            print("Pretrained weights loaded (CPU)")
        except Exception as e:
            print(f"Warning: Pretrained load failed: {e}")

    def forward(self, left, right):
        fl = self.backbone(left)
        fr = self.backbone(right)
        x  = torch.cat([fl, fr], dim=1)
        return (self.head_total(x), self.head_gdm(x), self.head_green(x))

# ---------------------------------------------------------------
# 6. LOSS (MSE on all 5)
# ---------------------------------------------------------------
def biomass_loss(p_total, p_gdm, p_green, labels):
    mse = nn.MSELoss()
    loss_total = mse(p_total.squeeze(), labels[:, 4])
    loss_gdm   = mse(p_gdm.squeeze(),   labels[:, 3])
    loss_green = mse(p_green.squeeze(), labels[:, 0])

    p_clover = torch.clamp(p_gdm - p_green, min=0)
    p_dead   = torch.clamp(p_total - p_gdm, min=0)

    loss_clover = mse(p_clover.squeeze(), labels[:, 2])
    loss_dead   = mse(p_dead.squeeze(),   labels[:, 1])

    return (loss_total + loss_gdm + loss_green + loss_clover + loss_dead) / 5

# ---------------------------------------------------------------
# 7. VALIDATION WITH WEIGHTED R²
# ---------------------------------------------------------------
@torch.no_grad()
def valid_epoch(model, loader, device):
    model.eval()
    running_loss = 0.0
    preds = {'total':[], 'gdm':[], 'green':[]}
    all_labels = []

    for l, r, lab in tqdm(loader, desc='valid', leave=False):
        l, r, lab = l.to(device, non_blocking=True), r.to(device, non_blocking=True), lab.to(device, non_blocking=True)
        p_tot, p_gdm, p_green = model(l, r)
        loss = biomass_loss(p_tot, p_gdm, p_green, lab)
        running_loss += loss.item() * l.size(0)

        preds['total'].extend(p_tot.cpu().numpy().ravel())
        preds['gdm'].extend(p_gdm.cpu().numpy().ravel())
        preds['green'].extend(p_green.cpu().numpy().ravel())
        all_labels.extend(lab.cpu().numpy())

    # Convert to numpy
    pred_total = np.array(preds['total'])
    pred_gdm   = np.array(preds['gdm'])
    pred_green = np.array(preds['green'])
    true_labels = np.stack(all_labels)  # (N, 5)

    # Compute derived
    pred_clover = np.clip(pred_gdm - pred_green, 0, None)
    pred_dead   = np.clip(pred_total - pred_gdm, 0, None)

    # Stack predictions in correct order
    pred_all = np.stack([
        pred_green,      # Dry_Green_g
        pred_dead,       # Dry_Dead_g
        pred_clover,     # Dry_Clover_g
        pred_gdm,        # GDM_g
        pred_total       # Dry_Total_g
    ], axis=1)

    # Compute weighted R²
    weighted_r2, per_target_r2 = weighted_r2_score(true_labels, pred_all)

    return running_loss / len(loader.dataset), weighted_r2, per_target_r2, pred_all, true_labels

# ---------------------------------------------------------------
# 8. TRAINING LOOP
# ---------------------------------------------------------------
def train_epoch(model, loader, opt, scheduler, device):
    model.train()
    running = 0.0
    opt.zero_grad()
    for i, (l, r, lab) in enumerate(tqdm(loader, desc='train', leave=False)):
        l, r, lab = l.to(device, non_blocking=True), r.to(device, non_blocking=True), lab.to(device, non_blocking=True)
        p_tot, p_gdm, p_green = model(l, r)
        loss = biomass_loss(p_tot, p_gdm, p_green, lab) / CFG.GRAD_ACC
        loss.backward()
        running += loss.item() * l.size(0) * CFG.GRAD_ACC

        if (i + 1) % CFG.GRAD_ACC == 0 or (i + 1) == len(loader):
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            opt.zero_grad()

    scheduler.step()
    return running / len(loader.dataset)

# ---------------------------------------------------------------
# 9. MAIN – 5-FOLD WITH R² TRACKING
# ---------------------------------------------------------------
if __name__ == '__main__':
    print("Loading data...")
    df_long = pd.read_csv(CFG.TRAIN_CSV)
    df_wide = df_long.pivot(index='image_path', columns='target_name', values='target').reset_index()
    df_wide = df_wide[['image_path'] + CFG.ALL_TARGET_COLS]
    print(f"{len(df_wide)} training images")

    kfold = KFold(n_splits=CFG.N_FOLDS, shuffle=True, random_state=42)

    for fold, (tr_idx, val_idx) in enumerate(kfold.split(df_wide)):
        print('\n' + '='*70)
        print(f'   FOLD {fold+1}/{CFG.N_FOLDS}   |   {len(tr_idx)} train / {len(val_idx)} val')
        print('='*70)

        torch.cuda.empty_cache()
        gc.collect()

        tr_df  = df_wide.iloc[tr_idx].reset_index(drop=True)
        val_df = df_wide.iloc[val_idx].reset_index(drop=True)

        tr_set = BiomassDataset(tr_df,  get_train_transforms(), CFG.TRAIN_IMAGE_DIR)
        val_set= BiomassDataset(val_df,get_val_transforms(),   CFG.TRAIN_IMAGE_DIR)

        tr_loader  = DataLoader(tr_set,  batch_size=CFG.BATCH_SIZE, shuffle=True,
                               num_workers=CFG.NUM_WORKERS, pin_memory=True, drop_last=True)
        val_loader = DataLoader(val_set, batch_size=CFG.BATCH_SIZE, shuffle=False,
                               num_workers=CFG.NUM_WORKERS, pin_memory=True)

        print("Building model...")
        model = BiomassModel(CFG.MODEL_NAME, pretrained=CFG.PRETRAINED)
        model = model.to(CFG.DEVICE)
        model = nn.DataParallel(model)

        optimizer = optim.AdamW(model.parameters(), lr=CFG.LR, weight_decay=CFG.WD)
        scheduler = CosineAnnealingLR(optimizer, T_max=CFG.EPOCHS)

        best_r2 = -np.inf
        patience = 0

        for epoch in range(1, CFG.EPOCHS+1):
            tr_loss = train_epoch(model, tr_loader, optimizer, scheduler, CFG.DEVICE)
            val_loss, val_r2, per_r2, _, _ = valid_epoch(model, val_loader, CFG.DEVICE)

            per_r2_str = " | ".join([f"{CFG.ALL_TARGET_COLS[i][:5]}: {r2:.3f}" for i, r2 in enumerate(per_r2)])

            print(f'Epoch {epoch:02d} | '
                  f'TrainLoss {tr_loss:.5f} | '
                  f'ValLoss {val_loss:.5f} | '
                  f'ValR² {val_r2:.4f} {"(BEST)" if val_r2 > best_r2 else ""}')
            print(f'     → {per_r2_str}')

            if val_r2 > best_r2:
                best_r2 = val_r2
                save_path = os.path.join(CFG.MODEL_DIR, f'best_model_fold{fold}.pth')
                torch.save(model.module.state_dict() if hasattr(model, 'module') else model.state_dict(), save_path)
                print(f'   → SAVED (R²: {best_r2:.4f})')
                patience = 0
            else:
                patience += 1
                if patience >= CFG.PATIENCE:
                    print(f'   → EARLY STOP (no improvement in {CFG.PATIENCE} epochs)')
                    break

        # Cleanup
        del model, tr_loader, val_loader, optimizer, scheduler
        torch.cuda.empty_cache()
        gc.collect()

    print('\nTraining complete! Best models saved in:', CFG.MODEL_DIR)
    print('Use these in inference with:')
    print('   MODEL_NAME = "convnext_tiny"')
    print('   IMG_SIZE = 512')