In [1]:
import numpy as np
import gym
# !pip install -e gym-env

In [2]:
np.random.seed(5)
env = gym.make("gym_env:firp-v0")##This is our environment

In [3]:
assert env.action_space.n==8
assert env.observation_space.n == 9

In [4]:
initial_bush_rewards = [70,70,70,70,0,0,0,0]
repl_rate = [2,2,2,2,0,0,0,0]

env.set_env(initial_bush_rewards, repl_rate)

In [5]:
# TEST CASE - 1 : Checking initial setup of environment is correct

def test_case_1(env):

    env.reset()

    assert env.get_curr_state()[0] == 8
    assert env.get_curr_state()[1] == 0
    assert np.all(env.get_curr_state()[2] == np.array(initial_bush_rewards, dtype=np.float32))
    assert env.get_curr_state()[3] == 0
    
    print('TEST CASE - 1 : PASSED')

In [6]:
#TEST CASE - 2 : Checking time taken for each movement

def test_case_2(env):
    env.reset()
    prev_time = 300
    for i in range(0,8):
        for j in range(0,8):
            if(i==j):
                continue
            env.reset()
            prev_time = 300
            
            s_new, r, t, _ = env.step(i)
            assert s_new[0]==i
            assert r == 0
            assert s_new[1] == prev_time-(1/(2*np.sin(np.pi/8)))
            prev_time = s_new[1]
            
            s_new, r, t, _ = env.step(j)
            assert s_new[0] == j
            assert r == 0 
            if(abs(i-j)==1 or 8-abs(i-j)==1):
                assert s_new[1] == prev_time-1
            elif(abs(i-j)==2 or 8-abs(i-j)==2):
                assert s_new[1] == prev_time-((2*np.sin(np.pi/4))/(2*np.sin(np.pi/8)))
            elif(abs(i-j)==3 or 8-abs(i-j)==3):
                assert s_new[1] == prev_time-((2*np.sin(3*np.pi/8))/(2*np.sin(np.pi/8)))
            elif(abs(i-j)==4 or 8-abs(i-j)==4):
                assert s_new[1] == prev_time-((2*np.sin(np.pi/2))/(2*np.sin(np.pi/8)))
            
    print('TEST CASE - 2 : PASSED')

In [7]:
#TEST CASE - 3 : Checking harvesting reward and replenishment on rewarding bush


def test_case_3(env):
    env.reset()
    
    _ = env.step(0)
    s_new,r,t,_ = env.step(0)
    assert r == 63
    assert s_new[0] == 0
    assert s_new[1] == 300-1-(1/(2*np.sin(np.pi/8)))
    assert np.all(env.get_curr_state()[2] == np.array([63,72,72,72,0,0,0,0], dtype=np.float32))
    
    s_new,r,t,_ = env.step(0)
    assert r == int(63*0.9)
    assert s_new[0] == 0
    assert s_new[1] == 300-2-(1/(2*np.sin(np.pi/8)))
    assert np.all(env.get_curr_state()[2] == np.array([56,74,74,74,0,0,0,0], dtype=np.float32))

    print('TEST CASE - 3 : PASSED')

In [8]:
#TEST CASE - 4 : Checking harvesting reward and replenishment on non-rewarding bush


def test_case_4(env):
    env.reset()
    
    _ = env.step(4)
    s_new,r,t,_ = env.step(4)
    assert r == 0
    assert s_new[0] == 4
    assert s_new[1] == 300-1-(1/(2*np.sin(np.pi/8)))
    assert np.all(env.get_curr_state()[2] == np.array([70,70,70,70,0,0,0,0], dtype=np.float32))
    
    print('TEST CASE - 4 : PASSED')

In [9]:
# TEST CASE - 5 : Checking termination and replenishment limits

def test_case_5(env):
    env.reset()
    
    s_new,r,t,_ = env.step(0)
    
    for i in range(299):
        s_new,r,t,_ = env.step(0)
        
        if(i >= 64):
            assert np.all(env.get_curr_state()[2][1:] == np.array([200,200,200,0,0,0,0], dtype=np.float32))
        
        if(i==298):
            assert t == True
            assert s_new[1]==0
        else:
            assert t==False
            assert s_new[1]>0
        assert s_new[0] == 0
    
    print('TEST CASE - 5 : PASSED')

In [10]:
test_case_1(env)
test_case_2(env)
test_case_3(env)
test_case_4(env)
test_case_5(env)

TEST CASE - 1 : PASSED
TEST CASE - 2 : PASSED
TEST CASE - 3 : PASSED
TEST CASE - 4 : PASSED
TEST CASE - 5 : PASSED
