# Ablation Study Setup

In this study, we investigate the sensitivity of our method to various hyperparameters. We maintain the same experimental setup as our main simulations, with a $d$-dimensional domain $W \subset \mathbb{R}^d$ and data set $\mathcal{D}_n := \{(X_1, Y_1), \ldots, (X_n,Y_n)\}$ consisting of i.i.d. samples from $(X,Y) \in W \times \mathbb{R}$ where $Y = f(X) + \varepsilon$.

We examine three key hyperparameters of our method:

1. **Number of Trees ($T$)**: The size of the Mondrian forest ensemble.
   - Tested values: $T \in \{5, 10, 25, 50\}$
   - Fixed parameters: $\lambda = 3$, step size $t = 0.1$

2. **Step Size ($t$)**: The parameter used in the finite difference method for gradient estimation.
   - Tested values: $t \in \{0.25, 0.1, 0.4, 0.1\}$
   - Fixed parameters: $\lambda = 3$, $T = 10$ trees

For each parameter configuration, we:
- Use the sample size $n = 3200$ and test size $m = 3200$
- Perform $R = 10$ repeated experiments to assess stability
- Measure three key metrics:
  1. Angle distance between estimated $\hat{H}$ and true $H$
  2. Mean Squared Error (MSE) on test set
  3. Computation time

All experiments are conducted across four scenarios that combine:
- Two different rotation matrices $A$: simple and orthogonal
- Two different link functions $g$: polynomial ($g(x) = x_1^4 + x_2^4$) and exponential ($g(x) = \exp(-0.25 * \min(x_1^2, x_2^2))$)

The remaining parameters are kept constant:
- Total dimension: $d = 5$
- Active dimension: $s = 2$
- Each $X_i \sim U[-1,1]^d$
- Noise: $\varepsilon \sim \mathcal{N}(0, \sigma^2 = 0.01)$

This comprehensive ablation study allows us to:
1. Identify optimal parameter ranges for different scenarios
2. Understand the tradeoff between computational cost and accuracy
3. Assess the robustness of our method to parameter choices
4. Provide practical guidelines for parameter selection in real applications

In [None]:
#!/usr/bin/env python
# coding: utf-8

import numpy as np
import matplotlib.pyplot as plt
from utils import *
import time
from typing import List, Dict, Any
from tqdm import tqdm
import os

class AblationStudy:
    def __init__(self, dim_in, active, tries, x_train, x_test, y_train, y_test, true_H):
        self.dim_in = dim_in
        self.active = active
        self.tries = tries
        self.x_train = x_train
        self.x_test = x_test
        self.y_train = y_train
        self.y_test = y_test
        self.true_H = true_H
        
    def _get_distances(self, forests, sample_size):
        """Calculate distances for all trials at given sample size"""
        distances = []
        for trial in range(self.tries):
            H_0 = forests[sample_size][trial].H
            dist = get_angle_distance(H_0, self.true_H, self.active)
            distances.append(dist)
        return distances
        
    def _run_simulation(self, n_estimators, lifetime, step_size) -> Dict:
        start_time = time.time()
        
        # Create simulation with single sample size
        simulation = Simulation(
            self.dim_in, self.active, n_estimators, self.tries, 
            lifetime, step_size, [3200],  # Use largest sample size only
            self.x_train, self.x_test, self.y_train, self.y_test, self.true_H
        )
        
        # Train the forests
        simulation.train()

        simulation.evaluation_comparison()
        
        # Get distances from the trained forests
        distances = self._get_distances(simulation.forests, 3200)
        
        computation_time = time.time() - start_time
        
        return {
            'distance': np.mean(distances),
            'std': np.std(distances),
            'computation_time': computation_time,
            'mse': simulation.evaluation_results[2][0]['mse'][-1] if simulation.evaluation_results else None
        }
    
    def study_n_estimators(self, lifetime=3, step_size=0.2, n_estimators_range=[5, 10, 25, 50]):
        results = []
        for n_est in tqdm(n_estimators_range, desc="Studying number of trees"):
            result = self._run_simulation(n_est, lifetime, step_size)
            results.append({
                'param_name': 'n_estimators',
                'param_value': n_est,
                **result
            })
        return results
    
    def study_step_size(self, lifetime=3, n_estimators=10, step_sizes=[0.05, 0.2, 0.8, 2]):
        results = []
        for step in tqdm(step_sizes, desc="Studying step size"):
            result = self._run_simulation(n_estimators, lifetime, step)
            results.append({
                'param_name': 'step_size',
                'param_value': step,
                **result
            })
        return results

