In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from Environment import *
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
'''
Direction mapping:
0: left = [-1, 0]
1: right = [1, 0]
2: up = [0, -1]
3: down = [0, 1]
'''

class User_Agent:
    def __init__(self):
        #model
        #-----------------------------------------------------
        input_A = Input(shape = (4,))    #curr_x, curr_y, target_x, target_y
        x = Dense(32, activation = 'relu')(input_A)
        x = Dense(4, activation = 'softmax')(x) #left,right, down, up
        
        self.model = Model(inputs = input_A, outputs = x)
        print(self.model.summary())
        #---------------------------------------------------
        
        self.target_model = tf.keras.models.clone_model(self.model)
        self.target_model.set_weights(self.model.get_weights())
        
        self.loss_fn = tf.keras.losses.mean_squared_error
        self.optimizer = tf.keras.optimizers.Adam(lr = 0.005)
        self.batch_size = 128
        self.replay_buffer_size = 1024
        self.replay_buffer = Replay_Buffer(self.replay_buffer_size)
        self.epsilon = 1
        self.gamma = 0.9
        
    def exp_policy(self, state):
        if np.random.rand()<self.epsilon:
            return np.random.randint(4)
        else:
            state = np.array(state)[np.newaxis]
            Q_values = self.model(state)
            return np.argmax(Q_values[0])
        
    def sample_experience(self):
        indices = np.random.randint(len(self.replay_buffer.state_history), size = self.batch_size)
        
        states = np.array([self.replay_buffer.state_history[i] for i in indices])
        actions = np.array([self.replay_buffer.action_history[i] for i in indices])
        next_states = np.array([self.replay_buffer.next_state_history[i] for i in indices])
        rewards = np.array([self.replay_buffer.rewards_history[i] for i in indices])
        dones = np.array([self.replay_buffer.done_history[i] for i in indices])
        
        return states, actions, next_states, rewards, dones
    

    def play_one_step(self, env, state, mod_agent):
        action_user = self.exp_policy(state)
        action_user_one_hot = make_one_hot(action_user, 4)
        curr_loc = state[:2]
        target_loc = state[2:]
        action_user_one_hot.extend(target_loc)
        mod_state = action_user_one_hot[:]
        mod_state = np.array(mod_state)
        new_loc, reward, done = mod_agent.play_one_step(env, mod_state, curr_loc, target_loc, self)
        next_state = [new_loc[0], new_loc[1], target_loc[0], target_loc[1]]
        self.replay_buffer.append(state, action_user, reward, next_state, done)
        
        return next_state, reward, done
    
    def train(self):
        states, actions, next_states, rewards, dones = self.sample_experience()
        next_Q_values = self.target_model(next_states)
        max_next_Q_values = np.max(next_Q_values, axis= 1)
        target_Q_values = rewards + (1-dones)*self.gamma*max_next_Q_values
        
        mask = tf.one_hot(actions, 4)
        
        with tf.GradientTape() as tape:
            all_Q_values = self.model(states)
            Q_values = tf.reduce_sum(all_Q_values*mask, axis = 1, keepdims = True)
            loss = tf.reduce_mean(self.loss_fn(target_Q_values, Q_values))

        grads = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))

