# **Biomass Train Data Mutual Importance**

## Related notebooks for the 2nd approach

**check mutual importance to plan strategy for this compe(new, this notebook)** <br>
https://www.kaggle.com/code/stpeteishii/biomass-train-data-mutual-importance<br>

fit model to predict target using tabular data (use the former one)<br>
https://www.kaggle.com/code/stpeteishii/biomass-train-data-visualize-importance<br>

fit model to predict pre-gshh-ndvi/height-ave-cm using images (use the former ones)<br>
https://www.kaggle.com/code/stpeteishii/pre-gshh-ndvi-pytorch-lightning-cnn-regressor<br>
https://www.kaggle.com/code/stpeteishii/height-ave-cm-pytorch-lightning-cnn-regressor<br>

fit model to predict species using tabular data (new, not published)<br>
https://www.kaggle.com/code/stpeteishii/biomass-train-data-wo-target-species-importance<br>

predict test species and test target (new, not published) <br>
https://www.kaggle.com/code/stpeteishii/biomass-test-inference-the-2nd-approach<br>

In [None]:
# Mutual Importance Analysis for All Features
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from time import time
from tqdm import tqdm
import lightgbm as lgbm
import joblib
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import LabelEncoder

# ============================================================================
# Timer Utility
# ============================================================================
class Timer:
    def __init__(self, logger=None, format_str='{:.3f}[s]', prefix=None, suffix=None, sep=' '):
        if prefix: format_str = str(prefix) + sep + format_str
        if suffix: format_str = format_str + sep + str(suffix)
        self.format_str = format_str
        self.logger = logger
        self.start = None
        self.end = None

    @property
    def duration(self):
        if self.end is None:
            return 0
        return self.end - self.start

    def __enter__(self):
        self.start = time()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.end = time()
        out_str = self.format_str.format(self.duration)
        if self.logger:
            self.logger.info(out_str)
        else:
            print(out_str)

# ============================================================================
# Data Preparation
# ============================================================================
data0 = pd.read_csv("/kaggle/input/csiro-biomass/train.csv")

# Delete unnecessary columns
delete_cols = ['sample_id', 'image_path', 'Sampling_Date', 'State']
data0 = data0.drop(columns=delete_cols, axis=1)

# Encode target_name
target_names = sorted(data0['target_name'].unique().tolist())
target_name_mapping = dict(zip(target_names, list(range(len(target_names)))))
data0['target_name'] = data0['target_name'].map(target_name_mapping)

# Label encode categorical columns
def labelencoder(df):
    for c in df.columns:
        if df[c].dtype == 'object': 
            df[c] = df[c].fillna('N')
            lbl = LabelEncoder()
            lbl.fit(list(df[c].values))
            df[c] = lbl.transform(df[c].values)
    return df

data1 = labelencoder(data0)

# ============================================================================
# Model Training Function
# ============================================================================
def fit_lgbm(X, y, cv, params: dict = None, verbose: int = 50):
    """Train LightGBM model with cross-validation"""
    if params is None:
        params = {}

    models = []
    oof_pred = np.zeros_like(y, dtype=float)

    for i, (idx_train, idx_valid) in enumerate(cv): 
        x_train, y_train = X[idx_train], y[idx_train]
        x_valid, y_valid = X[idx_valid], y[idx_valid]

        clf = lgbm.LGBMRegressor(**params)
        
        with Timer(prefix=f'Fit fold={i} '):
            clf.fit(x_train, y_train, 
                    eval_set=[(x_valid, y_valid)],
                    callbacks=[lgbm.early_stopping(stopping_rounds=50, verbose=False)])

        pred_i = clf.predict(x_valid)
        oof_pred[idx_valid] = pred_i
        models.append(clf)
        print(f'Fold {i} RMSE: {mean_squared_error(y_valid, pred_i) ** .5:.4f}')

    score = mean_squared_error(y, oof_pred) ** .5
    print(f'Overall RMSE: {score:.4f}\n')
    return oof_pred, models

