In [None]:
import import_ipynb

from keras.layers import Input, Dense, Concatenate, GlobalAveragePooling1D, Dropout
from IPython.display import display, Image, update_display, HTML
from tensorflow.keras.utils import plot_model
from keras.models import Model, Sequential
from transformer import TransformerBlock
from keras.regularizers import L1L2
from threading import Thread, Lock
from dataclasses import dataclass
import user_ad_interaction
from poibin import PoiBin
import keras_tuner as kt
import tensorflow as tf
from queue import Queue
import adgorithm_htuner
from tqdm import tqdm
import pandas as pd
import numpy as np
import scipy.stats
import warnings
import bisect
import tools
import users
import time
import os

%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm


In [None]:
root_directories = ["ADS16_Benchmark_part1", "ADS16_Benchmark_part2"]

In [None]:
tuner = tools.HyperbandWithBatchSize(
    user_ad_interaction.create_model,
    objective=kt.Objective("val_rating_mae", direction="min"),
    max_epochs=100,
    factor=3,
    hyperband_iterations=5,
    directory="user_ad_interaction_model_logs/hypertraining",
    project_name="user_ad_interaction"
)
best_hps=tuner.get_best_hyperparameters(num_trials=1)[0]

user_model = user_ad_interaction.create_model(best_hps)
user_model.load_weights("user_ad_interaction_model_logs/checkpoints/model-000020-1.202267.hdf5")

In [None]:
ad_ftrs, num_categories, real_user_ftrs, _ = user_ad_interaction.load_user_and_ad_ftrs(root_directories)
user_pca_normal_params = users.approximate_normal_params(real_user_ftrs.T)

In [None]:
rating_to_exp_ctr = {
    1: 0.014925,
    2: 0.024786,
    3: 0.031071,
    4: 0.040562,
    5: 0.068341
}

In [None]:
tuner = tools.HyperbandWithBatchSize(
    lambda hp: adgorithm_htuner.create_rl_model(hp),
    objective=kt.Objective("val_output_1_loss", direction="min"),
    max_epochs=100,
    factor=3,
    hyperband_iterations=5,
    directory="adgorithm_logs/hypertraining",
    project_name="adgorithm_htuning"
)
rl_best_hps=tuner.get_best_hyperparameters(num_trials=1)[0]

model = adgorithm_htuner.create_rl_model(rl_best_hps)
model.summary()
c = tools.ModelDisplayer()
c.model = model
c.on_epoch_end(0)

In [None]:
@dataclass
class RankedAd:
    ad_id: int
    value: float

    def __lt__(self, other):
        return self.value < other.value
    
    def __repr__(self):
        return f"{self.ad_id}{{{self.value}}}"

class Memory:
    def __init__(self, size=100):
        self.size = size
        self.ad_ftrs = (
            np.zeros((self.size, ad_ftrs[0].shape[-1]+1)),
            np.zeros((self.size, *ad_ftrs[1].shape[1:]))
        )
        self.rewards = np.zeros((self.size))
        self.ptr = 0
        self.full = False
    
    def store(self, ad_ftrs, num_clicks, reward):
        self.ad_ftrs[0][self.ptr][:-1] = ad_ftrs[0]
        self.ad_ftrs[0][self.ptr][-1] = num_clicks
        self.ad_ftrs[1][self.ptr] = ad_ftrs[1]
        self.rewards[self.ptr] = reward

        self.ptr += 1
        if self.ptr >= self.size:
            self.full = True
            self.ptr = 0
    
    def get_training_data(self):
        if self.full:
            return (self.ad_ftrs, self.rewards)
        else:
            return ((self.ad_ftrs[0][:self.ptr], self.ad_ftrs[1][:self.ptr]), self.rewards[:self.ptr])