In [3]:
class Mod_Agent:
    def __init__(self):
        #model
        #-----------------------------------------------------
        input_A = Input(shape = (6,))   #direction of motion_one_hot(4), curr_x, curr_y
        x = Dense(32, activation = 'relu')(input_A)
        x = Dense(4, activation = 'softmax')(x) #modulate by 1,2,3,4 
        
        self.model = Model(inputs = input_A, outputs = x)
        print(self.model.summary())
        #---------------------------------------------------
        
        self.target_model = tf.keras.models.clone_model(self.model)
        self.target_model.set_weights(self.model.get_weights())
        
        self.loss_fn = tf.keras.losses.mean_squared_error
        self.optimizer = tf.keras.optimizers.Adam(lr = 0.005)
        self.batch_size = 128
        self.replay_buffer_size = 1024
        self.replay_buffer = Replay_Buffer(self.replay_buffer_size)
        self.epsilon = 1
        self.steps_per_epoch = 1
        self.gamma = 0.9
    
    def exp_policy(self, state):
        if np.random.rand()<self.epsilon:
            return np.random.randint(4)
        else:
            state = np.array(state)[np.newaxis]
            Q_values = self.model(state)
            return np.argmax(Q_values[0])
        
        
    def sample_experience(self):
        indices = np.random.randint(len(self.replay_buffer.state_history), size = self.batch_size)
        
        states = np.array([self.replay_buffer.state_history[i] for i in indices])
        actions = np.array([self.replay_buffer.action_history[i] for i in indices])
        next_states = np.array([self.replay_buffer.next_state_history[i] for i in indices])
        rewards = np.array([self.replay_buffer.rewards_history[i] for i in indices])
        dones = np.array([self.replay_buffer.done_history[i] for i in indices])
        
        return states, actions, next_states, rewards, dones
    
    def play_one_step(self, env, state, curr_loc, target_loc, user_agent):
        #Agent not aware of target location
        action_mod = self.exp_policy(state)
        action_user = state[0]
        
        new_loc, reward, done = env.step(action_user, action_mod, target_loc, curr_loc)
        next_dir = user_agent.exp_policy(np.array([[new_loc[0], new_loc[1], target_loc[0], target_loc[1]]]))
        
        next_dir_one_hot = make_one_hot(next_dir, 4)
        next_dir_one_hot.extend(target_loc)
        next_state = next_dir_one_hot[:]
        next_state = np.array(next_state)
        
        self.replay_buffer.append(state, action_mod, reward, next_state, done)
        
        
        return new_loc, reward, done
    
    def train(self):
        states, actions, next_states, rewards, dones = self.sample_experience()
        next_Q_values = self.target_model(next_states)
        max_next_Q_values = np.max(next_Q_values, axis= 1)
        target_Q_values = rewards + (1-dones)*self.gamma*max_next_Q_values
        
        mask = tf.one_hot(actions, 4)
        
        with tf.GradientTape() as tape:
            all_Q_values = self.model(states)
            Q_values = tf.reduce_sum(all_Q_values*mask, axis = 1, keepdims = True)
            loss = tf.reduce_mean(self.loss_fn(target_Q_values, Q_values))

        grads = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
            

In [4]:
env = Environment()
user_agent = User_Agent()
mod_agent = Mod_Agent()

Icon Locations:
[[0.4 0.9]
 [0.  0.2]
 [0.9 0.4]
 [0.4 0.2]
 [0.4 0.7]
 [0.5 0. ]]
Icon usage Probabilities
[0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667]
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 4)]               0         
_________________________________________________________________
dense (Dense)                (None, 32)                160       
_________________________________________________________________
dense_1 (Dense)              (None, 4)                 132       
Total params: 292
Trainable params: 292
Non-trainable params: 0
_________________________________________________________________
None
Model: "functional_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 6)]           

In [5]:
rewards = []
mean_rewards = []
max_steps = 200
reached = 0
for epoch in tqdm(range(1000)):
    done = False
    episode_reward = 0
    step = 0
    while not done and step<max_steps:
        start, dest = env.give_start_dest()
        state = [start[0], start[1], dest[0], dest[1]]
        state = np.array(state)
        next_state, reward, done = user_agent.play_one_step(env, state, mod_agent)
        state = next_state
        episode_reward+=reward
        step+=1
        if done:
            reached+=1
        if epoch>50:
            user_agent.train()
            mod_agent.train()
    
    mean_rewards.append(episode_reward)
    if epoch%10==0:
        rewards.append(np.mean(mean_rewards))
        print(f'Mean Reward = {rewards[-1]}')
        print(reached)
        reached = 0
    

  0%|                                                                                         | 0/1000 [00:00<?, ?it/s]

