# Conetextual Bandits

The notion of a state in reinforcement learning is (more or less) the same as the notion of a context in contextual bandits. The main difference is that, in reinforcement learning, an action  in state  not only affects the reward  that the agent will get but it will also affect the next state  the agent will end up in, while, in contextual bandits (aka associative search problems), an action  in the state  only affects the reward

In [57]:
# Imports
import numpy as np
import pandas as pd
from scipy.optimize import minimize
from scipy import stats
import plotly.graph_objects as go
import plotly.offline
from plotly.subplots import make_subplots
import cufflinks as cf
import regex as re

In [58]:
cf.go_offline()
cf.set_config_file(world_readable=True, theme='white')

# Using logistic distribution to calculate probablity



```
p(x) = 1/(1 + e^(-f(x)))
```



In [59]:
# Creating user objects with specified Betas
class UserGenerator(object):
  # Initialize
  def __init__(self):
    self.beta = {}
    self.beta['A'] = np.array([-4, -0.1, -3, 0.1])
    self.beta['B'] = np.array([-6, -0.1, 1, 0.1])
    self.beta['C'] = np.array([2, 0.1, 1, -0.1])
    self.beta['D'] = np.array([4, 0.1, -3, -0.2])
    self.beta['E'] = np.array([-0.1, 0, 0.5, -0.01])
    self.context = None

  # Generate context
  def generate_user_with_context(self):
    # 0: Int'l 1: US
    location = np.random.binomial(n=1, p=0.6)
    # 0: Desktop 1: Mobile
    device = np.random.binomial(n=1, p=0.8)
    # User Age
    age = 10 + int(np.random.beta(2,3) * 60)
    # Add context
    self.context = [1, device, location, age]
    return self.context

  # Use function to generate clicks
  def logistic(self, beta, context):
    f = np.dot(self.beta, self.context)

  # display_ad
  def display_ad(self, ad):
    if ad in self.beta.keys():
      p = self.logistic(self.beta[ad], self.context)
      reward = np.random.binomial(n = 1, p=p)
      return reward

    else:
      raise Exception('Unknown Ad!')

In [60]:
#Visualise
def get_scatter(x: np.array, y: np.array, name: str, showlegend: bool) -> go.Figure():
  dashmap = {"A": "solid", "B": "dot", "C": "dash", "D": "dashdot", "E": "longdash"}
  s = go.Scatter(x=x, y=y, legendgroup = name,showlegend = showlegend, name = name, line = dict(color='blue', dash=dashmap[name]))
  return s

In [61]:
def visualise_bandits(ug: UserGenerator()) -> go.Figure():
  ad_list = 'ABCDE'
  ages = np.linspace(10,70)
  fig = make_subplots(rows=2, cols=2, subplot_titles=("Desktop, International","Desktop, U. S.", "Mobile, International", "Mobile, U. S."))
  for device in [0,1]:
    for loc in [0,1]:
      showlegend = (device == 0) & (loc == 0)
      for ad in ad_list:
        probs = [(ug.beta[ad], np.array([1, device, loc, age])) for age in ages]
        fig.add_trace(get_scatter(ages, probs, ad, True), row=device+1, col=loc+1)

  fig.update_layout(template="presentation")
  return fig

In [63]:
# Simualte
ug = UserGenerator()
x = visualise_bandits(ug)
x.show()