class User:
    def __init__(self, ftrs=None):
        self.ftrs = ftrs
        if self.ftrs is None:
            self.ftrs = users.generate_synthetic_pca_ftrs(1, user_pca_normal_params)[0]
        self.rl_model = adgorithm_htuner.create_rl_model(rl_best_hps)
        self.sorted_ads = []
        self.memory = Memory()
        self.clicks = {}
        self.interaction_history = []
        self.diminishing_returns_coeff = 0.9
        self.ad_memory_size = 100
        self.random_mean_ctrs = []
        self.cheater_ctrs = []
        self.interaction_history_max_size = 100
    
    def get_all_true_exp_ctr(self, ad_ids, ad_ftrs):
        ratings = np.argmax(user_model((
            np.concatenate([np.broadcast_to(self.ftrs[np.newaxis], (len(ad_ftrs[0]), *self.ftrs.shape)), ad_ftrs[0]], axis=-1),
            ad_ftrs[1]
        ), training=False), axis=-1) + 1
        exp_ctrs = np.zeros_like(ratings, dtype="float32")
        for i, rating in enumerate(ratings):
            exp_ctrs[i] = rating_to_exp_ctr[rating]
            if (ad_id := ad_ids[i]) in self.clicks:
                exp_ctrs[i] *= self.diminishing_returns_coeff ** self.clicks[ad_id]
        
        return exp_ctrs
    
    def get_true_exp_ctr(self, ad_id, ad_ftrs):
        return self.get_all_true_exp_ctr([ad_id], (ad_ftrs[0][np.newaxis], ad_ftrs[1][np.newaxis]))[0]

    def predict_reward(self, ad_id, ad_ftrs):
        c = self.clicks[ad_id] if ad_id in self.clicks else 0
        exp_reward = self.rl_model((np.concatenate((ad_ftrs[0], [c]))[np.newaxis], ad_ftrs[1][np.newaxis]))[0,0]
        ra = RankedAd(ad_id, exp_reward)
        for i, existing_ad in enumerate(self.sorted_ads):
            if existing_ad.ad_id == ad_id:
                del self.sorted_ads[i]
                break
        bisect.insort(self.sorted_ads, ra)
    
    def interact_with(self, ad_id, ad_ftrs):
        exp_ctr = self.get_true_exp_ctr(ad_id, ad_ftrs)
        did_click = np.random.uniform() <= exp_ctr
        self.interaction_history.append(did_click)
        if len(self.interaction_history) > self.interaction_history_max_size:
            self.interaction_history = self.interaction_history[1:]
        if did_click:
            if ad_id not in self.clicks:
                self.clicks[ad_id] = 1
            else:
                self.clicks[ad_id] += 1
            AdDatabase.ad_was_clicked(ad_id)
        return did_click
    
    def select_important_ad(self, p_random=0.1):
        if np.random.uniform() <= p_random or len(self.sorted_ads) == 0:
            return AdDatabase.get_random_ad()
        ad_id = self.sorted_ads[np.random.randint(0, len(self.sorted_ads))].ad_id
        return AdDatabase.get_ad(ad_id)
    
    def select_ad(self, p_random=0.1):
        if np.random.uniform() <= p_random or len(self.sorted_ads) == 0:
            return AdDatabase.get_random_ad()

        return AdDatabase.get_ad(self.sorted_ads[-1].ad_id)
            
    def remember(self, ad_id, ad_ftrs, reward):
        self.memory.store(ad_ftrs, self.clicks[ad_id] if ad_id in self.clicks else 0, reward)
        all_ctrs = self.get_all_true_exp_ctr(*AdDatabase.get_all_available_ads())
        self.random_mean_ctrs.append(np.mean(all_ctrs))
        if len(self.random_mean_ctrs) > self.interaction_history_max_size:
            self.random_mean_ctrs = self.random_mean_ctrs[1:]
        self.cheater_ctrs.append(np.max(all_ctrs))
        if len(self.cheater_ctrs) > self.interaction_history_max_size:
            self.cheater_ctrs = self.cheater_ctrs[1:]

    def learn(self, rl_epochs=10, n_ads_to_fetch=10, p_random=0.1):
        if p_random >= 1:
            return
        self.rl_model.fit(
            *self.memory.get_training_data(),
            verbose=0,
            epochs=rl_epochs,
            batch_size=32
        )

        ads_resampled = []
        if len(self.sorted_ads) > 0:
            self.predict_reward(*AdDatabase.get_ad(best_ad_id := self.sorted_ads[-1].ad_id))
            ads_resampled.append(best_ad_id)
        while len(ads_resampled) < n_ads_to_fetch:
            ad = self.select_important_ad(p_random=p_random)
            if (ad_id := ad[0]) in ads_resampled:
                continue
            self.predict_reward(*ad)
            ads_resampled.append(ad_id)
        self.sorted_ads = self.sorted_ads[-self.ad_memory_size:]
    
    def evaluate_prediction_accuracy_for_ads(self, ad_ids, ad_ftrs):
        ground_truth = self.get_all_true_exp_ctr(ad_ids, ad_ftrs)
        num_times_clicked = np.zeros((len(ad_ids), 1))
        for i, ad_id in enumerate(ad_ids):
            num_times_clicked[i] = self.clicks[ad_id] if ad_id in self.clicks else 0
        pred = self.rl_model((
            np.concatenate([ad_ftrs[0], num_times_clicked], axis=-1),
            ad_ftrs[1]
        ), training=False)[:,0].numpy()
        err = pred - ground_truth
        mse = np.mean(err*err)
        return mse
    
    def evaluate_prediction_accuracy_for_available_ads(self):
        if len(self.sorted_ads) == 0:       # Quick and dirty way to detect whether user has p_random=1 or is Cheater
            return np.nan
        return self.evaluate_prediction_accuracy_for_ads(*AdDatabase.get_all_available_ads())
    
    def evaluate_prediction_accuracy_for_ads_in_memory(self):
        ad_ids = list(x.ad_id for x in self.sorted_ads)
        ad_ftrs = list(AdDatabase.get_ad(ad_id, random_if_unavailable=False) for ad_id in ad_ids)
        ad_ftrs = list(x for x in ad_ftrs if x[0] is not None)
        if len(ad_ftrs) == 0:
            return np.nan
        ad_ids, ad_ftrs = list(zip(*ad_ftrs))

        ad_ftrs_0 = list(x[0] for x in ad_ftrs)
        ad_ftrs_1 = list(x[1] for x in ad_ftrs)
        ad_ftrs = (
            np.stack(ad_ftrs_0),
            np.stack(ad_ftrs_1)
        )

        return self.evaluate_prediction_accuracy_for_ads(ad_ids, ad_ftrs)
    
    def get_significance_levels(self):
        p_random = PoiBin(self.random_mean_ctrs).pval(int(np.sum(self.interaction_history)))
        p_cheater = PoiBin(self.cheater_ctrs).cdf(int(np.sum(self.interaction_history)))
        return p_random, p_cheater

