In [None]:
#!pip install -q "gymnasium[atari, accept-rom-license]"

**Импорт библиотек**

In [1]:
import gymnasium as gym
import matplotlib.pyplot as plt
import glob
import io
import base64
from IPython import display as ipythondisplay
from IPython.display import HTML
import matplotlib.pyplot as pl
import torch
import torch.nn as nn
from collections import deque
import numpy as np
from IPython.display import clear_output

In [2]:
env = gym.make('PongNoFrameskip-v4')
env = gym.wrappers.AtariPreprocessing(env, noop_max=30, \
frame_skip=4, screen_size=84, terminal_on_life_loss=False, \
 grayscale_obs=True, grayscale_newaxis=False, scale_obs=True)
env = gym.wrappers.FrameStack(env, 4)
n_states  = env.observation_space.shape
n_actions = env.action_space.n
print(f"состояний: {n_states} действий: {n_actions}")

состояний: (4, 84, 84) действий: 6


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [12]:
def show_progress(rewards_batch, log):
    """
    Удобная функция, которая отображает прогресс обучения.
    """
    mean_reward = np.mean(rewards_batch)
    log.append(mean_reward)

    clear_output(True)
    plt.figure(figsize=[8, 4])
    plt.subplot(1, 2, 1)
    plt.plot(log, label='Mean rewards')
    plt.legend(loc=4)
    plt.show()

**Нейронная сеть**

In [4]:
class DQN(nn.Module):
    def __init__(self, input_shape, num_of_actions):
        super(DQN, self).__init__()
        self.conv_nn = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.BatchNorm2d(64),
            nn.ReLU() 
        )      
        cnn_output_shape = self.conv_nn(torch.zeros(1, *input_shape))
        cnn_output_shape = int(np.prod(cnn_output_shape.size()))
        
        self.linear_nn = nn.Sequential(
            nn.Linear(cnn_output_shape, 512),
            nn.ReLU(),
            nn.Linear(512, num_of_actions)
        ) 
    def forward(self, x):
        batch_size = x.size()[0] 
        cnn_output = self.conv_nn(x).view(batch_size, -1)        
        return self.linear_nn(cnn_output)

In [6]:
def select_greedy_action(moving_nn, obs):
  
    tensor_obs = torch.tensor(np.array([obs])).to(device)
    all_actions = moving_nn(tensor_obs)
    return all_actions.max(1)[1].item()

def select_action_eps_greedy(env, network, obs, epsilon):
    rand_num = np.random.rand()
    if epsilon > rand_num:
        
        return env.action_space.sample()
    else:
       
        return select_greedy_action(network, obs)

In [7]:
def compute_td_loss(
        network, states, actions, rewards, next_states, is_done, gamma=0.99, check_shapes=False, regularizer=.1
):
    """ Считатет td ошибку, используя лишь операции фреймворка torch. Используйте формулу выше. """
    
    # переводим входные данные в тензоры
    states = torch.tensor(np.array(states), dtype=torch.float32).to(device)    # shape: [batch_size, state_size]
    actions = torch.tensor(actions, dtype=torch.long)     # shape: [batch_size]
    rewards = torch.tensor(rewards, dtype=torch.float32).to(device)  # shape: [batch_size]
    
    next_states = torch.tensor(np.array(next_states), dtype=torch.float32).to(device) # shape: [batch_size, state_size]
    is_done = torch.tensor(is_done, dtype=torch.bool).to(device)    # shape: [batch_size]

    # получаем значения q для всех действий из текущих состояний
    predicted_qvalues = network(states)

    # получаем q-values для выбранных действий
    predicted_qvalues_for_actions = predicted_qvalues[range(states.shape[0]), actions]

    # применяем сеть для получения q-value для следующих состояний (next_states)
    predicted_next_qvalues = network(next_states)
    
    # вычисляем V*(next_states), что соответствует max_{a'} Q(s',a')
    next_state_values = torch.max(predicted_next_qvalues, axis=-1)[0]
    
    assert next_state_values.dtype == torch.float32
    
    #print(next_state_values)
    #print(rewards)
    # вычисляем target q-values для функции потерь
    target_qvalues_for_actions = rewards + gamma * next_state_values #.item()  #!!!!!!
    
    # для последнего действия в эпизоде используем 
    # упрощенную формулу Q(s,a) = r(s,a), 
    # т.к. s' для него не существует
    target_qvalues_for_actions = torch.where(is_done, rewards, target_qvalues_for_actions)
    
    losses = (predicted_qvalues_for_actions - target_qvalues_for_actions.detach().to(device)) ** 2

    # MSE loss для минимизации
    loss = torch.mean(losses)
    # добавляем регуляризацию на значения Q 
    loss += regularizer * predicted_qvalues_for_actions.mean()
    
        
    if check_shapes:
        assert predicted_next_qvalues.data.dim(
        ) == 2, "убедитесь, что вы предсказали q-значения для всех действий в следующем состоянии"
        assert next_state_values.data.dim(
        ) == 1, "убедитесь, что вы вычислили V (s ') как максимум только по оси действий, а не по всем осям"
        assert target_qvalues_for_actions.data.dim(
        ) == 1, "что-то не так с целевыми q-значениями, они должны быть вектором"

    return loss, losses

