We begin by importing all the necssary packages and modules. Next our task is to build the Machine_Replacement environment. For doing that, we create a class named as Machine_Replacement which accepts the number_of_states(nS), number_of_actions(nA) and replacement_cost(rep_cost) as input and generate the environment. Later we just need to cal the function gen_probability() and gen_expected_reward_function() to get the Probability distribution matrix and Reward matrix

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import numpy as np
import pandas as pd
import pickle
from itertools import product
import multiprocessing as mp
mp.set_start_method('spawn',True)
torch.multiprocessing.set_start_method('spawn',True)
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Machine_Replacement:
    def __init__(self,rep_cost=0.7,nS=6,nA=2):
        self.nS = nS;
        self.nA = nA;
        self.cost = np.linspace(0.1, 0.99,nS);
        self.rep_cost = rep_cost;
    def gen_probability(self):
        self.P = np.zeros((self.nA,self.nS,self.nS));
        for i in range(self.nS):
            for j in range(self.nS):
                if(i<=j):
                    self.P[0,i,j]=(i+1)*(j+1);
                else:
                    continue;
            self.P[0,i,:]=self.P[0,i,:]/np.sum(self.P[0,i,:])
            self.P[1,i,0]=1;
        return self.P;
    def gen_reward(self):
        self.R=np.zeros((self.nA,self.nS,self.nS));
        for i in range(self.nS):
            self.R[0,i,:] = self.cost[i];
            self.R[1,i,0] = self.rep_cost+self.cost[0];
        return self.R;
    def gen_expected_reward(self):
        self.R = np.zeros((self.nA,self.nS));
        for i in range(self.nS):
            self.R[0,i] = self.cost[i];
            self.R[1,i] = self.rep_cost + self.cost[0];
        return self.R;

Next we create another class specifying the hyperparameters that might be required in our algorithm. Later when required we can just call ret_hyperparameters() to get the hyperparameters.

In [16]:
class get_hyperparameters:
    def __init__(self):
        self.T = 50000;
        self.runs = 1;
        self.lr = 0.1;
        self.batch_size = 50;
        self.start = 0;
        self.nS = 20;
        self.nA = 2;
        self.rep_cost = 0.7
        self.alpha = 0.5
        self.gamma = 0.95
    
    def ret_hyperparameters(self):
        return (self.T,self.runs,self.lr,self.batch_size,self.start,self.nS,self.nA,self.rep_cost,self.alpha,self.gamma)

Let us now define the pytorch model. So for doing that, we create a class named weights. There are 3 parameters, input_size which defines the number of perceptrons in the input layer. input_size = number of states(nS). The output size is the number of perceptrons in the output_layer. output_size = 1(which gives us the state distribution ratio for a particular state).

Later when we want to update or find the state distribution ratio of any state, just pass that state to the forward(). First that particular state is converted into one_hot vector and then fed to the network. Finally the network returns the output value as the ratio of state distribution.

In [17]:
class weights(nn.Module):
    def __init__(self,input_size,output_size,hidden_size = 0):
        super(weights,self).__init__()
        self.input_size = input_size;
        self.hidden_size = hidden_size;
        self.output_size = output_size;
        if(hidden_size!=0):
            self.linear1 = nn.Linear(self.input_size, self.hidden_size, bias=False)
            self.linear2 = nn.Linear(self.hidden_size, self.output_size, bias=False)
        else:
            self.linear1 = nn.Linear(self.input_size, self.output_size, bias=False)
    '''
        forward(): We accept a state 's' as input. Then we convert this into one hot encoding which is accomplished by first two lines.
        Further we convert this one_hot vector 's' into pytorch tensor and then pass it through the network to obtain a output which is returned 
    '''
    def forward(self,state):
        s = np.zeros(self.input_size);
        #print(state,end='===>');
        s[state] = 1;
        state = torch.FloatTensor(s).to(device)
        #print(state);
        if(self.hidden_size == 0):
            output = torch.exp(self.linear1(state)) #To ensure that the outputs are always positive. giving Relu will cause problems.
        else:
            output = torch.exp(self.linear2(torch.exp(self.linear1(state))));
        return output

Now it is evident that the network written above will give me the ratio of state distribution. But the requirement is state distribution of the target_policy. So for doing that we need to to follow the below equation

state_distribution_of_target_policy = Normalize(state_distribution_ratio_obtained_from_network * behaviour_policy_state_distribution).

Now to get the target_policy_state_distribution, we need to obtain the behaviour_policy_state_distribution. To find that we use the class below. Now using the 

