In [11]:
import os
import ipywidgets as widgets
from IPython.display import display
import matplotlib.pyplot as plt; plt.rcdefaults()
from datetime import datetime
import numpy as np
import pickle
import glob

BKP_FOLDER = os.path.join("bkp", "bandit_task")
os.makedirs(BKP_FOLDER, exist_ok=True)


class BanditApplication:
    
    def __init__(self, labels, prob_dist, n_iteration):
        
        assert len(prob_dist) == len(labels)
        
        self.time_stamp = datetime.utcnow()
        
        self.labels = labels
        self.prob_dist = prob_dist
        self.hist = []
        
        self.n_iteration = n_iteration
        
        self.t = 0
        self.s = 0
        
        self.next_button = widgets.Button(
            description='???', 
            button_style='', 
            disabled=False)
        
        self.counter = widgets.Label(self.str_counter_value())
        self.score = widgets.Label(self.str_score_value())
        
        
        self.output = widgets.Output()
        
        self.elements = self.get_elements()
        self.menu = widgets.VBox(self.elements)
        
        display(self.menu, self.output)
        
    def on_next_clicked(self, b):
        
        self.t += 1
        
        self.next_button.disabled = True
        self.next_button.description = '???'
        self.next_button.button_style = ''
        
        self.enable_choice(True)
        
        self.update_counter()
        self.update_score()
        
        if self.t == self.n_iteration:
            self.end_task()
    
    def on_button_clicked(self, b):
        
        with self.output:
            idx_button = self.labels.index(b.description)
            self.hist.append(idx_button)
            self.enable_choice(False)
            self.next_button.disabled = False
            self.determine_success()
        
    def get_elements(self):
    
        elements = []
        for label in self.labels:
            
            # if label == "-":
            #     elements.append(widgets.Label(" "))
            # 
            # else:
            element = widgets.Button(description=label, 
                                     button_style='info', 
                                     disabled=False)
            element.on_click(self.on_button_clicked)
            elements.append(element)
        
        self.next_button.on_click(self.on_next_clicked)
        elements.append(self.next_button)
        elements.append(self.counter)
        elements.append(self.score)

        return elements
    
    def update_counter(self):
        self.counter.value = self.str_counter_value()
        
    def update_score(self):
        self.score.value = self.str_score_value()
    
    def determine_success(self):
        
        random_number = np.random.random()
        p_success = self.prob_dist[self.hist[-1]]
        
        success = p_success > random_number
        
        if success:
            self.s += 1
            self.next_button.button_style = 'success'
            self.next_button.description = "SUCCESS"
        else:
            self.next_button.button_style = 'danger'
            self.next_button.description = "FAILURE"
        
    def str_counter_value(self):
        return f"Trial {self.t} [{(self.t/self.n_iteration)*100:.2f}%]"
    
    def str_score_value(self):
        return f"Score: {self.s} pt(s)"
    
    def end_task(self):
        
        self.enable_choice(False)
        self.next_button.description = 'Thank you!'
        self.score.value = self.str_score_value() + " ...CONGRATS!"        
        self.counter.value = self.str_counter_value() + " Done!"
        self.backup_results()
        
    def backup_results(self):
        
        str_time_stamp = \
            str(datetime.utcnow())\
                .replace(" ", "_")\
                .replace(":", "-")\
                .replace(".", "-")
        
        bkp_file = os.path.join(BKP_FOLDER, f"{str_time_stamp}.p")
        
        results = {
            'time_stamp': self.time_stamp,
            'data': self.hist,
            'prob_dist': self.prob_dist
        }
        pickle.dump(results, open(bkp_file, 'wb'))
        
    def enable_choice(self, value):
        
        for i in range(len(self.labels)):
            self.elements[i].disabled = not value

BanditApplication(labels=["Option 1", "Option 2"], 
                  prob_dist=[0.2, 0.8],
                  n_iteration=10)

VBox(children=(Button(button_style='info', description='Option 1', style=ButtonStyle()), Button(button_style='…

Output()

<__main__.BanditApplication at 0x11efc6310>

In [12]:
def show_data():
    
    data_files = glob.glob(os.path.join(BKP_FOLDER, '*.p'))
    
    for i, df in enumerate(data_files):
        r = pickle.load(open(df, 'rb'))
        
        for key in sorted(r.keys(), reverse=True):
            print(f'{key}\n', r[key], '\n')
        
        print("=" * 10)

show_data()

time_stamp
 2019-11-12 13:44:13.793439 

prob_dist
 [0.2, 0.8] 

data
 [0, 1, 1, 0, 1, 1, 1, 1, 1, 1] 

time_stamp
 2019-11-12 13:48:10.445325 

prob_dist
 [0.2, 0.8] 

data
 [0, 1, 1, 0, 1, 0, 0, 1, 0, 0] 

