In [1]:
import gym
from gym.utils import play 
from gym import wrappers
from gym.wrappers import GrayScaleObservation, RecordEpisodeStatistics, TimeLimit, ResizeObservation, FrameStack

import numpy as np

import matplotlib.pyplot as plt
import time
from hashlib import md5
import os
import random

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

from tensordict import TensorDict

import torchrl
from torchrl.data.replay_buffers import PrioritizedReplayBuffer, ReplayBuffer, PrioritizedSampler
from torchrl.data import TensorDictPrioritizedReplayBuffer, LazyMemmapStorage, TensorDictReplayBuffer

In [5]:
torch.cuda.is_available()

True

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
class MaskVelocityWrapper(gym.ObservationWrapper):
    """
    Gym environment observation wrapper used to mask velocity terms in
    observations. The intention is the make the MDP partially observatiable.
    """
    def __init__(self, env):
        super(MaskVelocityWrapper, self).__init__(env)
        if ENV == "CartPole-v1":
            self.mask = np.array([1., 0., 1., 0.])
        elif ENV == "Pendulum-v0":
            self.mask = np.array([1., 1., 0.])
        elif ENV == "LunarLander-v2":
            self.mask = np.array([1., 1., 0., 0., 1., 0., 1., 1,])
        elif ENV == "LunarLanderContinuous-v2":
            self.mask = np.array([1., 1., 0., 0., 1., 0., 1., 1,])
        else:
            raise NotImplementedError

    def observation(self, observation):
        return  observation * self.mask

In [38]:
# Успешный вариант. Дошел до отметки 400.
# class DRQNAgent(nn.Module):
#     def __init__(self, input_shape, action_n, lr=1e-3, gamma=0.95, batch_size=5, period=10, N=20, M=0, episode_n=1000):
#         super().__init__()
# class RQfunction(nn.Module):
#     def __init__(self, state_dim, action_dim, hidden_size=16):

In [34]:
class RQfunction(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_size=32, num_layers=2):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.LSTM(state_dim, hidden_size=hidden_size, num_layers=num_layers)
        self.q = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            # nn.LayerNorm(hidden_size), # перед RELU или ELU нет смысла ставить Norm так как градиенты пройдут нормально и без этого через эти функции активации
            nn.ELU(),
            nn.Linear(hidden_size, action_dim)
        )
        self.prev_hidden_state = None
        self.hidden_state = None
        self.current_batch_size = 0

    def initialize(self, batch_size):
        self.current_batch_size = batch_size
        if batch_size == 0:
            self.hidden_state = (torch.zeros(self.num_layers, self.hidden_size).to(device), torch.zeros(self.num_layers, self.hidden_size).to(device))
        # else:
        #     self.hidden_state = (torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device), torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device))
        # self.prev_hidden_state = None

    def forward(self, x, batch_size=None):
        if len(x.shape) == 3:
            batch_size = x.shape[0]
        if len(x.shape) == 2:
            batch_size = 0
        if self.hidden_state is None or batch_size != self.current_batch_size:
            self.initialize(batch_size)

        self.prev_hidden_state = self.hidden_state
        # with torch.no_grad():
        output, self.hidden_state = self.rnn(x, self.hidden_state)
        # if self.current_batch_size == 0:
        #     print(f"obs: {x}")
        #     print(f"state: {output}")
        q_values = self.q(output)
        return q_values


# class RecurrentActor(nn.Module):
#     def __init__(self, state_dim, action_dim, hidden_size=64, batch_size=64):
        
        
        