In [18]:
class beh_pol_sd:
    def __init__(self,P,policy,nS,nA):
        self.P = P
        self.policy = policy
        self.nS = nS;
        self.nA = nA;
    
    def onehot(self):
        pol = np.zeros((self.nS,self.nA));
        for i in range(self.nS):
            pol[i][int(self.policy[i])]=1;
        return pol;
    def find_transition_matrix(self,onehot_encode=1):
        if(onehot_encode==1):
            self.policy = self.onehot()
        T_s_s_next = np.zeros((self.nS,self.nS));
        for s in range(self.nS):
            for s_next in range(self.nS):
                for a in range(self.nA):
                    #print(s,s_next,a);
                    #print(T[a,s,s_next]);
                    T_s_s_next[s,s_next]+=self.P[a,s,s_next]*self.policy[s,a];
        return T_s_s_next;
    def state_distribution_simulated(self,onehot_encode=1):
        P_policy = self.find_transition_matrix(onehot_encode)
        #print(P_policy);
        P_dash = np.append(P_policy - np.eye(self.nS),np.ones((self.nS,1)),axis=1);
        #print(P_dash);
        P_last = np.linalg.pinv(np.transpose(P_dash))[:,-1]
        return P_last;

In [19]:
def one_hot(target_policy,nS,nA):
    one_hot_tp = [];
    for i in range(len(target_policy)):
        policy = target_policy[i];
        print(policy);
        tp=np.zeros((nS,nA));
        for j in range(nS):
            tp[j][policy[j]] = 1;
        one_hot_tp.append(tp);
    return np.array(one_hot_tp);

Now that we are all ready got our ingedients, let us define a separate class where we will define the average_case technique to find the state_distribution_ratio. Any object of the average_case class must possess the environment details such as number of states(nS) and number of actions(nA). The behaviour_policy, learning_rate,list_of_beta_values, a weight_object to refer to the weights class. The batch_size and the optimizer to be used(Adam_optimzer). 

Since in this variant, we will be having only 1 Neural Network and update the network for the sample that is obtained from the theta values. For this reason, we need to store all the beta values. Hence, we define our first function
1) find_beta(): This function finds all the beta(importance sampling) values. So beta values = target_policy[s,a]/behaviour_policy[s,a]. Now there are nPOL policies so for each policies we need to find the beta values for the ith target_policy, beta value will be beta[i,s,a]=target_policy[i,s,a]/behaviour_policy[s,a]. Now it is certain that a(action) can be {0,1} only. So instead of creating a new loop, we manually define two lines once for a=0 when beta value becomes beta[i,s,0]=target_policy[i,s,0]/behaviour_policy[s,0] and beta[i,s,1] = target_policy[i,s,1]/behaviour_policy[s,1]. Finally return all the beta values.

2) set_batch(): This function is used to set a data batch which is sampled from the behaviour_policy. Now the batch is set to be used for updating the state_distribution_ratio.

3) get_batch(): This function is used to get a random batch of 50 samples from the set data to be used to update our state_distribution_ratio. We actually create a batch of 50 data samples 10000 times in order to reach to a good value of state_distribution_ratio(Like in Linear regression we use a batch for several times until our gradient converges)

4) get_w(): This function is used to find the numerator and denominator of the loss function as mentioned in the paper 'Breaking the curse of horizon'. Now for finding the numerator paarameter pair = 0. To find the denominator pair = 1. Now it is observed that the denominator value easily goes to 0. So, to avoid divide by zero error, we add a small noise value of 0.000000001. This makes sure that the denominator value never goes to zero.

5) get_state_distribution_ratio(): This function uses the set data in the self.set_data() to get batches of size 50. Then calculate the loss using the equation mentioned in the paper 'Breaking the curse of horizon'. We use this calculated loss to update our weights of the Neural network by using Adam optimizer.

