In [1]:
from collections import namedtuple
import random

In [2]:
# s_t, a_t => s_{t+1}
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

In [3]:
class ReplayMemory:
    
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.index = 0
        
    def push(self, state, action, next_state, reward):
        if len(self.memory) < self.capacity:
            # placeholder
            self.memory.append(None)
        self.memory[self.index] = Transition(state, action, next_state, reward)
        self.index = (self.index + 1) % self.capacity
        
    # list of transition
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    def __len__(self):
        return len(self.memory)

In [4]:
from torch import nn
import torch
import torch.nn.functional as F
from torch import optim

In [5]:

# Q_function base NN
class DQN(nn.Module):
    def __init__(self, n_states, n_actions):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(n_states, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, n_actions)
    def forward(self, x):
        # x.shape: batch_size*n_states
        # output.shape: batch_size*n_actions, state_action_value
        return self.fc3(F.relu(self.fc2(F.relu(self.fc1(x)))))

In [6]:
trs = [Transition(1, 2, 3, 4), Transition(5, 6, 7, 8)]
trs = Transition(*zip(*trs))
trs

Transition(state=(1, 5), action=(2, 6), next_state=(3, 7), reward=(4, 8))

In [7]:
t = torch.rand(3, 2)
print(t)
print(t.max(dim=1)[0])
t.max(dim=1)[1]

tensor([[0.3459, 0.6515],
        [0.6835, 0.4760],
        [0.1795, 0.2964]])
tensor([0.6515, 0.6835, 0.2964])


tensor([1, 0, 1])

In [8]:
class Agent:
    def __init__(self, n_states, n_actions, eta=0.5, gamma=0.99, capacity=10000, batch_size=32):
        self.n_states = n_states
        self.n_actions = n_actions
        self.eta = eta
        self.gamma = gamma
        self.batch_size = batch_size
        
        self.memory = ReplayMemory(capacity)
        self.model = DQN(n_states, n_actions)
        
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.0001)
    
    def _replay(self):
        if len(self.memory) < self.batch_size:
            return
        # list of transition
        batch = self.memory.sample(self.batch_size)
        # Transition, column: len(tuple) == batch_size
        batch = Transition(*zip(*batch))
        
        # s_t.shape: batch_size * 4
        state_batch = torch.cat(batch.state)
        # a_t.shape: batch_size * 1
        action_batch = torch.cat(batch.action)
        # r_{t+1}.shape: batch_size * 1
        reward_batch = torch.cat(batch.reward)
        # < batch_size
        non_final_next_state_batch = torch.cat([s for s in batch.next_state if s is not None])
        
    
        # 构造模型训练用的输入和输出（true）
        # s_t, input
        
        # pred: Q(s_t, a_t)
        # true: R_{t+1} + \gamma*\max_aQ(s_t, a)
        
        # 开启 eval 模式
        self.model.eval()
        
        # pred, batch_size*1
        state_action_values = self.model(state_batch).gather(dim=1, index=action_batch)
        
        # true: R_{t+1} + \gamma*\max_aQ(s_t, a)
        # tuple(map(lambda s: s is not None, batch.next_state)): batch_size 长度的 0/1
        non_final_mask = torch.ByteTensor(tuple(map(lambda s: s is not None, batch.next_state)))
        next_state_values = torch.zeros(self.batch_size)
        # Q(s_{t+1}, a)
        next_state_values[non_final_mask] = self.model(non_final_next_state_batch).max(dim=1)[0].detach()
        
        # (batch_size, )
        expected_state_action_values = reward_batch + self.gamma * next_state_values
        
        
        # 开启train mode
        self.model.train()
        
        # expected_state_action_values.unsqueeze(1): (batch_size, ) => (batch_size, 1)
        loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
    
    def update_q_function(self):
        self._replay()
        
    def memorize(self, state, action, next_state, reward):
        self.memory.push(state, action, next_state, reward)
    
    # action policy
    # epsilon_greedy
    # double e: explore, exploit
    def choose_action(self, state, episode):
        eps = 0.5*1/(1+episode)
        if random.random() < eps:
            # explore
            action = torch.IntTensor([[random.randrange(self.n_actions)]])
        else:
            self.model.eval()
            with torch.no_grad():
                action = self.model(state).max(1)[1].view(1, 1)
        return action

In [9]:
import gymnasium as gym

env = gym.make('CartPole-v1')

print(env.observation_space.shape[0])
print(env.action_space.n)