In [30]:
class DRQNAgent(nn.Module):
    def __init__(self, input_shape, action_n, lr=1e-3, gamma=0.97, batch_size=8, period=10, N=20, M=0, episode_n=1000):
        super().__init__()
        self.N = N
        self.M = M
        self.q_function = RQfunction(input_shape, action_n).to(device)
        self.target_q_function = RQfunction(input_shape, action_n).to(device)
        self.update_weights()

        self.epsilon_min = 0.02
        self.epsilon_decay = 0.5 * 1.0 / (episode_n)
        self.epsilon = 0.9

        self.episode_n = episode_n
        self.current_episode = 0

        self.input_shape = input_shape
        self.action_n = action_n
        
        self.lr = lr
        self.gamma = gamma
        self.batch_size = batch_size
        self.period = period
        self.counter = 1
        
        self.optimizer = torch.optim.Adam(self.q_function.parameters(), lr=lr)
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[int(0.7 * self.episode_n)], gamma=0.5)

        # self.capacity = 100_000
        # storage = LazyMemmapStorage(self.capacity, scratch_dir='/home/artem/atari_games/tmp/')
        # self.sampler = PrioritizedSampler(self.capacity, alpha=self.alpha, beta=self.beta)
        # self.tdrb = TensorDictReplayBuffer(storage=storage, sampler=self.sampler, priority_key='td_error')
        self.states_count = 0
        self.rb = []

        self.states = None
        self.hidden_states =  None
        self.actions = None
        self.rewards = None
        self.dones = None
        self.next_states = None
        self.next_hidden_states = None

    def save_model(self, path=f'/home/artem/atari_games/models/DRQN_{md5(str(time.time()).encode()).hexdigest()}.pth'):
        state = {
            'model_dict': self.q_function.state_dict(),
            'optimizer_dict': self.optimizer.state_dict(),
            'epsilon': self.epsilon
        }
        torch.save(state, path)

    def load_model(self,path=f'/home/artem/atari_games/models/DRQN_{md5(str(time.time()).encode()).hexdigest()}'):
        if os.path.exists(path):
            state = torch.load(path)
            self.q_function.rnn.load_state_dict(state['model_dict'])

    def decay_epsilon(self):
        self.epsilon = max((self.epsilon - self.epsilon_decay), self.epsilon_min)

    def decay_all(self):
        self.decay_epsilon()

    def e_greedy_action(self, q_values):
        probs = np.ones(self.action_n) * self.epsilon / self.action_n
        probs[np.argmax(q_values.cpu().numpy())] += 1 - self.epsilon
        action = np.random.choice(np.arange(self.action_n), p=probs)
        return action
        
    def get_action(self, obs: np.ndarray) -> torch.Tensor:
        with torch.no_grad():
            obs = torch.tensor(obs, dtype=torch.float).unsqueeze(dim=0).to(device)
            q_values = self.q_function(obs).squeeze()
            if self.counter % 2500 == 0:
                print(q_values)
            action = self.e_greedy_action(q_values)
            return action

    def add_sample(self, state, action, reward, next_state, done):
        # print(f"stacked {torch.stack((self.q_function.prev_hidden_state[0], self.q_function.prev_hidden_state[1]), dim=0)}")
        # print(f"hidden_states {0 if self.hidden_states is None else self.hidden_states.shape}")

        stacked_hidden_state = torch.stack((self.q_function.prev_hidden_state[0], self.q_function.prev_hidden_state[1]), dim=0)
        stacked_next_hidden_state = torch.stack((self.q_function.hidden_state[0], self.q_function.hidden_state[1]), dim=0)
        

        self.states = torch.tensor(state, dtype=torch.float32).unsqueeze(dim=0) if self.states is None else torch.cat((self.states, torch.tensor(state).unsqueeze(dim=0)), dim=0)
        self.hidden_states = stacked_hidden_state.unsqueeze(dim=0) if self.hidden_states is None else torch.cat((self.hidden_states, stacked_hidden_state.unsqueeze(dim=0)), dim=0)
        self.actions = torch.tensor(action).unsqueeze(dim=0) if self.actions is None else torch.cat((self.actions, torch.tensor(action).unsqueeze(dim=0)), dim=0)
        self.rewards = torch.tensor(reward).unsqueeze(dim=0) if self.rewards is None else torch.cat((self.rewards, torch.tensor(reward).unsqueeze(dim=0)), dim=0)
        self.dones = torch.tensor(int(done)).unsqueeze(dim=0) if self.dones is None else torch.cat((self.dones, torch.tensor(int(done)).unsqueeze(dim=0)), dim=0)
        self.next_states = torch.tensor(next_state, dtype=torch.float32).unsqueeze(dim=0) if self.next_states is None else torch.cat((self.next_states, torch.tensor(next_state).unsqueeze(dim=0)), dim=0)
        self.next_hidden_states = stacked_next_hidden_state.unsqueeze(dim=0) if self.next_hidden_states is None else torch.cat((self.next_hidden_states, stacked_next_hidden_state.unsqueeze(dim=0)), dim=0)
        
        self.states_count += 1

        if done or self.states_count % self.N == 0:
            # print(self.states.shape)
            self.rb.append(
                {
                    'state': self.states.to(device),#.clone().detach().requires_grad_(False) ,
                    'hidden_state': self.hidden_states.to(device),
                    'action': self.actions.to(device),
                    'reward': self.rewards.to(device),
                    'next_state': self.next_states.to(device),
                    'next_hidden_state': self.next_hidden_states.to(device),
                    'done': self.dones.to(device)
                        }
            )

            self.states = None
            self.hidden_states =  None
            self.actions = None
            self.rewards = None
            self.dones = None
            self.next_states = None
            self.next_hidden_states = None

        if done:
            self.current_episode += 1
            self.decay_all()
            self.q_function.prev_hidden_state = None
            self.q_function.hidden_state = None

    def update_weights(self):
        for parameter_freeze, parameter in zip(self.target_q_function.rnn.parameters(), self.q_function.rnn.parameters()):
            with torch.no_grad():
                parameter_freeze.data.copy_(parameter.data)
        for parameter_freeze, parameter in zip(self.target_q_function.q.parameters(), self.q_function.q.parameters()):
            with torch.no_grad():
                parameter_freeze.data.copy_(parameter.data)

    def fit(self):
        if self.batch_size < len(self.rb):
            if self.counter % self.period == 0:
                # print('weights change')
                # не уверен, что LSTM можно так легко скопировать
                self.update_weights()
                
            self.counter += 1

            sample = random.sample(self.rb, self.batch_size)

            flag = True
            for rollout in sample:
                # print(rollout['state'])
                # вычисляем таргеты
                self.q_function.hidden_state = (rollout['hidden_state'][0][0], rollout['hidden_state'][0][1])
                self.target_q_function.hidden_state = (rollout['hidden_state'][0][0], rollout['hidden_state'][0][1])
                # попробую инициализировать нулями hidden_state не очень пошло
                # self.q_function.initialize(1)
                # self.target_q_function.initialize(1)

                # print(f"q_func: {self.target_q_function(rollout['next_state'])}")
                targets = rollout['reward'].unsqueeze(dim=1) + (1 - rollout['done'].unsqueeze(dim=1)) * self.gamma * self.target_q_function(rollout['next_state'].float())\
                .gather(1, torch.argmax(self.q_function(rollout['next_state'].float()), dim=1).unsqueeze(dim=1))

                
                # print(f"done: {(1 - rollout['done']).shape}")
                # print(f"shape check: {torch.argmax(self.q_function(rollout['next_state']), dim=1).unsqueeze(dim=1).shape}")
                # print(f"target q: {self.target_q_function(rollout['next_state']).shape}")
                # print(f"the whole: {self.target_q_function(rollout['next_state']).gather(1, torch.argmax(self.q_function(rollout['next_state']), dim=1).unsqueeze(dim=1)).shape}")
                # print(f"targets check: {targets.shape}")
                # вычисляем loss 
                q_values = self.q_function(rollout['state'].float()).gather(1, rollout['action'].unsqueeze(dim=1))
                print(rollout['state'].float().shape)

                # получаем последние  N - M состояний, так как их отображение q_values более правдоподобно из-за прогрева hidden_state
                
                td = (q_values - targets.detach()) ** 2
                if td.shape[0] > self.M and self.M != 0:
                    zero_mask = torch.zeros((self.M, 1))
                    ones_mask = torch.ones((td.shape[0] - self.M, 1))
                    mask = torch.cat((zero_mask, ones_mask), dim=0).to(device)
                    td *= mask
                loss = torch.mean(td)
                loss.backward()

                # nn.utils.clip_grad_norm_(self.q_function.q.parameters(), max_norm=100.0)
                # nn.utils.clip_grad_norm_(self.q_function.rnn.parameters(), max_norm=100.0)

                if self.counter % 2500 == 0 and flag:
                    print(f"targets: {targets}")
                    print(f"q_values: {q_values}")
                    print(f"loss: {loss}")
                    flag = False

                self.optimizer.step()
                self.optimizer.zero_grad()

            self.scheduler.step()
            self.q_function.initialize(1)
            

