In [1]:
import os
import sys
current_dir = os.getcwd()

# Get the absolute path of the parent directory
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
sys.path.append(parent_dir)

from benchmark_extension import Experiment

In [2]:
sample_difficulty_configurations = {
     'easy': {
             'n_samples': 2000,
             'n_clusters': 3,
             # Features
             'n_numeric_features': 5,
             'n_categorical_features': 5,
             'categorical_cardinalities': [6, 6, 6, 6, 6],
             'n_multival_features': 5,
             'multival_vocab_lens': [(3, 3, 3),  # How many vocab items
                                     (3, 3, 3),  # are associated to a
                                     (3, 3, 3),  # cluster.
                                     (3, 3, 3),
                                     (3, 3, 3)],
             # Difficulty params
             'separability': 3.0,
             'multival_intersections': 1,
             'noise': 0.01,
             'class_weights': [0.33, 0.33],
             # Approach Settings
             'approach_settings': {
                'naive': {
                    'gamma': None
                },
                'one-hot': {
                    'gamma': None,
                    'max_dummies': 100
                },
                'one-hot-pca': {
                    'gamma': None,
                    'reduced_dimensions': 0.25
                },
                'extended': {
                    'gamma_c': 0.33,
                    'gamma_m': 0.33,
                    'theta': 0.001
                }
               },
              },
     'medium': None,
     'hard': None
}

In [3]:
exp = Experiment(benchmarking_config=sample_difficulty_configurations['easy'],
                 approaches=('naive', 'one-hot', 'one-hot-pca', 'extended'),
                 random_state=42)

In [4]:
exp.run_experiment()

{'naive': {'preprocess_time': 0.012549499981105328,
  'clustering_time': 4.353618000051938,
  'n_iter': 3,
  'MIS': 0.9698692316348873,
  'ARI': 0.9850825134727914,
  'centroids': array([['3.016293282833785', '-2.9659214650364696', '2.9648799685851577',
          '2.98451012463196', '2.971865234885399', '0', '1', '0', '0', '0',
          '{0, 1, 2, 9, 10}', '{0, 1, 2, 9, 10}', '{0, 1, 2, 9, 10}',
          '{0, 1, 2, 9, 10}', '{0, 1, 2, 9, 10}'],
         ['2.9457840200839347', '3.0208802544753515', '2.9393406917607687',
          '-2.9322378585132336', '2.9183526188934086', '3', '2', '2', '2',
          '3', '{3, 4, 5, 9, 11}', '{3, 4, 5, 9, 11}', '{3, 4, 5, 9, 11}',
          '{3, 4, 5, 9, 11}', '{3, 4, 5, 9, 11}'],
         ['-3.0268583751308813', '-3.0604115203343425',
          '2.9566655341127706', '2.999128700761383', '-3.016087121799718',
          '5', '4', '5', '4', '5', '{6, 7, 8, 10, 11}',
          '{6, 7, 8, 10, 11}', '{6, 7, 8, 10, 11}', '{6, 7, 8, 10, 11}',
          '{