class Cheater(User):
    def select_ad(self, *args, **kwargs):
        ad_ids, ad_ftrs = AdDatabase.get_all_available_ads()
        winner = np.argmax(self.get_all_true_exp_ctr(ad_ids, ad_ftrs))
        return ad_ids[winner], (ad_ftrs[0][winner], ad_ftrs[1][winner])
    
    def learn(self, *args, **kwargs):
        return
    

In [None]:
available_ads = {}
ad_clicks = {}
for i, af0 in enumerate(ad_ftrs[0]):
    af1 = ad_ftrs[1][i]
    available_ads[i] = (af0, af1)
    ad_clicks[i] = 0
ad_id_ctr = len(ad_ftrs[0])
click_thresh = 100

class NullLock:
    def __enter__(self, *args, **kwargs):
        pass
    def __exit__(self, *args, **kwargs):
        pass

class BypassableLock:
    def __init__(self):
        self.lock = Lock()
    
    def __call__(self, use_lock=True):
        return self.lock if use_lock else NullLock()
    
    def __enter__(self, *args, **kwargs):
        return self.lock.__enter__(*args, **kwargs)

    def __exit__(self, *args, **kwargs):
        return self.lock.__exit__(*args, **kwargs)

ad_db_lock = BypassableLock()

class AdDatabase:
    @staticmethod
    def get_ad(ad_id, use_lock=True, random_if_unavailable=True):
        with ad_db_lock(use_lock):
            if ad_id in available_ads:
                return ad_id, available_ads[ad_id]
            return AdDatabase.get_random_ad(use_lock=False) if random_if_unavailable else (None, None)

    @staticmethod
    def get_random_ad(use_lock=True):
        with ad_db_lock(use_lock):
            ad_id = np.random.choice(list(available_ads.keys()))
            return AdDatabase.get_ad(ad_id, use_lock=False)
    
    @staticmethod
    def ad_was_clicked(ad_id, use_lock=True):
        with ad_db_lock(use_lock):
            ad_clicks[ad_id] += 1
            if ad_clicks[ad_id] >= click_thresh:
                available_ads[ad_id_ctr] = available_ads[ad_id]
                ad_clicks[ad_id_ctr] = 0
                ad_id_ctr += 1
                del ad_clicks[ad_id]
                del available_ads[ad_id]
    
    @staticmethod
    def get_all_available_ads():
        ad_ids = list(available_ads.keys())
        ad_ftrs = list(available_ads.values())
        ad_ftrs_0 = list(x[0] for x in ad_ftrs)
        ad_ftrs_1 = list(x[1] for x in ad_ftrs)
        ad_ftrs = (
            np.stack(ad_ftrs_0),
            np.stack(ad_ftrs_1)
        )
        return ad_ids, ad_ftrs



