In [1]:
import sys

sys.path.append('../src')

In [None]:
import numpy as np
import pandas as pd

from typing import Optional, Union
from itertools import product

from Utils.utils import GetMeasurements
from Utils.cherrypick_simulations import CherryPickEquilibria

In [None]:
class AlternationIndex:
    '''Estimates the alternation index from simulated data'''

    def __init__(
                self, 
                num_points: Optional[int]=10,
                num_episodes: Optional[int]=10,
                max_agents: Optional[int]=16, 
                max_epsilon: Optional[float]=0.1,
                seed: Optional[Union[int, None]]=None
            ) -> None:
        range_num_agents = np.linspace(2, max_agents, 100)
        range_threshold = np.linspace(0, 1, 100)
        range_epsilon = np.linspace(0, max_epsilon, 100)
        configurations = product(range_num_agents, range_threshold, range_epsilon)
        self.num_episodes = num_episodes
        self.seed = seed
        self.rng = np.random.default_rng(seed=seed)
        self.configuration_points = self.rng.choice(
            configurations, size=num_points, replace=False
        )
        self.measures = ['efficiency', 'entropy', 'conditional_entropy', 'inequality']
        self.debug = True

    def simulate_data(self) -> pd.DataFrame:        
        data_type = 'alternation'
    
    def simulate_data_kind(self, data_type:str) -> pd.DataFrame:
        assert(data_type in ['segmentation', 'alternation', 'random'])
        num_episodes = self.num_episodes
        if data_type in ['segmentation', 'random']:
            num_episodes = self.num_episodes // 2
        df_list = list()
        for num_agents, threshold, epsilon in self.configuration_points:
            eq_generator = CherryPickEquilibria(
                num_agents=num_agents,
                threshold=threshold,
                epsilon=epsilon,
                num_episodes=num_episodes,
                seed=self.seed
            )
            eq_generator.debug = False
            df_alternation = eq_generator.generate_data(data_type)
            get_m = GetMeasurements(df_alternation, self.measures)
            df = get_m.get_measurements() 
            df_list.append(df)
        df = pd.concat(df_list, ignore_index=True)
        return df
