I am setting the action by modulator agent to be 1 always, effectively cutting it out of the equation to see the efffect of just the user agent in action.

In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from Environment import *
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
'''
Direction mapping:
0: left = [-1, 0]
1: right = [1, 0]
2: up = [0, -1]
3: down = [0, 1]
'''

class User_Agent:
    def __init__(self):
        #model
        #-----------------------------------------------------
        input_A = Input(shape = (4,))    #curr_x, curr_y, target_x, target_y
        x = Dense(32, activation = 'relu')(input_A)
        x = Dense(4, activation = 'softmax')(x) #left,right, down, up
        
        self.model = Model(inputs = input_A, outputs = x)
        print(self.model.summary())
        #---------------------------------------------------
        
        self.target_model = tf.keras.models.clone_model(self.model)
        self.target_model.set_weights(self.model.get_weights())
        
        self.loss_fn = tf.keras.losses.mean_squared_error
        self.optimizer = tf.keras.optimizers.Adam(lr = 0.005)
        self.batch_size = 128
        self.replay_buffer_size = 1024
        self.replay_buffer = Replay_Buffer(self.replay_buffer_size)
        self.epsilon = 1
        self.gamma = 0.9
        
    def exp_policy(self, state):
        if np.random.rand()<self.epsilon:
            return np.random.randint(4)
        else:
            state = np.array(state)[np.newaxis]
            Q_values = self.model(state)
            return np.argmax(Q_values[0])
        
    def sample_experience(self):
        indices = np.random.randint(len(self.replay_buffer.state_history), size = self.batch_size)
        
        states = np.array([self.replay_buffer.state_history[i] for i in indices])
        actions = np.array([self.replay_buffer.action_history[i] for i in indices])
        next_states = np.array([self.replay_buffer.next_state_history[i] for i in indices])
        rewards = np.array([self.replay_buffer.rewards_history[i] for i in indices])
        dones = np.array([self.replay_buffer.done_history[i] for i in indices])
        
        return states, actions, next_states, rewards, dones
    

    def play_one_step(self, env, state, mod_agent):
        action_user = self.exp_policy(state)
        action_user_one_hot = make_one_hot(action_user, 4)
        curr_loc = state[:2]
        target_loc = state[2:]
        action_user_one_hot.extend(target_loc)
        mod_state = action_user_one_hot[:]
        mod_state = np.array(mod_state)
        new_loc, reward, done = mod_agent.play_one_step(env, mod_state, curr_loc, target_loc, self)
        next_state = [new_loc[0], new_loc[1], target_loc[0], target_loc[1]]
        self.replay_buffer.append(state, action_user, reward, next_state, done)
        
        return next_state, reward, done
    
    def train(self):
        states, actions, next_states, rewards, dones = self.sample_experience()
        next_Q_values = self.target_model(next_states)
        max_next_Q_values = np.max(next_Q_values, axis= 1)
        target_Q_values = rewards + (1-dones)*self.gamma*max_next_Q_values
        
        mask = tf.one_hot(actions, 4)
        
        with tf.GradientTape() as tape:
            all_Q_values = self.model(states)
            Q_values = tf.reduce_sum(all_Q_values*mask, axis = 1, keepdims = True)
            loss = tf.reduce_mean(self.loss_fn(target_Q_values, Q_values))

        grads = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))