Mean Reward = -46.599999999999994
1


  2%|█▎                                                                              | 16/1000 [00:00<00:21, 46.18it/s]

Mean Reward = -109.10000000000004
5


  3%|██▏                                                                             | 27/1000 [00:00<00:20, 48.54it/s]

Mean Reward = -118.75238095238095
4
Mean Reward = -117.33225806451614
6


  5%|███▊                                                                            | 47/1000 [00:00<00:17, 54.26it/s]

Mean Reward = -110.29024390243906
7
Mean Reward = -108.46274509803922
7


  6%|████▊                                                                           | 60/1000 [00:12<09:26,  1.66it/s]

Mean Reward = -111.82295081967212
5


  7%|█████▋                                                                          | 71/1000 [00:27<19:43,  1.27s/it]

Mean Reward = -115.51830985915494
4


  8%|██████▍                                                                         | 81/1000 [00:37<13:52,  1.10it/s]

Mean Reward = -111.88888888888891
7


  9%|███████▏                                                                        | 90/1000 [00:47<18:02,  1.19s/it]

Mean Reward = -110.55274725274727
5


 10%|███████▉                                                                       | 101/1000 [00:55<09:29,  1.58it/s]

Mean Reward = -107.1653465346535
7


 11%|████████▊                                                                      | 111/1000 [01:07<18:00,  1.22s/it]

Mean Reward = -108.49549549549549
6


 12%|█████████▌                                                                     | 121/1000 [01:21<16:39,  1.14s/it]

Mean Reward = -109.8603305785124
4


 13%|██████████▎                                                                    | 131/1000 [01:33<16:34,  1.14s/it]

Mean Reward = -110.68931297709925
7


 14%|███████████▏                                                                   | 141/1000 [01:47<21:29,  1.50s/it]

Mean Reward = -113.38652482269504
3


 15%|███████████▉                                                                   | 151/1000 [01:59<19:29,  1.38s/it]

Mean Reward = -112.21920529801324
9


 16%|████████████▋                                                                  | 161/1000 [02:10<16:13,  1.16s/it]

Mean Reward = -111.6913043478261
6


 17%|█████████████▌                                                                 | 171/1000 [02:21<12:05,  1.14it/s]

Mean Reward = -111.2058479532164
7


 18%|██████████████▎                                                                | 181/1000 [02:31<15:30,  1.14s/it]

Mean Reward = -111.08342541436463
6


 19%|███████████████                                                                | 191/1000 [02:41<09:26,  1.43it/s]

Mean Reward = -109.7649214659686
7


 20%|███████████████▉                                                               | 201/1000 [02:53<18:16,  1.37s/it]

Mean Reward = -110.07910447761195
5


 21%|████████████████▋                                                              | 211/1000 [03:11<20:21,  1.55s/it]

Mean Reward = -111.05165876777252
6


 22%|█████████████████▍                                                             | 221/1000 [03:23<14:39,  1.13s/it]

Mean Reward = -110.85158371040724
6


 23%|██████████████████▏                                                            | 231/1000 [03:35<13:53,  1.08s/it]

Mean Reward = -110.78658008658009
6


 24%|███████████████████                                                            | 241/1000 [03:50<20:39,  1.63s/it]

Mean Reward = -111.36348547717841
4


 25%|███████████████████▊                                                           | 251/1000 [04:01<08:33,  1.46it/s]

Mean Reward = -111.12549800796813
6


 26%|████████████████████▌                                                          | 261/1000 [04:16<15:23,  1.25s/it]

Mean Reward = -112.05632183908047
4


 27%|█████████████████████▍                                                         | 271/1000 [04:28<13:22,  1.10s/it]