# ============================================================================
# Mutual Importance Calculation
# ============================================================================
def calculate_mutual_importance(data, target_col=None, n_splits=5, random_state=42):
    """
    Calculate mutual importance matrix for all features
    
    Args:
        data: DataFrame with all features
        target_col: Optional target column name to exclude from predictors
        n_splits: Number of CV folds
        random_state: Random seed
        
    Returns:
        importance_matrix: DataFrame where [i,j] represents importance of feature j 
                          when predicting feature i
    """
    
    # LightGBM parameters
    params = {
        'objective': 'rmse', 
        'learning_rate': 0.1,
        'reg_lambda': 1.0,
        'reg_alpha': 0.1,
        'max_depth': 5, 
        'n_estimators': 500, 
        'colsample_bytree': 0.5, 
        'min_child_samples': 10,
        'subsample_freq': 3,
        'subsample': 0.9,
        'importance_type': 'gain', 
        'random_state': random_state,
        'num_leaves': 31,
        'verbose': -1
    }
    
    # Prepare columns
    if target_col and target_col in data.columns:
        columns = [col for col in data.columns if col != target_col]
    else:
        columns = data.columns.tolist()
    
    # Initialize importance matrix
    importance_matrix = pd.DataFrame(
        np.zeros((len(columns), len(columns))),
        index=columns,
        columns=columns
    )
    
    # Calculate importance for each feature as target
    for target_feature in tqdm(columns, desc="Calculating mutual importance"):
        # Prepare data: use all other features to predict target_feature
        predictor_cols = [col for col in columns if col != target_feature]
        X = data[predictor_cols].values
        y = data[target_feature].values
        
        # Cross-validation
        fold = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
        cv = list(fold.split(X, y))
        
        # Train model
        print(f"\n{'='*60}")
        print(f"Target Feature: {target_feature}")
        print(f"{'='*60}")
        oof, models = fit_lgbm(X, y, cv, params=params)
        
        # Average feature importance across folds
        avg_importance = np.mean([model.feature_importances_ for model in models], axis=0)
        
        # Store in matrix
        importance_matrix.loc[target_feature, predictor_cols] = avg_importance
    
    return importance_matrix

# ============================================================================
# Visualization Functions
# ============================================================================
def visualize_importance_heatmap(importance_matrix, figsize=(14, 12), cmap='viridis'):
    """Visualize mutual importance matrix as heatmap"""
    
    fig, ax = plt.subplots(figsize=figsize)
    
    # Create heatmap
    sns.heatmap(importance_matrix, 
                annot=False,
                fmt='.2f',
                cmap=cmap,
                cbar_kws={'label': 'Feature Importance'},
                ax=ax)
    
    ax.set_title('Mutual Feature Importance Matrix\n(Row: Target Feature, Column: Predictor Feature)', 
                 fontsize=14, pad=20)
    ax.set_xlabel('Predictor Features', fontsize=12)
    ax.set_ylabel('Target Features', fontsize=12)
    
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    
    return fig, ax

def visualize_top_relationships(importance_matrix, top_n=20):
    """Visualize top N feature relationships"""
    
    # Get all pairwise relationships (excluding diagonal)
    relationships = []
    for target in importance_matrix.index:
        for predictor in importance_matrix.columns:
            if target != predictor:
                relationships.append({
                    'target': target,
                    'predictor': predictor,
                    'importance': importance_matrix.loc[target, predictor]
                })
    
    # Sort and get top N
    df_relationships = pd.DataFrame(relationships)
    df_relationships = df_relationships.sort_values('importance', ascending=False).head(top_n)
    
    # Create labels
    df_relationships['pair'] = df_relationships['predictor'] + ' → ' + df_relationships['target']
    
    # Plot
    fig, ax = plt.subplots(figsize=(10, max(6, top_n * 0.3)))
    
    sns.barplot(data=df_relationships, 
                y='pair', 
                x='importance',
                palette='rocket',
                ax=ax)
    
    ax.set_title(f'Top {top_n} Feature Relationships by Importance', fontsize=14, pad=15)
    ax.set_xlabel('Feature Importance', fontsize=12)
    ax.set_ylabel('Feature Relationship', fontsize=12)
    ax.grid(axis='x', alpha=0.3)
    
    plt.tight_layout()
    
    return fig, ax, df_relationships