In [116]:
# class DRQNAgent(nn.Module):
#     def __init__(self, input_shape, action_n, lr=1e-3, gamma=0.95, batch_size=2, period=500, N=50, M=10, episode_n=1000):
# рабочий вариант на 3к операциях почти сошелся

In [55]:
# делаем пошаговые эксперименты.
# уменьшил период до 5 
# маленький период не способствует улучшению

# увеличил длину траектории с 20 до 30
# показал примерно такой же результат

# еще хочется потом уменьшить hidden size до 8 так как и такого размера по моему мнению должно быть достаточно
# не сработало

# оставил только те роллауты что по длине достаточно длинные

# сейчас попробую сделать мултилэйер LSTM

In [35]:
 # Environment parameters
ENV = "CartPole-v1"
env = gym.make('CartPole-v1')

# ENV = "LunarLander-v2"
# env = gym.make('LunarLander-v2')
# убираем скорость из состояний
env = MaskVelocityWrapper(env)
# torch.manual_seed(42)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

episode_n = 10_000

agent = DRQNAgent(state_dim, action_dim, episode_n=episode_n)
# agent.load_model("pretrained_lstm.ptx")

total_rewards = []
loss1 = []
grads1 = []
loss2 = []
grads2 = []
counter = 0
for episode in range(episode_n):

    total_reward = 0
    state, info = env.reset()
    
    for i in range(1000):
        action = agent.get_action(state)
        
        next_state, reward, terminated, truncated, info = env.step(action)
        
        agent.add_sample(state, action, reward, next_state, terminated or truncated)
        counter += 1

        if counter % 2 == 0:
            agent.fit()
    
        total_reward += reward
        state = next_state

        if terminated or truncated:
            break

    
        
    total_rewards.append(total_reward)
    if episode % 10 == 0:
        print(f"episode: {episode}  mean last ten: {np.mean(total_rewards[-10:])}")

