In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  return torch._C._cuda_getDeviceCount() > 0


In [9]:
device

device(type='cpu')

In [79]:
import numpy as np
import random
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import json
import torch.nn.functional as F



class Qfunction(nn.Module):
    def __init__(self, state_dim, action_dim, n_atoms):
        super().__init__()
        self.model = nn.Sequential(
                nn.Linear(state_dim + action_dim, 64),
                nn.ReLU(),
                nn.Linear(64, 64),
                nn.ReLU(),
                nn.Linear(64, n_atoms),
                nn.Softmax(dim=1)
        )

    def forward(self, states, actions):
        concat_input = torch.cat((states, actions), dim=1)
        reward_distr = self.model(concat_input)
        return reward_distr


class DQN_double:
    def __init__(self,
                 state_dim,
                 action_dim,
                 gamma=0.95,
                 lr=1e-3,
                 batch_size=64,
                 epsilon_decrease=0.002,
                 epsilon_min=0.0001,
                 period=100,
                 n_atoms=51,
                 v_min=-10,
                 v_max=10
                ):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.q_function = Qfunction(state_dim, action_dim, n_atoms).to(device)
        self.freezing_q_function = Qfunction(state_dim, action_dim, n_atoms).to(device)
        self.freezing_q_function.load_state_dict(self.q_function.state_dict())
        self.gamma = gamma
        self.batch_size = batch_size
        self.epsilon = 1.0
        self.epsilon_decrease = epsilon_decrease
        self.epsilon_min = epsilon_min
        self.memory = []
        self.optimzaer = torch.optim.Adam(self.q_function.parameters(), lr=lr)
        self.counter = 1
        self.period = period
        
        self.n_atoms = n_atoms
        self.v_min = v_min
        self.v_max = v_max
        self.atoms = self.get_atoms()

    def get_atoms(self):
        atoms = np.zeros((self.n_atoms,))
        for i in range(self.n_atoms):
            atoms[i] = self.v_min + i * (self.v_max - self.v_min) / (self.n_atoms - 1.0)
        return atoms


    def get_action(self, obs, show=False):
        with torch.no_grad():
            one_hot_action = np.zeros((self.action_dim,))
            max_action = 0
            max_q_value = -np.inf
            for i in range(self.action_dim):
                one_hot_action[i] = 1
                state = torch.FloatTensor(obs).unsqueeze(dim=0).to(device)
                action = torch.FloatTensor(one_hot_action).unsqueeze(dim=0).to(device)
                # print(f"state: {state.shape}")
                # print(f"action: {action.shape}")
                reward_distr = self.q_function(state, action)
                reward_distr = reward_distr.squeeze().cpu().numpy()
                q_value = np.sum(self.atoms * reward_distr)
                if show:
                    print(f"state: {obs}")
                    print(f"action: {i}")
                    print(f"q_value: {q_value}")
                if q_value > max_q_value:
                    max_q_value = q_value
                    max_action = i
                one_hot_action[i] = 0
                
            probs = self.epsilon * np.ones(self.action_dim) / self.action_dim
            probs[max_action] += 1.0 - self.epsilon
            if show:
                print(f"probs: {probs}")
            if self.counter % 500 == 0:
                print(probs)
            action = np.random.choice(np.arange(self.action_dim), p=probs)
            return action

    def project(self, values, distr):
        j = 0
        new_target_distr = np.zeros((self.n_atoms, ))
        for i in range(self.n_atoms):
            if values[i] <= self.atoms[0]:
                new_target_distr[0] += distr[i]
            elif values[i] >= self.atoms[-1]:
                new_target_distr[-1] += distr[i]
            else:
                while values[i] > self.atoms[j]:
                    j += 1
                new_target_distr[j - 1] += distr[i] * (self.atoms[j] - values[i]) / (self.atoms[j] - self.atoms[j - 1]) 
                new_target_distr[j] += distr[i] * (values[i] - self.atoms[j - 1]) / (self.atoms[j] - self.atoms[j - 1])

        return new_target_distr
        
    
    def fit(self, state, action, reward, done, next_state, show=False):
        self.memory.append([state, action, reward, int(done), next_state])

        if len(self.memory) > 10 * self.batch_size:

            if self.counter % self.period == 0:
                self.freezing_q_function.load_state_dict(self.q_function.state_dict())
            
            self.counter += 1

            batch = random.sample(self.memory, self.batch_size)

            states, actions, rewards, dones, next_states = map(torch.FloatTensor, list(zip(*batch)))
            states, actions, rewards, dones, next_states = states.to(device), actions.unsqueeze(1).to(device), rewards.unsqueeze(dim=1), dones.unsqueeze(dim=1) , next_states.to(device)

            with torch.no_grad():
                one_hot_actions = np.zeros((self.batch_size, self.action_dim))
                q_values = None
                for i in range(self.action_dim):
                    one_hot_actions[:, i] = 1
                    next_actions = torch.FloatTensor(one_hot_actions).to(device)
                    reward_distrs = self.freezing_q_function(next_states, next_actions)
                    reward_distrs = reward_distrs
                    q_value = torch.sum(torch.tensor(self.atoms) * reward_distrs, dim=1, keepdim=True)
                    q_values = q_value if q_values is None else torch.cat((q_values, q_value), dim=1)
                    one_hot_actions[:, i] = 0

                    
                argmax_actions = torch.argmax(q_values, dim=1)
                if show:
                    print(f"q_values: {q_values}")
                    print(f"argmax_actions: {argmax_actions}")
                
                one_hot_actions[np.arange(self.batch_size), argmax_actions] = 1
                max_one_hot_actions = torch.FloatTensor(one_hot_actions)
                reward_distrs = self.freezing_q_function(next_states.float(), max_one_hot_actions).cpu().numpy()
                atoms = torch.tensor(self.atoms).repeat(self.batch_size, 1)
                targets = rewards + self.gamma * (1 - dones) * atoms
                # print(f"reward_distrs: {reward_distrs}")
                # print(f"targets: {targets}")
                # print(reward_distrs.shape)
                projected_targets = [self.project(targets[k], reward_distrs[k]) for k in range(self.batch_size)]
                projected_targets = torch.FloatTensor(projected_targets).to(device)
                if show:
                    print(f"atoms: {atoms}")
                    print(f"targets: {targets}")
                    print(f"projected_targets: {projected_targets}")
                # print(f"projected: {projected_targets}")
                

            one_hot_actions = np.zeros((self.batch_size, self.action_dim))
            # была ошибка здесь. Отдавал вместо actions - argmax_actions.
            one_hot_actions[np.arange(self.batch_size), actions.long()] = 1
            one_hot_actions = torch.FloatTensor(one_hot_actions)
            distrs = self.q_function(states.float(), one_hot_actions)
            loss = - torch.mean(torch.sum(projected_targets.detach() * torch.log(distrs + 1e-8), dim=1, keepdim=True))

            self.optimzaer.zero_grad()
            loss.backward()
            self.optimzaer.step()

            if show:
                print(f"cross-entorpys: {projected_targets.detach() * torch.log(distrs + 1e-8)}")
                print(f"loss: {loss}")

    def decrease_params(self):
        self.epsilon = max(self.epsilon - self.epsilon_decrease, self.epsilon_min)
            


