In [None]:
# ============================================================
# CSIRO Image2Biomass Prediction - Corrected & Optimized
# ============================================================
# Improvements:
# - fixed backbone feature dimension detection (use backbone.num_features)
# - corrected WeightedMSELoss to weight per-target MSE correctly
# - safer image path handling
# - reduced default img_size and batch_size for Kaggle P100
# - consistent target scaling and inverse transform for R² evaluation
# ============================================================

import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
import cv2
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# ============================================================
# CONFIGURATION
# ============================================================
class CFG:
    # Paths
    train_csv = '/kaggle/input/csiro-biomass/train.csv'
    test_csv = '/kaggle/input/csiro-biomass/test.csv'
    train_dir = '/kaggle/input/csiro-biomass/train'
    test_dir = '/kaggle/input/csiro-biomass/test'
    output_dir = '/kaggle/working'
    
    # Model
    model_name = 'tf_efficientnetv2_m'  # EfficientNetV2-M
    img_size = 512        # lowered from 800 to be safer on P100
    pretrained = True
    
    # Training
    n_folds = 5
    seed = 42
    epochs = 30
    batch_size = 8        # lowered from 16 to reduce OOM risk
    num_workers = 4
    lr = 3e-4
    weight_decay = 1e-5
    warmup_epochs = 2
    
    # Augmentation / TTA
    use_tta = True
    tta_steps = 5
    
    # Targets
    targets = ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g', 'GDM_g', 'Dry_Total_g']
    target_weights = [0.1, 0.1, 0.1, 0.2, 0.5]
    
    use_target_scaling = True
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# reproducibility
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    import random
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(CFG.seed)

# ============================================================
# DATA PREPROCESSING
# ============================================================
def prepare_data(train_csv_path):
    df = pd.read_csv(train_csv_path)
    # create image_id (if sample_id contains __)
    if 'sample_id' in df.columns:
        df['image_id'] = df['sample_id'].apply(lambda x: x.split('__')[0] if '__' in str(x) else str(x))
    else:
        df['image_id'] = df['image_path'].apply(lambda x: os.path.basename(x).split('.')[0])
    
    # metadata columns commonly present; guard if missing
    metadata_cols = [c for c in ['image_path', 'Sampling_Date', 'State', 'Species', 'Pre_GSHH_NDVI', 'Height_Ave_cm'] if c in df.columns]
    
    # pivot long -> wide if necessary
    if 'target_name' in df.columns and 'target' in df.columns:
        df_pivot = df.pivot_table(
            index=['image_id'] + metadata_cols,
            columns='target_name',
            values='target',
            aggfunc='first'
        ).reset_index()
    else:
        df_pivot = df.copy()
        if 'image_path' not in df_pivot.columns:
            # try to reconstruct image_path from image_id
            df_pivot['image_path'] = df_pivot['image_id'].astype(str) + '.jpg'
    
    # ensure targets exist
    for t in CFG.targets:
        if t not in df_pivot.columns:
            df_pivot[t] = 0.0
        else:
            df_pivot[t] = df_pivot[t].fillna(0.0)
    
    # biomass bin for stratification
    try:
        df_pivot['biomass_bin'] = pd.qcut(df_pivot['Dry_Total_g'], q=10, labels=False, duplicates='drop')
    except Exception:
        df_pivot['biomass_bin'] = pd.cut(df_pivot['Dry_Total_g'], bins=10, labels=False)
    df_pivot['biomass_bin'] = df_pivot['biomass_bin'].fillna(0).astype(int)
    
    print(f"Prepared {len(df_pivot)} unique images")
    return df_pivot