In [None]:
class IterationDataPrinter(Thread):
    def __init__(self, num_sims, num_trainable_sims):
        super().__init__(daemon=True)
        self.iteration = -1
        self.killed = False
        self.num_sims = num_sims
        self.num_trainable_sims = num_trainable_sims
    
    @staticmethod
    def format_time(time):
        return f"{str(int(time//60)).rjust(2, '0')}:{str(int(time)%60).rjust(2, '0')}"
    
    @staticmethod
    def format_proportion(x, y):
        return f"[{x}/{y}] ({(100*(x/y)):.2f}%)"

    def new_iteration(self):
        now = time.time()
        self.num_sims_completed = 0
        self.num_trainable_sims_completed = 0
        self.iteration += 1
        self.last_iteration_time_string = "<b>Last iteration:</b> "
        if self.iteration > 0:
            self.last_iteration_time_string += self.__class__.format_time(now - self.iteration_start_time)
        else:
            self.last_iteration_time_string += "n/a"
        self.iteration_start_time = now
    
    def sim_completed(self, is_trainable):
        if is_trainable:
            self.num_trainable_sims_completed += 1
        self.num_sims_completed += 1
    
    def kill(self):
        self.killed = True

    def run(self):
        while not self.killed:
            if self.iteration < 0:
                continue
            lines = [
                f"<h1>Iteration {self.iteration}</h2>",
                self.last_iteration_time_string,
                f"<b>This iteration:</b> {self.__class__.format_time(time.time() - self.iteration_start_time)}",
                f"<b>Sims completed: {self.__class__.format_proportion(self.num_sims_completed, self.num_sims)}",
                f"<b>Trainable sims completed: {self.__class__.format_proportion(self.num_trainable_sims_completed, self.num_trainable_sims)}",
            ]
            update_display(HTML("<br/>".join(lines)), display_id="iteration_data")

In [None]:
class Simulation(Thread):
    def __init__(self, sim_id, user=None, p_random=0.1):
        super().__init__(daemon=True)
        self.sim_id = sim_id
        self.user = user
        if self.user is None:
            self.user = User()
        self.killed = False
        self.p_random = p_random
        self.lock = BypassableLock()
        self.restart_queue = Queue()
    
    def kill(self, blocking=True):
        with self.lock(blocking):
            self.killed = True
            self.restart_queue.put(False)
    
    def run(self):
        while not self.killed:
            if self.restart_queue.get() == False:
                break
            with self.lock:
                ad = self.user.select_ad(p_random=self.p_random)
                did_click = self.user.interact_with(*ad)
                if not self.killed:
                    q.put((EVENT_INTERACTION, self.sim_id, did_click))
                    self.user.remember(*ad, did_click)
                    self.user.learn(p_random=self.p_random)
                    q.put((
                        EVENT_PREDICTION_ACCURACY,
                        self.sim_id,
                        self.user.evaluate_prediction_accuracy_for_available_ads(),
                        self.user.evaluate_prediction_accuracy_for_ads_in_memory()
                    ))
                    q.put((
                        EVENT_SIGNIFICANCE_LEVEL,
                        self.sim_id,
                        *self.user.get_significance_levels()
                    ))
                    q.put((EVENT_ITERATION_COMPLETE, self.sim_id))

