In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from Environment import *
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
'''
Direction mapping:
0: left = [-1, 0]
1: right = [1, 0]
2: up = [0, -1]
3: down = [0, 1]
'''

class User_Agent:
    def __init__(self):
        #model
        #-----------------------------------------------------
        input_A = Input(shape = (4,))    #curr_x, curr_y, target_x, target_y
        x = Dense(32, activation = 'relu')(input_A)
        x = Dense(16, activation = 'relu')(x)
        x = Dense(4)(x) #left, right, down, up
        
        self.model = Model(inputs = input_A, outputs = x)
        print(self.model.summary())
        #---------------------------------------------------
        
        self.target_model = tf.keras.models.clone_model(self.model)
        self.target_model.set_weights(self.model.get_weights())
        
        self.loss_fn = tf.keras.losses.mean_squared_error
        self.optimizer = tf.keras.optimizers.Adam(lr = 1e-3)
        self.batch_size = 128
        self.replay_buffer_size = 1024
        self.replay_buffer = Replay_Buffer(self.replay_buffer_size)
        self.epsilon = 1
        self.gamma = 0.9
        
    def exp_policy(self, state):
        if np.random.rand()<self.epsilon:
            return np.random.randint(4)
        else:
            state = np.array(state)[np.newaxis]
            Q_values = self.model(state)
            return np.argmax(Q_values[0])
        
    def sample_experience(self):
        indices = np.random.randint(len(self.replay_buffer.state_history), size = self.batch_size)
        
        states = np.array([self.replay_buffer.state_history[i] for i in indices])
        actions = np.array([self.replay_buffer.action_history[i] for i in indices])
        next_states = np.array([self.replay_buffer.next_state_history[i] for i in indices])
        rewards = np.array([self.replay_buffer.rewards_history[i] for i in indices])
        dones = np.array([self.replay_buffer.done_history[i] for i in indices])
        
        return states, actions, next_states, rewards, dones
    

    def play_one_step(self, env, state, mod_agent):
        action_user = self.exp_policy(state)
        action_user_one_hot = make_one_hot(action_user, 4)
        curr_loc = state[:2]
        target_loc = state[2:]
        action_user_one_hot.extend(curr_loc)
        mod_state = action_user_one_hot[:]
        mod_state = np.array(mod_state)
        new_loc, reward, done = mod_agent.play_one_step(env, mod_state, curr_loc, target_loc, self)
        next_state = [new_loc[0], new_loc[1], target_loc[0], target_loc[1]]
        self.replay_buffer.append(state, action_user, reward, next_state, done)
        
        return next_state, reward, done
    
    def train(self):
        states, actions, next_states, rewards, dones = self.sample_experience()
        next_Q_values = self.target_model(next_states)
        max_next_Q_values = np.max(next_Q_values, axis= 1)
        target_Q_values = rewards + (1-dones)*self.gamma*max_next_Q_values
        
        mask = tf.one_hot(actions, 4)
        
        with tf.GradientTape() as tape:
            all_Q_values = self.model(states)
            Q_values = tf.reduce_sum(all_Q_values*mask, axis = 1, keepdims = True)
            loss = tf.reduce_mean(self.loss_fn(target_Q_values, Q_values))

        grads = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))