In [3]:
class Mod_Agent:
    def __init__(self):
        #model
        #-----------------------------------------------------
        input_A = Input(shape = (6,))   #direction of motion_one_hot(4), curr_x, curr_y
        x = Dense(32, activation = 'relu')(input_A)
        x = Dense(4, activation = 'softmax')(x) #modulate by 1,2,3,4 
        
        self.model = Model(inputs = input_A, outputs = x)
        print(self.model.summary())
        #---------------------------------------------------
        
        self.target_model = tf.keras.models.clone_model(self.model)
        self.target_model.set_weights(self.model.get_weights())
        
        self.loss_fn = tf.keras.losses.mean_squared_error
        self.optimizer = tf.keras.optimizers.Adam(lr = 0.005)
        self.batch_size = 128
        self.replay_buffer_size = 1024
        self.replay_buffer = Replay_Buffer(self.replay_buffer_size)
        self.epsilon = 1
        self.steps_per_epoch = 1
        self.gamma = 0.9
    
    def exp_policy(self, state):
        if np.random.rand()<self.epsilon:
            return np.random.randint(4)
        else:
            state = np.array(state)[np.newaxis]
            Q_values = self.model(state)
            return np.argmax(Q_values[0])
        
        
    def sample_experience(self):
        indices = np.random.randint(len(self.replay_buffer.state_history), size = self.batch_size)
        
        states = np.array([self.replay_buffer.state_history[i] for i in indices])
        actions = np.array([self.replay_buffer.action_history[i] for i in indices])
        next_states = np.array([self.replay_buffer.next_state_history[i] for i in indices])
        rewards = np.array([self.replay_buffer.rewards_history[i] for i in indices])
        dones = np.array([self.replay_buffer.done_history[i] for i in indices])
        
        return states, actions, next_states, rewards, dones
    
    def play_one_step(self, env, state, curr_loc, target_loc, user_agent):
        #Agent not aware of target location
        action_mod = self.exp_policy(state)
        action_mod = 1
        action_user = state[0]
        
        new_loc, reward, done = env.step(action_user, action_mod, target_loc, curr_loc)
        next_dir = user_agent.exp_policy(np.array([[new_loc[0], new_loc[1], target_loc[0], target_loc[1]]]))
        
        next_dir_one_hot = make_one_hot(next_dir, 4)
        next_dir_one_hot.extend(target_loc)
        next_state = next_dir_one_hot[:]
        next_state = np.array(next_state)
        
        self.replay_buffer.append(state, action_mod, reward, next_state, done)
        
        
        return new_loc, reward, done
    
    def train(self):
        states, actions, next_states, rewards, dones = self.sample_experience()
        next_Q_values = self.target_model(next_states)
        max_next_Q_values = np.max(next_Q_values, axis= 1)
        target_Q_values = rewards + (1-dones)*self.gamma*max_next_Q_values
        
        mask = tf.one_hot(actions, 4)
        
        with tf.GradientTape() as tape:
            all_Q_values = self.model(states)
            Q_values = tf.reduce_sum(all_Q_values*mask, axis = 1, keepdims = True)
            loss = tf.reduce_mean(self.loss_fn(target_Q_values, Q_values))

        grads = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
            

In [4]:
env = Environment()
user_agent = User_Agent()
mod_agent = Mod_Agent()

Icon Locations:
[[0.  0.5]
 [0.6 0.7]
 [0.4 0.6]
 [0.1 0.4]
 [0.2 0. ]
 [0.6 0.1]]
Icon usage Probabilities
[0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667]
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 4)]               0         
_________________________________________________________________
dense (Dense)                (None, 32)                160       
_________________________________________________________________
dense_1 (Dense)              (None, 4)                 132       
Total params: 292
Trainable params: 292
Non-trainable params: 0
_________________________________________________________________
None
Model: "functional_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 6)]           

In [5]:
rewards = []
mean_rewards = []
max_steps = 200
reached = 0
for epoch in tqdm(range(1000)):
    done = False
    episode_reward = 0
    step = 0
    while not done and step<max_steps:
        start, dest = env.give_start_dest()
        state = [start[0], start[1], dest[0], dest[1]]
        state = np.array(state)
        next_state, reward, done = user_agent.play_one_step(env, state, mod_agent)
        state = next_state
        episode_reward+=reward
        step+=1
        if done:
            reached+=1
        if epoch>50:
            user_agent.train()
            user_agent.epsilon*=0.9