print('end')

episode: 0  mean last ten: 12.0
episode: 10  mean last ten: 24.8
episode: 20  mean last ten: 29.2
episode: 30  mean last ten: 19.2
episode: 40  mean last ten: 24.2
episode: 50  mean last ten: 14.6
episode: 60  mean last ten: 23.1
episode: 70  mean last ten: 25.3
episode: 80  mean last ten: 22.3
episode: 90  mean last ten: 28.0
episode: 100  mean last ten: 27.1
episode: 110  mean last ten: 23.9
episode: 120  mean last ten: 23.8
episode: 130  mean last ten: 22.8
episode: 140  mean last ten: 20.9
episode: 150  mean last ten: 20.9
episode: 160  mean last ten: 19.5
episode: 170  mean last ten: 19.1
episode: 180  mean last ten: 18.5
episode: 190  mean last ten: 22.7
episode: 200  mean last ten: 18.0
episode: 210  mean last ten: 19.0
episode: 220  mean last ten: 23.1
episode: 230  mean last ten: 20.0
targets: tensor([[4.3312],
        [1.8377],
        [2.8917],
        [1.0000]], grad_fn=<AddBackward0>)
q_values: tensor([[4.2395],
        [1.6095],
        [1.7319],
        [0.5307]], grad_f

KeyboardInterrupt: 

In [38]:
t = torch.tensor([1, 2, 4]).unsqueeze(dim=0).unsqueeze(dim=0)
t.shape

torch.Size([1, 1, 3])

In [115]:
rnn = nn.LSTM(10, 20, 1)
input = torch.randn(1, 10)
h0 = torch.randn(1, 20)
c0 = torch.randn(1, 20)
output, (hn, cn) = rnn(input, (h0, c0))

In [116]:
output

tensor([[-0.1831,  0.2089,  0.4142, -0.0557, -0.0669,  0.3513, -0.1570, -0.2058,
          0.1245,  0.1779,  0.2874, -0.1077, -0.2030, -0.0157, -0.0557,  0.1683,
          0.0930,  0.0799, -0.3804,  0.1875]], grad_fn=<SqueezeBackward1>)

In [117]:
hn

tensor([[-0.1831,  0.2089,  0.4142, -0.0557, -0.0669,  0.3513, -0.1570, -0.2058,
          0.1245,  0.1779,  0.2874, -0.1077, -0.2030, -0.0157, -0.0557,  0.1683,
          0.0930,  0.0799, -0.3804,  0.1875]], grad_fn=<SqueezeBackward1>)

In [118]:
cn

tensor([[-0.3891,  0.4474,  1.2947, -0.1222, -0.1394,  1.9529, -0.2809, -0.3401,
          0.3209,  0.8142,  0.6584, -0.2212, -0.4863, -0.0372, -0.0857,  0.6757,
          0.2002,  0.1816, -0.9577,  0.4648]], grad_fn=<SqueezeBackward1>)

In [186]:
agent.save_model()