# ============================================================
# DATASET
# ============================================================
class BiomassDataset(Dataset):
    def __init__(self, df, img_dir, transform=None, is_test=False, scaler=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.is_test = is_test
        
        # prepare tabular features
        if 'Pre_GSHH_NDVI' in df.columns and 'Height_Ave_cm' in df.columns:
            tabular_data = df[['Pre_GSHH_NDVI','Height_Ave_cm']].fillna(0).values
        else:
            tabular_data = np.zeros((len(df), 2), dtype=np.float32)
        
        if not is_test:
            if scaler is None:
                self.scaler = StandardScaler()
                self.tabular_features = self.scaler.fit_transform(tabular_data)
            else:
                self.scaler = scaler
                self.tabular_features = self.scaler.transform(tabular_data)
        else:
            if scaler is not None:
                self.scaler = scaler
                self.tabular_features = self.scaler.transform(tabular_data)
            else:
                self.tabular_features = tabular_data.astype(np.float32)
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # build image path robustly
        # allow either 'image_path' containing full path or just filename
        if 'image_path' in row and isinstance(row['image_path'], str) and row['image_path'].strip() != '':
            fname = os.path.basename(row['image_path'])
        elif 'image_id' in row:
            fname = str(row['image_id']) + '.jpg'
        else:
            raise ValueError("No valid image identifier for row idx {}".format(idx))
        
        img_path = os.path.join(self.img_dir, fname)
        image = cv2.imread(img_path)
        if image is None:
            # try alternate path directly from image_path if absolute
            alt = row.get('image_path', None)
            if isinstance(alt, str) and os.path.exists(alt):
                image = cv2.imread(alt)
            else:
                raise FileNotFoundError(f"Failed to load image: {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        
        tabular = torch.tensor(self.tabular_features[idx], dtype=torch.float32)
        
        if self.is_test:
            return image, tabular
        else:
            targets = torch.tensor([
                row.get('Dry_Green_g', 0.0),
                row.get('Dry_Dead_g', 0.0),
                row.get('Dry_Clover_g', 0.0),
                row.get('GDM_g', 0.0),
                row.get('Dry_Total_g', 0.0)
            ], dtype=torch.float32)
            return image, tabular, targets

# ============================================================
# AUGMENTATIONS
# ============================================================
def get_train_transforms():
    return A.Compose([
        A.Resize(CFG.img_size, CFG.img_size),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.08, scale_limit=0.12, rotate_limit=15, p=0.5),
        A.OneOf([
            A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=1),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1),
        ], p=0.7),
        A.OneOf([
            A.GaussNoise(var_limit=(5.0, 30.0), p=1),
            A.GaussianBlur(blur_limit=(3, 7), p=1),
        ], p=0.3),
        A.CoarseDropout(max_holes=6, max_height=32, max_width=32, p=0.25),
        A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
        ToTensorV2(),
    ])

def get_valid_transforms():
    return A.Compose([
        A.Resize(CFG.img_size, CFG.img_size),
        A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
        ToTensorV2(),
    ])