Mean Reward = -111.67785977859779
10


 28%|██████████████████████▏                                                        | 281/1000 [04:37<08:54,  1.35it/s]

Mean Reward = -110.53024911032028
8


 29%|██████████████████████▉                                                        | 291/1000 [04:47<12:13,  1.03s/it]

Mean Reward = -109.81890034364262
8


 30%|███████████████████████▊                                                       | 301/1000 [04:54<12:03,  1.03s/it]

Mean Reward = -108.38139534883722
9


 31%|████████████████████████▌                                                      | 311/1000 [05:06<10:06,  1.14it/s]

Mean Reward = -108.45176848874598
6


 32%|█████████████████████████▎                                                     | 321/1000 [05:20<15:06,  1.34s/it]

Mean Reward = -109.40498442367601
3


 33%|██████████████████████████▏                                                    | 332/1000 [05:34<10:44,  1.04it/s]

Mean Reward = -109.72416918429005
5


 34%|██████████████████████████▉                                                    | 341/1000 [05:46<12:17,  1.12s/it]

Mean Reward = -109.7108504398827
6


 35%|███████████████████████████▋                                                   | 351/1000 [05:56<12:15,  1.13s/it]

Mean Reward = -109.26011396011397
6


 36%|████████████████████████████▌                                                  | 361/1000 [06:10<17:08,  1.61s/it]

Mean Reward = -110.02022160664819
4


 37%|█████████████████████████████▎                                                 | 371/1000 [06:24<16:06,  1.54s/it]

Mean Reward = -110.59433962264151
4


 38%|██████████████████████████████                                                 | 381/1000 [06:38<12:48,  1.24s/it]

Mean Reward = -111.1233595800525
4


 39%|██████████████████████████████▉                                                | 391/1000 [06:53<13:56,  1.37s/it]

Mean Reward = -111.82915601023016
5


 40%|███████████████████████████████▋                                               | 401/1000 [07:07<14:04,  1.41s/it]

Mean Reward = -112.31720698254364
7


 41%|████████████████████████████████▍                                              | 411/1000 [07:20<10:26,  1.06s/it]

Mean Reward = -112.54768856447689
5


 42%|█████████████████████████████████▎                                             | 421/1000 [07:33<11:19,  1.17s/it]

Mean Reward = -112.5874109263658
7


 43%|██████████████████████████████████                                             | 431/1000 [07:40<10:20,  1.09s/it]

Mean Reward = -111.51067285382832
9


 44%|██████████████████████████████████▊                                            | 441/1000 [07:51<11:36,  1.25s/it]

Mean Reward = -111.30385487528345
7


 45%|███████████████████████████████████▋                                           | 451/1000 [08:03<10:28,  1.15s/it]

Mean Reward = -111.4490022172949
6


 46%|████████████████████████████████████▍                                          | 461/1000 [08:15<12:02,  1.34s/it]

Mean Reward = -111.39783080260304
5


 47%|█████████████████████████████████████▏                                         | 471/1000 [08:23<07:04,  1.25it/s]

Mean Reward = -110.71125265392782
8


 48%|█████████████████████████████████████▉                                         | 481/1000 [08:36<12:23,  1.43s/it]

Mean Reward = -110.87318087318087
7


 49%|██████████████████████████████████████▊                                        | 491/1000 [08:49<12:18,  1.45s/it]

Mean Reward = -111.19959266802444
6


 50%|███████████████████████████████████████▌                                       | 501/1000 [09:03<13:41,  1.65s/it]

Mean Reward = -111.7572854291417
4


 51%|████████████████████████████████████████▎                                      | 510/1000 [09:14<10:23,  1.27s/it]

Mean Reward = -111.53424657534246
6


 52%|█████████████████████████████████████████▏                                     | 521/1000 [09:28<11:15,  1.41s/it]

Mean Reward = -112.0397312859885
4


 53%|█████████████████████████████████████████▉                                     | 531/1000 [09:43<11:14,  1.44s/it]