In [20]:
class average_case_version_2:
  def __init__(self,nS,nA,behaviour_policy,state,lr,target_policy,batch_size,data_used):
    self.nS = nS;
    self.nPOL = nS
    self.nA = nA;
    self.behaviour_policy = behaviour_policy;
    self.lr = lr;
    self.beta = self.find_beta(target_policy,behaviour_policy)
    self.W_loss = 0
    self.weight_obj = weights(nS,1).to(device);
    self.batch_size = batch_size
    self.optimizerW = optim.Adam(self.weight_obj.parameters(),lr = self.lr);
    self.loss = [];
    self.Z = [];
    self.data_used = data_used
  def find_beta(self,target_policy,behaviour_policy):
    beta = np.zeros((self.nPOL,self.nS,self.nA))
    for i in range(self.nPOL):
      for s in range(self.nS):
        beta[i][s][0] = target_policy[i,s,0]/behaviour_policy[s,0];
        beta[i,s,1] = target_policy[i,s,1]/behaviour_policy[s,1];
    return beta;
  def set_batch(self,data):
        self.data = data;
        self.T = len(data);
  def get_batch(self):
      if(self.T<=50):
          return self.data
      else:
          i = 1;
          j = np.random.choice(self.T);
          batch = [];
          while(i<=self.batch_size):
              if(np.random.random()<=0.5):
                  batch.append([self.data[j][0],self.data[j][1],self.data[j][2]])
                  j = (j+1)%self.T;
                  i+=1;
          return batch; 
  def get_w(self,data,weight_obj,m,pair=0):
        if(pair == 1):
            Z_w_state = 0;
            for i in range(len(data)):
                val = weight_obj(data[i][0]);
                #print(val);
                Z_w_state+=val;
            #print(Z_w_state.detach().numpy()[0]/self.batch_size);
            Z_w_state = Z_w_state.cpu().detach().numpy()[0]/self.batch_size;
            if(Z_w_state<0.00000000000005):
                Z_w_state+=0.000000000001;
            return Z_w_state;
        else:
            state1,state2,w_state1,w_state2,w_next_state1,w_next_state2,beta1,beta2 = list(),list(),list(),list(),list(),list(),list(),list();
            K = list();
            for i in range(len(data)):
                sample1 = data[i][0];
                sample2 = data[i][1];
                state1.append(sample1[0]);
                #print(sample1);
                w_state1.append(weight_obj(sample1[0]));
                w_next_state1.append(weight_obj(sample1[2]));
                state2.append(sample2[0]);
                w_state2.append(weight_obj(sample2[0]));
                w_next_state2.append(weight_obj(sample2[2]));
                #beta1.append(self.target_policy[sample1[0],sample1[1]]/self.behaviour_policy[sample1[0],sample1[1]]);
                beta1.append(self.beta[self.selected_policy,sample1[0],sample1[1]]);
                #beta2.append(self.target_policy[sample2[0],sample2[1]]/self.behaviour_policy[sample2[0],sample2[1]]);
                beta2.append(self.beta[self.selected_policy,sample2[0],sample2[1]])
                K.append(sample1[2]==sample2[2]);
            return (state1,state2,w_state1,w_state2,w_next_state1,w_next_state2,beta1,beta2,K);
    
  def get_state_distribution_ratio(self,selected_policy,run,t):
        batch = self.get_batch();
        eps = 0.04;
        self.data_used[run] =self.data_used[run]+batch;
        self.selected_policy = selected_policy
        pairs = list(product(batch,repeat=2));
        self.loss_episode = [];
        for _ in range(50):
            batch = self.get_batch();
            state1,state2,w_state1,w_state2,w_next_state1,w_next_state2,beta1,beta2,K = self.get_w(pairs, self.weight_obj, len(batch));
            Z_w_state = self.get_w(batch, self.weight_obj, len(batch),1);
            self.w_loss = 0
            for i in range(len(state1)):
                self.w_loss+=(beta1[i]*(w_state1[i]/Z_w_state) - (w_next_state1[i]/Z_w_state))*(beta2[i]*(w_state2[i]/Z_w_state)-(w_next_state2[i]/Z_w_state))*K[i];
            self.w_loss/=(2*self.batch_size);
            self.optimizerW.zero_grad();
            self.w_loss.backward();
            self.optimizerW.step();
            self.optimizerW.zero_grad();
            self.Z.append(Z_w_state)
        self.loss.append(self.w_loss.cpu().detach().numpy()[0]);
        state_dist=[];
        for i in range(self.nS):
            w_state = self.weight_obj(i);
            w_state = w_state.cpu().detach().numpy()[0];
            state_dist.append(w_state);
        return np.array(state_dist);

Instead of sampling state, action, next_state values on the go, we do it before hand and store it in a list named 'data'. So, that when required we can simply pass this data and save some time by before hand sampling the data.

