In [1]:
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm
import math

from result_analysis.helper_functions import (
  process_csv_data,
  post_process_general_instances,
  analyze_learning_curves,
  add_seed_magnitude_column,
  add_generator_iter_column,
  add_seed_index_column
)

In [None]:
satzilla_features = process_csv_data("/home/csoare/generated_instances/generated_instances_features_output.csv")

satzilla_features = add_seed_magnitude_column(satzilla_features)
satzilla_features = add_seed_index_column(satzilla_features)
satzilla_features = add_generator_iter_column(satzilla_features)

# Total instances
total_instances = len(satzilla_features)

# Average number of instances per generator
avg_instances_per_generator = math.ceil(satzilla_features['generator'].value_counts().mean())

# Number of unique seeds
num_unique_seeds = len(satzilla_features['seed_magnitude'].unique())

# Number of instances per generator per seed
nr_of_instances_per_generator_per_seed = satzilla_features['generator_iter_number'].nunique()

# Display statistics
print(f"Total number of instances: {total_instances}")
print(f"Average instances per generator: {avg_instances_per_generator}")
print(f"Number of unique seeds: {num_unique_seeds}")
print(f"Instances per generator per seed: {nr_of_instances_per_generator_per_seed}")
# Post-process general instances
numeric_features = post_process_general_instances(satzilla_features)

# Perform learning curve analysis
results_df = analyze_learning_curves(
  data=satzilla_features,
  numeric_features=numeric_features,
  target_column='generator_encoded'
)
def plot_learning_curves(
  results_df: pd.DataFrame,
  accuracy_threshold: float = 0.95,
  error_rate_threshold: float = 0.05,
  min_seeds: int = 4,
  min_iters_per_gen: int = 30,
  cmap_name: str = "viridis"
):
  """
  Plot performance metrics (accuracy and error rate) versus remaining samples.

  Args:
      results_df (pd.DataFrame): DataFrame containing learning curve results.
      accuracy_threshold (float): Accuracy threshold for highlighting results.
      error_rate_threshold (float): Error rate threshold for highlighting results.
      min_seeds (int): Minimum number of seeds for annotations.
      min_iters_per_gen (int): Minimum iterations per generator for annotations.
      cmap_name (str): Name of the Matplotlib colormap to use (e.g., 'viridis').

  Returns:
      None
  """
  # Get the colormap
  colormap = cm.get_cmap(cmap_name)

  plt.figure(figsize=(12, 6))

  # Plot accuracy, using a color from the selected colormap
  plt.plot(
    results_df['remaining_samples'],
    results_df['accuracy'],
    label='Accuracy',
    marker='o',
    markersize=4,
    linestyle='-',
    linewidth=1.5,
    color=colormap(0.3),  # pick any fraction between 0 and 1
    alpha=0.7,
  )

  # Plot error rate, using a different color from the same colormap
  plt.plot(
    results_df['remaining_samples'],
    results_df['error_rate'],
    label='Error Rate',
    marker='o',
    markersize=4,
    linestyle='-',
    linewidth=1.5,
    color=colormap(0.7),
    alpha=0.7,
  )

  # Highlight thresholds
  plt.axhline(
    y=accuracy_threshold,
    color='green',
    linestyle='--',
    linewidth=1.2,
    label='Accuracy Threshold',
  )
  plt.axhline(
    y=error_rate_threshold,
    color='red',
    linestyle='--',
    linewidth=1.2,
    label='Error Rate Threshold',
  )

  # Add annotations for seeds and iterations
  plt.text(
    x=100,  # adjust as needed
    y=0.2,  # adjust as needed
    s=f"Min # seeds: {min_seeds}\nMin # iter/gen: {min_iters_per_gen}",
    fontsize=10,
    bbox=dict(facecolor='white', alpha=0.5, edgecolor='gray'),
  )

  # Set labels, title, and legend
  plt.xlabel('Nr of Samples', fontsize=12)
  plt.ylabel('Accuracy vs Error rate based on 3-fold crossvalidation', fontsize=12)
  plt.title('Learning Curve (each sample represents a different training set of a different size)', fontsize=14)
  plt.legend(loc='upper right', fontsize=10)
  plt.grid(visible=True, linestyle='--', alpha=0.6)

  # Optimize layout
  plt.tight_layout()
  plt.show()

plot_learning_curves(results_df, accuracy_threshold=0.95, error_rate_threshold=0.05, cmap_name="viridis")