# ============================================================
# MODEL
# ============================================================
class BiomassModel(nn.Module):
    def __init__(self, model_name, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0, global_pool='avg')
        # safest way to get features:
        img_features = getattr(self.backbone, 'num_features', None)
        if img_features is None:
            # fallback: forward a small tensor (on CPU) to inspect shape (rare)
            self.backbone.eval()
            with torch.no_grad():
                try:
                    dummy = torch.randn(1,3,CFG.img_size,CFG.img_size)
                    feat = self.backbone(dummy)
                    img_features = feat.shape[1]
                except Exception:
                    img_features = 1280  # reasonable default for many backbones
        self.img_features = img_features
        
        # tabular encoder
        self.tabular_encoder = nn.Sequential(
            nn.Linear(2, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
        )
        
        fusion_dim = self.img_features + 128
        self.fusion = nn.Sequential(
            nn.Linear(fusion_dim, 200),
            nn.BatchNorm1d(200),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(200, 200),
            nn.BatchNorm1d(200),
            nn.ReLU(),
            nn.Dropout(0.3),
        )
        
        self.head = nn.Linear(200, len(CFG.targets))
    
    def forward(self, image, tabular):
        img_feats = self.backbone(image)            # [B, img_features]
        tab_feats = self.tabular_encoder(tabular)   # [B, 128]
        combined = torch.cat([img_feats, tab_feats], dim=1)
        fused = self.fusion(combined)
        outputs = self.head(fused)
        return outputs

# ============================================================
# LOSS & METRIC
# ============================================================
class WeightedMSELoss(nn.Module):
    def __init__(self, weights):
        super().__init__()
        self.register_buffer('weights', torch.tensor(weights, dtype=torch.float32))
    
    def forward(self, preds, targets):
        # preds, targets: [B, T]
        mse_per_target = ((preds - targets) ** 2).mean(dim=0)  # [T] mean over batch
        weighted = mse_per_target * self.weights              # [T]
        return weighted.sum() / self.weights.sum()            # scalar normalized

def calculate_r2_score(y_true, y_pred):
    ss_res = np.sum((y_true - y_pred) ** 2)
    ss_tot = np.sum((y_true - y_true.mean()) ** 2)
    if ss_tot == 0:
        return 0.0
    return 1 - (ss_res / ss_tot)

def calculate_weighted_r2(y_true, y_pred, weights):
    scores = []
    for i in range(y_true.shape[1]):
        r2 = calculate_r2_score(y_true[:, i], y_pred[:, i])
        scores.append(r2)
    weighted = sum(s * w for s, w in zip(scores, weights))
    return weighted, scores

# ============================================================
# TRAIN / VALID
# ============================================================
def train_epoch(model, loader, optimizer, criterion, device, scaler):
    model.train()
    running = 0.0
    n_batches = 0
    pbar = tqdm(loader, desc='Train', leave=False)
    for batch_idx, (images, tabular, targets) in enumerate(pbar):
        images = images.to(device)
        tabular = tabular.to(device)
        targets = targets.to(device)
        
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=True):
            preds = model(images, tabular)
            loss = criterion(preds, targets)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        running += loss.item()
        n_batches += 1
        pbar.set_postfix({'loss': running / n_batches})
    return running / (n_batches + 1e-12)

def validate_epoch(model, loader, criterion, device, target_scaler=None):
    model.eval()
    running = 0.0
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for images, tabular, targets in tqdm(loader, desc='Valid', leave=False):
            images = images.to(device)
            tabular = tabular.to(device)
            targets = targets.to(device)
            preds = model(images, tabular)
            loss = criterion(preds, targets)
            running += loss.item()
            all_preds.append(preds.cpu().numpy())
            all_targets.append(targets.cpu().numpy())
    all_preds = np.vstack(all_preds)
    all_targets = np.vstack(all_targets)
    
    if target_scaler is not None:
        all_preds_orig = target_scaler.inverse_transform(all_preds)
        all_targets_orig = target_scaler.inverse_transform(all_targets)
        weighted_r2, indiv = calculate_weighted_r2(all_targets_orig, all_preds_orig, CFG.target_weights)
    else:
        weighted_r2, indiv = calculate_weighted_r2(all_targets, all_preds, CFG.target_weights)
        all_preds_orig, all_targets_orig = all_preds, all_targets
    
    return running / (len(loader) + 1e-12), weighted_r2, indiv, all_preds_orig, all_targets_orig