Mean Reward = -112.59114877589454
4


 54%|██████████████████████████████████████████▋                                    | 541/1000 [09:55<08:50,  1.16s/it]

Mean Reward = -112.54584103512015
6


 55%|███████████████████████████████████████████▌                                   | 551/1000 [10:08<09:49,  1.31s/it]

Mean Reward = -112.70344827586207
5


 56%|████████████████████████████████████████████▍                                  | 562/1000 [10:22<08:09,  1.12s/it]

Mean Reward = -113.11586452762924
5


 57%|█████████████████████████████████████████████                                  | 571/1000 [10:31<05:22,  1.33it/s]

Mean Reward = -112.54203152364273
7


 58%|█████████████████████████████████████████████▉                                 | 581/1000 [10:44<08:32,  1.22s/it]

Mean Reward = -112.89259896729777
5


 59%|██████████████████████████████████████████████▋                                | 591/1000 [10:57<09:19,  1.37s/it]

Mean Reward = -113.0181049069374
5


 60%|███████████████████████████████████████████████▍                               | 601/1000 [11:12<10:46,  1.62s/it]

Mean Reward = -113.52945091514145
4


 61%|████████████████████████████████████████████████▎                              | 611/1000 [11:25<08:27,  1.31s/it]

Mean Reward = -113.69869067103112
7


 62%|█████████████████████████████████████████████████                              | 621/1000 [11:38<07:59,  1.27s/it]

Mean Reward = -113.86264090177136
6


 63%|█████████████████████████████████████████████████▊                             | 631/1000 [11:50<07:42,  1.25s/it]

Mean Reward = -113.7824088748019
7


 64%|██████████████████████████████████████████████████▋                            | 641/1000 [12:01<08:20,  1.39s/it]

Mean Reward = -113.61466458658346
5


 65%|███████████████████████████████████████████████████▍                           | 651/1000 [12:12<06:49,  1.17s/it]

Mean Reward = -113.4403993855607
6


 66%|████████████████████████████████████████████████████▏                          | 661/1000 [12:25<07:47,  1.38s/it]

Mean Reward = -113.63706505295008
4


 67%|█████████████████████████████████████████████████████                          | 671/1000 [12:39<07:43,  1.41s/it]

Mean Reward = -113.96482861400895
5


 68%|█████████████████████████████████████████████████████▉                         | 682/1000 [12:53<05:43,  1.08s/it]

Mean Reward = -114.20616740088107
6


 69%|██████████████████████████████████████████████████████▌                        | 691/1000 [13:04<05:13,  1.02s/it]

Mean Reward = -114.02387843704776
7


 70%|███████████████████████████████████████████████████████▍                       | 701/1000 [13:11<04:20,  1.15it/s]

Mean Reward = -113.28131241084166
9


 71%|████████████████████████████████████████████████████████▏                      | 711/1000 [13:26<07:20,  1.52s/it]

Mean Reward = -113.6748241912799
4


 72%|████████████████████████████████████████████████████████▉                      | 721/1000 [13:37<06:24,  1.38s/it]

Mean Reward = -113.50748959778085
6


 73%|█████████████████████████████████████████████████████████▋                     | 731/1000 [13:51<07:22,  1.64s/it]

Mean Reward = -113.96990424076607
3


 74%|██████████████████████████████████████████████████████████▌                    | 741/1000 [14:06<06:03,  1.40s/it]

Mean Reward = -114.25910931174089
6


 75%|███████████████████████████████████████████████████████████▎                   | 751/1000 [14:18<04:17,  1.04s/it]

Mean Reward = -114.24993342210388
6


 76%|████████████████████████████████████████████████████████████                   | 761/1000 [14:33<06:03,  1.52s/it]

Mean Reward = -114.65072273324574
5


 77%|████████████████████████████████████████████████████████████▉                  | 771/1000 [14:43<04:04,  1.07s/it]

