In [None]:
import warnings
warnings.filterwarnings('ignore')

import os
import gc
import json
import random
import numpy as np
import pandas as pd
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import r2_score
from sklearn.preprocessing import LabelEncoder
from sklearn.linear_model import Ridge
import matplotlib.pyplot as plt
import seaborn as sns
import pickle

# HuggingFace login for DINOv3
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
HF_TOKEN = user_secrets.get_secret("HF_TOKEN")
from huggingface_hub import login
login(token=HF_TOKEN)

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 10)

# ====================================================
# CONFIG
# ====================================================
class Config:
    TRAIN_CSV = '/kaggle/input/csiro-biomass/train.csv'
    IMG_DIR = '/kaggle/input/csiro-biomass/'
    
    # Model specifications
    SWIN_MODEL = 'swin_base_patch4_window12_384'
    SWIN_IMG_SIZE = 384
    
    DINOV3_MODEL = 'facebook/dinov3-vitb16-pretrain-lvd1689m'
    DINOV3_IMG_SIZE = 224
    
    SIGLIP_MODEL = 'vit_so400m_patch14_siglip_384'
    SIGLIP_IMG_SIZE = 384
    
    # Training
    SEED = 42
    N_FOLDS = 5
    EPOCHS_STAGE1 = 35
    
    BATCH_SIZE = 8
    ACCUMULATION_STEPS = 2
    LR = 3e-5
    WEIGHT_DECAY = 1e-4
    
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Ensemble strategy
    ENSEMBLE_METHOD = 'learnable'
    
    # Normalization
    TARGET_MEAN = None
    TARGET_STD = None

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(Config.SEED)

def clear_gpu_memory():
    gc.collect()
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.synchronize()

