In [None]:
import numpy as np
from typing import List

In [None]:

class Task:
    """ A single task """
    def __init__(self, name: str, progress_start: float, progress_slope: float):
        self.name = name
        self.progress_start = progress_start
        self.progress_slope = progress_slope

    def get_progress(self, sample: float) -> float:
        """ Return the progress of a given sample """
        return self.progress_start + sample * self.progress_slope


class Stage:
    def __init__(self, name: str, tasks: List[Task], total_sample: int, task_probability_start: List[float], task_probability_end: List[float]):
        self.name = name
        self.tasks = tasks
        self.task_probability_start = task_probability_start
        self.task_probability_end = task_probability_end

    def draw_task(self, sample:int) -> Task:
        """ Return the task to be executed """
        task = np.random.choice(self.tasks, p=self.task_ps[current_batch]) if type(self.tasks) is tuple or list else self.tasks
        idx = np.random.choice(self.data.df_train.index, self.batch_size, p=self.get_sampling_p(self.current_sample, task))  




class Experience:
    def __init__(self, stages):
        self.stages = stages

    def get_stage(self, sample:int) -> Stage:
        """Get the current stage by no. sample ingested"""
        cumulative_sample = 0
        for stage in self.stages:
            if sample < cumulative_sample + stage.total_sample:
                return stage
            else:
                cumulative_sample += stage.total_sample



In [None]:
triangle = Task("triangle", 0, 1/10000)

In [None]:
triangle.get_progress(100000)

In [None]:

class Sampler:
    """v2 Sampler for tf model v4
    Features: 
    1) Smooth non-stationary environment
    2) Visualizing enviroment change with plot_env()
    3) Visualizing relative testset exposure with plot_testset_on_env()
    """

    def __init__(self, cfg, data):

        # Get necessary environment config from cfg object
        self.cfg = cfg
        self.tasks = cfg.tasks
        self.wf_clip_low = cfg.wf_clip_low
        self.wf_clip_high = cfg.wf_clip_high
        self.wf_compression = cfg.wf_compression
        self.oral_start_pct = cfg.oral_start_pct
        self.oral_end_pct = cfg.oral_end_pct 
        self.oral_sample = cfg.oral_sample
        self.oral_tasks_ps = cfg.oral_tasks_ps
        self.transition_sample = cfg.transition_sample
        self.reading_sample = cfg.reading_sample
        self.reading_tasks_ps = cfg.reading_tasks_ps
        self.batch_size = cfg.batch_size
        self.n_timesteps = cfg.n_timesteps
        self.inject_error_ticks = cfg.inject_error_ticks
        
        np.random.seed(cfg.rng_seed)
        
        self.data = data
        self.current_sample = 0

        # Basic convienient variables
        self._calculate_aux_variables()

        # Task probability dictionary (batch, ps)
        self.ps = {}
        self._calculate_task_ps()

        # Progress dictionary 
        self.progress = {}
        self._calculate_progress_dict()

    def plot_env(self, ax=None):

        if ax is None:
            fig, ax = plt.subplots(1, 1, figsize=(7,5))

        """Plot an easy to understand environment progression figure"""
        ax.plot(self.progress['oral'], label='oral corpus')
        ax.plot(self.progress['reading'], label='reading corpus')

        reading_p = [self.task_ps[i][-1] if type(self.task_ps[i]) is list or tuple else self.task_ps[i] for i in range(self.total_batches)]
        # print(reading_p)
        ax.plot(reading_p, label='reading_p', linestyle='dashdot', color='black')

        ax.axvline(x=self.oral_batches, ymin=0, ymax=1, linestyle = 'dotted', color='red', label='transition start')
        ax.axvline(x=self.oral_batches + self.transition_batches, ymin=0, ymax=1, linestyle = 'dotted', color='green', label = 'transition end')
        ax.text(x=10, y=0.8, s=f"oral phase task ps \n{self.oral_tasks_ps}")
        ax.text(x=self.total_batches*0.5, y=0.8, s=f"reading phase task ps \n(after transition) \n{self.reading_tasks_ps}")
        ax.set_xlabel('batch')
        ax.set_ylabel('percetile rank of word frequency')

        ax.legend(loc="lower right")
        ax.set_title('Corpus opening progression (%)')


    def _join_testset_pct(self, testset_name):
        """Create a temp df for plotting testset on env"""
        f = os.path.join(self.cfg.tf_root, "dataset", "testsets", f"{testset_name}.pkl.gz")
        ts = load_testset(f)
        ts_sel = {key:ts[key] for key in ('item', 'cond')}
        ts_sel['pct'] = [self.rank_pct_wf_dict[x] for x in ts_sel['item']]
        return pd.DataFrame(ts_sel)

    def _calculate_aux_variables(self):
        self.total_sample = self.oral_sample + self.reading_sample

        self.total_batches = self.sample_to_batch(self.total_sample)
        self.oral_batches = self.sample_to_batch(self.oral_sample)
        self.transition_batches = self.sample_to_batch(self.transition_sample)
        self.remaining_batches = self.total_batches - self.oral_batches - self.transition_batches

        # Word frequency related
        if self.wf_clip_low is None:
            self.wf_clip_low = 0
        
        if self.wf_clip_high is None:
            self.wf_clip_high = 999999999            
            
        clip_wf = self.data.df_train.wf.clip(self.wf_clip_low, self.wf_clip_high).copy()
        self.rank_pct_wf = clip_wf.rank(pct=True, ascending=False)
        self.rank_pct_wf_dict = dict(zip(self.data.df_train.word, self.rank_pct_wf))

        assert self.wf_compression in ('log', 'root')
        self.compressed_wf = np.log(clip_wf + 1) if self.wf_compression == 'log' else np.sqrt(clip_wf)

    def _calculate_task_ps(self):
        """Task probabililty store all task probabilities in a single dictionary (epoch, ps) """

        # Oral
        self.task_ps = {i:self.oral_tasks_ps for i in range(self.oral_batches)}

        # Transition
        tps = np.linspace(self.oral_tasks_ps, self.reading_tasks_ps, self.transition_batches)
        tran_ps = {i + self.oral_batches: tuple(tps[i]) for i in range(self.transition_batches)}
        self.task_ps.update(tran_ps)

        # Remaining
        remain_ps = {i + self.oral_batches + self.transition_batches: self.reading_tasks_ps for i in range(self.remaining_batches)}
        self.task_ps.update(remain_ps)

    def _calculate_progress_dict(self):
        """Progress dictionary storing all %max progress at each epoch in oral and reading stage"""
        # Oral progress
        opg = np.linspace(self.oral_start_pct, self.oral_end_pct, self.oral_batches)
        self.progress['oral'] = self._extrapolate_ps(opg)

        # Reading progress
        self.progress['reading'] = self._right_shift_ps(self.progress['oral'])         

    def _extrapolate_ps(self, array):
        """Extrapolate an array to the number of batches long"""
        d = array[1] - array[0]
        e = array[-1]
        n = len(array)
        ex_array = [array[i] if i < n  else e + (i-n) * d for i in range(self.total_batches)]
        return np.clip(ex_array, 0, 1)

    def _right_shift_ps(self, oral_progress):
        """Convert oral ps to reading ps
        Since reading lag behide oral task by a constant, we just need to shift oral progress to the right
        to get reading progress
        """
        n = self.oral_batches
        beginning = np.zeros(n)
        return np.concatenate([beginning, oral_progress[:-n]])

    def wf_to_ps(self, wf):
        """convert squashed compressed word frequncy to probabilty"""
        return np.array(wf/np.sum(wf), dtype="float32")

    def sample_to_batch(self, sample):
        """Convert sample to batch in 0-indexing format"""
        return int(sample/self.batch_size)

    def get_sampling_p(self, current_sample, task):
        current_batch = self.sample_to_batch(current_sample)

        if task == 'triangle':
            progress = self.progress['reading'][current_batch]
        else:
            progress = self.progress['oral'][current_batch]

        # Create selection mask
        mask = self.rank_pct_wf < progress
    
        return self.wf_to_ps(self.compressed_wf * mask)

    def generator(self):
        """ Generator that draw task and sample idx 
        """
        x_ticks = self.n_timesteps
        y_ticks = self.inject_error_ticks
        
        while True:
            current_batch = self.sample_to_batch(self.current_sample)
            task = np.random.choice(self.tasks, p=self.task_ps[current_batch]) if type(self.tasks) is tuple or list else self.tasks
            idx = np.random.choice(self.data.df_train.index, self.batch_size, p=self.get_sampling_p(self.current_sample, task))
            
            x, y = modeling.IN_OUT[task]
            batch_x = [self.data.np_representations[x][idx]] * x_ticks

            if type(y) is list:
                    batch_y = {yi: [self.data.np_representations[yi][idx]] * y_ticks for yi in y}
            else:
                # Single output
                batch_y = [self.data.np_representations[y][idx]] * y_ticks
            
            
            self.current_sample += self.batch_size
            yield task, idx, batch_x, batch_y
     