In [3]:
class Mod_Agent:
    def __init__(self):
        #model
        #-----------------------------------------------------
        input_A = Input(shape = (6,))   #direction of motion_one_hot(4), curr_x, curr_y
        x = Dense(32, activation = 'relu')(input_A)
        x = Dense(16, activation = 'relu')(x)
        x = Dense(4)(x) #modulate by 1,2,3,4 
        
        self.model = Model(inputs = input_A, outputs = x)
        print(self.model.summary())
        #---------------------------------------------------
        
        self.target_model = tf.keras.models.clone_model(self.model)
        self.target_model.set_weights(self.model.get_weights())
        
        self.loss_fn = tf.keras.losses.mean_squared_error
        self.optimizer = tf.keras.optimizers.Adam(lr = 1e-3)
        self.batch_size = 128
        self.replay_buffer_size = 1024
        self.replay_buffer = Replay_Buffer(self.replay_buffer_size)
        self.epsilon = 1
        self.steps_per_epoch = 1
        self.gamma = 0.9
    
    def exp_policy(self, state):
        if np.random.rand()<self.epsilon:
            return np.random.randint(1,5)
        else:
            state = np.array(state)[np.newaxis]
            Q_values = self.model(state)
            return np.argmax(Q_values[0])+1
        
        
    def sample_experience(self):
        indices = np.random.randint(len(self.replay_buffer.state_history), size = self.batch_size)
        
        states = np.array([self.replay_buffer.state_history[i] for i in indices])
        actions = np.array([self.replay_buffer.action_history[i] for i in indices])
        next_states = np.array([self.replay_buffer.next_state_history[i] for i in indices])
        rewards = np.array([self.replay_buffer.rewards_history[i] for i in indices])
        dones = np.array([self.replay_buffer.done_history[i] for i in indices])
        
        return states, actions, next_states, rewards, dones
    
    def play_one_step(self, env, state, curr_loc, target_loc, user_agent):
        #Agent not aware of target location
        action_mod = self.exp_policy(state)
        action_user = np.argmax(state[:4])
        new_loc, reward, done = env.step(action_user, action_mod, target_loc, curr_loc)
        next_dir = user_agent.exp_policy(np.array([new_loc[0], new_loc[1], target_loc[0], target_loc[1]]))
        
        next_dir_one_hot = make_one_hot(next_dir, 4)
        next_dir_one_hot.extend(new_loc)
        next_state = next_dir_one_hot[:]
        next_state = np.array(next_state)
        
        self.replay_buffer.append(state, action_mod-1, reward, next_state, done)
        
        
        return new_loc, reward, done
    
    def train(self):
        states, actions, next_states, rewards, dones = self.sample_experience()
        next_Q_values = self.target_model(next_states)
        max_next_Q_values = np.max(next_Q_values, axis= 1)
        target_Q_values = rewards + (1-dones)*self.gamma*max_next_Q_values
        
        mask = tf.one_hot(actions, 4)
        
        with tf.GradientTape() as tape:
            all_Q_values = self.model(states)
            Q_values = tf.reduce_sum(all_Q_values*mask, axis = 1, keepdims = True)
            loss = tf.reduce_mean(self.loss_fn(target_Q_values, Q_values))

        grads = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
            

In [4]:
env = Environment()
user_agent = User_Agent()
mod_agent = Mod_Agent()

Icon Locations:
[[0.4 0.1]
 [0.2 0.5]
 [0.4 0.7]
 [0.5 0.6]
 [0.2 0.1]
 [0.8 0.8]]
Icon usage Probabilities
[0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667]
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 4)]               0         
_________________________________________________________________
dense (Dense)                (None, 32)                160       
_________________________________________________________________
dense_1 (Dense)              (None, 16)                528       
_________________________________________________________________
dense_2 (Dense)              (None, 4)                 68        
Total params: 756
Trainable params: 756
Non-trainable params: 0
_________________________________________________________________
None
Model: "functional_3"
___________________________________________________

In [5]:
rewards = []
mean_rewards = []
max_steps = 40
reached = 0
for epoch in tqdm(range(10000)):
    done = False
    episode_reward = 0
    step = 0
    start, dest = env.give_start_dest()
    start = np.array([0.1,0.1])
    dest = np.array([0.1,0.3])
    state = [start[0], start[1], dest[0], dest[1]]
    while not done and step<max_steps:
        state = np.array(state)
        next_state, reward, done = user_agent.play_one_step(env, state, mod_agent)
        state = next_state
        episode_reward+=reward
        step+=1
        if done:
            reached+=1
            
    if epoch>50:
        user_agent.train()
        mod_agent.train()
    
    if epoch>50 and epoch%25==0:
        user_agent.target_model.set_weights(user_agent.model.get_weights())
        mod_agent.target_model.set_weights(mod_agent.model.get_weights())
        print('Updated Weights')
        
    
    if epoch>50 and epoch%50==0:
        mod_agent.epsilon*=0.9
        user_agent.epsilon*=0.9
            
    mean_rewards.append(episode_reward)
    if epoch%10==0:
        rewards.append(np.mean(mean_rewards))
        mean_rewards = []
        print(f'Mean Reward = {rewards[-1]}')
        print(reached)
        reached = 0
    

  0%|                                                                                        | 0/10000 [00:00<?, ?it/s]

