In [2]:
import random
import numpy as np
from utils.test_env import EnvTest


class LinearSchedule(object):
    
    def __init__(self, eps_begin, eps_end, nsteps):
        
        self.epsilon = eps_begin
        self.eps_begin = eps_begin
        self.eps_end = eps_end
        self.nsteps = nsteps

    def update(self, t: int):
        
        rng_eps = np.linspace(self.eps_begin, self.eps_end, int(self.nsteps+1))
    
        if t >= self.nsteps:
            self.epsilon = rng_eps[int(self.nsteps)]
        else:
            self.epsilon = rng_eps[t]
            


class LinearExploration(LinearSchedule):
    def __init__(self, env, eps_begin, eps_end, nsteps):
        self.env = env
        super(LinearExploration, self).__init__(eps_begin, eps_end, nsteps)

        
    def get_action(self, best_action: int) -> int:

        prob = np.random.rand()
        if prob <= self.epsilon:
            action = random.randint(0,self.env.num_actions()-1)
            return action
        else:
            return best_action


def test1():
    env = EnvTest((5, 5, 1))
    exp_strat = LinearExploration(env, 1, 0, 10)

    found_diff = False
    for i in range(10):
        rnd_act = exp_strat.get_action(0)
        if rnd_act != 0 and rnd_act is not None:
            found_diff = True

    assert found_diff, "Test 1 failed."
    print("Test1: ok")


def test2():
    env = EnvTest((5, 5, 1))
    exp_strat = LinearExploration(env, 1, 0, 10)
    exp_strat.update(5)
    assert exp_strat.epsilon == 0.5, "Test 2 failed"
    print("Test2: ok")


def test3():
    env = EnvTest((5, 5, 1))
    exp_strat = LinearExploration(env, 1, 0.5, 10)
    exp_strat.update(20)
    assert exp_strat.epsilon == 0.5, "Test 3 failed"
    print("Test3: ok")


def your_test():
    """
    Use this to implement your own tests if you'd like (not required)
    """
    pass


if __name__ == "__main__":
    test1()
    test2()
    test3()
    your_test()


ModuleNotFoundError: No module named 'utils'