In [1]:
from Environment_one_hot import *
from Networks import *

import tensorflow as tf
from tensorflow import keras 
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input

import numpy as np
from tqdm import tqdm

In [2]:
LENGTH = 6000
WIDTH = 5000
DIVISION = 50
K = 6
NUM_BASE_STATIONS = 7

env = Environment(LENGTH, WIDTH, DIVISION, 1, 0, K)
eps = 10e-6

Location and directions of cells are: 
{1: [0, 0, 0], 2: [0, 0, 1], 3: [0, 0, 2], 4: [1500.0, 0, 0], 5: [1500.0, 0, 1], 6: [1500.0, 0, 2], 7: [750.0, 1250.0, 0], 8: [750.0, 1250.0, 1], 9: [750.0, 1250.0, 2], 10: [-750.0, 1250.0, 0], 11: [-750.0, 1250.0, 1], 12: [-750.0, 1250.0, 2], 13: [-1500.0, 0, 0], 14: [-1500.0, 0, 1], 15: [-1500.0, 0, 2], 16: [-750.0, -1250.0, 0], 17: [-750.0, -1250.0, 1], 18: [-750.0, -1250.0, 2], 19: [750.0, -1250.0, 0], 20: [750.0, -1250.0, 1], 21: [750.0, -1250.0, 2]} 
 

Strongest cells for sector (0, 0)
[[5.0, -69.68031402659005], [13.0, -71.07287859388722], [10.0, -77.36490164108055], [16.0, -77.39288922194086], [9.0, -77.7794686317523], [20.0, -79.71838870710408]]


Strongest cells for sector (0, 100)
[[13.0, -68.54875066828438], [5.0, -72.23390665442798], [2.0, -76.24066816535561], [17.0, -81.04742797484525], [20.0, -81.67563656386488], [10.0, -83.59618041497248]]


Strongest cells for sector (50, 0)
[[5.0, -64.51934435908518], [13.0, -75.59831577353562],

In [3]:
class Agent:
    def __init__(self):
        self.loss_fn = tf.keras.losses.huber
        self.actor_optimizer = tf.keras.optimizers.Adam(lr = 0.0001)
        self.critic_optimizer = tf.keras.optimizers.Adam(lr = 0.0002)
        self.actor = ACTOR_NET(K, NUM_BASE_STATIONS)
        self.critic = CRITIC_NET(K, NUM_BASE_STATIONS)
        self.gamma = 0.3

    def learn(self, WRSRP, WHO):
        running_reward = 0
        running_handovers = 0
        max_reward = float('-inf')
        for epoch in tqdm(range(100000)):
            action_probs_history = []
            critic_value_history = []
            rewards_history = []
            
            episode_reward = 0
            src,dest = env.give_src_dest()
            route = env.compute_route(src, dest)
            state = route.popleft()
            depth = 3*NUM_BASE_STATIONS
            one_hot_cell = make_one_hot(env.sector_cells[src][0][0], depth)
            one_hot_direction = make_one_hot(state[-1]+1, 8)
            state = state[:-1]
            state.extend(one_hot_direction) #According route choosing one of the 8 directions
            state.extend(one_hot_cell) #Setting strongest cell as the initial serving cell (one_hot)
            done = False
            handovers = 0
            
            with tf.GradientTape(persistent = True) as tape:
                while not done:
                    norm_state = list(state[:])
                    norm_state[0] = norm_state[0]/(LENGTH//2)
                    norm_state[1] = norm_state[1]/(WIDTH//2) 
                    norm_state = np.array(norm_state)[np.newaxis]
                    action_probs, critic_value = self.actor.model(norm_state), self.critic.model(norm_state)
                    critic_value_history.append(critic_value[0, 0])
                    
                    action = np.random.choice(K, p=np.squeeze(action_probs))
                    action_probs_history.append(tf.math.log(action_probs[0, action]))
                    
                    next_state, reward, done, change = env.step(state, route, action, dest)
                    
                    state = next_state
                    if change:
                        handovers+=1
                    reward*=WRSRP
                    reward-=change*WHO
                    rewards_history.append(reward)
                    episode_reward += reward
                    
                returns = []
                discounted_sum = 0
                
                for r in rewards_history[::-1]:
                    discounted_sum = r + self.gamma * discounted_sum
                    returns.append(discounted_sum)
                    
                returns.reverse()
                
                returns = np.array(returns)
                returns = (returns - np.mean(returns)) / (np.std(returns) + eps)
                returns = returns.tolist()
                
                if epoch:
                    running_reward = 0.05 * episode_reward + (1 - 0.05) * running_reward
                    running_handovers = 0.05*handovers + (1-0.05)*running_handovers
                else:
                    running_reward = episode_reward
                    running_handovers = handovers
                    
                history = zip(action_probs_history, critic_value_history, returns)
                actor_losses = []
                critic_losses = []
                for log_prob, value, ret in history:
                    diff = ret - value
                    actor_losses.append(-log_prob * diff)
                    critic_losses.append(self.loss_fn(tf.expand_dims(value, 0), tf.expand_dims(ret, 0)))
                actor_loss_value = sum(actor_losses)
                critic_loss_value = sum(critic_losses)
                
            grads = tape.gradient(actor_loss_value, self.actor.model.trainable_variables)
            self.actor_optimizer.apply_gradients(zip(grads, self.actor.model.trainable_variables))
            
            grads = tape.gradient(critic_loss_value, self.critic.model.trainable_variables)
            self.critic_optimizer.apply_gradients(zip(grads, self.critic.model.trainable_variables))
            
            action_probs_history.clear()
            critic_value_history.clear()
            rewards_history.clear()
            
            if epoch and epoch%1000 == 0:
                print(running_reward)
                print(running_handovers)
                if running_reward>=max_reward:
                    max_reward = running_reward
                    tf.keras.models.save_model(self.actor.model, f'model_{WRSRP}_{WHO}.h5')

In [4]:
agent = Agent()

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 31)]              0         
_________________________________________________________________
dense (Dense)                (None, 32)                1024      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 32)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 64)                2112      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 64)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 32)                2080      
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 32)               