def plot_ablation_results(results: List[Dict[str, Any]], save_dir: str):
    param_name = results[0]['param_name']
    param_values = [r['param_value'] for r in results]
    distances = [r['distance'] for r in results]
    stds = [r['std'] for r in results]
    times = [r['computation_time'] for r in results]
    mses = [r['mse'] for r in results if r['mse'] is not None]
    
    # Set up the figure with three subplots if MSE is available, otherwise two
    n_plots = 3 if mses else 2
    fig, axes = plt.subplots(1, n_plots, figsize=(5*n_plots, 6))

    if 0.05 in param_values:
        # convert step size so it does not confuse the plot
        param_values = [p/2 for p in param_values]
    
    # Plot distance to true H
    axes[0].errorbar(param_values, distances, yerr=stds, marker='o', capsize=5)
    axes[0].set_xlabel(param_name, fontsize=15)
    axes[0].set_ylabel('Angle Distance from True H', fontsize=15)
    axes[0].set_title(f'Impact on Distance to True H')
    axes[0].grid(True)
    axes[0].tick_params(labelsize=12)
    
    # Plot computation time
    axes[1].plot(param_values, times, marker='o')
    axes[1].set_xlabel(param_name, fontsize=15)
    axes[1].set_ylabel('Computation Time (seconds)', fontsize=15)
    axes[1].set_title(f'Impact on Computation Time')
    axes[1].grid(True)
    axes[1].tick_params(labelsize=12)
    
    # Plot MSE if available
    if mses:
        axes[2].plot(param_values, mses, marker='o')
        axes[2].set_xlabel(param_name, fontsize=15)
        axes[2].set_ylabel('MSE', fontsize=15)
        axes[2].set_title(f'Impact on MSE')
        axes[2].grid(True)
        axes[2].tick_params(labelsize=12)
        y_formatter = ScalarFormatter(useOffset=False)
        y_formatter.set_scientific(True)
        y_formatter.set_powerlimits((-1,1))
        axes[2].yaxis.set_major_formatter(y_formatter)
    
    plt.tight_layout()
    plt.savefig(f"{save_dir}/ablation_{param_name}.png")
    plt.close()

def run_all_scenarios():
    # Parameters
    dim_in = 5
    active = 2
    tries = 10
    np.random.seed(0)
    
    # Generate data
    x_train = np.random.rand(3200, dim_in) * 2 - 1
    x_test = np.random.rand(3200, dim_in) * 2 - 1
    
    scenarios = [
        ("simple_poly", SimulatedData(dim_in, active, rotation="simple", fun="poly")),
        ("simple_max", SimulatedData(dim_in, active, rotation="simple", fun="max")),
        ("orth_poly", SimulatedData(dim_in, active, rotation="orth", fun="poly")),
        ("orth_max", SimulatedData(dim_in, active, rotation="orth", fun="max"))
    ]
    
    for scenario_name, fn in scenarios:
        print(f"\nRunning ablation study for scenario: {scenario_name}")
        
        # Generate labels
        y_train = vmap(fn.fun, in_axes=0, out_axes=0)(x_train)
        y_test = vmap(fn.fun, in_axes=0, out_axes=0)(x_test)
        true_H = fn.get_true_H(x_test)
        
        # Initialize ablation study
        ablation = AblationStudy(dim_in, active, tries, x_train, x_test, y_train, y_test, true_H)
        
        # Create directory for results
        save_dir = f"simulation/ablation_{scenario_name}"
        os.makedirs(save_dir, exist_ok=True)
        
        # Study number of estimators impact
        n_estimators_results = ablation.study_n_estimators()
        plot_ablation_results(n_estimators_results, save_dir)
        
        # Study step size impact
        step_size_results = ablation.study_step_size()
        plot_ablation_results(step_size_results, save_dir)
        
        # Save numerical results
        results = {
            'n_estimators': n_estimators_results,
            'step_size': step_size_results
        }
        np.save(f"{save_dir}/numerical_results.npy", results)

#if __name__ == "__main__":
#    run_all_scenarios()