In [1]:
from nirs4all.data.predictions import Predictions
import os

# Parquet files to analyze
filenames = [
    "digestibility_custom2",
    "digestibility_custom3",
    "digestibility_custom5",
    "digestibility_0_8",
    "hardness_custom2",
    "hardness_custom4",
    "hardness_0_8",
    "tannin_custom2",
    "tannin_custom3",
    "tannin_0_8"
]

# folders = ["wk_src", "wk_tabpfn", "wk_denis", "wk_denis_0", "wk_local"]

# for file in filenames:
#     input_files = []
#     for folder in folders:
#         if os.path.exists(f"{folder}/{file}.meta.parquet"):
#             input_files.append(f"{folder}/{file}.meta.parquet")

#     output_file = f"{file}.meta.parquet"
#     Predictions.merge_parquet_files(
#         input_files=input_files,
#         output_file=f"wk/{output_file}",
#         deduplicate=True  # Remove duplicate prediction IDs (default)
#     )


WORKSPACE_PATH = "wk"


In [2]:
from pathlib import Path
import matplotlib.pyplot as plt
import os

from nirs4all.data.predictions import Predictions
from nirs4all.visualization.charts import ChartConfig
from nirs4all.visualization.predictions import PredictionAnalyzer

In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
from nirs4all.data import Predictions
from nirs4all.visualization.predictions import PredictionAnalyzer



# Where to save charts
CHARTS_OUTPUT_DIR = "charts"

def save_figure(fig, name):
    """Save figure to charts output directory with multiple formats."""
    output_dir = Path(CHARTS_OUTPUT_DIR)
    output_dir.mkdir(exist_ok=True)

    for fmt in ['png']:
        path = output_dir / f"{filename}_{name}.{fmt}"
        fig.savefig(path, dpi=150, bbox_inches='tight')
        print(f"  üìÅ Saved: {path}")

