In [10]:
%matplotlib widget

import math
import random
import time
import numpy as np
import matplotlib.pyplot as plt

# Import Tensorflow 2.0
import tensorflow as tf

from time import sleep

from ipycanvas import Canvas, RoughCanvas, hold_canvas

from scipy.stats import truncnorm

gpus = tf.config.experimental.list_physical_devices('GPU')

if gpus:
    for gpu in gpus:
        print(gpu)
        tf.config.experimental.set_memory_growth(gpu, True)
        
#!jupyter labextension list

PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')


In [19]:
[(random.randrange(50,200)), (random.randrange(500,700))]

[54, 557]

In [25]:
class Env():
    
    def __init__(self, num_steps):
        self.num_steps = num_steps
        self.num_trees = 2
        self.reset()

    def reset(self):
        
        self.trees = [(random.randrange(100,300)), (random.randrange(400,600))]
        self.fruits = [(90.,20),(105.,10)] #array of position, remaining lifetime
        self.agent = random.randrange(0., 800.)
        self.countfruits = 0
        self.score = 0.
        self.timestep = 0
        return self.obs()

    def obs(self):
        o = [(xt - self.agent) for (xt) in self.trees]
        return o
    
        
    def step(self, action):

        self.timestep += 1
        
        #remove old fruits
        self.fruits = [(x,t-1) for (x,t) in self.fruits if t>0]

        #add new ones
        for (xt) in self.trees:
            if random.random()>.9:
                x = xt + random.randrange(-50,50,1)
                t = random.randrange(50,100)
                self.fruits.append((x,t))

        #move the agent
        direction = -1 if action==0 else +1
        distance = direction * random.randint(1,5)
        self.agent += distance  #change the position
        
        #calculate the reward
        
        self.score -= abs(distance)  #consume energy, negative reward, for moving
        
        found = [i for i in range(len(self.fruits)) if self.fruits[i][0]==self.agent]
        self.score += 100 * len(found) #it is possible that several fruits have the same location
        self.countfruits += len(found)
        self.fruits = [f for i,f in enumerate(self.fruits) if i not in found]
        
        o = self.obs()
        reward = self.score
        terminated = True if self.timestep>=self.num_steps else False
        if self.agent<0 or self.agent>1000:
            terminated = True
        
        return o, reward, terminated   

    
    def init_canvas(self):
        #canvas = Canvas(width=1000, height=200)
        self.canvas = RoughCanvas(width=1000, height=200)
        self.canvas.font = "10px serif"
        display(self.canvas)        
        
        
    def update_canvas(self, sleeptime=0.02):
        #draw the scene
        with hold_canvas():
            # Clear the old animation step
            self.canvas.clear()

            y = 100
            size = 5

            self.canvas.stroke_text("time:%d"%self.timestep, 10, 10)
            #self.canvas.stroke_text("#fruits:%d"%len(self.fruits), 10, 30)
            self.canvas.stroke_text("#score:%d"%self.score, 10, 30)
            self.canvas.stroke_text("#found:%d"%self.countfruits, 10, 50)

            self.canvas.stroke_style = "blue"
            for (x) in self.trees:
                self.canvas.stroke_rect(x, y, size, size)

            self.canvas.fill_style = "red"
            xs = [x for (x,t) in self.fruits]
            ys = [100]*len(xs)
            self.canvas.fill_circles(xs, ys, size) #use vectorized version

            self.canvas.stroke_style = "green"
            self.canvas.stroke_rect(self.agent, y, size, size)


        # Animation frequency ~50Hz = 1./50. seconds
        if sleeptime>0:
            sleep(sleeptime)


    def play(self, model):
        self.init_canvas()
        terminated=False
        obs = self.reset()
        while not terminated:
            action = choose_action(model, obs)
            obs, reward, terminated = self.step(action)
            self.update_canvas(sleeptime=0)        
            
e = Env(500)

e.init_canvas()
terminated=False
obs = e.reset()
while not terminated:
    
    #policy:
    xt = obs[0]
    if abs(xt)>50:
        action = 1 if xt>0 else 0
    else:
        action = random.randint(0,1)
    
    #action = random.randint(0,1)

    obs, reward, terminated = e.step(action)
    e.update_canvas()
     

RoughCanvas(height=200, width=1000)

In [26]:
def create_rl_model(n_actions):
    
    model = tf.keras.models.Sequential([
        tf.keras.layers.Dense(units=32, activation='relu'),
        tf.keras.layers.Dense(units=n_actions, activation=None)  #returns logits (un-normalized log-proba for each action)
    ])
    return model

#a logbook to remember observations, actions and rewards for an entire episode
class Memory:
    def __init__(self):
        self.clear()
    
    def clear(self):
        self.observations=[]
        self.actions=[]
        self.rewards=[]
        
    def add_to_memory(self, new_obs, new_action, new_reward):
        self.observations.append(new_obs)
        self.actions.append(new_action)
        self.rewards.append(new_reward)
        
    def __len__(self):
        return len(self.actions)

def choose_action(model, observation, single=True):

    # add batch dimension to the observation if only a single example was provided
    observation = np.expand_dims(observation, axis=0) if single else observation
    logits = model.predict(observation, verbose=0)
    action = tf.random.categorical(logits, num_samples=1)  #randomly pick an action - tf's categorical takes unornmalized log proba as input
    action = action.numpy().flatten()
    return action[0] if single else action