#             mod_agent.train()
    
    mean_rewards.append(episode_reward)
    if epoch%10==0:
        rewards.append(np.mean(mean_rewards))
        print(f'Mean Reward = {rewards[-1]}')
        print(reached)
        reached = 0
    

  0%|                                                                                         | 0/1000 [00:00<?, ?it/s]

Mean Reward = 1.5
1


  1%|▊                                                                               | 10/1000 [00:00<00:10, 97.35it/s]

Mean Reward = -47.43636363636363
10


  2%|█▊                                                                              | 23/1000 [00:00<00:13, 72.15it/s]

Mean Reward = -60.86666666666667
7


  3%|██▏                                                                             | 28/1000 [00:00<00:15, 63.27it/s]

Mean Reward = -72.26451612903227
10


  4%|███▌                                                                            | 45/1000 [00:00<00:14, 66.71it/s]

Mean Reward = -71.4780487804878
10
Mean Reward = -69.08235294117648
9


  6%|████▉                                                                           | 61/1000 [00:02<01:43,  9.10it/s]

Mean Reward = -60.68196721311476
10


  7%|█████▊                                                                          | 72/1000 [00:09<07:48,  1.98it/s]

Mean Reward = -63.28309859154931
8


  8%|██████▍                                                                         | 81/1000 [00:17<09:34,  1.60it/s]

Mean Reward = -67.35185185185186
7


  9%|███████▎                                                                        | 91/1000 [00:23<12:31,  1.21it/s]

Mean Reward = -68.19010989010988
8


 10%|███████▉                                                                       | 101/1000 [00:29<09:04,  1.65it/s]

Mean Reward = -67.78118811881187
9


 11%|████████▊                                                                      | 112/1000 [00:37<07:00,  2.11it/s]

Mean Reward = -70.16756756756756
5


 12%|█████████▌                                                                     | 121/1000 [00:43<11:38,  1.26it/s]

Mean Reward = -70.39421487603306
9


 13%|██████████▎                                                                    | 131/1000 [00:50<12:48,  1.13it/s]

Mean Reward = -70.21908396946566
8


 14%|███████████▏                                                                   | 141/1000 [01:01<16:53,  1.18s/it]

Mean Reward = -73.35319148936172
7


 15%|████████████                                                                   | 152/1000 [01:10<09:44,  1.45it/s]

Mean Reward = -73.97549668874171
9


 16%|████████████▋                                                                  | 161/1000 [01:18<13:25,  1.04it/s]

Mean Reward = -74.42422360248447
8


 17%|█████████████▌                                                                 | 171/1000 [01:25<12:03,  1.15it/s]

Mean Reward = -74.0391812865497
9


 18%|██████████████▎                                                                | 181/1000 [01:31<08:40,  1.57it/s]

Mean Reward = -73.03425414364641
10


 19%|███████████████                                                                | 191/1000 [01:38<12:55,  1.04it/s]

Mean Reward = -73.4958115183246
7


 20%|███████████████▉                                                               | 201/1000 [01:49<14:38,  1.10s/it]

Mean Reward = -75.85820895522389
7


 21%|████████████████▋                                                              | 211/1000 [01:54<08:58,  1.46it/s]

Mean Reward = -74.78056872037915
8


 22%|█████████████████▍                                                             | 221/1000 [02:07<16:35,  1.28s/it]

Mean Reward = -75.71719457013575
7


 23%|██████████████████▏                                                            | 231/1000 [02:22<19:25,  1.52s/it]

Mean Reward = -77.15800865800865
7


 24%|███████████████████                                                            | 241/1000 [02:32<13:13,  1.05s/it]

Mean Reward = -77.35726141078838
8


 25%|███████████████████▊                                                           | 250/1000 [02:42<10:51,  1.15it/s]

Mean Reward = -76.97450199203186
9


 26%|████████████████████▌                                                          | 261/1000 [02:48<06:19,  1.95it/s]

Mean Reward = -75.67164750957855
9


 27%|█████████████████████▍                                                         | 271/1000 [02:57<17:09,  1.41s/it]

