In [None]:
import torch 
import numpy as np

import flappy_bird_gymnasium
import gymnasium

from deepq_agent import DQNAgent_pytorch

from gymnasium.wrappers import FlattenObservation

In [None]:
TARGET_UPDATE = 10
DEVICE = 'cuda' #torch.device("cuda" if torch.cuda.is_available() else "cpu")
LR = 1e-4
GAMMA = 0.99
EPS = 0
EPS_DECAY = 0.999
EPS_END = 0
BATCH_SIZE = 128
PLAY_MEMORY = 10000
LAYERS_SIZES = [32, 32]


EPOCHS = 1000000

In [None]:
env = gymnasium.make("FlappyBird-v0")
state,_ = env.reset()

In [None]:
#get size of observation space
obs_space = 11 #len(state) #overriden
act_space = env.action_space.n

obs_space, act_space

(11, 2)

In [None]:
agent = DQNAgent_pytorch(
        device=DEVICE,
        act_space=act_space,
        obs_space=obs_space,
        training_batch_size=BATCH_SIZE,
        learn_rate=LR,
        gamma=GAMMA,
        eps=EPS,                                                               #rate of exploration
        eps_decay_rate=EPS_DECAY,                                                   
        eps_floor=EPS_END,                                                       
        network_shape=LAYERS_SIZES,
        pmem_buffer_size=PLAY_MEMORY
    )

In [None]:
agent.load("model.pt")

In [None]:
env = gymnasium.make("FlappyBird-v0",render_mode='human')

env = FlattenObservation(env)

epochs_since_jump = 0
for i in range(EPOCHS):
    total_reward = 0
    state, _ = env.reset()
    state = torch.tensor(state[:-1], dtype=torch.float32, device=DEVICE).unsqueeze(0)
    terminated = False

    while not terminated:
        action = agent.get_action(state)
        real_action = action.item()
        if real_action == 1:
            if epochs_since_jump < 5:
                real_action = 0
                epochs_since_jump += 1
            else: 
                epochs_since_jump = 0
        else:
            epochs_since_jump += 1
        new_state, reward, terminated, truncated, info = env.step(real_action)
        if terminated or truncated: #or (new_state[9] < 0 or new_state[9] > 1):
            
            terminated = True
            
            #reset the environment
            state, _ = env.reset()
            state = torch.tensor(state[:-1], dtype=torch.float32, device=DEVICE).unsqueeze(0)
        else:
            #reward = 1
            pass
        
        total_reward += reward

        new_state = torch.tensor(new_state[:-1], dtype=torch.float32, device=DEVICE).unsqueeze(0)
        reward = torch.tensor([reward], device=DEVICE)
        #agent.remember(state, action, new_state, reward)
        #agent.train()

        state = new_state
    print("Iteration: ", i, "Total reward: ", total_reward)





Iteration:  0 Total reward:  8.99999999999998
Iteration:  1 Total reward:  8.99999999999998
Iteration:  2 Total reward:  8.99999999999998
Iteration:  3 Total reward:  8.99999999999998
Iteration:  4 Total reward:  8.99999999999998
Iteration:  5 Total reward:  8.99999999999998
Iteration:  6 Total reward:  2.1000000000000014
Iteration:  7 Total reward:  2.1000000000000014
Iteration:  8 Total reward:  2.1000000000000014
Iteration:  9 Total reward:  2.1000000000000014
Iteration:  10 Total reward:  2.1000000000000014
Iteration:  11 Total reward:  2.3000000000000016
Iteration:  12 Total reward:  2.3000000000000016
Iteration:  13 Total reward:  2.3000000000000016
Iteration:  14 Total reward:  2.3000000000000016
Iteration:  15 Total reward:  8.99999999999998
Iteration:  16 Total reward:  8.99999999999998
Iteration:  17 Total reward:  8.99999999999998
Iteration:  18 Total reward:  8.99999999999998
Iteration:  19 Total reward:  8.99999999999998
Iteration:  20 Total reward:  8.99999999999998
Itera