In [8]:
def sample_batch(replay_buffer, n_samples):
    indices = np.random.choice(len(replay_buffer), n_samples)
    states, actions, rewards, next_actions, dones = [], [], [], [], []
    for i in indices:
        s, a, r, n_s, done = replay_buffer[i]
        states.append(s)
        actions.append(a)
        rewards.append(r)
        next_actions.append(n_s)
        dones.append(done)       
    return np.array(states), np.array(actions), np.array(rewards), np.array(next_actions), np.array(dones)

In [9]:
def generate_session_rb(
        env, network, opt, replay_buffer, glob_step,
        train_schedule, batch_size,
        t_max = 3000, epsilon=0, train=False
):
    """генерация сессии и обучение"""
    total_reward = 0
    s, _ = env.reset()
    epsilon = epsilon if train else 0.

    for t in range(t_max):
        a = select_action_eps_greedy(env, network, s, epsilon=epsilon)
        next_s, r, terminated, truncated, _ = env.step(a)
        
        if train:
          
            replay_buffer.append((s, a, r, next_s, terminated))
            
            if replay_buffer and glob_step % train_schedule == 0:
               
                train_batch = sample_batch(replay_buffer, batch_size)
                states, actions, rewards, next_states, is_done = train_batch
                
                opt.zero_grad()
                loss, _ = compute_td_loss(network, states, actions, rewards, next_states, is_done)
               
                loss.backward()
                opt.step()

        glob_step += 1
        total_reward += r
        s = next_s
        if terminated or truncated:
            break

    return total_reward, glob_step

In [None]:
lr = 0.0001

eps, eps_decay = 0.001, 0.999
train_ep_len, eval_schedule = 250000, 5
train_schedule, batch_size = 4, 32
replay_buffer = deque(maxlen=4000)
eval_rewards = deque(maxlen=5)
glob_step = 0
rewards_batch = []
log = []
rrrr = []
env.reset()
network = DQN(env.observation_space.shape, env.action_space.n).to(device)
network.load_state_dict(torch.load('new_pongstat2.pt'))
opt = torch.optim.Adam(network.parameters(), lr=lr)

for ep in range(train_ep_len):
    _, glob_step = generate_session_rb(
        env, network, opt, replay_buffer, glob_step, train_schedule, batch_size, epsilon=eps, train=True
    )

    if (ep + 1) % eval_schedule == 0:
        ep_rew, _ = generate_session_rb(
            env, network, opt, replay_buffer, 0, train_schedule, batch_size, epsilon=eps, train=False
        )
        eval_rewards.append(ep_rew)
        running_avg_rew = np.mean(eval_rewards)
        print("Epoch: #{}\tmean reward = {:.3f}\tepsilon = {:.3f}".format(ep, running_avg_rew, eps))
        #torch.save(network.state_dict(), './new_pongstat.pt')
        if eval_rewards and running_avg_rew >= 19:
            print("Принято!")
            break
        rewards_batch.append(ep_rew)        
        show_progress(rewards_batch, log)
        rewards_batch = []
        rrrr.append(running_avg_rew)
    eps *= eps_decay