In [1]:
# Import necessary libraries
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from tqdm import tqdm
import time

In [3]:
class ArmedBanditsEnv():
    """
    num_expt -> number of experiments 
    num_slots -> available slots that can transmit data based on availablility
    p_values -> num_expts x num_slots matrix containing p-values for availability of slot
    action -> num_expts x num_slots array denoting the order of checking slots for availability for each expt
    """
    
    def __init__(self, p_values):
        assert len(p_values.shape) == 2
        
        self.num_slots = p_values.shape[1]
        self.num_expts = p_values.shape[0]
        
        self.p_values = p_values
        
    def step(self, action):
        
        # Sample from the specified slot using it's bernoulli distribution
        assert (action.shape == (self.num_expts,self.num_slots))
        
        sampled_state = np.random.binomial(n=1, p=self.p_values)

        # Convert action to a zero-based index
        action = [np.array(a) - 1 for a in action]

        # Create a mask with the size of the largest action list
        max_action_length = max(len(a) for a in action)
        mask = np.full((self.num_expts, max_action_length), -1)

        # Fill the mask with the indices from action
        for i, a in enumerate(action):
            mask[i, :len(a)] = a

        # Use advanced indexing to select the relevant elements from sampled_state
        relevant_states = np.take_along_axis(sampled_state, mask, axis=1)

        # Find the index of the first '1' in each row, else the row length
        cost = np.argmax((relevant_states == 1) | (mask == -1), axis=1).reshape(-1, 1)
        
        # Return a constant state of 0. Our environment has no terminal state
        observation, done, info = 0, False, dict()
        return observation, cost, done, info
    
    def reset(self):
        return 0
        
    def render(self, mode='human', close=False):
        pass
    
    def _seed(self, seed=None):
        self.np_random, seed = seeding.np.random(seed)
        return [seed]
    
    def close(self):
        pass
    
    
class ArmedBanditsBernoulli(ArmedBanditsEnv):
    def __init__(self, num_expts=1, num_slots=5):
        self.p_values = np.random.uniform(0, 1, (num_expts, num_slots))
        
        ArmedBanditsEnv.__init__(self, self.p_values)

In [6]:
p_values = np.array([[1.0, 1.0, 1.0, 1.0]]) # The p_values for a four-slot channel. Single experiment

env = ArmedBanditsEnv(p_values) # Create the environment

for i in range(4):
    action = np.random.choice(range(1, 5), size=(1,4), replace=False)
    _, cost, _, _ = env.step(action)
    print("Order:", action, " gave a cost of:",cost[0])

Order: [[2 4 1 3]]  gave a cost of: [0]
Order: [[1 2 4 3]]  gave a cost of: [0]
Order: [[2 3 1 4]]  gave a cost of: [0]
Order: [[2 3 1 4]]  gave a cost of: [0]
