In [5]:
import pandas as pd
import numpy as np
from obp.dataset import SyntheticBanditDataset
from sklearn.utils import check_random_state
import random
from tqdm import tqdm
from obp.dataset.synthetic import * 

In [53]:
def select_context(n_rounds: int, n_actions: int, context_file_path: str):
    """
    Seleciona o contexto dos usuários para as rodadas e ações especificadas.

    Args:
        n_rounds (int): O número de rodadas.
        n_actions (int): O número de ações.
        context_file_path (str): O caminho do arquivo de contexto.

    Returns:
        Um tuple contendo um array de IDs de contexto e um array de vetores de contexto para cada usuário.
    """
    print("Selecionando contexto...")

    # Carrega o arquivo de contexto
    df = pd.read_csv(context_file_path, delimiter='|', converters={'context': eval})
    
    # Cria uma lista de contexto para cada usuário com base na frequência de ocorrência
    u_ids = [u for u, freq in zip(df['user_id'], df['freq']) for i in range(freq)]
    
    # Embaralha as listas de IDs
    random.shuffle(u_ids)
    u_contexts = [ df[df['user_id'] == u]['context'].values[0] for u in u_ids]
    np.array(u_contexts).shape
    print("Seleção de contexto concluída.")

    return np.array(u_ids), np.array(u_contexts)

In [54]:
users,context = select_context(100, 100, '/home/labpi/yan/iRec-OBP/examples/context/user.csv')
context.shape

Selecionando contexto...
Seleção de contexto concluída.


(100, 4)

In [3]:
def sample_action_context(action_dist: np.ndarray, users: np.ndarray, random_state: int = None) -> np.ndarray:
    """
    Samples actions for each user according to a distribution.

    Parameters:
        action_dist (numpy.ndarray): distribuição de probabilidade das ações para cada usuário a serem selecionadas.
        users (numpy.ndarray): Contem o indice dos usuarios a serem selecionados.
        random_state (int, optional): Semente para o gerador de números aleatórios.

    Returns:
        numpy.ndarray: Contem as ações selecionadas para cada usuário.

    """
    random_ = check_random_state(random_state)
    n_actions, n_users = action_dist.shape
    print('select action ', action_dist.shape)
    chosen_actions = np.zeros(n_users, dtype=np.int)

    cum_action_dist = np.cumsum(action_dist, axis=1)
    uniform_rvs = random_.uniform(size=n_users)
    print(cum_action_dist.shape)
    for i in tqdm(range(n_users), desc="Selecting actions"):
        hist = set()
        for _ in range(n_actions):
            action = np.argmax(cum_action_dist[i] > uniform_rvs[i])
            if action not in hist:
                chosen_actions[i] = action
                hist.add(action)
                break
            cum_action_dist[i][action] = -1

In [6]:
dataset = SyntheticBanditDataset(
    n_actions=100, # number of actions; |A|
    dim_context=5, # number of dimensions of context vector
    reward_function=logistic_reward_function, # mean reward function; q(x,a)
    behavior_policy_function=linear_behavior_policy, # behavior policy; \pi_b
    random_state=12345,
)

In [7]:
training_bandit_data = dataset.obtain_batch_bandit_feedback(n_rounds=100)
test_bandit_data = dataset.obtain_batch_bandit_feedback(n_rounds=100)

In [10]:
pi_b = training_bandit_data['pi_b']
pi_b.shape

(100, 100, 1)

In [9]:
sample_action_context(pi_b, users: np.ndarray, random_state: int = None) 

array([[[0.00984519],
        [0.00677159],
        [0.00500839],
        ...,
        [0.0106929 ],
        [0.00531413],
        [0.01092079]],

       [[0.01128862],
        [0.00557664],
        [0.00454625],
        ...,
        [0.00850335],
        [0.00839168],
        [0.00866747]],

       [[0.01237064],
        [0.00450942],
        [0.00511151],
        ...,
        [0.01166402],
        [0.00963456],
        [0.00900317]],

       ...,

       [[0.00720934],
        [0.00641521],
        [0.01171338],
        ...,
        [0.01533758],
        [0.00899217],
        [0.00854998]],

       [[0.00765181],
        [0.00702797],
        [0.00710734],
        ...,
        [0.00904024],
        [0.00636136],
        [0.010292  ]],

       [[0.00689777],
        [0.0081101 ],
        [0.00498613],
        ...,
        [0.00574977],
        [0.00462273],
        [0.00897257]]])