In [None]:
import numpy as np
import os, json
from tqdm import tqdm
from collections import deque
# self-defined packages
from codebase import utility
from codebase.probmats import ProbMat
from codebase.bcienv import SimulatedEnv

In [None]:
class BBTS:
    def __init__(self, subset_size = 5):
        self.subset_size = subset_size
        self.alphas, self.betas = np.ones(12), np.ones(12)
    
    def select_indices(self):
        thetas = np.random.beta(self.alphas, self.betas)
        return np.argsort(-thetas)[:self.subset_size].tolist()
    
    def update(self, P: np.ndarray):
        r = np.concatenate([P.sum(axis=1), P.sum(axis=0)])
        self.alphas += r; self.betas += 1-r

In [None]:
############
mu1 = [0.9, 1.2, 1.5][2]
############
mu = [0, mu1]; sigma = [1, 1]

# Simulation 1

In [None]:
result = {'accu': [], 'util': [], 'time': []}
#
for rep_id in tqdm(range(50)):
    env = SimulatedEnv(ProbMat(mu, sigma), t_flash=0.8/60, t_chr=5/60, score_mu=mu, score_sigma=sigma, p0=0)
    obs, info = env.reset()
    rewards = []
    for _ in range(1000):
        bbts = BBTS()
        for n_seq in range(45):
            actions = bbts.select_indices()
            for action in actions:
                # update gp
                obs, reward, _, _, info = env.step(action)
            bbts.update(obs['certainty_scores'])
            if (obs['certainty_scores'].max()>0.9):
                break
        #
        obs, reward, _, _, info = env.step(12)
        rewards.append(reward)
    # summary results
    n_correct, n_wrong = np.sum(np.array(rewards)>0), np.sum(np.array(rewards)<0)
    result['accu'].append(n_correct/len(rewards))
    result['util'].append(utility.cal_utility(n_correct, n_wrong, env.time))
    result['time'].append(env.time/env.current_chr_id)
{k: (np.mean(v), np.std(v)) for k, v in result.items()}

# Simulation 2

In [None]:
result = {'accu': [], 'util': [], 'time': []}
#
for rep_id in tqdm(range(50)):
    env = SimulatedEnv(ProbMat(mu, sigma), t_flash=0.2/60, t_chr=5/60, score_mu=mu, score_sigma=sigma, p0=0.5)
    obs, info = env.reset()
    rewards = []
    for _ in range(1000):
        bbts = BBTS()
        P_hist = []
        for n_seq in range(45):
            actions = bbts.select_indices()
            for action in actions:
                # update gp
                obs, reward, _, _, info = env.step(action)
            P_hist.append(obs['certainty_scores'])
            if n_seq >= 1:
                bbts.update(P_hist[-2])
                if (P_hist[-2].max()>0.9):
                    break
        #
        obs, reward, _, _, info = env.step(12)
        rewards.append(reward)
    # summary results
    n_correct, n_wrong = np.sum(np.array(rewards)>0), np.sum(np.array(rewards)<0)
    result['accu'].append(n_correct/len(rewards))
    result['util'].append(utility.cal_utility(n_correct, n_wrong, env.time))
    result['time'].append(env.time/env.current_chr_id)
{k: (np.mean(v), np.std(v)) for k, v in result.items()}