In [None]:
import matplotlib.pyplot as plt
import random
import sys

sys.path.append('../')
from utils import prep_shoe_data
from data_loader import get_data

from lin_ucb.lin_ucb import LinearUCB
from lin_ucb.bandit import ContextualBandit
from lin_ucb.dataset.ucb_dataset import ShoeDataset

In [None]:
def train_lin_ucb(seed, alpha, num_train, task='shoe', num_test=100, num_arms=6, num_features=7):
    # Set up
    seed = seed
    alpha = alpha
    random.seed(seed)

    # Training prep
    train_data, test_data = get_data(task, num_train, num_test)
    train_x, train_y = prep_shoe_data(train_data)
    train_dataset = ShoeDataset(x=train_x, y=train_y)
    train_bandit = ContextualBandit(T=num_train, n_arms=num_arms, seed=seed, dataset=train_dataset)

    # Training
    linucb = LinearUCB(bandit=train_bandit, num_features=num_features, alpha=alpha)
    linucb.train()
    train_acc = round(1-sum(linucb.regrets)/len(linucb.regrets), 3)

    # Testing prep
    test_x, test_y = prep_shoe_data(test_data)
    test_dataset = ShoeDataset(x=test_x, y=test_y)
    test_bandit = ContextualBandit(T=num_test, n_arms=num_arms, seed=seed, dataset=test_dataset)

    # Testing
    linucb.test(test_bandit)
    test_acc = round(1-sum(linucb.regrets)/len(linucb.regrets), 3)


    return train_acc, test_acc

In [None]:
train_acc_by_alpha = {}
test_acc_by_alpha = {}
for alpha in [0.0, 0.01, 0.05, 0.1, 0.3, 0.5, 1.0, 2.0, 4.0]:
    for seed in [0, 1, 2, 3, 4]:
        train_acc, test_acc = train_lin_ucb(seed, alpha, num_train=100)
        if alpha in train_acc_by_alpha:
            train_acc_by_alpha[alpha].append(train_acc)
        else:
            train_acc_by_alpha[alpha] = [train_acc]
        if alpha in test_acc_by_alpha:
            test_acc_by_alpha[alpha].append(test_acc)
        else:
            test_acc_by_alpha[alpha] = [test_acc]

print(f'train_acc_by_alpha: {train_acc_by_alpha}')
for alpha in train_acc_by_alpha:
    print(f'alpha: {alpha}, average train_acc: {round(sum(train_acc_by_alpha[alpha])/len(train_acc_by_alpha[alpha]), 3)}')

print(f'test_acc_by_alpha: {test_acc_by_alpha}')
for alpha in test_acc_by_alpha:
    print(f'alpha: {alpha}, average test_acc: {round(sum(test_acc_by_alpha[alpha])/len(test_acc_by_alpha[alpha]), 3)}')