# ====================================================
# PLOTTING FUNCTION (SIMPLIFIED)
# ====================================================
def plot_training_metrics(history, model_name, fold):
    """Plot training/validation loss and R² in one figure"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Plot 1: Loss curves
    ax1.plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2.5, marker='o', markersize=6)
    ax1.plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2.5, marker='s', markersize=6)
    ax1.set_xlabel('Epoch', fontsize=13, fontweight='bold')
    ax1.set_ylabel('Loss', fontsize=13, fontweight='bold')
    ax1.set_title(f'{model_name.upper()} - Fold {fold+1} | Loss', fontsize=15, fontweight='bold')
    ax1.legend(fontsize=12, loc='upper right')
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: R² curve
    ax2.plot(epochs, history['val_r2'], 'g-', label='Val R²', linewidth=3, marker='D', markersize=7)
    best_r2 = max(history['val_r2'])
    best_epoch = history['val_r2'].index(best_r2) + 1
    ax2.axhline(y=best_r2, color='red', linestyle='--', linewidth=2, alpha=0.7)
    ax2.axvline(x=best_epoch, color='orange', linestyle='--', linewidth=2, alpha=0.7)
    ax2.scatter([best_epoch], [best_r2], color='red', s=200, zorder=5, edgecolors='black', linewidths=2)
    ax2.text(best_epoch, best_r2, f'  Best: {best_r2:.4f}\n  Epoch: {best_epoch}', 
             fontsize=11, fontweight='bold', va='bottom')
    ax2.set_xlabel('Epoch', fontsize=13, fontweight='bold')
    ax2.set_ylabel('R² Score', fontsize=13, fontweight='bold')
    ax2.set_title(f'{model_name.upper()} - Fold {fold+1} | R²', fontsize=15, fontweight='bold')
    ax2.legend(fontsize=12, loc='lower right')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'{model_name}_fold{fold}_metrics.png', dpi=150, bbox_inches='tight')
    plt.show()
    plt.close()

# ====================================================
# DATA PREPARATION
# ====================================================
def get_data():
    df = pd.read_csv(Config.TRAIN_CSV)
    
    df_pivot = df.pivot(index='image_path', columns='target_name', values='target').reset_index()
    
    meta_df = df.groupby('image_path').agg({
        'State': 'first',
        'Species': 'first',
        'Pre_GSHH_NDVI': 'first',
        'Height_Ave_cm': 'first',
        'Sampling_Date': 'first'
    }).reset_index()
    
    train_df = df_pivot.merge(meta_df, on='image_path', how='left').fillna(0)
    
    state_le = LabelEncoder()
    train_df['state_idx'] = state_le.fit_transform(train_df['State'].astype(str))
    
    target_cols = ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g', 'GDM_g', 'Dry_Total_g']
    
    Config.TARGET_MEAN = train_df[target_cols].mean().values.astype(np.float32)
    Config.TARGET_STD = train_df[target_cols].std().values.astype(np.float32)
    
    print(f"TARGET_MEAN: {Config.TARGET_MEAN}")
    print(f"TARGET_STD: {Config.TARGET_STD}")
    
    return train_df, target_cols, state_le

# ====================================================
# DATASET
# ====================================================
class BiomassDataset(Dataset):
    def __init__(self, df, target_cols, transform=None):
        self.df = df.reset_index(drop=True)
        self.target_cols = target_cols
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(Config.IMG_DIR, row['image_path'])
        
        image = cv2.imread(img_path)
        if image is None:
            image = np.zeros((384, 384, 3), dtype=np.uint8)
        else:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            image = self.transform(image=image)['image']
        
        targets = row[self.target_cols].values.astype(np.float32)
        targets_norm = (targets - Config.TARGET_MEAN) / (Config.TARGET_STD + 1e-6)
        
        return image, torch.tensor(targets_norm), row['image_path']

# ====================================================
# TRANSFORMS
# ====================================================
def get_transforms(img_size, data='train'):
    if data == 'train':
        return A.Compose([
            A.Resize(img_size, img_size),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
    else:
        return A.Compose([
            A.Resize(img_size, img_size),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])

# ====================================================
# MODELS
# ====================================================
class SwinRegressor(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model(Config.SWIN_MODEL, pretrained=True, num_classes=0, global_pool='avg')
        feat_dim = self.backbone.num_features
        print(f"Swin feature dimension: {feat_dim}")
        self.head = nn.Sequential(
            nn.Linear(feat_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(0.3),
            nn.Linear(512, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(0.2),
            nn.Linear(256, 5)
        )
    def forward(self, x):
        return self.head(self.backbone(x))

class DINOv3Regressor(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(Config.DINOV3_MODEL, trust_remote_code=True)
        feat_dim = self.backbone.config.hidden_size
        print(f"DINOv3 feature dimension: {feat_dim}")
        self.head = nn.Sequential(
            nn.Linear(feat_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(0.3),
            nn.Linear(512, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(0.2),
            nn.Linear(256, 5)
        )
    def forward(self, x):
        return self.head(self.backbone(pixel_values=x).pooler_output)

class SigLIPRegressor(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model(Config.SIGLIP_MODEL, pretrained=True, num_classes=0, global_pool='avg')
        feat_dim = self.backbone.num_features
        print(f"SigLIP feature dimension: {feat_dim}")
        self.head = nn.Sequential(
            nn.Linear(feat_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(0.3),
            nn.Linear(512, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(0.2),
            nn.Linear(256, 5)
        )
    def forward(self, x):
        return self.head(self.backbone(x))

class LearnableEnsemble(nn.Module):
    def __init__(self, num_models=3, num_targets=5):
        super().__init__()
        self.weights = nn.Parameter(torch.ones(num_models, num_targets) / num_models)
    def forward(self, predictions_list):
        stacked = torch.stack(predictions_list, dim=0)
        normalized_weights = F.softmax(self.weights, dim=0)
        ensemble_pred = torch.einsum('mbt,mt->bt', stacked, normalized_weights)
        return ensemble_pred, normalized_weights

# ====================================================
# EVALUATION
# ====================================================
def evaluate_r2(preds, targets):
    weights = np.array([0.1, 0.1, 0.1, 0.2, 0.5])
    w_j = np.tile(weights, (len(targets), 1))
    y_bar_w = np.sum(w_j * targets) / np.sum(w_j)
    ss_res = np.sum(w_j * (targets - preds)**2)
    ss_tot = np.sum(w_j * (targets - y_bar_w)**2)
    weighted_r2 = 1 - (ss_res / ss_tot)
    per_target_r2 = r2_score(targets, preds, multioutput='raw_values')
    return weighted_r2, per_target_r2

# ====================================================
# TRAINING (WITH HISTORY TRACKING)
# ====================================================
def train_single_model(model, train_loader, val_loader, model_name, fold):
    print(f"\n{'='*70}")
    print(f"Training {model_name} - Fold {fold}")
    print(f"{'='*70}")
    
    optimizer = optim.AdamW(model.parameters(), lr=Config.LR, weight_decay=Config.WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.EPOCHS_STAGE1)
    criterion = nn.SmoothL1Loss()
    scaler = torch.cuda.amp.GradScaler()
    
    best_r2 = -np.inf
    patience = 0
    max_patience = 5
    
    # Track history
    history = {'train_loss': [], 'val_loss': [], 'val_r2': []}
    
    for epoch in range(Config.EPOCHS_STAGE1):
        # Train
        model.train()
        train_loss = 0.0
        optimizer.zero_grad()
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{Config.EPOCHS_STAGE1}")
        for batch_idx, (imgs, targets_norm, _) in enumerate(pbar):
            imgs = imgs.to(Config.DEVICE)
            targets_norm = targets_norm.to(Config.DEVICE)
            
            with torch.cuda.amp.autocast():
                preds = model(imgs)
                loss = criterion(preds, targets_norm)
                loss = loss / Config.ACCUMULATION_STEPS
            
            scaler.scale(loss).backward()
            
            if (batch_idx + 1) % Config.ACCUMULATION_STEPS == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            
            train_loss += loss.item() * Config.ACCUMULATION_STEPS
            pbar.set_postfix({'loss': loss.item() * Config.ACCUMULATION_STEPS})
        
        train_loss /= len(train_loader)
        scheduler.step()
        
        # Validate
        model.eval()
        val_loss = 0.0
        val_preds = []
        val_targets = []
        
        with torch.no_grad():
            for imgs, targets_norm, _ in val_loader:
                imgs = imgs.to(Config.DEVICE)
                targets_norm_gpu = targets_norm.to(Config.DEVICE)
                
                with torch.cuda.amp.autocast():
                    preds = model(imgs)
                    loss = criterion(preds, targets_norm_gpu)
                    val_loss += loss.item()
                    preds = preds.cpu().numpy()
                
                val_preds.append(preds)
                val_targets.append(targets_norm.numpy())
        
        val_loss /= len(val_loader)
        val_preds = np.vstack(val_preds)
        val_targets = np.vstack(val_targets)
        
        # Denormalize
        val_preds_denorm = val_preds * Config.TARGET_STD + Config.TARGET_MEAN
        val_targets_denorm = val_targets * Config.TARGET_STD + Config.TARGET_MEAN
        val_preds_denorm = np.maximum(0, val_preds_denorm)
        
        val_r2, per_target = evaluate_r2(val_preds_denorm, val_targets_denorm)
        
        # Store history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_r2'].append(val_r2)
        
        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, Val R²={val_r2:.4f}")
        
        if val_r2 > best_r2:
            best_r2 = val_r2
            patience = 0
            torch.save(model.state_dict(), f'{model_name}_fold{fold}_best.pth')
            print(f"  ✓ Saved best model (R²={best_r2:.4f})")
        else:
            patience += 1
            if patience >= max_patience:
                print(f"  Early stopping at epoch {epoch+1}")
                break
    
    # Plot training metrics
    plot_training_metrics(history, model_name, fold)
    
    # Load best weights
    model.load_state_dict(torch.load(f'{model_name}_fold{fold}_best.pth'))
    return model, best_r2

# ====================================================
# EXTRACT PREDICTIONS
# ====================================================
def extract_predictions_from_model(model, loader, model_name):
    model.eval()
    all_preds = []
    all_paths = []
    
    with torch.no_grad():
        for imgs, _, paths in tqdm(loader, desc=f"Extracting {model_name}"):
            imgs = imgs.to(Config.DEVICE)
            with torch.cuda.amp.autocast():
                preds = model(imgs).cpu().numpy()
            all_preds.append(preds)
            all_paths.extend(paths)
    
    all_preds = np.vstack(all_preds)
    return all_preds, all_paths

# ====================================================
# ENSEMBLE WEIGHTS
# ====================================================
def train_ensemble_weights(swin_preds, dinov3_preds, siglip_preds, targets, method='learnable'):
    print(f"\n{'='*70}")
    print(f"Training Ensemble Weights ({method})")
    print(f"{'='*70}")
    
    if method == 'simple_average':
        weights = np.array([1/3, 1/3, 1/3])
        ensemble_preds = (swin_preds + dinov3_preds + siglip_preds) / 3
        ensemble_r2, _ = evaluate_r2(ensemble_preds, targets)
        print(f"Simple Average R²: {ensemble_r2:.4f}")
        return weights, ensemble_preds
    
    elif method == 'ridge':
        X = np.hstack([swin_preds, dinov3_preds, siglip_preds])
        ensemble_preds = np.zeros_like(targets)
        ridge_models = []
        
        for target_idx in range(5):
            ridge = Ridge(alpha=1.0)
            ridge.fit(X, targets[:, target_idx])
            ensemble_preds[:, target_idx] = ridge.predict(X)
            ridge_models.append(ridge)
        
        ensemble_preds = np.maximum(0, ensemble_preds)
        ensemble_r2, _ = evaluate_r2(ensemble_preds, targets)
        print(f"Ridge Ensemble R²: {ensemble_r2:.4f}")
        
        all_coefs = np.array([m.coef_ for m in ridge_models])
        avg_weights = np.abs(all_coefs).mean(axis=0)
        swin_weight = avg_weights[:5].mean()
        dinov3_weight = avg_weights[5:10].mean()
        siglip_weight = avg_weights[10:15].mean()
        total = swin_weight + dinov3_weight + siglip_weight
        weights = np.array([swin_weight, dinov3_weight, siglip_weight]) / total
        
        with open('ridge_ensemble.pkl', 'wb') as f:
            pickle.dump(ridge_models, f)
        
        return weights, ensemble_preds
    
    elif method == 'learnable':
        dataset = torch.utils.data.TensorDataset(
            torch.tensor(swin_preds, dtype=torch.float32),
            torch.tensor(dinov3_preds, dtype=torch.float32),
            torch.tensor(siglip_preds, dtype=torch.float32),
            torch.tensor(targets, dtype=torch.float32)
        )
        loader = DataLoader(dataset, batch_size=32, shuffle=True)
        
        ensemble_model = LearnableEnsemble().to(Config.DEVICE)
        optimizer = optim.Adam(ensemble_model.parameters(), lr=1e-3)
        criterion = nn.MSELoss()
        
        for epoch in range(100):
            ensemble_model.train()
            total_loss = 0
            
            for swin_b, dinov3_b, siglip_b, targets_b in loader:
                swin_b = swin_b.to(Config.DEVICE)
                dinov3_b = dinov3_b.to(Config.DEVICE)
                siglip_b = siglip_b.to(Config.DEVICE)
                targets_b = targets_b.to(Config.DEVICE)
                
                optimizer.zero_grad()
                ensemble_pred, _ = ensemble_model([swin_b, dinov3_b, siglip_b])
                loss = criterion(ensemble_pred, targets_b)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            
            if (epoch + 1) % 20 == 0:
                print(f"Epoch {epoch+1}: Loss={total_loss/len(loader):.4f}")
        
        ensemble_model.eval()
        with torch.no_grad():
            ensemble_preds, final_weights = ensemble_model([
                torch.tensor(swin_preds).to(Config.DEVICE),
                torch.tensor(dinov3_preds).to(Config.DEVICE),
                torch.tensor(siglip_preds).to(Config.DEVICE)
            ])
            ensemble_preds = ensemble_preds.cpu().numpy()
            final_weights = final_weights.cpu().numpy()
        
        ensemble_preds = np.maximum(0, ensemble_preds)
        ensemble_r2, _ = evaluate_r2(ensemble_preds, targets)
        
        print(f"\nLearnable Ensemble R²: {ensemble_r2:.4f}")
        avg_weights = final_weights.mean(axis=1)
        print(f"Avg Weights: Swin={avg_weights[0]:.4f}, DINOv3={avg_weights[1]:.4f}, SigLIP={avg_weights[2]:.4f}")
        
        torch.save(ensemble_model.state_dict(), 'learnable_ensemble.pth')
        return avg_weights, ensemble_preds

# ====================================================
# MAIN
# ====================================================
def main():
    print("="*70)
    print("SWIN + DINOV3 + SIGLIP DYNAMIC ENSEMBLE")
    print("="*70)
    
    train_df, target_cols, state_le = get_data()
    train_df = train_df[~train_df['image_path'].str.contains('ID230058600')].reset_index(drop=True)
    print(f"\n✓ Training samples: {len(train_df)}")
    
    sgkf = StratifiedGroupKFold(n_splits=Config.N_FOLDS, shuffle=True, random_state=Config.SEED)
    splits = list(sgkf.split(train_df, y=train_df['state_idx'], groups=train_df['Sampling_Date']))
    
    fold_results = []
    
    for fold, (train_idx, val_idx) in enumerate(splits):
        print(f"\n\n{'#'*70}")
        print(f"FOLD {fold+1}/{Config.N_FOLDS}")
        print(f"{'#'*70}\n")
        
        df_train = train_df.iloc[train_idx].reset_index(drop=True)
        df_val = train_df.iloc[val_idx].reset_index(drop=True)
        val_targets = df_val[target_cols].values
        
        # SWIN
        print(f"\n{'='*70}\nSTAGE 1: SWIN\n{'='*70}")
        swin_train_ds = BiomassDataset(df_train, target_cols, get_transforms(Config.SWIN_IMG_SIZE, 'train'))
        swin_val_ds = BiomassDataset(df_val, target_cols, get_transforms(Config.SWIN_IMG_SIZE, 'val'))
        swin_train_loader = DataLoader(swin_train_ds, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=2)
        swin_val_loader = DataLoader(swin_val_ds, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=2)
        swin_model = SwinRegressor().to(Config.DEVICE)
        swin_model, swin_r2 = train_single_model(swin_model, swin_train_loader, swin_val_loader, 'swin', fold)
        swin_val_preds, _ = extract_predictions_from_model(swin_model, swin_val_loader, 'Swin')
        swin_val_preds = np.maximum(0, swin_val_preds * Config.TARGET_STD + Config.TARGET_MEAN)
        del swin_model, swin_train_loader, swin_val_loader, swin_train_ds, swin_val_ds
        clear_gpu_memory()
        
        # DINOV3
        print(f"\n{'='*70}\nSTAGE 2: DINOV3\n{'='*70}")
        dinov3_train_ds = BiomassDataset(df_train, target_cols, get_transforms(Config.DINOV3_IMG_SIZE, 'train'))
        dinov3_val_ds = BiomassDataset(df_val, target_cols, get_transforms(Config.DINOV3_IMG_SIZE, 'val'))
        dinov3_train_loader = DataLoader(dinov3_train_ds, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=2)
        dinov3_val_loader = DataLoader(dinov3_val_ds, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=2)
        dinov3_model = DINOv3Regressor().to(Config.DEVICE)
        dinov3_model, dinov3_r2 = train_single_model(dinov3_model, dinov3_train_loader, dinov3_val_loader, 'dinov3', fold)
        dinov3_val_preds, _ = extract_predictions_from_model(dinov3_model, dinov3_val_loader, 'DINOv3')
        dinov3_val_preds = np.maximum(0, dinov3_val_preds * Config.TARGET_STD + Config.TARGET_MEAN)
        del dinov3_model, dinov3_train_loader, dinov3_val_loader, dinov3_train_ds, dinov3_val_ds
        clear_gpu_memory()
        
        # SIGLIP
        print(f"\n{'='*70}\nSTAGE 3: SIGLIP\n{'='*70}")
        siglip_train_ds = BiomassDataset(df_train, target_cols, get_transforms(Config.SIGLIP_IMG_SIZE, 'train'))
        siglip_val_ds = BiomassDataset(df_val, target_cols, get_transforms(Config.SIGLIP_IMG_SIZE, 'val'))
        siglip_train_loader = DataLoader(siglip_train_ds, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=2)
        siglip_val_loader = DataLoader(siglip_val_ds, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=2)
        siglip_model = SigLIPRegressor().to(Config.DEVICE)
        siglip_model, siglip_r2 = train_single_model(siglip_model, siglip_train_loader, siglip_val_loader, 'siglip', fold)
        siglip_val_preds, _ = extract_predictions_from_model(siglip_model, siglip_val_loader, 'SigLIP')
        siglip_val_preds = np.maximum(0, siglip_val_preds * Config.TARGET_STD + Config.TARGET_MEAN)
        del siglip_model, siglip_train_loader, siglip_val_loader, siglip_train_ds, siglip_val_ds
        clear_gpu_memory()
        
        # Ensemble
        ensemble_weights, ensemble_preds = train_ensemble_weights(
            swin_val_preds, dinov3_val_preds, siglip_val_preds, val_targets, method=Config.ENSEMBLE_METHOD
        )
        ensemble_r2, per_target_r2 = evaluate_r2(ensemble_preds, val_targets)
        
        fold_results.append({
            'fold': fold,
            'swin_r2': swin_r2,
            'dinov3_r2': dinov3_r2,
            'siglip_r2': siglip_r2,
            'ensemble_r2': ensemble_r2,
            'ensemble_weights': ensemble_weights,
            'per_target_r2': per_target_r2
        })
        
        print(f"\n{'='*70}")
        print(f"FOLD {fold+1} SUMMARY")
        print(f"{'='*70}")
        print(f"Swin: {swin_r2:.4f} | DINOv3: {dinov3_r2:.4f} | SigLIP: {siglip_r2:.4f} | Ensemble: {ensemble_r2:.4f}")
        
        clear_gpu_memory()
    
    # Final summary
    print(f"\n\n{'#'*70}\nFINAL RESULTS\n{'#'*70}\n")
    avg_ensemble = np.mean([r['ensemble_r2'] for r in fold_results])
    print(f"Average Ensemble R²: {avg_ensemble:.4f} ± {np.std([r['ensemble_r2'] for r in fold_results]):.4f}")
    
    # Save config
    config_dict = {
        'n_folds': Config.N_FOLDS,
        'target_mean': Config.TARGET_MEAN.tolist(),
        'target_std': Config.TARGET_STD.tolist(),
        'avg_ensemble_r2': float(avg_ensemble)
    }
    
    with open('ensemble_config.json', 'w') as f:
        json.dump(config_dict, f, indent=4)
    
    print(f"\n✓ Saved ensemble_config.json")
    print(f"✓ Training complete!")

if __name__ == '__main__':
    main()