q = Queue()
EVENT_INTERACTION = 0
EVENT_PREDICTION_ACCURACY = 1
EVENT_SIGNIFICANCE_LEVEL = 2
EVENT_ITERATION_COMPLETE = 3

sims = []
clicks = []
ctr_dfs = []
accs = []
significance_levels = []

rolling_window_size = 100

num_users_per_p = 1000
synthetic_user_ftrs = users.generate_synthetic_pca_ftrs(num_users_per_p, user_pca_normal_params)

p_randoms = [
    0.01,
    0.05,
    0.10,
    1.00,
    -1.0
]

ps = sorted(p_randoms)
colors = {}
if len(ps) > 1:
    for i, v in enumerate(ps):
        colors[v] = i/(len(ps)-1)
elif len(ps) == 1:
    colors[ps[0]] = 1.

display(HTML("<h1>Iteration Data</h1>"), display_id="iteration_data")
display(HTML("<h1>CTR Graph</h1>"), display_id="ctr_graph")
display(HTML("<h1>Prediction Accuracy Graph</h1>"), display_id="pred_acc_graph")
display(HTML("<h1>Significance Levels Graph</h1>"), display_id="sig_graph")

total_num_sims = len(p_randoms) * num_users_per_p
total_num_trainable_sims = len(list(filter(lambda x: 0 <= x < 1, p_randoms))) * num_users_per_p

idp = IterationDataPrinter(total_num_sims, total_num_trainable_sims)

i = 0
for p_random in p_randoms:
    for j in range(num_users_per_p):
        sims.append(Simulation(i, user=(Cheater if p_random < 0 else User)(synthetic_user_ftrs[j]), p_random=p_random))
        clicks.append([])
        accs.append(([], []))
        significance_levels.append(([], []))
        ctr_dfs.append(pd.Series(dtype="float32"))
        i += 1

def plot_stat(ax, generator, min_size=0.01):
    all_ys = np.array([])
    mean_ys = []

    by_p = {}
    for s in sims:
        ys = generator(s.sim_id)
        if np.all(np.isnan(ys)):
            continue
        p_random = round(s.p_random, 5)
        if p_random not in by_p:
            by_p[p_random] = []
        by_p[p_random].append(ys)
        ax.plot(ys, alpha=0.1, linestyle="dashed", color=cm.rainbow(colors[p_random]))
        all_ys = np.concatenate([all_ys, ys])
    
    for p_random in sorted(by_p.keys()):
        if len(by_p[p_random]) == 0:
            m = np.array([])
        else:
            max_len = max(len(x) for x in by_p[p_random])
            for i, x in enumerate(by_p[p_random]):
                y = np.empty(max_len)
                y[:len(x)] = x
                y[len(x):] = np.nan
                by_p[p_random][i] = y
            m = np.mean(np.stack(by_p[p_random]), axis=0)
        ax.plot(m, linewidth=2, color=cm.rainbow(colors[p_random]), label=((f"ε={p_random}" + (" (Random)" if p_random >= 1 else "")) if p_random >= 0 else "Cheater"))
        if len(m[~np.isnan(m)]):
            mean_ys.append(m[~np.isnan(m)][-1])
    
    if len(by_p):
        ax.legend()

    all_non_nan_ys = all_ys[~np.isnan(all_ys)]
    y_max = max([np.percentile(all_non_nan_ys, 90) if len(all_non_nan_ys) else min_size, *mean_ys, min_size])
    dy = y_max * 0.1
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        ax.set_ylim(top=y_max + dy, bottom=-dy)