def analyze_feature_connectivity(importance_matrix, threshold=None):
    """Analyze how connected each feature is to others"""
    
    if threshold is None:
        threshold = importance_matrix.values[importance_matrix.values > 0].mean()
    
    # Count strong connections for each feature
    # As predictor (how many features depend on it)
    as_predictor = (importance_matrix > threshold).sum(axis=0)
    # As target (how many features it depends on)
    as_target = (importance_matrix > threshold).sum(axis=1)
    
    connectivity_df = pd.DataFrame({
        'Feature': importance_matrix.index,
        'As_Predictor': as_predictor.values,
        'As_Target': as_target.values,
        'Total_Connections': as_predictor.values + as_target.values
    }).sort_values('Total_Connections', ascending=False)
    
    return connectivity_df

# ============================================================================
# Main Analysis
# ============================================================================
if __name__ == "__main__":
    
    # Calculate mutual importance INCLUDING 'target'
    print("\n" + "="*70)
    print("STARTING MUTUAL IMPORTANCE ANALYSIS (INCLUDING 'target')")
    print("="*70 + "\n")
    
    print("⚠ NOTE: 'target' is included in the analysis.")
    print("This will reveal:")
    print("  1. How features predict 'target' (forward direction)")
    print("  2. How 'target' predicts features (reverse direction)")
    print("  3. Hidden relationships between all variables\n")
    
    importance_matrix = calculate_mutual_importance(
        data1, 
        target_col=None,  # Include ALL columns including 'target'
        n_splits=3,  # Use fewer splits for faster computation
        random_state=42
    )
    
    # Save results
    os.makedirs('mutual_importance_results', exist_ok=True)
    importance_matrix.to_csv('mutual_importance_results/importance_matrix.csv')
    print("\n✓ Importance matrix saved to 'mutual_importance_results/importance_matrix.csv'")
    
    # Visualize heatmap
    print("\n" + "="*70)
    print("CREATING VISUALIZATIONS")
    print("="*70 + "\n")
    
    fig1, ax1 = visualize_importance_heatmap(importance_matrix)
    fig1.savefig('mutual_importance_results/importance_heatmap.png', dpi=150, bbox_inches='tight')
    print("✓ Heatmap saved to 'mutual_importance_results/importance_heatmap.png'")
    
    # Visualize top relationships
    fig2, ax2, top_rels = visualize_top_relationships(importance_matrix, top_n=30)
    fig2.savefig('mutual_importance_results/top_relationships.png', dpi=150, bbox_inches='tight')
    print("✓ Top relationships saved to 'mutual_importance_results/top_relationships.png'")
    
    # Analyze connectivity
    connectivity = analyze_feature_connectivity(importance_matrix)
    connectivity.to_csv('mutual_importance_results/feature_connectivity.csv', index=False)
    print("✓ Connectivity analysis saved to 'mutual_importance_results/feature_connectivity.csv'")
    
    # Display summary statistics
    print("\n" + "="*70)
    print("SUMMARY STATISTICS")
    print("="*70)
    print(f"\nMean importance: {importance_matrix.values.mean():.4f}")
    print(f"Max importance: {importance_matrix.values.max():.4f}")
    print(f"Std importance: {importance_matrix.values.std():.4f}")
    
    print("\n" + "="*70)
    print("TOP 10 MOST CONNECTED FEATURES")
    print("="*70)
    print(connectivity.head(10).to_string(index=False))
    
    print("\n" + "="*70)
    print("TOP 10 FEATURE RELATIONSHIPS")
    print("="*70)
    print(top_rels.head(10).to_string(index=False))
    
    plt.show()

### **From the result above, we are planning another strategy. To improve accuracy of prediction of species, we will use predicted valus of Pre_GSHH_NDVI, Height_Ave_cm, but not images.**
<br>

#### **Images → Pre_GSHH_NDVI, Height_Ave_cm**
####         ↓
#### **Tabular (with Pre_GSHH_NDVI, Height_Ave_cm) → Species**
####         ↓
#### **Tabular (with Species, Pre_GSHH_NDVI, Height_Ave_cm) → target**