In [21]:
def simulate_episode(T,state,behaviour_policy,P,batch_size):
  #global P,behaviour_policy,batch_size;
  data={};temp=[];
  for t in range(1,T+1):
    action = np.argmax(np.random.multinomial(1,behaviour_policy[state,:]))
    next_state = np.argmax(np.random.multinomial(1,P[action,state,:]));
    state = next_state;
    temp.append([state,action,next_state]);
    if(t%batch_size==0):
      data[int(t/batch_size)-1]=temp[:];
  return data;

For our help we create a softmax function.

In [22]:
def softmax(theta):
  theta = np.exp(theta);
  sum = np.sum(theta)
  return theta/sum;

We create a function naemd as preprocessing() which is responsible to perform the theta update. After every val=50 times, our learning rate is divided by 10. So if we start with lr = 1. After 50 instances, lr = 0.1. After 100 steps lr = 0.01 and so on...
Now at each instant, we found the softmax of the theta values. Now from the given probability values we sample a policy that will be updated next. 

In [23]:
def processing(nPOL,run,policy_sampled,policy_selected,loss,estimated_value,T_update,P,R,val,val2,behaviour_policy_state_distribution,beta,batch_size,alpha):
    lr = 1
    nS = nPOL;
    theta = np.ones(nPOL);
    c=0;
    rew = np.zeros(nPOL);
    S = np.ones(nPOL);
    F = np.ones(nPOL);
    n = np.ones(nPOL);
    for t in tqdm(range(1,T_update+1)):
      sampled_policy = np.argmin([np.random.beta(S[j],F[j]) for j in range(nPOL)])
      #policy_selected[t,run]=selected_policy;
      if sampled_policy ==1:
          policy_sampled[t-1,run] = 10000;
      else:
          policy_sampled[t-1,run] = sampled_policy;
      #loss[t,run] = w_obj.w_loss.cpu().detach().numpy()[0];
      w_obj.set_batch(data[t-1]);
      sd = w_obj.get_state_distribution_ratio(sampled_policy,run,t-1);
      sd = sd * behaviour_policy_state_distribution;
      sd = sd/np.sum(sd)
      rho_i = sum([sd[s]*R[target_policy[sampled_policy,s],s] for s in range(nS)]);
      rew[sampled_policy] = (1-alpha)*rew[sampled_policy] + alpha*rho_i
      S[sampled_policy]+=rew[sampled_policy];
      F[sampled_policy]=F[sampled_policy]+t-rew[sampled_policy]
      estimated_value[t-1,run] = rho_i;
    #print(data)

Finally call the hyperparameters, create the environment and call the pre-processing function.

