In [1]:
# Parquet files to analyze
filenames = [
    # "digestibility_custom2",
    # "digestibility_custom3",
    # "digestibility_custom5",
    "digestibility_0_8",
    # "hardness_custom2",
    # "hardness_custom4",
    # "hardness_0_8",
    # "tannin_custom2",
    # "tannin_custom3",
    # "tannin_0_8"
]

# folders = ["local", "_denis", "_wsl"]

# for file in filenames:
#     input_files = [
#         f"workspace/{folders[0]}/{file}.meta.parquet",
#         f"workspace/{folders[1]}/{file}.meta.parquet",
#         f"workspace/{folders[2]}/{file}.meta.parquet"
#     ]
#     output_file = f"{file}.meta.parquet"
#     Predictions.merge_parquet_files(
#         input_files=input_files,
#         output_file=f"workspace/{output_file}",
#         deduplicate=True  # Remove duplicate prediction IDs (default)
#     )

WORKSPACE_PATH = "workspace"

In [2]:
from pathlib import Path
import matplotlib.pyplot as plt
import os

from nirs4all.data.predictions import Predictions
from nirs4all.visualization.charts import ChartConfig
from nirs4all.visualization.predictions import PredictionAnalyzer

In [3]:
def analyze_predictions_file(predictions_path: str, save_dir: str = "charts", exclude_models: list = None):
    """Analyze a single predictions parquet file with all visualizations.

    Args:
        predictions_path: Path to the predictions parquet file.
        save_dir: Directory to save chart images.
        exclude_models: List of model names to exclude from analysis.
    """
    import polars as pl

    # Check if file exists
    if not Path(predictions_path).exists():
        print(f"‚ö†Ô∏è  File not found: {predictions_path}")
        return None

    # Load predictions
    predictions = Predictions()
    predictions.load_from_file(predictions_path)

    # Exclude specified models by filtering the internal DataFrame
    if exclude_models:
        original_count = len(predictions)
        df = predictions._storage._df
        predictions._storage._df = df.filter(~pl.col("model_name").is_in(exclude_models))
        excluded_count = original_count - len(predictions)
        if excluded_count > 0:
            print(f"üö´ Excluded {excluded_count} predictions from models: {exclude_models}")

    file_name = Path(predictions_path).stem.replace('.meta', '')

    # Create output directory for this file
    output_dir = Path(save_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    print(f"\n{'='*80}")
    print(f"üìä Analyzing: {file_name}")
    print(f"üìÅ Saving charts to: {output_dir}")
    print(f"{'='*80}")

    def save_figure(fig, name: str):
        """Helper to save figure with meaningful name."""
        filepath = output_dir / f"{file_name}_{name}.png"
        fig.savefig(filepath, dpi=150, bbox_inches='tight', facecolor='white')
        print(f"üíæ Saved: {filepath.name}")
        plt.close(fig)

    # =========================================================================
    # SYNTHETIC DATAVIZ: Summary statistics
    # =========================================================================
    print(f"\nüìà Summary Statistics:")
    print(f"   ‚Ä¢ Number of predictions: {len(predictions)}")
    print(f"   ‚Ä¢ Models: {len(predictions.get_models())} unique ({', '.join(predictions.get_models()[:5])}{'...' if len(predictions.get_models()) > 5 else ''})")
    print(f"   ‚Ä¢ Datasets: {predictions.get_datasets()}")
    print(f"   ‚Ä¢ Partitions: {predictions.get_partitions()}")
    print(f"   ‚Ä¢ Configs: {len(predictions.get_configs())} unique")
    print(f"   ‚Ä¢ Folds: {predictions.get_folds()}")

    # Determine task type
    sample_preds = predictions.filter_predictions(partition='test', load_arrays=False)
    if sample_preds:
        task_type = sample_preds[0].get('task_type', 'regression')
        is_classification = 'classif' in task_type.lower()
    else:
        task_type = 'regression'
        is_classification = False

    print(f"   ‚Ä¢ Task type: {task_type}")

    # Select metrics based on task type
    if is_classification:
        rank_metric = 'balanced_accuracy'
        display_metrics = ['accuracy', 'balanced_accuracy', 'f1']
    else:
        rank_metric = 'rmse'
        display_metrics = ['rmse', 'r2', 'mae']

    # =========================================================================
    # TOP MODELS
    # =========================================================================
    print(f"\nüèÜ Top 5 models by {rank_metric} (val):")
    top_models = predictions.top(n=5, rank_metric=rank_metric, rank_partition='val')
    for idx, model in enumerate(top_models, 1):
        summary = Predictions.pred_short_string(model, metrics=display_metrics, partition=['val', 'test'])
        print(f"   {idx}. {summary}")

    print(f"\nüèÜ Top 5 models by {rank_metric} (val) - Agg reps:")
    top_models = predictions.top(n=5, rank_metric=rank_metric, rank_partition='val', aggregate="ID")
    for idx, model in enumerate(top_models, 1):
        summary = Predictions.pred_short_string(model, metrics=display_metrics, partition=['val', 'test'])
        print(f"   {idx}. {summary}")

    # =========================================================================
    # VISUALIZATIONS
    # =========================================================================
    analyzer = PredictionAnalyzer(predictions, output_dir=None)

    # --- Classification: Top 6 Confusion Matrices (val and test) ---
    if is_classification:
        print(f"\nüìä Confusion Matrices (Top 6 by {rank_metric}):")

        cm_config = ChartConfig(
            # title_fontsize=24,        # Chart titles
            # label_fontsize=14,        # Axis labels (X/Y labels)
            # tick_fontsize=12,         # Tick labels on axes
            # legend_fontsize=11,       # Legend text
            annotation_fontsize=32    # Text inside charts (heatmap cells, etc.)
        )

        # Confusion matrices ranked by test, displayed on test, aggregated
        fig_cm = analyzer.plot_confusion_matrix(
            k=2,
            rank_metric='balanced_accuracy',
            rank_partition='test',
            display_partition='test',
            display_metric=['balanced_accuracy', 'accuracy'],
            aggregate="ID",
            config=cm_config
        )
        plt.suptitle(f"{file_name} - Confusion Matrices (ranked by test, display test) - Agg reps", y=1.02)
        save_figure(fig_cm, "confusion_matrix_rank_test_display_test_agg")

        # Confusion matrices ranked by val, displayed on test, aggregated
        fig_cm = analyzer.plot_confusion_matrix(
            k=6,
            rank_metric='balanced_accuracy',
            rank_partition='val',
            display_partition='test',
            display_metric=['balanced_accuracy', 'accuracy'],
            aggregate="ID",
            config=cm_config
        )
        plt.suptitle(f"{file_name} - Confusion Matrices (ranked by val, display test) - Agg reps", y=1.02)
        save_figure(fig_cm, "confusion_matrix_rank_val_display_test_agg")

        # Confusion matrices ranked by test, displayed on test
        fig_cm = analyzer.plot_confusion_matrix(
            k=2,
            rank_metric='balanced_accuracy',
            rank_partition='test',
            display_partition='test',
            display_metric=['balanced_accuracy', 'accuracy'],
            config=cm_config
        )
        plt.suptitle(f"{file_name} - Confusion Matrices (ranked by test, display test)", y=1.02)
        save_figure(fig_cm, "confusion_matrix_rank_test_display_test")

        # Confusion matrices ranked by val, displayed on test
        fig_cm = analyzer.plot_confusion_matrix(
            k=6,
            rank_metric='balanced_accuracy',
            rank_partition='val',
            display_partition='test',
            display_metric=['balanced_accuracy', 'accuracy'],
            config=cm_config
        )
        plt.suptitle(f"{file_name} - Confusion Matrices (ranked by val, display test)", y=1.02)
        save_figure(fig_cm, "confusion_matrix_rank_val_display_test")

    # --- Regression: Top 3 ---
    else:
        print(f"\nüìä Top 3 Model Comparison:")
        fig_top3 = analyzer.plot_top_k(
            k=3,
            rank_metric='rmse',
            rank_partition='val',
            aggregate="ID"
        )
        plt.suptitle(f"{file_name} - Top 3 Models (ranked by val) - Agg reps", y=1.02)
        save_figure(fig_top3, "top3_models_rank_val_agg")

        fig_top3 = analyzer.plot_top_k(
            k=3,
            rank_metric='rmse',
            rank_partition='val',
        )
        plt.suptitle(f"{file_name} - Top 3 Models (ranked by val)", y=1.02)
        save_figure(fig_top3, "top3_models_rank_val")




    # --- Heatmap ranked by VAL ---
    hm_config = ChartConfig(
        # title_fontsize=18,        # Chart titles
        # label_fontsize=14,        # Axis labels (X/Y labels)
        # tick_fontsize=12,         # Tick labels on axes
        # legend_fontsize=11,       # Legend text
        annotation_fontsize=15    # Text inside charts (heatmap cells, etc.)
    )


    print(f"\nüó∫Ô∏è  Heatmaps:")
    fig_heatmap = analyzer.plot_heatmap(
        x_var="partition",
        y_var="model_name",
        rank_metric=rank_metric,
        display_metric=rank_metric,
        show_counts=False,
        rank_partition='val',
        column_scale = True,
        top_k=20,
        sort_by='borda',
        config=hm_config
    )
    plt.suptitle(f"{file_name} - Heatmap (ranked by val)", y=1.02)
    save_figure(fig_heatmap, f"heatmap_{rank_metric}_rank_val")

    # --- Heatmap ranked by TEST ---
    fig_heatmap = analyzer.plot_heatmap(
        x_var="partition",
        y_var="model_name",
        rank_metric=rank_metric,
        display_metric=rank_metric,
        show_counts=False,
        rank_partition='test',
        column_scale = True,
        top_k=4,
        sort_by_value=True,
        config=hm_config
    )
    plt.suptitle(f"{file_name} - Heatmap (ranked by test)", y=1.02)
    save_figure(fig_heatmap, f"heatmap_{rank_metric}_rank_test")

    # --- Heatmap ranked by VAL - Aggregated ---
    fig_heatmap = analyzer.plot_heatmap(
        x_var="partition",
        y_var="model_name",
        rank_metric=rank_metric,
        display_metric=rank_metric,
        show_counts=False,
        rank_partition='val',
        aggregate="ID",
        column_scale = True,
        top_k=20,
        sort_by='borda',
        config=hm_config
    )
    plt.suptitle(f"{file_name} - Heatmap (ranked by val) - Agg reps", y=1.02)
    save_figure(fig_heatmap, f"heatmap_{rank_metric}_rank_val_agg")

    # --- Heatmap ranked by TEST - Aggregated ---
    fig_heatmap = analyzer.plot_heatmap(
        x_var="partition",
        y_var="model_name",
        rank_metric=rank_metric,
        display_metric=rank_metric,
        show_counts=False,
        rank_partition='test',
        aggregate="ID",
        column_scale = True,
        top_k=4,
        sort_by_value=True,
        config=hm_config
    )
    plt.suptitle(f"{file_name} - Heatmap (ranked by test) - Agg reps", y=1.02)
    save_figure(fig_heatmap, f"heatmap_{rank_metric}_rank_test_agg")

    # --- Candlestick plots ---
    print(f"\nüìä Candlestick plots:")
    # fig_candlestick = analyzer.plot_candlestick(
    #     variable="model_classname",
    #     display_metric=rank_metric,
    #     display_partition='test'
    # )
    # plt.suptitle(f"{file_name} - Score Distribution by Model", y=1.02)
    # save_figure(fig_candlestick, f"candlestick_{rank_metric}_test")

    fig_candlestick = analyzer.plot_candlestick(
        variable="model_classname",
        display_metric=rank_metric,
        display_partition='test',
        aggregate="ID"
    )
    plt.suptitle(f"{file_name} - Score Distribution by Model - Agg reps", y=1.02)
    save_figure(fig_candlestick, f"candlestick_{rank_metric}_test_agg")

    # --- Histograms ---
    # print(f"\nüìä Histograms:")
    # fig_histogram = analyzer.plot_histogram(
    #     display_metric=rank_metric,
    #     display_partition='test'
    # )
    # plt.suptitle(f"{file_name} - Score Histogram", y=1.02)
    # save_figure(fig_histogram, f"histogram_{rank_metric}_test")

    fig_histogram = analyzer.plot_histogram(
        display_metric=rank_metric,
        display_partition='test',
        aggregate="ID"
    )
    plt.suptitle(f"{file_name} - Score Histogram - Agg reps", y=1.02)
    save_figure(fig_histogram, f"histogram_{rank_metric}_test_agg")

    print(f"\n‚úÖ All charts saved to: {output_dir}")
    return predictions

In [4]:
# Process all parquet files
results = {}

# Output directory for all charts
CHARTS_OUTPUT_DIR = "charts"
Path(CHARTS_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

# Models to exclude from analysis (e.g., buggy models)
EXCLUDE_MODELS = ["KernelPLS"]

for filename in filenames:
    predictions_path = f"{WORKSPACE_PATH}/{filename}.meta.parquet"
    result = analyze_predictions_file(
        predictions_path,
        save_dir=CHARTS_OUTPUT_DIR,
        exclude_models=EXCLUDE_MODELS
    )
    if result is not None:
        results[filename] = result


print(f"\n{'='*80}")
print(f"‚úÖ Analysis complete! Processed {len(results)}/{len(filenames)} files.")
print(f"üìÅ All charts saved to: {CHARTS_OUTPUT_DIR}/")
print(f"{'='*80}")

 Excluded 15 predictions from models: ['KernelPLS']

 Analyzing: digestibility_0_8
 Saving charts to: charts

 Summary Statistics:
    Number of predictions: 3585
    Models: 40 unique (opls_2_6, FCKPLS, ridge, opls_1_5, pls_14...)
    Datasets: ['digestibility_0_8']
    Partitions: ['val', 'test', 'train']
    Configs: 54 unique
    Folds: ['1', 'avg', '2', '0', 'w_avg']
    Task type: regression

 Top 5 models by rmse (val):
   1. LWPLS_5_components - mse [test: 218.0761], [val: 31.6120],  [val]: [rmse:14.7674], [r2:0.2427], [mae:11.8765] [test]: [rmse:14.7674], [r2:0.2427], [mae:11.8765], (fold: w_avg, id: 5, step: 8) - [3d0428a77da45101]
   2. LWPLS_5_components - mse [test: 218.0648], [val: 31.8298],  [val]: [rmse:14.7670], [r2:0.2428], [mae:11.8784] [test]: [rmse:14.7670], [r2:0.2428], [mae:11.8784], (fold: avg, id: 4, step: 8) - [1484c8c70075df4d]
   3. LWPLS_5_components - mse [test: 206.1026], [val: 32.1629],  [val]: [rmse:14.3563], [r2:0.2843], [mae:11.4103] [test]: [rmse:14.