Mean Reward = -46.39999999999999
0
Mean Reward = -40.279999999999994
1
Mean Reward = -42.87
0


  0%|▏                                                                             | 27/10000 [00:00<00:37, 263.24it/s]

Mean Reward = -33.620000000000005
0
Mean Reward = -35.910000000000004
0
Mean Reward = -43.99
0


  1%|▍                                                                              | 59/10000 [00:03<18:40,  8.87it/s]

Mean Reward = -36.260000000000005
1


  1%|▌                                                                              | 71/10000 [00:07<48:49,  3.39it/s]

Mean Reward = -38.82000000000001
0


  1%|▌                                                                              | 75/10000 [00:09<48:10,  3.43it/s]

Updated Weights


  1%|▌                                                                            | 81/10000 [00:11<1:06:45,  2.48it/s]

Mean Reward = -38.160000000000004
1


  1%|▋                                                                            | 91/10000 [00:16<1:21:46,  2.02it/s]

Mean Reward = -37.50000000000001
0


  1%|▊                                                                           | 101/10000 [00:20<1:05:43,  2.51it/s]

Updated Weights
Mean Reward = -41.550000000000004
0


  1%|▊                                                                           | 111/10000 [00:25<1:15:14,  2.19it/s]

Mean Reward = -39.88
0


  1%|▉                                                                           | 121/10000 [00:30<1:09:18,  2.38it/s]

Mean Reward = -43.7
0


  1%|▉                                                                           | 126/10000 [00:32<1:15:34,  2.18it/s]

Updated Weights


  1%|▉                                                                           | 131/10000 [00:34<1:15:40,  2.17it/s]

Mean Reward = -43.92
0


  1%|█                                                                           | 141/10000 [00:39<1:09:06,  2.38it/s]

Mean Reward = -44.64
0


  2%|█▏                                                                          | 151/10000 [00:43<1:10:02,  2.34it/s]

Updated Weights
Mean Reward = -39.029999999999994
0


  2%|█▏                                                                          | 161/10000 [00:48<1:09:31,  2.36it/s]

Mean Reward = -46.980000000000004
0


  2%|█▎                                                                          | 171/10000 [00:52<1:09:18,  2.36it/s]

Mean Reward = -47.39
0


  2%|█▎                                                                          | 176/10000 [00:54<1:13:48,  2.22it/s]

Updated Weights


  2%|█▍                                                                          | 181/10000 [00:56<1:17:09,  2.12it/s]

Mean Reward = -46.73
0


  2%|█▍                                                                          | 191/10000 [01:01<1:08:58,  2.37it/s]

Mean Reward = -39.75
0


  2%|█▌                                                                          | 201/10000 [01:05<1:08:57,  2.37it/s]

Updated Weights
Mean Reward = -38.529999999999994
1


  2%|█▌                                                                          | 211/10000 [01:09<1:10:04,  2.33it/s]

Mean Reward = -41.4
0


  2%|█▋                                                                          | 221/10000 [01:14<1:09:52,  2.33it/s]

Mean Reward = -36.779999999999994
0


  2%|█▋                                                                          | 226/10000 [01:16<1:16:40,  2.12it/s]

Updated Weights


  2%|█▊                                                                          | 231/10000 [01:18<1:09:34,  2.34it/s]

Mean Reward = -48.129999999999995
0


  2%|█▊                                                                          | 241/10000 [01:22<1:07:33,  2.41it/s]