In [None]:
agent.learn(1,1)

  1%|▋                                                                        | 1001/100000 [10:51<20:30:27,  1.34it/s]

2.7678878864705796
28.62145893554227


  2%|█▍                                                                       | 2001/100000 [23:03<21:40:48,  1.26it/s]

2.6764856894616416
23.720883905026152


  3%|██▏                                                                      | 3001/100000 [34:24<14:48:52,  1.82it/s]

3.5710719033475815
22.236378325404605


  4%|██▉                                                                      | 4001/100000 [46:18<18:01:13,  1.48it/s]

4.0672640177068144
23.352760161644404


  5%|███▋                                                                     | 5001/100000 [57:38<11:51:57,  2.22it/s]

2.9000304380007815
21.704124706369143


  6%|████▎                                                                  | 6001/100000 [1:07:55<22:13:40,  1.17it/s]

2.8804016126869216
26.404589067938286


  7%|████▉                                                                  | 7001/100000 [1:18:31<12:18:50,  2.10it/s]

3.469223347787159
22.064280759423784


  8%|█████▋                                                                 | 8001/100000 [1:29:12<19:46:50,  1.29it/s]

5.26184451178248
24.18628701909653


  9%|██████▍                                                                | 9001/100000 [1:39:47<16:24:48,  1.54it/s]

2.1415082035533346
25.88657889264978


 10%|███████                                                               | 10001/100000 [1:50:13<13:28:20,  1.86it/s]

2.324095608937157
22.191164355431532


 11%|███████▋                                                              | 11001/100000 [2:00:35<10:11:20,  2.43it/s]

3.0508484001035923
19.403856178637167


 12%|████████▌                                                              | 12002/100000 [2:10:57<7:01:46,  3.48it/s]

2.6647576624137446
22.730359797543528


 13%|█████████                                                             | 13001/100000 [2:21:25<15:04:05,  1.60it/s]