In [24]:
if __name__ =='__main__':
    T,runs,lr,batch_size,state,nS,nA,rep_cost,alpha,gamma = get_hyperparameters().ret_hyperparameters();
    print(nS,nA);
    nPOL = nS;
    T_update = int(T/batch_size);
    mr_obj = Machine_Replacement(rep_cost,nS,nA);
    P,R = mr_obj.gen_probability(),mr_obj.gen_expected_reward()
    theta = np.ones(nPOL);
    behaviour_policy = np.ones((nS,nA))*0.5;
    print(behaviour_policy,P);
    behaviour_policy_state_distribution = beh_pol_sd(P,behaviour_policy,nS,nA).state_distribution_simulated(0);
    data = simulate_episode(T,state,behaviour_policy,P,batch_size);
    with open('data_used','wb') as f:
      pickle.dump(data,f);
    target_policy = np.ones((nPOL,nS),dtype = np.int8)
    data_dict={0:[],1:[],2:[],3:[],4:[]};
    val2 = []
    for i in range(nPOL-1,0,-1):
        target_policy[nPOL-i-1][0:i] = 0;
    for t in target_policy:
        val2.append(sum([R[t[s],s] for s in range(nS)]))
    val2 = np.array(val2);
    one_hot_target_policy = one_hot(target_policy,nS,nA)
    w_obj = average_case_version_2(nS,nA,behaviour_policy,state,lr,one_hot_target_policy,batch_size,data_dict);
    '''for i in range(nPOL):
      w_obj[i].set_target_policy(one_hot_target_policy[i]);'''
    policy_selected = np.zeros((T_update,runs))
    policy_sampled = np.zeros((T_update,runs))
    theta_change = []
    estimated_value = np.zeros((T_update,runs))
    loss = np.zeros((T_update,runs))
    val = 50;
    lr = 1;
    beta = 1;
    #######################################################################################
    '''mp.set_start_method('spawn')
    p1 = mp.Process(target = processing ,args=(nPOL,0,policy_sampled,policy_selected,loss,estimated_value,T_update,P,R,val))
    p2 = mp.Process(target = processing ,args=(nPOL,1,policy_sampled,policy_selected,loss,estimated_value,T_update,P,R,val))
    p3 = mp.Process(target = processing ,args=(nPOL,2,policy_sampled,policy_selected,loss,estimated_value,T_update,P,R,val))
    p4 = mp.Process(target = processing ,args=(nPOL,3,policy_sampled,policy_selected,loss,estimated_value,T_update,P,R,val))
    p5 = mp.Process(target = processing ,args=(nPOL,4,policy_sampled,policy_selected,loss,estimated_value,T_update,P,R,val))
    #p1 = mp.Process(target = processing ,args=(0,policy_sampled,policy_selected,loss,estimated_value,T_update,val))
    #p1 = mp.Process(target = processing ,args=(0,policy_sampled,policy_selected,loss,estimated_value,T_update,val))
    #p1 = mp.Process(target = processing ,args=(0,policy_sampled,policy_selected,loss,estimated_value,T_update,val))
    #p1 = mp.Process(target = processing ,args=(0,policy_sampled,policy_selected,loss,estimated_value,T_update,val))
    #p1 = mp.Process(target = processing ,args=(0,policy_sampled,policy_selected,loss,estimated_value,T_update,val))
    p1.start();
    p2.start();
    p3.start();
    p4.start();
    p5.start();
    p1.join();
    p2.join();
    p3.join();
    p4.join();
    p5.join();'''
    #######################################################################################
    for run in tqdm(range(runs)):
        processing(nPOL,run,policy_sampled,policy_selected,loss,estimated_value,T_update,P,R,val,val2,behaviour_policy_state_distribution,beta,batch_size,alpha)
        print("One run completed");
    #######################################################################################
    #pd.DataFrame(policy_selected).to_excel("Policy_selection.xlsx");
    #pd.DataFrame(np.array(theta_change)).to_excel("Theta_values_non_estimate.xlsx");
    pd.DataFrame(policy_sampled).to_excel("Policy_sampling_Thompson_Sampling_variant_2_20_states.xlsx");
    pd.DataFrame(data[T_update-1]).to_excel("Data_Used_Thompson_Sampling_variant_2_20_states.xlsx");
    pd.DataFrame(estimated_value).to_excel("Estimated_Value_functions_Thompson_Sampling_variant_2_20_states.xlsx");