Mean Reward = -75.64391143911438
8


 28%|██████████████████████▏                                                        | 281/1000 [03:04<07:32,  1.59it/s]

Mean Reward = -74.51921708185054
9


 29%|██████████████████████▉                                                        | 291/1000 [03:17<18:38,  1.58s/it]

Mean Reward = -75.57903780068727
6


 30%|███████████████████████▊                                                       | 301/1000 [03:30<17:08,  1.47s/it]

Mean Reward = -76.29667774086379
8


 31%|████████████████████████▌                                                      | 311/1000 [03:41<12:47,  1.11s/it]

Mean Reward = -77.36591639871382
6


 32%|█████████████████████████▎                                                     | 321/1000 [03:47<07:52,  1.44it/s]

Mean Reward = -76.98878504672898
7


 33%|██████████████████████████▏                                                    | 331/1000 [03:57<10:49,  1.03it/s]

Mean Reward = -77.9036253776435
6


 34%|██████████████████████████▉                                                    | 341/1000 [04:06<08:21,  1.31it/s]

Mean Reward = -78.58504398826979
7


 35%|███████████████████████████▋                                                   | 351/1000 [04:19<15:44,  1.45s/it]

Mean Reward = -79.35413105413105
8


 36%|████████████████████████████▌                                                  | 361/1000 [04:31<08:22,  1.27it/s]

Mean Reward = -79.79473684210527
9


 37%|█████████████████████████████▍                                                 | 372/1000 [04:40<05:44,  1.82it/s]

Mean Reward = -79.48167115902964
8


 38%|██████████████████████████████                                                 | 381/1000 [04:50<11:38,  1.13s/it]

Mean Reward = -79.59921259842518
8


 39%|██████████████████████████████▉                                                | 392/1000 [05:00<07:34,  1.34it/s]

Mean Reward = -80.10306905370844
7


 40%|███████████████████████████████▌                                               | 400/1000 [05:05<06:52,  1.45it/s]

Mean Reward = -79.4718204488778
8


 41%|████████████████████████████████▍                                              | 411/1000 [05:12<05:46,  1.70it/s]

Mean Reward = -79.35815085158151
9


 42%|█████████████████████████████████▎                                             | 421/1000 [05:18<05:06,  1.89it/s]

Mean Reward = -79.19263657957244
8


 43%|██████████████████████████████████                                             | 431/1000 [05:26<08:03,  1.18it/s]

Mean Reward = -79.28607888631089
7


 44%|██████████████████████████████████▉                                            | 442/1000 [05:32<04:57,  1.87it/s]

Mean Reward = -79.05238095238094
9


 45%|███████████████████████████████████▋                                           | 451/1000 [05:41<09:09,  1.00s/it]

Mean Reward = -79.73414634146341
5


 46%|████████████████████████████████████▍                                          | 461/1000 [05:50<06:49,  1.32it/s]

Mean Reward = -80.18481561822125
8


 47%|█████████████████████████████████████▏                                         | 471/1000 [05:58<06:16,  1.40it/s]

Mean Reward = -80.43036093418259
8


 48%|█████████████████████████████████████▉                                         | 481/1000 [06:05<06:27,  1.34it/s]

Mean Reward = -80.44989604989605
9


 49%|██████████████████████████████████████▊                                        | 491/1000 [06:12<05:47,  1.46it/s]

Mean Reward = -80.23849287169044
8


 50%|███████████████████████████████████████▌                                       | 501/1000 [06:18<03:10,  2.62it/s]

Mean Reward = -79.90798403193612
10


 51%|████████████████████████████████████████▎                                      | 511/1000 [06:26<07:36,  1.07it/s]

Mean Reward = -80.2119373776908
9


 52%|█████████████████████████████████████████▏                                     | 521/1000 [06:35<07:05,  1.13it/s]