4
2


In [10]:
env.reset()

(array([-0.00781699,  0.03912979,  0.00963907, -0.01852011], dtype=float32),
 {})

In [11]:
from JSAnimation.IPython_display import display_animation
from IPython.display import display, HTML
import matplotlib.pyplot as plt
from matplotlib import animation

In [12]:
def display_frames_as_gif(frames, output):
    """
    Displays a list of frames as a gif, with controls
    以gif格式显示关键帧列，带有控件
    """
    
    fig = plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0),dpi=72)
    patch = plt.imshow(frames[0])
    plt.axis('off')
    
    def animate(i):
        img = patch.set_data(frames[i])
        return img   ## *** return是必须要有的 ***
        
    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
    
    anim.save(output)
    return HTML(anim.to_jshtml())  ## *** 返回一个HTML对象，以便被调用者显示。 ***
    # display(display_animation(anim, default_mode='loop'))

In [16]:
env = gym.make('CartPole-v1')
n_states = env.observation_space.shape[0]
n_actions = env.action_space.n

max_episodes = 500
max_steps = 200

complete_episodes = 0
finished_flag = False

agent = Agent(n_states, n_actions)
frames = []

for episode in range(max_episodes):
    state = env.reset()[0]
    state = torch.from_numpy(state).type(torch.FloatTensor).unsqueeze(0)
    for step in range(max_steps):
        if finished_flag == True:
            frames.append(env.render(mode='rgb_array'))
        # IntTensor of 1*1
        action = agent.choose_action(state, episode)[0][0].numpy()
        # action = random.randrange(2)
        # print(f'state: {state}, action: {action}')
        # transition on env
        # cartpole v0: next_state, _, done, _ = env.step(action.item())
        next_state, reward, terminated, truncated, info = env.step(action)
        
        if terminated:
            next_state = None
            
            if step < 195:
                # 1d
                reward = torch.FloatTensor([-1.])
                complete_episodes = 0
            else:
                reward =  torch.FloatTensor([1.])
                complete_episodes += 1
        else:
            reward = torch.FloatTensor([0])
            # (4, )
            next_state = torch.from_numpy(next_state).type(torch.FloatTensor)
            # (4, ) ==> (1, 4)，便于后续的 torch.cat => (1, 4) => (32, 4)
            next_state = next_state.unsqueeze(0)
            
        # 和示例相比多了一些类型转换。
        agent.memorize(state, torch.tensor([[int(action)]], dtype=torch.int64), next_state, reward)
        agent.update_q_function()
        state = next_state
        
        if terminated:
            print(f'episode: {episode}, steps: {step}')
            break
        
    if finished_flag == True:
        break
        
    if complete_episodes >= 10:
        finished_flag = True
        print('连续成功10轮')

episode: 0, steps: 18
episode: 1, steps: 27
episode: 2, steps: 15
episode: 3, steps: 13
episode: 4, steps: 17
episode: 5, steps: 40
episode: 6, steps: 18


  next_state_values[non_final_mask] = self.model(non_final_next_state_batch).max(dim=1)[0].detach()


episode: 7, steps: 54
episode: 8, steps: 12
episode: 9, steps: 31
episode: 10, steps: 14
episode: 11, steps: 15
episode: 12, steps: 26
episode: 13, steps: 15
episode: 14, steps: 29
episode: 15, steps: 28
episode: 16, steps: 26
episode: 17, steps: 11
episode: 18, steps: 12
episode: 19, steps: 12
episode: 20, steps: 41
episode: 21, steps: 25
episode: 22, steps: 42
episode: 23, steps: 35
episode: 24, steps: 19
episode: 25, steps: 59
episode: 26, steps: 12
episode: 27, steps: 20
episode: 28, steps: 9
episode: 29, steps: 42
episode: 30, steps: 18
episode: 31, steps: 16
episode: 32, steps: 12
episode: 33, steps: 14
episode: 34, steps: 35
episode: 35, steps: 28
episode: 36, steps: 33
episode: 37, steps: 14
episode: 38, steps: 36
episode: 39, steps: 14
episode: 40, steps: 16
episode: 41, steps: 20
episode: 42, steps: 23
episode: 43, steps: 15
episode: 44, steps: 19
episode: 45, steps: 19
episode: 46, steps: 32
episode: 47, steps: 10
episode: 48, steps: 17
episode: 49, steps: 12
episode: 50, st