20 2
[[0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]] [[[0.0047619  0.00952381 0.01428571 0.01904762 0.02380952 0.02857143
   0.03333333 0.03809524 0.04285714 0.04761905 0.05238095 0.05714286
   0.06190476 0.06666667 0.07142857 0.07619048 0.08095238 0.08571429
   0.09047619 0.0952381 ]
  [0.         0.00956938 0.01435407 0.01913876 0.02392344 0.02870813
   0.03349282 0.03827751 0.0430622  0.04784689 0.05263158 0.05741627
   0.06220096 0.06698565 0.07177033 0.07655502 0.08133971 0.0861244
   0.09090909 0.09569378]
  [0.         0.         0.01449275 0.01932367 0.02415459 0.02898551
   0.03381643 0.03864734 0.04347826 0.04830918 0.0531401  0.05797101
   0.06280193 0.06763285 0.07246377 0.07729469 0.0821256  0.08695652
   0.09178744 0.09661836]
  [0.         0.         0.         0.01960784 0.0245098  0.02941176
   0.03

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1]
[0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1]
[0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1]
[0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1]
[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]
[0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1]
[0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1]
[0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1]
[0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]


  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                  | 0/1000 [00:00<?, ?it/s][A
  0%|                                      | 1/1000 [01:59<33:08:32, 119.43s/it][A
  0%|                                      | 2/1000 [04:00<33:21:20, 120.32s/it][A
  0%|                                      | 3/1000 [06:00<33:14:08, 120.01s/it][A
  0%|▏                                     | 4/1000 [08:00<33:14:12, 120.13s/it][A
  0%|▏                                     | 5/1000 [09:59<33:07:32, 119.85s/it][A
  1%|▏                                     | 6/1000 [12:00<33:09:47, 120.11s/it][A
  1%|▎                                     | 7/1000 [14:01<33:12:25, 120.39s/it][A
  1%|▎                                     | 8/1000 [16:00<33:06:19, 120.14s/it][A
  1%|▎                                     | 9/1000 [18:02<33:09:53, 120.48s/it][A
  1%|▎                                    | 10/1000 [20:02<33:06:16, 120.38s/it

 10%|███▎                               | 96/1000 [3:13:52<30:18:11, 120.68s/it][A
 10%|███▍                               | 97/1000 [3:15:53<30:15:57, 120.66s/it][A
 10%|███▍                               | 98/1000 [3:17:56<30:26:18, 121.48s/it][A
 10%|███▍                               | 99/1000 [3:20:00<30:33:14, 122.08s/it][A
 10%|███▍                              | 100/1000 [3:22:04<30:39:03, 122.60s/it][A
 10%|███▍                              | 101/1000 [3:24:08<30:43:47, 123.06s/it][A
 10%|███▍                              | 102/1000 [3:26:09<30:35:17, 122.63s/it][A
 10%|███▌                              | 103/1000 [3:28:10<30:23:44, 121.99s/it][A
 10%|███▌                              | 104/1000 [3:30:12<30:21:00, 121.94s/it][A
 10%|███▌                              | 105/1000 [3:32:14<30:21:59, 122.14s/it][A
 11%|███▌                              | 106/1000 [3:34:17<30:21:59, 122.28s/it][A
 11%|███▋                              | 107/1000 [3:36:18<30:12:54, 121.81s

 19%|██████▌                           | 193/1000 [6:29:02<26:59:28, 120.41s/it][A
 19%|██████▌                           | 194/1000 [6:31:01<26:54:47, 120.21s/it][A
 20%|██████▋                           | 195/1000 [6:33:02<26:54:04, 120.30s/it][A
 20%|██████▋                           | 196/1000 [6:35:05<27:04:23, 121.22s/it][A
 20%|██████▋                           | 197/1000 [6:37:07<27:02:04, 121.20s/it][A
 20%|██████▋                           | 198/1000 [6:39:08<27:00:37, 121.24s/it][A
 20%|██████▊                           | 199/1000 [6:41:08<26:53:16, 120.84s/it][A
 20%|██████▊                           | 200/1000 [6:43:09<26:53:29, 121.01s/it][A
 20%|██████▊                           | 201/1000 [6:45:09<26:46:21, 120.63s/it][A
 20%|██████▊                           | 202/1000 [6:47:09<26:44:10, 120.61s/it][A
 20%|██████▉                           | 203/1000 [6:49:10<26:42:18, 120.63s/it][A
 20%|██████▉                           | 204/1000 [6:51:11<26:39:26, 120.56s

 29%|██████████▏                        | 290/1000 [9:32:02<15:24:56, 78.16s/it][A
 29%|██████████▏                        | 291/1000 [9:33:20<15:22:41, 78.08s/it][A
 29%|██████████▏                        | 292/1000 [9:34:38<15:21:27, 78.09s/it][A
 29%|██████████▎                        | 293/1000 [9:35:56<15:19:32, 78.04s/it][A
 29%|██████████▎                        | 294/1000 [9:37:13<15:17:40, 77.99s/it][A
 30%|██████████▎                        | 295/1000 [9:38:31<15:16:20, 77.99s/it][A
 30%|██████████▎                        | 296/1000 [9:39:50<15:15:31, 78.03s/it][A
 30%|██████████▍                        | 297/1000 [9:41:08<15:15:50, 78.17s/it][A
 30%|██████████▍                        | 298/1000 [9:42:26<15:13:55, 78.11s/it][A
 30%|██████████▍                        | 299/1000 [9:43:44<15:11:56, 78.06s/it][A
 30%|██████████▌                        | 300/1000 [9:45:02<15:09:36, 77.97s/it][A
 30%|██████████▌                        | 301/1000 [9:46:20<15:11:03, 78.20s

 39%|█████████████▏                    | 387/1000 [11:38:10<13:16:40, 77.98s/it][A
 39%|█████████████▏                    | 388/1000 [11:39:27<13:14:37, 77.90s/it][A
 39%|█████████████▏                    | 389/1000 [11:40:46<13:14:29, 78.02s/it][A
 39%|█████████████▎                    | 390/1000 [11:42:04<13:13:27, 78.04s/it][A
 39%|█████████████▎                    | 391/1000 [11:43:22<13:12:17, 78.06s/it][A
 39%|█████████████▎                    | 392/1000 [11:44:40<13:11:51, 78.14s/it][A
 39%|█████████████▎                    | 393/1000 [11:45:58<13:10:32, 78.14s/it][A
 39%|█████████████▍                    | 394/1000 [11:47:17<13:09:43, 78.19s/it][A
 40%|█████████████▍                    | 395/1000 [11:48:35<13:09:24, 78.29s/it][A
 40%|█████████████▍                    | 396/1000 [11:49:53<13:07:51, 78.26s/it][A
 40%|█████████████▍                    | 397/1000 [11:51:11<13:06:20, 78.24s/it][A
 40%|█████████████▌                    | 398/1000 [11:52:29<13:04:07, 78.15s

 48%|████████████████▍                 | 484/1000 [13:44:12<11:11:12, 78.05s/it][A
 48%|████████████████▍                 | 485/1000 [13:45:31<11:11:25, 78.22s/it][A
 49%|████████████████▌                 | 486/1000 [13:46:50<11:11:23, 78.37s/it][A
 49%|████████████████▌                 | 487/1000 [13:48:08<11:08:47, 78.22s/it][A
 49%|████████████████▌                 | 488/1000 [13:49:26<11:06:49, 78.14s/it][A
 49%|████████████████▋                 | 489/1000 [13:50:43<11:04:16, 78.00s/it][A
 49%|████████████████▋                 | 490/1000 [13:52:01<11:02:06, 77.90s/it][A
 49%|████████████████▋                 | 491/1000 [13:53:19<11:01:05, 77.93s/it][A
 49%|████████████████▋                 | 492/1000 [13:54:36<10:59:01, 77.84s/it][A
 49%|████████████████▊                 | 493/1000 [13:55:54<10:56:50, 77.73s/it][A
 49%|████████████████▊                 | 494/1000 [13:57:12<10:56:39, 77.86s/it][A
 50%|████████████████▊                 | 495/1000 [13:58:31<10:56:41, 78.02s

 58%|████████████████████▎              | 581/1000 [15:50:10<9:06:49, 78.30s/it][A
 58%|████████████████████▎              | 582/1000 [15:51:28<9:05:09, 78.25s/it][A
 58%|████████████████████▍              | 583/1000 [15:52:46<9:03:17, 78.17s/it][A
 58%|████████████████████▍              | 584/1000 [15:54:04<9:02:11, 78.20s/it][A
 58%|████████████████████▍              | 585/1000 [15:55:22<9:00:55, 78.21s/it][A
 59%|████████████████████▌              | 586/1000 [15:56:41<8:59:28, 78.18s/it][A
 59%|████████████████████▌              | 587/1000 [15:57:59<8:58:04, 78.17s/it][A
 59%|████████████████████▌              | 588/1000 [15:59:17<8:56:29, 78.13s/it][A
 59%|████████████████████▌              | 589/1000 [16:00:35<8:54:20, 78.01s/it][A
 59%|████████████████████▋              | 590/1000 [16:01:52<8:52:47, 77.97s/it][A
 59%|████████████████████▋              | 591/1000 [16:03:11<8:52:19, 78.09s/it][A
 59%|████████████████████▋              | 592/1000 [16:04:29<8:51:26, 78.15s

 68%|███████████████████████▋           | 678/1000 [17:56:16<6:57:52, 77.87s/it][A
 68%|███████████████████████▊           | 679/1000 [17:57:35<6:57:00, 77.95s/it][A
 68%|███████████████████████▊           | 680/1000 [17:58:52<6:55:13, 77.86s/it][A
 68%|███████████████████████▊           | 681/1000 [18:00:10<6:54:09, 77.90s/it][A
 68%|███████████████████████▊           | 682/1000 [18:01:28<6:53:18, 77.98s/it][A
 68%|███████████████████████▉           | 683/1000 [18:02:46<6:51:24, 77.87s/it][A
 68%|███████████████████████▉           | 684/1000 [18:04:04<6:50:04, 77.86s/it][A
 68%|███████████████████████▉           | 685/1000 [18:05:21<6:48:21, 77.78s/it][A
 69%|████████████████████████           | 686/1000 [18:06:39<6:46:51, 77.74s/it][A
 69%|████████████████████████           | 687/1000 [18:07:57<6:46:15, 77.88s/it][A
 69%|████████████████████████           | 688/1000 [18:09:15<6:44:28, 77.78s/it][A
 69%|████████████████████████           | 689/1000 [18:10:33<6:44:01, 77.95s

 78%|███████████████████████████▏       | 775/1000 [20:02:12<4:51:57, 77.85s/it][A
 78%|███████████████████████████▏       | 776/1000 [20:03:30<4:50:47, 77.89s/it][A
 78%|███████████████████████████▏       | 777/1000 [20:04:49<4:50:33, 78.18s/it][A
 78%|███████████████████████████▏       | 778/1000 [20:06:07<4:49:09, 78.15s/it][A
 78%|███████████████████████████▎       | 779/1000 [20:07:25<4:47:40, 78.10s/it][A
 78%|███████████████████████████▎       | 780/1000 [20:08:43<4:46:22, 78.10s/it][A
 78%|███████████████████████████▎       | 781/1000 [20:10:01<4:44:49, 78.04s/it][A
 78%|███████████████████████████▎       | 782/1000 [20:11:20<4:43:51, 78.13s/it][A
 78%|███████████████████████████▍       | 783/1000 [20:12:38<4:42:18, 78.06s/it][A
 78%|███████████████████████████▍       | 784/1000 [20:13:56<4:41:20, 78.15s/it][A
 78%|███████████████████████████▍       | 785/1000 [20:15:14<4:40:08, 78.18s/it][A
 79%|███████████████████████████▌       | 786/1000 [20:16:33<4:39:15, 78.30s

 87%|██████████████████████████████▌    | 872/1000 [22:08:31<2:46:52, 78.22s/it][A
 87%|██████████████████████████████▌    | 873/1000 [22:09:49<2:45:21, 78.12s/it][A
 87%|██████████████████████████████▌    | 874/1000 [22:11:07<2:44:06, 78.15s/it][A
 88%|██████████████████████████████▋    | 875/1000 [22:12:26<2:42:53, 78.18s/it][A
 88%|██████████████████████████████▋    | 876/1000 [22:13:44<2:41:31, 78.16s/it][A
 88%|██████████████████████████████▋    | 877/1000 [22:15:02<2:40:16, 78.19s/it][A
 88%|██████████████████████████████▋    | 878/1000 [22:16:20<2:38:54, 78.15s/it][A
 88%|██████████████████████████████▊    | 879/1000 [22:17:38<2:37:30, 78.10s/it][A
 88%|██████████████████████████████▊    | 880/1000 [22:18:56<2:36:12, 78.10s/it][A
 88%|██████████████████████████████▊    | 881/1000 [22:20:14<2:34:53, 78.10s/it][A
 88%|██████████████████████████████▊    | 882/1000 [22:21:32<2:33:38, 78.12s/it][A
 88%|██████████████████████████████▉    | 883/1000 [22:22:51<2:32:29, 78.20s

 97%|███████████████████████████████████▊ | 969/1000 [24:14:46<40:13, 77.87s/it][A
 97%|███████████████████████████████████▉ | 970/1000 [24:16:04<38:55, 77.84s/it][A
 97%|███████████████████████████████████▉ | 971/1000 [24:17:22<37:40, 77.96s/it][A
 97%|███████████████████████████████████▉ | 972/1000 [24:18:42<36:37, 78.47s/it][A
 97%|████████████████████████████████████ | 973/1000 [24:20:00<35:16, 78.39s/it][A
 97%|████████████████████████████████████ | 974/1000 [24:21:18<33:57, 78.36s/it][A
 98%|████████████████████████████████████ | 975/1000 [24:22:36<32:35, 78.24s/it][A
 98%|████████████████████████████████████ | 976/1000 [24:23:54<31:17, 78.24s/it][A
 98%|████████████████████████████████████▏| 977/1000 [24:25:13<29:59, 78.26s/it][A
 98%|████████████████████████████████████▏| 978/1000 [24:26:30<28:38, 78.09s/it][A
 98%|████████████████████████████████████▏| 979/1000 [24:27:48<27:19, 78.05s/it][A
 98%|████████████████████████████████████▎| 980/1000 [24:29:07<26:02, 78.15s

One run completed


In [25]:
state_dist=[]
for s in range(nS):
    c=w_obj.weight_obj(s).detach().cpu().numpy()[0]
    state_dist.append(c)
c=np.array(state_dist)*behaviour_policy_state_distribution

In [26]:
c/np.sum(c)

array([1.18992480e-03, 3.28779329e-01, 1.57603352e-02, 2.67999149e-02,
       6.66894095e-04, 5.42729050e-04, 5.08046257e-01, 8.37924615e-04,
       7.66831928e-02, 1.58677661e-02, 6.75073729e-04, 1.92126133e-02,
       4.45931477e-04, 5.44204063e-04, 6.76297072e-04, 7.03075342e-04,
       4.47978685e-04, 7.33454595e-04, 6.33998275e-04, 7.53105872e-04])