Mean Reward = -80.52245681381957
8


 53%|█████████████████████████████████████████▉                                     | 531/1000 [06:39<04:47,  1.63it/s]

Mean Reward = -79.7367231638418
10


 54%|██████████████████████████████████████████▋                                    | 541/1000 [06:52<10:01,  1.31s/it]

Mean Reward = -80.25785582255084
7


 55%|███████████████████████████████████████████▌                                   | 551/1000 [07:01<09:15,  1.24s/it]

Mean Reward = -80.12014519056261
8


 56%|████████████████████████████████████████████▎                                  | 561/1000 [07:10<04:58,  1.47it/s]

Mean Reward = -80.08146167557932
8


 57%|█████████████████████████████████████████████                                  | 571/1000 [07:19<06:38,  1.08it/s]

Mean Reward = -79.91978984238177
9


 58%|█████████████████████████████████████████████▉                                 | 581/1000 [07:27<04:44,  1.47it/s]

Mean Reward = -79.55318416523235
8


 59%|██████████████████████████████████████████████▋                                | 591/1000 [07:38<06:56,  1.02s/it]

Mean Reward = -79.88172588832485
8


 60%|███████████████████████████████████████████████▍                               | 601/1000 [07:46<04:52,  1.37it/s]

Mean Reward = -79.75374376039933
8


 61%|████████████████████████████████████████████████▎                              | 611/1000 [07:54<05:06,  1.27it/s]

Mean Reward = -79.96743044189854
8


 62%|█████████████████████████████████████████████████                              | 621/1000 [08:00<04:14,  1.49it/s]

Mean Reward = -79.78438003220613
9


 63%|█████████████████████████████████████████████████▊                             | 631/1000 [08:08<05:08,  1.20it/s]

Mean Reward = -79.96640253565769
5


 64%|██████████████████████████████████████████████████▋                            | 642/1000 [08:15<03:32,  1.68it/s]

Mean Reward = -79.91778471138845
8


 65%|███████████████████████████████████████████████████▍                           | 651/1000 [08:21<03:12,  1.81it/s]

Mean Reward = -79.55529953917049
10


 66%|████████████████████████████████████████████████████▏                          | 661/1000 [08:25<02:40,  2.11it/s]

Mean Reward = -78.92995461422088
10


 67%|█████████████████████████████████████████████████████                          | 671/1000 [08:31<02:40,  2.05it/s]

Mean Reward = -78.69701937406855
8


 68%|█████████████████████████████████████████████████████▊                         | 681/1000 [08:35<02:54,  1.83it/s]

Mean Reward = -78.11967694566813
10


 69%|██████████████████████████████████████████████████████▌                        | 691/1000 [08:41<03:15,  1.58it/s]

Mean Reward = -77.78654124457309
9


 70%|███████████████████████████████████████████████████████▍                       | 701/1000 [08:46<02:52,  1.74it/s]

Mean Reward = -77.49714693295293
9


 71%|████████████████████████████████████████████████████████▏                      | 711/1000 [08:53<03:29,  1.38it/s]

Mean Reward = -77.41181434599156
9


 72%|████████████████████████████████████████████████████████▉                      | 721/1000 [09:01<04:38,  1.00it/s]

Mean Reward = -77.22510402219139
9


 73%|█████████████████████████████████████████████████████████▋                     | 731/1000 [09:11<05:16,  1.18s/it]

Mean Reward = -77.27373461012311
7


 74%|██████████████████████████████████████████████████████████▌                    | 741/1000 [09:18<03:14,  1.33it/s]

Mean Reward = -76.95222672064776
9


 75%|███████████████████████████████████████████████████████████▎                   | 751/1000 [09:28<05:06,  1.23s/it]

Mean Reward = -77.10772303595206
7


 76%|████████████████████████████████████████████████████████████                   | 761/1000 [09:33<02:35,  1.54it/s]

Mean Reward = -76.53337713534822
10


 77%|████████████████████████████████████████████████████████████▉                  | 771/1000 [09:39<02:17,  1.66it/s]

