In [135]:
%load_ext autoreload
%autoreload 2

import os

if 'cachai' not in os.listdir('.'):
    os.chdir('../')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [136]:
import numpy as np
import pandas as pd
from sklearn.metrics import root_mean_squared_error, mean_absolute_error


class TTLSimulator:
    COLUMNS = [
        'iteration', 'observation_time', 'y', 'y_pred',
        'observation_type', 'hits', 'target', 'mse', 'mae'
    ]

    def __init__(self, iterations=1_000):
        self._iterations = iterations
        self._target_params = [
            # mean, std
            (50, 5),
            (200, 10),
            (400, 30),
        ]
        means = np.linspace(10, 500, 10).astype(int)
        std = np.arange(1, len(means) + 1)**2
        # self._target_params = np.array([means, std]).T

    def generate_features(self, target, num_features=1, correlation=0.8):
        cov_matrix = np.eye(num_features) * (1 - correlation) + np.ones((num_features, num_features)) * correlation
        features = np.random.multivariate_normal(np.ones(num_features) * target, cov_matrix)
        return features.reshape(1, -1)
        # return np.full((1, num_features), fill_value=target)

    def get_feedback(self, observation_time, prediction, target):
        hits = max(0, observation_time-1)
        observation_type = None
        if prediction[0] < target[0]:
            observation_type = 'MISS'
        elif prediction[0] > target[0]:
            observation_type = 'STALE'
        else:
            observation_type = 'VALID'
        return observation_type, hits

    def update_target_params(self, progress):
        target_params = []
        for param in self._target_params:
            mean = float(
                round(param[0]*np.sin(progress*2*np.pi/2)/(param[0]/2) + param[0], 2)
            )
            std = param[1]
            target_params.append((mean, std))
        self._target_params = target_params

    def run(self, model):
        history = []
        for i in range(self._iterations):
            target_param_index = np.random.randint(0, len(self._target_params))
            # self.update_target_params(i/self._iterations)
            target_params = self._target_params[target_param_index]
            y = np.random.normal(target_params[0], target_params[1], 1)
            X = self.generate_features(y)
            y_pred = model.predict(X)
            observation_time = int(min(y_pred, y)[0])
            observation_type, hits = self.get_feedback(observation_time, y_pred, y)
            model.observe(observation_time, observation_type, hits, y_pred)

            mse = root_mean_squared_error(y, y_pred)
            mae = mean_absolute_error(y, y_pred)
            history.append([
                i, observation_time, y[0], y_pred[0], observation_type, hits, target_params[0], mse, mae
            ])
        return pd.DataFrame(history, columns=self.COLUMNS)

In [137]:
class Evaluator():
    def __init__():
        pass

    def evaluate(self, df):
        return df.groupby('indicator').agg({
            'mse': 'mean',
            'mae': 'mean',
            'hits': 'mean'
        })

In [138]:
from abc import ABC, abstractmethod
from enum import Enum


class ObservationType(Enum):
    HIT = "HIT"
    MISS = "MISS"
    STALE = "STALE"
    VALID_TTL = "VALID_TTL"


class BaseModel(ABC):
    @abstractmethod
    def predict(self, X: np.array) -> np.array:
        pass

    @abstractmethod
    def observe(
        self,
        observation_time: int,
        observation_type: ObservationType,
        hits: int,
        prev_prediction: float
    ) -> None:
        pass

In [139]:
class Model(BaseModel):
    def __init__(self):
        pass

    def predict(self, X):
        output = X.mean()
        return np.array([output])

    def observe(self, observation_time, observation_type, hits, prev_prediction):
        a = 2+2

In [140]:
simulator = TTLSimulator()
model = Model()
simulator.run(model)

Unnamed: 0,iteration,observation_time,y,y_pred,observation_type,hits,target,mse,mae
0,0,45.356058,46.675461,45.356058,MISS,44.356058,50,1.319403,1.319403
1,1,390.903932,392.055430,390.903932,MISS,389.903932,400,1.151498,1.151498
2,2,45.588329,45.588329,45.710504,STALE,44.588329,50,0.122175,0.122175
3,3,220.797625,221.424303,220.797625,MISS,219.797625,200,0.626678,0.626678
4,4,43.970290,43.970290,44.717996,STALE,42.970290,50,0.747706,0.747706
...,...,...,...,...,...,...,...,...,...
995,995,180.009410,180.009410,180.374085,STALE,179.009410,200,0.364675,0.364675
996,996,392.735833,392.735833,393.040461,STALE,391.735833,400,0.304628,0.304628
997,997,374.397595,374.880008,374.397595,MISS,373.397595,400,0.482413,0.482413
998,998,428.927832,429.389330,428.927832,MISS,427.927832,400,0.461497,0.461497