for filename in filenames:
    predictions_path = f"{WORKSPACE_PATH}/{filename}.meta.parquet"
    print(predictions_path)
    # Load predictions
    predictions = Predictions.load(path=predictions_path)
    print(f"\nüìÇ Loaded {len(predictions)} predictions from {filename}")

    # Skip empty prediction files
    if len(predictions) == 0:
        print(f"  ‚ö†Ô∏è Skipping {filename} - no predictions found")
        continue

    # Note: Predictions class doesn't have a filter method that returns a new Predictions object
    # Filtering by model exclusion would need to be done at the chart/top() level using **filters
    # For now, we work with all predictions
    print(f"  ‚û°Ô∏è Using all {len(predictions)} predictions (model exclusion not implemented)")

    # Create analyzer for this predictions file
    analyzer = PredictionAnalyzer(
        predictions,
        output_dir=CHARTS_OUTPUT_DIR,
        # show_figures=True,  # Uncomment to auto-display figures
    )

    # Get the dataset name and determine the metric type
    datasets = predictions.get_datasets()
    dataset_name = datasets[0] if datasets else filename
    file_name = Path(filename).stem

    # Detect if classification or regression based on predictions
    # For now, we'll use a simple heuristic: if the target values are continuous, it's regression
    sample_pred = predictions.top(n=1, rank_metric='rmse', display_metrics=['rmse'], display_partition='test')
    if sample_pred:
        # Use regression metrics
        rank_metric = 'rmse'
        display_metrics = ['rmse', 'mse', 'mae', 'r2']
    else:
        # Use classification metrics
        rank_metric = 'balanced_accuracy'
        display_metrics = ['balanced_accuracy', 'accuracy', 'f1']

    # Show top models
    print(f"\nüìä Top 5 models for {file_name}:")
    # top_models = predictions.top(n=5, rank_metric=rank_metric, rank_partition='val')
    # for i, model in enumerate(top_models, 1):
    #     score = model.get('partitions', {}).get('val', {}).get(rank_metric, 'N/A')
    #     model_name = model.get('model_name', 'Unknown')
    #     print(f"  {i}. {model_name}: {rank_metric}={score:.4f}" if isinstance(score, float) else f"  {i}. {model_name}: {rank_metric}={score}")
    top_models = predictions.top(n=5, rank_metric=rank_metric, rank_partition='val', aggregate="ID")
    for i, model in enumerate(top_models, 1):
        score = model.get('partitions', {}).get('val', {}).get(rank_metric, 'N/A')
        model_name = model.get('model_name', 'Unknown')
        print(f"  {i}. {model_name}: {rank_metric}={score:.4f}" if isinstance(score, float) else f"  {i}. {model_name}: {rank_metric}={score}")

    # Common heatmap config
    hm_config = {
        # 'fig_fontsize': 20,      # Default font size
        # 'title_fontsize': 24,     # Figure titles
        # 'legend_fontsize': 12,    # Legend text
        # 'axis_fontsize': 20,      # Axis labels
        # 'tick_fontsize': 16,      # Axis tick labels
        'annotation_fontsize': 28    # Text inside charts (heatmap cells, etc.)
    }

    # ============================
    # CHARTS
    # ============================
    print(f"\n=== Charts for {file_name} ===")

    # --- Confusion matrix ---
    print(f"\nüî≤ Confusion matrix:")
    fig_cm = analyzer.plot_confusion_matrix(
        rank_metric=rank_metric,
        display_metric=rank_metric,
        display_partition='test',
        rank_partition='test',
        aggregate="ID",
    )
    plt.suptitle(f"{file_name} - Confusion Matrix", y=1.02)
    save_figure(fig_cm, f"confusion_matrix_{rank_metric}")

    fig_cm = analyzer.plot_confusion_matrix(
        rank_metric=rank_metric,
        display_metric=rank_metric,
        display_partition='test',
        rank_partition='val',
        aggregate="ID",
    )
    plt.suptitle(f"{file_name} - Confusion Matrix", y=1.02)
    save_figure(fig_cm, f"confusion_matrix_{rank_metric}_rank_val")

    # --- Top-K comparison ---
    print(f"\nüìà Top-K comparisons:")
    # Top 3 models
    fig_top3 = analyzer.plot_top_k(
            k=3,
            rank_metric=rank_metric,
            rank_partition='val',
            display_partition='test',
            aggregate="ID"
        )
    plt.suptitle(f"{file_name} - Top 3 Models on Val", y=1.02)
    save_figure(fig_top3, f"top3_test_rank_val")

    #     fig_top3 = analyzer.plot_top_k(
    #         k=3,
    #         rank_metric=rank_metric,
    #         rank_partition='test',
    #         display_partition='test',
    #         aggregate="ID"
    #     )
    #     plt.suptitle(f"{file_name} - Top 3 Models on Test", y=1.02)
    #     save_figure(fig_top3, f"top3_test")

    # --- Heatmaps ---
    # # --- Heatmap ranked by VAL ---
    # fig_heatmap = analyzer.plot_heatmap(
    #     x_var="partition",
    #     y_var="model_name",
    #     rank_metric=rank_metric,
    #     display_metric=rank_metric,
    #     show_counts=False,
    #     rank_partition='val',
    #     top_k=8,
    #     sort_by_value=True,
    #     sort_by='value',
    #     config=hm_config
    # )
    # plt.suptitle(f"{file_name} - Heatmap (ranked by val)", y=1.02)
    # save_figure(fig_heatmap, f"heatmap_{rank_metric}_rank_val")

    # # --- Heatmap ranked by TEST ---
    # fig_heatmap = analyzer.plot_heatmap(
    #     x_var="partition",
    #     y_var="model_name",
    #     rank_metric=rank_metric,
    #     display_metric=rank_metric,
    #     show_counts=False,
    #     rank_partition='test',
    #     top_k=8,
    #     sort_by_value=True,
    #     config=hm_config
    # )
    # plt.suptitle(f"{file_name} - Heatmap (ranked by test)", y=1.02)
    # save_figure(fig_heatmap, f"heatmap_{rank_metric}_rank_test")

    # --- Heatmap ranked by VAL - Aggregated ---
    fig_heatmap = analyzer.plot_heatmap(
        x_var="partition",
        y_var="model_name",
        rank_metric=rank_metric,
        display_metric=rank_metric,
        show_counts=False,
        rank_partition='val',
        aggregate="ID",
        column_scale = True,
        top_k=20,
        sort_by='value',
        config=hm_config
    )
    plt.suptitle(f"{file_name} - Heatmap (ranked by val) - Agg reps", y=1.02)
    save_figure(fig_heatmap, f"heatmap_{rank_metric}_rank_val_agg")

    # --- Heatmap ranked by TEST - Aggregated ---
    fig_heatmap = analyzer.plot_heatmap(
        x_var="partition",
        y_var="model_name",
        rank_metric=rank_metric,
        display_metric=rank_metric,
        show_counts=False,
        rank_partition='test',
        aggregate="ID",
        column_scale = True,
        top_k=8,
        sort_by='value',
        config=hm_config
    )
    plt.suptitle(f"{file_name} - Heatmap (ranked by test) - Agg reps", y=1.02)
    save_figure(fig_heatmap, f"heatmap_{rank_metric}_rank_test_agg")

    # --- Candlestick plots ---
    print(f"\nüìä Candlestick plots:")
    # fig_candlestick = analyzer.plot_candlestick(
    #     variable="model_classname",
    #     display_metric=rank_metric,
    #     display_partition='test'
    # )
    # plt.suptitle(f"{file_name} - Score Distribution by Model", y=1.02)
    # save_figure(fig_candlestick, f"candlestick_{rank_metric}_test")

    fig_candlestick = analyzer.plot_candlestick(
        variable="model_classname",
        display_metric=rank_metric,
        display_partition='test',
        aggregate="ID"
    )
    plt.suptitle(f"{file_name} - Score Distribution by Model - Agg reps", y=1.02)
    save_figure(fig_candlestick, f"candlestick_{rank_metric}_test_agg")

    # --- Histograms ---
    # print(f"\nüìä Histograms:")
    # fig_histogram = analyzer.plot_histogram(
    #     display_metric=rank_metric,
    #     display_partition='test'
    # )
    # plt.suptitle(f"{file_name} - Score Distribution", y=1.02)
    # save_figure(fig_histogram, f"histogram_{rank_metric}_test")

    fig_histogram = analyzer.plot_histogram(
        display_metric=rank_metric,
        display_partition='test',
        aggregate="ID"
    )
    plt.suptitle(f"{file_name} - Score Distribution - Agg reps", y=1.02)
    save_figure(fig_histogram, f"histogram_{rank_metric}_test_agg")

    # Add more dataset comparisons here...
    print(f"\n‚úÖ All charts generated for {file_name}")

print("\nüéâ All datasets processed!")

wk/digestibility_custom2.meta.parquet

 Loaded 0 predictions from digestibility_custom2
   Skipping digestibility_custom2 - no predictions found
wk/digestibility_custom3.meta.parquet

 Loaded 0 predictions from digestibility_custom3
   Skipping digestibility_custom3 - no predictions found
wk/digestibility_custom5.meta.parquet

 Loaded 0 predictions from digestibility_custom5
   Skipping digestibility_custom5 - no predictions found
wk/digestibility_0_8.meta.parquet

 Loaded 0 predictions from digestibility_0_8
   Skipping digestibility_0_8 - no predictions found
wk/hardness_custom2.meta.parquet

 Loaded 0 predictions from hardness_custom2
   Skipping hardness_custom2 - no predictions found
wk/hardness_custom4.meta.parquet

 Loaded 0 predictions from hardness_custom4
   Skipping hardness_custom4 - no predictions found
wk/hardness_0_8.meta.parquet

 Loaded 0 predictions from hardness_0_8
   Skipping hardness_0_8 - no predictions found
wk/tannin_custom2.meta.parquet

 Loaded 0 predictions 