In [1]:
import numpy as np

import itertools

from distributions.distribution_by_sequence import DistributionBySequence
from distributions.distribution import Distribution
from distributions.sequence import Sequence
from algorithms.semi_bandit_exp3 import SemiBanditExp3
from algorithms.full_bandit_exp3 import FullBanditExp3
from algorithms.semi_bandit_ftrl import SemiBanditFTRL
from algorithms.uniform_random import UniformRandom
from algorithms.non_contextual_exp3 import NonContextualExp3
from experiment_manager.experiment_manager import ExperimentManager

from distributions.actionsets.msets import MSets

from distributions.contexts.binary_context import BinaryContext
from distributions.thetas.single_hole import SingleHole
from distributions.thetas.independent_bernoulli import IndependentBernoulli

rng = np.random.default_rng()

In [2]:
algo = FullBanditExp3()

length = 1000
d = 2
K = 2
actionset = MSets(K, 1)

epsilon = 0.25 * np.min([np.sqrt(K / length), 1])
print("epsilon: ", epsilon)
p = np.zeros((d, K)) + 0.5
for i in range(d):
    p[i, 0] -= epsilon

dist_lower_bound = Distribution(BinaryContext(d), IndependentBernoulli(d, K, p), actionset)
dist_holes = Distribution(BinaryContext(d), SingleHole(d, K, np.array([0.7, 0.3])), actionset)

seq = dist_lower_bound.generate(length, rng, rng)
algo.set_constants(rng, seq)
print(seq.sigma, seq.m, algo.beta)
algo.run_on_sequence(rng, seq)
None

epsilon:  0.011180339887498949
1.0 1 0.5


In [5]:
perms = list(itertools.permutations(np.arange(4)))

letters = np.array(["a", "b", "c", "d"])

action_matrix = np.zeros((K,K))
for i in range(d):
    context = np.zeros(d)
    context[i] = 1

    probabilities = algo.get_policy(context)
    weighted_action = np.einsum("ab,a->b", algo.actionset.actionset, probabilities)
    action_matrix += np.outer(weighted_action, weighted_action) / d
answer = np.kron(np.identity(d)/d,  action_matrix)

for perm in perms:
    string = ""
    for index in perm:
        string += letters[index]
    correct = np.einsum(f"ab,cd->{string}", np.identity(d)/d,  action_matrix).reshape((d*K, d*K)) == answer
    print("trying perm", string, perm, np.all(correct))


trying perm abcd (0, 1, 2, 3) False
trying perm abdc (0, 1, 3, 2) False
trying perm acbd (0, 2, 1, 3) True
trying perm acdb (0, 2, 3, 1) False
trying perm adbc (0, 3, 1, 2) True
trying perm adcb (0, 3, 2, 1) False
trying perm bacd (1, 0, 2, 3) False
trying perm badc (1, 0, 3, 2) False
trying perm bcad (1, 2, 0, 3) True
trying perm bcda (1, 2, 3, 0) False
trying perm bdac (1, 3, 0, 2) True
trying perm bdca (1, 3, 2, 0) False
trying perm cabd (2, 0, 1, 3) False
trying perm cadb (2, 0, 3, 1) False
trying perm cbad (2, 1, 0, 3) False
trying perm cbda (2, 1, 3, 0) False
trying perm cdab (2, 3, 0, 1) False
trying perm cdba (2, 3, 1, 0) False
trying perm dabc (3, 0, 1, 2) False
trying perm dacb (3, 0, 2, 1) False
trying perm dbac (3, 1, 0, 2) False
trying perm dbca (3, 1, 2, 0) False
trying perm dcab (3, 2, 0, 1) False
trying perm dcba (3, 2, 1, 0) False


In [None]:
d = 10
K = 20

contexts = np.zeros((d,d))
for i in range(d):
    contexts[i, i] = 1

actions = np.zeros((K,K))
for i in range(K):
    actions[i, i] = 1


answer_dict = {}
for context in contexts:
    for action in actions:
        key = str(context) + str(action)
        buffer = np.zeros((d * K, d * K))
        x = np.where(context)[0] + np.where(action)[0] * d
        buffer[x,x] = 1
        answer_dict[key] = buffer

perms = list(itertools.permutations(np.arange(4)))

letters = np.array(["a", "b", "c", "d"])

def works(string):
    for context in contexts:
        for action in actions:
            tensor = np.einsum(f"a,b,c,d->{string}", context, context, action,  action)
            key = str(context) + str(action)
            if not np.all(tensor.reshape(d * K, d * K) == answer_dict[key]):
                return False
    return True

for perm in perms:
    string = ""
    for index in perm:
        string += letters[index]
    print("trying perm", string, perm, works(string))


trying perm abcd (0, 1, 2, 3) False
trying perm abdc (0, 1, 3, 2) False
trying perm acbd (0, 2, 1, 3) False
trying perm acdb (0, 2, 3, 1) False
trying perm adbc (0, 3, 1, 2) False
trying perm adcb (0, 3, 2, 1) False
trying perm bacd (1, 0, 2, 3) False
trying perm badc (1, 0, 3, 2) False
trying perm bcad (1, 2, 0, 3) False
trying perm bcda (1, 2, 3, 0) False
trying perm bdac (1, 3, 0, 2) False
trying perm bdca (1, 3, 2, 0) False
trying perm cabd (2, 0, 1, 3) False
trying perm cadb (2, 0, 3, 1) True
trying perm cbad (2, 1, 0, 3) False
trying perm cbda (2, 1, 3, 0) True
trying perm cdab (2, 3, 0, 1) False
trying perm cdba (2, 3, 1, 0) False
trying perm dabc (3, 0, 1, 2) False
trying perm dacb (3, 0, 2, 1) True
trying perm dbac (3, 1, 0, 2) False
trying perm dbca (3, 1, 2, 0) True
trying perm dcab (3, 2, 0, 1) False
trying perm dcba (3, 2, 1, 0) False