Mean Reward = -76.1396887159533
10


 78%|█████████████████████████████████████████████████████████████▋                 | 781/1000 [09:46<02:06,  1.73it/s]

Mean Reward = -75.90025608194622
9


 79%|██████████████████████████████████████████████████████████████▍                | 791/1000 [09:55<02:42,  1.29it/s]

Mean Reward = -75.91630847029077
8


 80%|███████████████████████████████████████████████████████████████▎               | 801/1000 [10:05<03:47,  1.14s/it]

Mean Reward = -75.95280898876405
7


 81%|████████████████████████████████████████████████████████████████▏              | 812/1000 [10:10<01:23,  2.26it/s]

Mean Reward = -75.66905055487052
9


 82%|████████████████████████████████████████████████████████████████▊              | 821/1000 [10:16<01:40,  1.79it/s]

Mean Reward = -75.54470158343483
8


 83%|█████████████████████████████████████████████████████████████████▋             | 831/1000 [10:19<00:45,  3.67it/s]

Mean Reward = -74.92912154031288
10


 84%|██████████████████████████████████████████████████████████████████▎            | 840/1000 [10:24<01:53,  1.42it/s]

Mean Reward = -74.71724137931034
8


 85%|███████████████████████████████████████████████████████████████████▏           | 851/1000 [10:28<00:48,  3.04it/s]

Mean Reward = -74.25734430082257
10


 86%|████████████████████████████████████████████████████████████████████           | 861/1000 [10:34<01:14,  1.88it/s]

Mean Reward = -74.2040650406504
8


 87%|████████████████████████████████████████████████████████████████████▊          | 871/1000 [10:41<01:15,  1.72it/s]

Mean Reward = -74.14569460390355
8


 88%|█████████████████████████████████████████████████████████████████████▌         | 881/1000 [10:47<01:09,  1.70it/s]

Mean Reward = -74.0872871736663
9


 89%|██████████████████████████████████████████████████████████████████████▍        | 891/1000 [10:53<01:17,  1.41it/s]

Mean Reward = -73.97755331088665
9


 90%|███████████████████████████████████████████████████████████████████████▏       | 901/1000 [11:01<01:06,  1.50it/s]

Mean Reward = -74.15782463928969
7


 91%|███████████████████████████████████████████████████████████████████████▉       | 911/1000 [11:06<00:37,  2.40it/s]

Mean Reward = -73.92250274423711
10


 92%|████████████████████████████████████████████████████████████████████████▊      | 921/1000 [11:12<00:53,  1.48it/s]

Mean Reward = -73.7399565689468
8


 93%|█████████████████████████████████████████████████████████████████████████▌     | 931/1000 [11:20<00:46,  1.48it/s]

Mean Reward = -73.91074113856068
7


 94%|██████████████████████████████████████████████████████████████████████████▎    | 941/1000 [11:27<00:43,  1.34it/s]

Mean Reward = -74.02337938363443
7


 95%|███████████████████████████████████████████████████████████████████████████    | 950/1000 [11:31<00:18,  2.67it/s]

Mean Reward = -73.64710830704522
10


 96%|███████████████████████████████████████████████████████████████████████████▉   | 961/1000 [11:38<00:22,  1.71it/s]

Mean Reward = -73.65234131113422
9


 97%|████████████████████████████████████████████████████████████████████████████▋  | 971/1000 [11:46<00:19,  1.48it/s]

Mean Reward = -73.7437693099897
8


 98%|█████████████████████████████████████████████████████████████████████████████▍ | 981/1000 [11:54<00:11,  1.59it/s]

Mean Reward = -73.98776758409785
8


 99%|██████████████████████████████████████████████████████████████████████████████▎| 991/1000 [12:04<00:13,  1.49s/it]

Mean Reward = -74.23784056508578
6


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [12:16<00:00,  1.36it/s]
