In [None]:
from tabrepo import load_repository, EvaluationRepository, get_context
from collections import Counter

In [None]:
context_name = "D244_F3_C1530_100"
repo: EvaluationRepository = load_repository(context_name, cache=True)
context = get_context(name=context_name)
all_config_hyperparameters = context.load_configs_hyperparameters()
datasets = repo.datasets()
configs = repo.configs()
folds = repo.folds

repo.print_info()
print(datasets)
print(configs)
print(folds)

In [None]:
problem_types = []
for ds in datasets:
    dataset_info = repo.dataset_info(dataset=ds)
    problem_types.append(dataset_info['problem_type'])
    
Counter(problem_types)

In [None]:
num_classes_list = []
for ds in datasets:
    dataset_info = repo.dataset_info(dataset=ds)
    if dataset_info['problem_type'] == 'multiclass':
        dataset_metadata = repo.dataset_metadata(dataset=ds)
        num_classes_list.append(dataset_metadata['NumberOfClasses'])

print(dataset_metadata.keys())
Counter(num_classes_list)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter

# Collect metadata for all datasets
all_datasets_metadata = []
for ds in datasets:
    dataset_info = repo.dataset_info(dataset=ds)
    problem_type = dataset_info.get('problem_type')
    # We only focus on classification datasets
    if problem_type in ['binary', 'multiclass']:
        dataset_metadata = repo.dataset_metadata(dataset=ds)
        all_datasets_metadata.append({
            'dataset': ds,
            'problem_type': problem_type,
            'n_samples': dataset_metadata.get('NumberOfInstances'),
            'n_features': dataset_metadata.get('NumberOfFeatures'),
            'n_classes': dataset_metadata.get('NumberOfClasses'),
        })

# Create a DataFrame from the collected metadata
cls_df = pd.DataFrame(all_datasets_metadata)

# Remove rows with NaN values
clean_data = cls_df.dropna(subset=['n_samples', 'n_features', 'n_classes'])

# Create the figure
plt.figure(figsize=(8, 4))

# Create the scatter plot with better visibility
scatter = plt.scatter(clean_data['n_samples'], clean_data['n_features'], 
                     c=clean_data['n_classes'], cmap='magma_r',
                     alpha=0.8, s=80, edgecolors='black', linewidth=0.8)

# Set axis limits to focus on the main cluster of data (excluding extreme outliers)
x_percentile_99 = np.percentile(clean_data['n_samples'], 99)
y_percentile_99 = np.percentile(clean_data['n_features'], 99)

plt.xlim(clean_data['n_samples'].min() * 0.8, x_percentile_99 * 2)
plt.ylim(clean_data['n_features'].min() * 0.8, y_percentile_99 * 2)

# Set log scales
plt.xscale('log')
plt.yscale('log')

# Add colorbar
cbar = plt.colorbar(scatter, label='Number of Classes')
cbar.ax.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: f'{int(x)}'))

# Labels and title
plt.xlabel('Number of Samples', fontsize=12)
plt.ylabel('Number of Features', fontsize=12)

# Add grid for better readability
plt.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)

# Add some statistics as text
n_datasets = len(clean_data)
min_samples = clean_data['n_samples'].min()
max_samples = clean_data['n_samples'].max()
min_features = clean_data['n_features'].min()
max_features = clean_data['n_features'].max()
min_classes = clean_data['n_classes'].min()
max_classes = clean_data['n_classes'].max()

# Adjust layout and show
plt.tight_layout()
plt.savefig("../plots/tabrepo-datasets.pdf", dpi=300)
plt.savefig("../plots/tabrepo-datasets.png", dpi=300)
plt.show()