In [75]:
def DQN_learning(env, agent, episode_n = 100, t_max = 500):
    # agent.epsilon_decrease = 1.0 / (0.75*episode_n)

    total_rewards = []
    for episode in range(episode_n):
        total_reward = 0

        show = False
        if episode % 50 == 0:
            show = False
                
        state, info = env.reset()
        for t in range(t_max):
                
            action = agent.get_action(state, show)

            next_state, reward, terminated, truncated, info = env.step(action)
    
            total_reward += reward

            if show:
                print(f"iteration: {t}")
            
            agent.fit(state, action, reward, terminated or truncated, next_state, show)
            
    
            state = next_state
    
            if terminated or truncated:
                break
                
        total_rewards.append(total_reward)
        if episode % 10 == 0:
            print(f'episode: {episode}, total_reward: {np.mean(total_rewards[-10:])}')
            
        agent.decrease_params()
    
    return total_rewards

In [21]:
### Проблема при реализации была в том, что в модели добавились новые слои, а я эти слои не копировал, а копировал только общий позвоночник модели как в старой версии.

In [58]:
agent.epsilon

-8.81239525796218e-16

In [87]:
import gym
from gym.wrappers import TransformReward

torch.manual_seed(43)

v_min = 1
v_max = 100
n_atoms = 51

env = gym.make('CartPole-v1')
# env = TransformReward(env, lambda r: v_min + r * (v_max - v_min) / 500.0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
print(action_dim)

episode_n=3000


agent = DQN_double(state_dim, action_dim, n_atoms=n_atoms, v_min=v_min, v_max=v_max)

total_rewards = DQN_learning(env, agent, episode_n=episode_n)


2
episode: 0, total_reward: 10.0
episode: 10, total_reward: 24.4
episode: 20, total_reward: 21.4
episode: 30, total_reward: 26.2
[0.465 0.535]
episode: 40, total_reward: 25.3
[0.45 0.55]
episode: 50, total_reward: 35.2
episode: 60, total_reward: 24.0
[0.568 0.432]
episode: 70, total_reward: 41.4
episode: 80, total_reward: 29.8
[0.418 0.582]
episode: 90, total_reward: 35.2
[0.405 0.595]
episode: 100, total_reward: 32.1
episode: 110, total_reward: 33.4
[0.611 0.389]
episode: 120, total_reward: 50.8
[0.621 0.379]
episode: 130, total_reward: 42.3
[0.632 0.368]
episode: 140, total_reward: 54.3
[0.642 0.358]
[0.35 0.65]
episode: 150, total_reward: 56.9
[0.34 0.66]
episode: 160, total_reward: 50.5
episode: 170, total_reward: 28.6
[0.675 0.325]
episode: 180, total_reward: 37.4
episode: 190, total_reward: 24.0
[0.693 0.307]
episode: 200, total_reward: 35.8
[0.292 0.708]
episode: 210, total_reward: 32.4
[0.719 0.281]
episode: 220, total_reward: 42.8
[0.729 0.271]
episode: 230, total_reward: 62.8

KeyboardInterrupt: 