In [1]:
import torch
from itertools import product
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
user_states = [-1, 0, 1]
items = [-1, 1]
slates = list(product(items, repeat=4))
num_actions=len(slates)

In [3]:
def reward(u,r):
    rew=0
    for j in range(len(r)):
        if u!=0:
            rew+=u*r[j]
        else:
            rew+=0.0
    return rew
        

In [4]:
def transition_function(current_state, action):
    # Update the new state based on the action
    new_state = current_state + sum(action)
    
    # Clip the new state to ensure it stays within the defined states (1, -1, 0)
    new_state = max(min(new_state, 1), -1)
    
    return new_state

In [5]:
Q_table = np.zeros((len(user_states), len(slates)))

In [6]:
def update_q_value(Q_table, current_state, slate_action, reward_value, next_state,alpha, gamma):
    # Calculate the Q-value for the current state-action pair
    action=slates.index(slate_action)
    state=user_states.index(current_state)
    current_q_value = Q_table[state, action]

    # Find the maximum Q-value for the next state and all possible actions
    next_state = user_states.index(next_state)
    max_next_q_value = np.max(Q_table[next_state, :])

    # Update the Q-value using the Q-learning formula
    
    new_q_value = current_q_value + alpha * (reward_value + gamma * max_next_q_value - current_q_value)

    # Update the Q-table with the new Q-value
    Q_table[state, action] = new_q_value

    return Q_table

In [7]:
def epsilon_greedy_action(Q_table, current_state, epsilon):
    state=user_states.index(current_state)
    if np.random.rand() < epsilon:
        # Choose a random action

        i = np.random.choice(num_actions)
        action = slates[i]
    else:
        # Choose the action with the highest Q-value
        i = np.argmax(Q_table[state])
        action = slates[i]
    return action

In [8]:
def run_episodes(Q_table, transition_function, reward, update_q_value, num_episodes, alpha, gamma, epsilon):
    for episode in range(num_episodes):
        # Randomly choose an initial state
        current_state = np.random.choice(user_states)
        
        while True:
            # Choose an action using epsilon-greedy strategy
            selected_action = epsilon_greedy_action(Q_table, current_state, epsilon)
            action=slates.index(selected_action)
            # Transition to the next state based on the chosen action
            next_state = transition_function(current_state, selected_action)
            
            # Assume a reward for the transition (you should replace this with the actual reward from your environment)
            reward_value = reward(current_state,selected_action) # Replace with the actual reward
            
            # Update the Q-value based on the transition
            Q_table = update_q_value(Q_table, current_state, selected_action, reward_value, next_state,alpha, gamma)
            
            # Move to the next state
            current_state = next_state
            
            # Terminate the episode if a terminal state is reached (you should replace this with your termination condition)
            if current_state in [-1, 1]:
                break


In [9]:
num_episodes = 20000
alpha = 0.1
gamma = 0.9
epsilon = 0.1

run_episodes(Q_table, transition_function, reward, update_q_value, num_episodes, alpha, gamma, epsilon)


In [10]:
first_max = np.argmax(Q_table, axis=1)
second_max= np.argsort(Q_table, axis=1)[:,-2]
# Print the result
for i in range(len(first_max)):
    print(first_max[i],second_max[i])
    print(f"For user state {user_states[i]}: optimal slate is {slates[first_max[i]]} and {slates[second_max[i]]} ")

0 4
For user state -1: optimal slate is (-1, -1, -1, -1) and (-1, 1, -1, -1) 
0 13
For user state 0: optimal slate is (-1, -1, -1, -1) and (1, 1, -1, 1) 
15 7
For user state 1: optimal slate is (1, 1, 1, 1) and (-1, 1, 1, 1) 


In [11]:
Q_table

array([[10.        ,  9.2482646 ,  9.35905099,  8.92934838,  9.44416439,
         8.58028066,  8.52216016,  8.49047049,  9.41622888,  8.73392176,
         8.74951174,  8.40197068,  8.9011353 ,  8.34902652,  8.45464298,
         7.91071345],
       [ 9.        ,  8.92265867,  8.92019606,  8.01080498,  8.90464382,
         7.92567025,  7.89626084,  8.92819625,  8.47663113,  8.04214741,
         8.05261731,  8.95683373,  8.03549523,  8.97365755,  8.91209385,
         8.96949394],
       [ 7.88963102,  8.463682  ,  8.33352851,  8.69011501,  8.433103  ,
         8.7189583 ,  8.83638412,  9.44715532,  8.41911709,  8.90637813,
         8.91155333,  9.4077317 ,  8.89048478,  9.22678888,  9.37855277,
        10.        ]])