def normalize(x):
    x -= np.mean(x)
    x /= np.std(x)
    return x.astype(np.float32)

def discount_rewards(rewards, gamma=0.95):
    discounted_rewards = np.zeros_like(rewards)
    R=0
    for t in reversed(range(0, len(rewards))):
        R = R*gamma + rewards[t]
        discounted_rewards[t]=R
    return normalize(discounted_rewards)

def compute_loss(logits, actions, rewards):
    neg_logprob = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=actions)
    loss = tf.reduce_mean(neg_logprob*rewards)
    return loss

def train_step(model, loss_function, optimizer, observations, actions, discounted_rewards):

    with tf.GradientTape() as tape:
        prediction = model(observations)
        loss = loss_function(prediction, actions, discounted_rewards)
    
    grads = tape.gradient(loss, model.trainable_variables)
    
    grads, _ = tf.clip_by_global_norm(grads, 2.0)
    
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

In [27]:

env = Env(500) #each episode lasts 500 time steps
model = create_rl_model(n_actions=2)

env.play(model)


RoughCanvas(height=200, width=1000)

In [28]:
memory = Memory()

learning_rate = 0.001
optimizer = tf.keras.optimizers.Adam(learning_rate)

start = time.time()

for i_episode in range(500):
    obs = env.reset()
    memory.clear()
    terminated = False
    
    #run the episode, keeping the model constant
    while not terminated:
        action = choose_action(model, obs)
        next_obs, reward, terminated = env.step(action)
        memory.add_to_memory(obs, action, reward)
        obs = next_obs
        
    train_step(model, compute_loss, optimizer, 
               observations=np.vstack(memory.observations),
               actions=np.array(memory.actions),
               discounted_rewards = discount_rewards(memory.rewards))
    
    end = time.time()
    print(i_episode, "time remaining:", (100-i_episode+1)*(end-start)/(i_episode+1))


0 time remaining: 2386.06121134758
1 time remaining: 2338.997185230255
2 time remaining: 2311.944427728653
3 time remaining: 2285.207857966423
4 time remaining: 2265.1178436279297
5 time remaining: 2240.4429054260254
6 time remaining: 2212.91021500315
7 time remaining: 2188.6616246700287
8 time remaining: 2163.8304739793143
9 time remaining: 2142.252578449249
10 time remaining: 2119.2137414325366
11 time remaining: 2094.7735887765884
12 time remaining: 2071.2750774163464
13 time remaining: 2047.0651490347725
14 time remaining: 2026.1409093856812
15 time remaining: 2003.024432182312
16 time remaining: 1979.08771276474
17 time remaining: 1955.931899547577
18 time remaining: 1932.1411768135272
19 time remaining: 1910.6754984617232
20 time remaining: 1886.751897607531
21 time remaining: 1864.0048937364058
22 time remaining: 1841.8855294455652
23 time remaining: 1819.1103928685188
24 time remaining: 1800.37176197052
25 time remaining: 1777.8491342984712
26 time remaining: 1756.376940674252


214 time remaining: -2715.092967311726
215 time remaining: -2739.4723413520387
216 time remaining: -2763.926355119125
217 time remaining: -2788.353953466503
218 time remaining: -2812.8838807915986
219 time remaining: -2837.904654669762
220 time remaining: -2862.5375966108763
221 time remaining: -2887.21699907973
222 time remaining: -2911.6722617320415
223 time remaining: -2936.512944259814
224 time remaining: -2961.214999748866
225 time remaining: -2986.02282229989
226 time remaining: -3012.571178212565
227 time remaining: -3038.0485525131226
228 time remaining: -3064.4187342391783
229 time remaining: -3090.0831113068953
230 time remaining: -3115.528146508452
231 time remaining: -3141.4117237514465
232 time remaining: -3167.0571638678275
233 time remaining: -3193.8203303019204
234 time remaining: -3220.182088727139
235 time remaining: -3246.732448680926
236 time remaining: -3272.759222954134
237 time remaining: -3298.845045498439
238 time remaining: -3324.725176512946
239 time remainin

424 time remaining: -8304.90921954155
425 time remaining: -8332.718985376223
426 time remaining: -8361.22258851344
427 time remaining: -8388.275043352742
428 time remaining: -8415.851487044687
429 time remaining: -8443.259472208245
430 time remaining: -8470.635639399221
431 time remaining: -8498.354035814604
432 time remaining: -8526.432738894953
433 time remaining: -8555.851794306584
434 time remaining: -8586.247493485747
435 time remaining: -8615.304719343098
436 time remaining: -8644.434294056839
437 time remaining: -8671.485271218706
438 time remaining: -8698.574691367312
439 time remaining: -8725.73242525946
440 time remaining: -8753.519071854702
441 time remaining: -8781.055355805616
442 time remaining: -8809.652688213719
443 time remaining: -8838.760991276922
444 time remaining: -8868.757143055454
445 time remaining: -8898.613271580683
446 time remaining: -8928.136695975425
447 time remaining: -8956.058671708617
448 time remaining: -8983.667656706277
449 time remaining: -9011.39

In [33]:
model_file = 'rl_'+time.strftime("%Y%m%d-%H%M%S")
print(model_file)
model.save(model_file)

rl_20221218-161557
INFO:tensorflow:Assets written to: rl_20221218-161557\assets


In [32]:
env = Env(500)
env.play(model)

RoughCanvas(height=200, width=1000)