try:
    for s in sims:
        s.start()
    idp.start()
    
    while True:
        for s in sims:
            s.restart_queue.put(None)

        sims_to_complete_iteration = total_num_sims
        idp.new_iteration()
        while sims_to_complete_iteration > 0:
            event_type, *event = q.get()
            if event_type == EVENT_INTERACTION:
                sim_id, did_click = event
                clicks[sim_id].append(did_click)
                if len(clicks[sim_id]) <= rolling_window_size:
                    ctr_dfs[sim_id] = pd.Series(clicks[sim_id]).expanding().mean()
                else:
                    ctr_dfs[sim_id] = pd.Series(clicks[sim_id]).rolling(rolling_window_size).mean()
                
                fig, ax = plt.subplots(figsize=(20, 10))
                fig.patch.set_facecolor("white")
                
                ax.set_xlabel("Iterations")
                ax.set_ylabel("CTR")
                ax.set_title(f"Average CTR over last {rolling_window_size} iterations.")
                plot_stat(ax, lambda sim_id: ctr_dfs[sim_id].values)
        
                update_display(fig, display_id="ctr_graph")
                plt.close()
            
            elif event_type == EVENT_PREDICTION_ACCURACY:
                sim_id, all_ads, ads_in_mem = event

                accs[sim_id][0].append(all_ads)
                accs[sim_id][1].append(ads_in_mem)

                fig, axs = plt.subplots(2, figsize=(20, 20))
                fig.patch.set_facecolor("white")

                i=0

                axs[i].set_title("Prediction MSE for All Available Ads in Database")
                axs[i].set_xlabel("Iterations")
                axs[i].set_ylabel("MSE")
                axs[i].set_yscale('log')
                plot_stat(axs[i], lambda sim_id: np.array(accs[sim_id][0]), min_size=0)
                i += 1

                axs[i].set_title("Prediction MSE for All Ads in User Memory")
                axs[i].set_xlabel("Iterations")
                axs[i].set_ylabel("MSE")
                axs[i].set_yscale('log')
                plot_stat(axs[i], lambda sim_id: np.array(accs[sim_id][1]), min_size=0)
                i += 1

                update_display(fig, display_id="pred_acc_graph")
                plt.close()
            
            elif event_type == EVENT_SIGNIFICANCE_LEVEL:
                sim_id, sig_random, sig_cheater = event

                significance_levels[sim_id][0].append(sig_random)
                significance_levels[sim_id][1].append(sig_cheater)

                fig, axs = plt.subplots(2, figsize=(20, 20))
                fig.patch.set_facecolor("white")

                i=0

                axs[i].set_title("Right-Tailed P-Values ($\mathregular{H_0}$=Ads Are Chosen Randomly)")
                axs[i].set_xlabel("Iterations")
                axs[i].set_ylabel("p")
                axs[i].axhline(0.05, c="black", label="5%", linestyle="dashed")
                plot_stat(axs[i], lambda sim_id: np.array(significance_levels[sim_id][0]))
                i += 1

                axs[i].set_title("Left-Tailed P-Values ($\mathregular{H_0}$=Ads Are Chosen Like Cheater)")
                axs[i].set_xlabel("Iterations")
                axs[i].set_ylabel("p")
                axs[i].axhline(0.05, c="black", label="5%", linestyle="dashed")
                plot_stat(axs[i], lambda sim_id: np.array(significance_levels[sim_id][1]))
                i += 1

                
                update_display(fig, display_id="sig_graph")
                plt.close()
            
            elif event_type == EVENT_ITERATION_COMPLETE:
                sim_id, = event
                sims_to_complete_iteration -= 1
                idp.sim_completed(0 <= sims[sim_id].p_random < 1)


except KeyboardInterrupt:
    pass
finally:
    plt.close()
    for s in sims:
        s.kill(False)
    idp.kill()