4.1393581005942695
25.916237703459764


 14%|█████████▊                                                            | 14001/100000 [2:32:09<18:29:34,  1.29it/s]

4.71405460108324
24.07767479934738


 15%|██████████▌                                                           | 15001/100000 [2:42:59<15:37:55,  1.51it/s]

3.370730579081626
25.278404282478423


 16%|███████████▏                                                          | 16001/100000 [2:53:23<13:18:02,  1.75it/s]

4.741149987903912
23.458414603790605


 17%|███████████▉                                                          | 17001/100000 [3:03:57<14:06:19,  1.63it/s]

2.743635951339218
25.402632233373378


 18%|████████████▌                                                         | 18001/100000 [3:14:27<12:15:29,  1.86it/s]

4.194158013240865
25.887032340812787


 19%|█████████████▍                                                         | 19002/100000 [3:24:56<9:55:40,  2.27it/s]

3.904613964940162
24.025142678401952


 20%|██████████████                                                        | 20001/100000 [3:35:28<11:46:39,  1.89it/s]

2.3775945698320538
25.13160817673673


 21%|██████████████▉                                                        | 21001/100000 [3:46:27<9:55:50,  2.21it/s]

4.508965943191136
23.469059815958268


 22%|███████████████▍                                                      | 22001/100000 [3:57:47<15:26:34,  1.40it/s]

2.7913998379277416
25.29174205363754


 23%|████████████████                                                      | 23001/100000 [4:09:12<16:11:17,  1.32it/s]

3.565091407294359
25.89772030877489


 24%|█████████████████                                                      | 24002/100000 [4:20:16<8:13:49,  2.56it/s]

3.9389997774644256
20.28984081031139


 25%|█████████████████▌                                                    | 25001/100000 [4:31:15<15:45:21,  1.32it/s]

2.681709469839523
25.653232738187274


 26%|██████████████████▏                                                   | 26001/100000 [4:41:59<12:33:17,  1.64it/s]

4.410775733358083
21.820647815397784


 27%|██████████████████▉                                                   | 27000/100000 [4:53:02<15:40:25,  1.29it/s]

5.8391814002757565
25.435149460417353


 28%|███████████████████▌                                                  | 28001/100000 [5:03:40<14:51:45,  1.35it/s]

3.9109556107213552
23.31773803885557


 29%|████████████████████▎                                                 | 29001/100000 [5:14:41<11:47:58,  1.67it/s]

2.3888759795066172
21.421753715409068


 30%|█████████████████████                                                 | 30001/100000 [5:25:33<12:38:32,  1.54it/s]

5.459143749691947
22.368266206743538


 31%|█████████████████████▋                                                | 31001/100000 [5:36:43<17:56:17,  1.07it/s]

4.196447584557267
27.89557215508347


 32%|██████████████████████▍                                               | 32000/100000 [5:48:01<18:07:38,  1.04it/s]

3.359363404938451
26.547514960790426


 33%|███████████████████████                                               | 33001/100000 [5:58:50<10:55:04,  1.70it/s]

2.8435949034031287
22.673671791850253


 34%|███████████████████████▊                                              | 34001/100000 [6:09:57<11:31:39,  1.59it/s]

2.7877782493078542
25.516372344409998


 35%|████████████████████████▌                                             | 35001/100000 [6:20:47<12:24:02,  1.46it/s]

2.186353660195376
25.478619677366765


 36%|█████████████████████████▏                                            | 36001/100000 [6:31:27<13:29:40,  1.32it/s]

3.9255786658636516
26.13923819047111


 37%|█████████████████████████▉                                            | 37001/100000 [6:42:27<12:25:18,  1.41it/s]

4.611538186109028
27.497380416090053


 38%|██████████████████████████▉                                            | 38001/100000 [6:53:30<9:54:18,  1.74it/s]

2.6776906481702025
24.96668536669774


 39%|███████████████████████████▏                                          | 38829/100000 [7:02:42<11:05:56,  1.53it/s]