# ============================================================
# TRAIN K-FOLD
# ============================================================
def train_kfold(df, fold):
    print(f"\n=== Fold {fold + 1}/{CFG.n_folds} ===")
    train_df = df[df['fold'] != fold].reset_index(drop=True)
    valid_df = df[df['fold'] == fold].reset_index(drop=True)
    print(f"Train size: {len(train_df)}, Valid size: {len(valid_df)}")
    
    # Optionally scale targets
    target_scaler = None
    if CFG.use_target_scaling:
        target_scaler = StandardScaler()
        target_scaler.fit(train_df[CFG.targets].values)
        train_df[CFG.targets] = target_scaler.transform(train_df[CFG.targets].values)
        valid_df[CFG.targets] = target_scaler.transform(valid_df[CFG.targets].values)
        print("Targets scaled (zero mean, unit var)")
    
    train_ds = BiomassDataset(train_df, CFG.train_dir, transform=get_train_transforms())
    valid_ds = BiomassDataset(valid_df, CFG.train_dir, transform=get_valid_transforms(), scaler=train_ds.scaler)
    
    train_loader = DataLoader(train_ds, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers, pin_memory=True)
    valid_loader = DataLoader(valid_ds, batch_size=CFG.batch_size*2, shuffle=False, num_workers=CFG.num_workers, pin_memory=True)
    
    model = BiomassModel(CFG.model_name, pretrained=CFG.pretrained).to(CFG.device)
    criterion = WeightedMSELoss(CFG.target_weights).to(CFG.device)
    optimizer = optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    
    warmup_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda e: (e+1)/CFG.warmup_epochs if e < CFG.warmup_epochs else 1.0)
    main_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
    scaler = torch.cuda.amp.GradScaler()
    
    best_score = -np.inf
    patience = 8
    patience_counter = 0
    
    for epoch in range(CFG.epochs):
        print(f"\nEpoch {epoch+1}/{CFG.epochs} LR={optimizer.param_groups[0]['lr']:.6f}")
        train_loss = train_epoch(model, train_loader, optimizer, criterion, CFG.device, scaler)
        valid_loss, weighted_r2, indiv_r2, preds_orig, targets_orig = validate_epoch(model, valid_loader, criterion, CFG.device, target_scaler)
        
        # print metrics
        print(f"Train Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f}")
        print(f"Weighted R²: {weighted_r2:.4f} | Individual R²: {['{:.4f}'.format(x) for x in indiv_r2]}")
        
        # scheduler step
        if epoch < CFG.warmup_epochs:
            warmup_scheduler.step()
        else:
            main_scheduler.step()
        
        # save best by original-scale weighted R2
        score_to_use = weighted_r2
        if score_to_use > best_score:
            best_score = score_to_use
            ckpt = {
                'model_state_dict': model.state_dict(),
                'tabular_scaler': train_ds.scaler,
                'target_scaler': target_scaler
            }
            torch.save(ckpt, os.path.join(CFG.output_dir, f'best_model_fold{fold}.pth'))
            print(f"Saved best model (R²={best_score:.4f})")
            patience_counter = 0
        else:
            patience_counter += 1
        
        if patience_counter >= patience:
            print(f"Early stopping (patience {patience})")
            break
        
        # free some memory
        torch.cuda.empty_cache()
    
    return best_score

# ============================================================
# MAIN
# ============================================================
def main():
    print("Loading CSV:", CFG.train_csv)
    df_raw = pd.read_csv(CFG.train_csv)
    print("Raw shape:", df_raw.shape)
    if 'target_name' in df_raw.columns and 'target' in df_raw.columns:
        df = prepare_data(CFG.train_csv)
    else:
        df = df_raw.copy()
        if 'image_path' not in df.columns and 'image_id' in df.columns:
            df['image_path'] = df['image_id'].astype(str) + '.jpg'
        if 'biomass_bin' not in df.columns:
            try:
                df['biomass_bin'] = pd.qcut(df['Dry_Total_g'], q=10, labels=False, duplicates='drop')
            except:
                df['biomass_bin'] = pd.cut(df['Dry_Total_g'], bins=10, labels=False)
            df['biomass_bin'] = df['biomass_bin'].fillna(0).astype(int)
    
    print("Final df shape:", df.shape)
    # create folds
    skf = StratifiedKFold(n_splits=CFG.n_folds, shuffle=True, random_state=CFG.seed)
    df['fold'] = -1
    for fold, (_, val_idx) in enumerate(skf.split(df, df['biomass_bin'])):
        df.loc[val_idx, 'fold'] = fold
    print("Fold counts:\n", df['fold'].value_counts().sort_index())
    
    # train folds
    fold_scores = []
    for fold in range(CFG.n_folds):
        score = train_kfold(df, fold)
        fold_scores.append(score)
    print("\nCV results:", fold_scores)
    print("Mean CV:", np.mean(fold_scores), "Std:", np.std(fold_scores))

if __name__ == '__main__':
    main()
