In [357]:
%load_ext autoreload
%autoreload 2

import os

if 'cachai' not in os.listdir('.'):
    os.chdir('../')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [358]:
import numpy as np
import pandas as pd
from sklearn.metrics import root_mean_squared_error, mean_absolute_error


class TTLSimulator:

    def __init__(self, iterations=1_000):
        self._iterations = iterations
        self._target_params = [
            (50, 5),
            (200, 10),
            (400, 30),
        ]
        # means = np.linspace(10, 500, 10).astype(int)
        # std = np.arange(1, len(means) + 1)**2
        # self._target_params = np.array([means, std]).T

    def generate_features(self, target, num_features=1, correlation=0.8):
        cov_matrix = np.eye(num_features) * (1 - correlation) + np.ones((num_features, num_features)) * correlation
        features = np.random.multivariate_normal(np.ones(num_features) * target, cov_matrix)
        return features.reshape(1, -1)
        # return np.full((1, num_features), fill_value=target)

    def update_target_params(self, progress):
        target_params = []
        for param in self._target_params:
            mean = float(
                round(param[0]*np.sin(progress*2*np.pi/2)/(param[0]/2) + param[0], 2)
            )
            std = param[1]
            target_params.append((mean, std))
        self._target_params = target_params

    def feedback(self, y_pred, y):
        observation_time = int(min(y_pred, y)[0])
        hits = max(0, observation_time-1)
        observation_type = None
        if y_pred[0] < y[0]:
            observation_type = 'MISS'
        elif y_pred[0] > y[0]:
            observation_type = 'STALE'
        else:
            observation_type = 'VALID'
        return observation_time, observation_type, hits

    def generate(self):
        target_param_index = np.random.randint(0, len(self._target_params))
        target_params = self._target_params[target_param_index]
        y = np.random.normal(target_params[0], target_params[1], 1)
        X = self.generate_features(y)
        return X, y

    def run(self, model):
        df = []
        for i in range(self._iterations):
            X, y = self.generate()
            y_pred = model.predict(X)
            observation_time, observation_type, hits = self.feedback(y_pred, y)
            model.observe(observation_time, observation_type, hits, y_pred)

            mse = root_mean_squared_error(y, y_pred)
            mae = mean_absolute_error(y, y_pred)
            df.append([
                i, observation_type, observation_time, y_pred[0], y[0], hits, mse, mae
            ])
        return pd.DataFrame(df, columns=[
            'iteration', 'observation_type', 'observation_time',
            'y_pred', 'y', 'hits', 'mse', 'mae'
        ])

In [359]:
class Experiment():

    def __init__(self, iterations=1_000):
        self._iterations = iterations

    def run(self, model):
        pass

In [360]:
from enum import Enum, auto


class ObservationType(Enum):
    HIT = auto()
    MISS = auto()
    STALE = auto()
    VALID_TTL = auto()

    def __str__(self):
        return self.name

In [361]:
from abc import ABC, abstractmethod


class BaseModel(ABC):

    @abstractmethod
    def predict(self, X: np.array) -> np.array:
        pass

    @abstractmethod
    def observe(
        self,
        observation_time: int,
        observation_type: ObservationType,
        hits: int,
        prev_prediction: float
    ) -> None:
        pass

In [362]:
class Model(BaseModel):
    def __init__(self):
        pass

    def predict(self, X):
        output = X.mean()
        return np.array([output])

    def observe(self, observation_time, observation_type, hits, prev_prediction):
        a = 2+2

In [363]:
simulator = TTLSimulator()
model = Model()

simulator.run(model)

Unnamed: 0,iteration,observation_type,observation_time,y,y_pred,hits,mse,mae
0,0,STALE,192,192.736619,193.203678,191,0.467059,0.467059
1,1,STALE,41,41.426386,41.876811,40,0.450425,0.450425
2,2,MISS,378,379.637934,378.673398,377,0.964536,0.964536
3,3,STALE,380,380.291930,381.668454,379,1.376524,1.376524
4,4,MISS,46,46.813887,46.430545,45,0.383342,0.383342
...,...,...,...,...,...,...,...,...
995,995,STALE,196,196.078853,196.427166,195,0.348313,0.348313
996,996,STALE,44,44.314977,44.396400,43,0.081422,0.081422
997,997,MISS,207,208.805001,207.164657,206,1.640344,1.640344
998,998,MISS,367,369.823634,367.108932,366,2.714703,2.714703