Mean Reward = -114.31232166018158
8


 78%|█████████████████████████████████████████████████████████████▋                 | 781/1000 [14:54<05:01,  1.38s/it]

Mean Reward = -114.12650448143407
7


 79%|██████████████████████████████████████████████████████████████▍                | 791/1000 [15:06<03:57,  1.14s/it]

Mean Reward = -114.15461441213654
6


 80%|███████████████████████████████████████████████████████████████▎               | 801/1000 [15:20<04:00,  1.21s/it]

Mean Reward = -114.43021223470663
4


 81%|████████████████████████████████████████████████████████████████               | 811/1000 [15:35<05:14,  1.66s/it]

Mean Reward = -114.86658446362516
3


 82%|████████████████████████████████████████████████████████████████▊              | 821/1000 [15:48<03:32,  1.19s/it]

Mean Reward = -114.97612667478685
6


 83%|█████████████████████████████████████████████████████████████████▋             | 831/1000 [16:02<03:46,  1.34s/it]

Mean Reward = -115.14861612515043
5


 84%|██████████████████████████████████████████████████████████████████▍            | 841/1000 [16:14<03:12,  1.21s/it]

Mean Reward = -115.16884661117717
6


 85%|███████████████████████████████████████████████████████████████████▏           | 851/1000 [16:25<02:37,  1.06s/it]

Mean Reward = -115.06615746180964
6


 86%|████████████████████████████████████████████████████████████████████           | 861/1000 [16:40<03:50,  1.66s/it]

Mean Reward = -115.3150987224158
6


 87%|████████████████████████████████████████████████████████████████████▊          | 871/1000 [16:55<03:39,  1.70s/it]

Mean Reward = -115.75820895522388
4


 88%|█████████████████████████████████████████████████████████████████████▌         | 881/1000 [17:09<02:21,  1.19s/it]

Mean Reward = -115.9950056753689
3


 89%|██████████████████████████████████████████████████████████████████████▍        | 891/1000 [17:21<02:10,  1.20s/it]

Mean Reward = -115.9132435465769
6


 90%|███████████████████████████████████████████████████████████████████████▏       | 901/1000 [17:36<02:41,  1.63s/it]

Mean Reward = -116.18512763596004
5


 91%|███████████████████████████████████████████████████████████████████████▉       | 911/1000 [17:49<01:59,  1.34s/it]

Mean Reward = -116.30263446761802
5


 92%|████████████████████████████████████████████████████████████████████████▊      | 922/1000 [18:03<01:02,  1.25it/s]

Mean Reward = -116.50097719869706
4


 93%|█████████████████████████████████████████████████████████████████████████▋     | 932/1000 [18:16<01:14,  1.10s/it]

Mean Reward = -116.56659505907626
4


 94%|██████████████████████████████████████████████████████████████████████████▎    | 941/1000 [18:26<00:55,  1.06it/s]

Mean Reward = -116.30456960680128
7


 95%|███████████████████████████████████████████████████████████████████████████▏   | 952/1000 [18:41<00:57,  1.20s/it]

Mean Reward = -116.57234490010516
5


 96%|███████████████████████████████████████████████████████████████████████████▉   | 961/1000 [18:53<00:43,  1.12s/it]

Mean Reward = -116.65119667013528
5


 97%|████████████████████████████████████████████████████████████████████████████▋  | 971/1000 [19:07<00:40,  1.41s/it]

Mean Reward = -116.77857878475798
6


 98%|█████████████████████████████████████████████████████████████████████████████▍ | 981/1000 [19:18<00:21,  1.14s/it]

Mean Reward = -116.67186544342508
6


 99%|██████████████████████████████████████████████████████████████████████████████▎| 991/1000 [19:35<00:15,  1.74s/it]

Mean Reward = -117.14732593340061
2


100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [19:46<00:00,  1.19s/it]