Mean Reward = -38.96
1


  3%|█▉                                                                          | 251/10000 [01:27<1:12:51,  2.23it/s]

Updated Weights
Mean Reward = -41.18
1


  3%|█▉                                                                          | 261/10000 [01:31<1:11:22,  2.27it/s]

Mean Reward = -44.67
0


  3%|██                                                                          | 271/10000 [01:35<1:10:04,  2.31it/s]

Mean Reward = -46.88
0


  3%|██                                                                          | 276/10000 [01:38<1:22:56,  1.95it/s]

Updated Weights


  3%|██▏                                                                         | 281/10000 [01:41<1:20:35,  2.01it/s]

Mean Reward = -45.21
0


  3%|██▏                                                                         | 291/10000 [01:46<1:29:22,  1.81it/s]

Mean Reward = -47.47
0


  3%|██▎                                                                         | 301/10000 [01:51<1:18:19,  2.06it/s]

Updated Weights
Mean Reward = -46.33
0


  3%|██▎                                                                         | 311/10000 [01:55<1:10:31,  2.29it/s]

Mean Reward = -42.6
1


  3%|██▍                                                                         | 321/10000 [01:59<1:12:00,  2.24it/s]

Mean Reward = -47.36
0


  3%|██▍                                                                         | 326/10000 [02:01<1:09:50,  2.31it/s]

Updated Weights


  3%|██▌                                                                         | 330/10000 [02:03<1:16:42,  2.10it/s]

Mean Reward = -34.28
2


  3%|██▌                                                                         | 341/10000 [02:08<1:10:47,  2.27it/s]

Mean Reward = -45.3
0


  4%|██▋                                                                         | 351/10000 [02:12<1:11:16,  2.26it/s]

Updated Weights
Mean Reward = -47.05
0


  4%|██▋                                                                         | 361/10000 [02:17<1:12:24,  2.22it/s]

Mean Reward = -43.42
0


  4%|██▊                                                                         | 371/10000 [02:22<1:12:10,  2.22it/s]

Mean Reward = -46.9
0


  4%|██▊                                                                         | 376/10000 [02:24<1:11:03,  2.26it/s]

Updated Weights


  4%|██▉                                                                         | 381/10000 [02:26<1:25:31,  1.87it/s]

Mean Reward = -46.9
0


  4%|██▉                                                                         | 391/10000 [02:31<1:11:28,  2.24it/s]

Mean Reward = -46.839999999999996
0


  4%|███                                                                         | 401/10000 [02:36<1:11:19,  2.24it/s]

Updated Weights
Mean Reward = -52.720000000000006
0


  4%|███                                                                         | 411/10000 [02:41<1:12:11,  2.21it/s]

Mean Reward = -45.980000000000004
0


  4%|███▏                                                                        | 421/10000 [02:45<1:20:16,  1.99it/s]

Mean Reward = -49.75000000000001
0


  4%|███▏                                                                        | 426/10000 [02:48<1:16:55,  2.07it/s]

Updated Weights


  4%|███▎                                                                        | 431/10000 [02:50<1:14:57,  2.13it/s]

Mean Reward = -46.95
0


  4%|███▎                                                                        | 441/10000 [02:55<1:13:49,  2.16it/s]

Mean Reward = -59.1
0


  5%|███▍                                                                        | 451/10000 [02:59<1:12:59,  2.18it/s]

Updated Weights
Mean Reward = -51.749999999999986
0


  5%|███▌                                                                        | 461/10000 [03:04<1:12:53,  2.18it/s]

Mean Reward = -47.529999999999994
0


  5%|███▌                                                                        | 471/10000 [03:09<1:25:00,  1.87it/s]

Mean Reward = -51.370000000000005
0


  5%|███▌                                                                        | 474/10000 [03:11<1:04:13,  2.47it/s]


KeyboardInterrupt: 

In [None]:
user_agent.model(np.array([[0.1, 0.1, 0.1 , 0.3]]))

In [None]:
mod_agent.model(np.array([[0, 1, 0, 0, 1, 0.3]]))