In [None]:
# !pip install gym
# !apt install freeglut3-dev -y
# !pip3 install tensorflow
# !apt-get install -y ffmpeg
# !apt-get install -y python3-opengl
# !pip3 install box2d-py

In [None]:
import warnings
warnings.filterwarnings('ignore')

import random
import os
import gym
import numpy as np
from collections import deque

from DQN import DQN
import torch
import torch.nn.functional as F

In [None]:
ENV_NAME = "BipedalWalker-v3"
GAMMA = 0.95
LEARNING_RATE = 0.001

MEMORY_SIZE = 1000000
BATCH_SIZE = 150

EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.01
EXPLORATION_DECAY = 0.995

In [None]:
# env = gym.make(ENV_NAME)
# from gym import envs
# for i in envs.registry.all():
#     print(i)

In [None]:
class DQNSolver:
    def __init__(self, observation_space, action_space, base = True):
        self.exploration_rate = EXPLORATION_MAX

        self.action_space = action_space
        self.memory = deque(maxlen=MEMORY_SIZE)

        self.policy_net = DQN(observation_space, action_space, base)
#         files = [int(file.split('_')[-1].replace('.pth','')) for file in os.listdir() if '.pth' in file]
#         if files!= []:
#             file = 'best_'+str(max(files))+'.pth'
#             if file in os.listdir():
#                 self.policy_net.load_state_dict(torch.load(file))
#                 print('Weights loaded!')
            
        self.target_net = DQN(observation_space, action_space, base)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
                
        self.loss_fn = torch.nn.MSELoss()
        self.optimizer = torch.optim.Adam(
            params=self.policy_net.parameters(), lr=LEARNING_RATE)
        
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if np.random.rand() < self.exploration_rate:
            return random.randrange(self.action_space)
        q_values = self.policy_net(state)
        return torch.argmax(q_values).item()
    
    def optimize_model(self):
        if len(self.memory) < BATCH_SIZE:
            return

        batch = random.sample(self.memory, BATCH_SIZE)
        non_final_mask = ~torch.stack([b[4] for b in batch])
        non_final_next_states = torch.stack([b[3] for b in batch if not b[4]])

        state_batch  = torch.stack([state[0]  for state  in batch])
        action_batch = torch.stack([action[1] for action in batch]).reshape(BATCH_SIZE,1)
        reward_batch = torch.stack([reward[2] for reward in batch])

        state_action_values = self.policy_net(state_batch).gather(1, action_batch)
        next_state_values = torch.zeros(BATCH_SIZE)
        next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0]

        expected_state_action_values = (next_state_values * GAMMA) + reward_batch
        loss = self.loss_fn(state_action_values, expected_state_action_values.unsqueeze(1))

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.exploration_rate *= EXPLORATION_DECAY
        self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)

def cartpole():
    env = gym.make(ENV_NAME)
    # !!! set observation_space, action_space accuratno
    observation_space = env.observation_space.shape[0]
    action_space = env.action_space.shape[0]
    dqn_solver = DQNSolver(observation_space, action_space,True)
    run = 0
    max_rez = 0
        
    while True:
        run += 1
        state = env.reset()
        state = torch.tensor(data = state, dtype = torch.float32)
        step = 0
        while True:
            step += 1
            env.render()
            action = dqn_solver.act(state)
            state_next, reward, terminal, info = env.step(action)
            reward = reward if not terminal else -reward

            action     = torch.tensor(data = action)
            reward     = torch.tensor(data = reward,     dtype = torch.float32)
            state_next = torch.tensor(data = state_next, dtype = torch.float32)
            terminal   = torch.tensor(data = terminal,   dtype = torch.bool)
            dqn_solver.remember(state, action, reward, state_next, terminal)
            state = state_next
            # print('state = ', state)
            if terminal:
                if step>=max_rez:
                    print("Run: " + str(run) + ", exploration: " + 
                          str(dqn_solver.exploration_rate) + ", score: " + str(step))
                    max_rez = step
                    if step>=450:
                        torch.save(dqn_solver.policy_net.state_dict(), 'best_'+str(step)+'.pth')
                        dqn_solver.target_net.load_state_dict(dqn_solver.policy_net.state_dict())
                break
            dqn_solver.optimize_model()
            
if __name__ == "__main__":
    cartpole()

In [None]:
env.close()