## Model-Based Policy Optimization: MBPO in Pendulum

基于模型策略优化

In [43]:
import gym

# 定义环境
class MyWrapper(gym.Wrapper):
  def __init__(self):
    env = gym.make('Pendulum-v1', render_mode='rgb_array')
    super().__init__(env)
    self.env = env
    self.step_n = 0
  
  def reset(self):
    state, _ = self.env.reset()
    self.step_n = 0
    return state
    
  def step(self, action):
    state, reward, terminated, truncated, info = self.env.step(action)
    done = terminated or truncated
    self.step_n += 1
    if self.step_n >= 200:
      done = True
    return state, reward, done, info 
  
env = MyWrapper()
env.reset()

array([-0.9844762 , -0.17551816,  0.25879556], dtype=float32)

In [44]:
import torch
import random
from IPython import display
import math

# 基底模型使用SAC
class SAC:
	class ModelAction(torch.nn.Module):
		def __init__(self):
			super().__init__()
			# 定义模型
			self.fc_state = torch.nn.Sequential(
				torch.nn.Linear(3, 128),
				torch.nn.ReLU(),
			)

			self.fc_mu = torch.nn.Linear(128, 1)
			self.fc_std = torch.nn.Sequential(
				torch.nn.Linear(128, 1),
				torch.nn.Softplus(),
			)

		def forward(self, state):
			# [b, 3] -> [b, 128]
			state = self.fc_state(state)

			# [b, 128] -> [b, 1]
			mu = self.fc_mu(state)

			# [b, 128] -> [b, 1]
			std = self.fc_std(state)

			# 定义b个正态分布
			dist = torch.distributions.Normal(mu, std)

			# 采样b个样本
      # 这里用的是rsample,表示重采样,
      # 其实就是先从一个标准正态分布中采样,然后乘以标准差,加上均值
			sample = dist.rsample()

			# 样本压缩到-1， 1，求动作
			action = torch.tanh(sample)

			# 求概率对数
			log_prob = dist.log_prob(sample)
  
			# 动作熵
			entropy = log_prob - (1 - action.tanh() ** 2 + 1e-7).log()
			entropy = -entropy

			return action * 2, entropy

	class ModelValue(torch.nn.Module):
		
		def __init__(self):
			super().__init__()
			self.sequential = torch.nn.Sequential(
        torch.nn.Linear(4, 128),
        torch.nn.ReLU(),
        torch.nn.Linear(128, 128),
        torch.nn.ReLU(),
        torch.nn.Linear(128, 1),
		)
    
		def forward(self, state, action):
			# [b, 3 + 1] -> [b, 4]
			state = torch.cat([state, action], dim=1)

			# [b, 4] -> [b, 1]
			return self.sequential(state)
  
	def __init__(self):
		self.model_action = self.ModelAction()
  
		self.model_value1 = self.ModelValue()
		self.model_value2 = self.ModelValue()

		self.model_value_next1 = self.ModelValue()
		self.model_value_next2 = self.ModelValue()

		self.model_value_next1.load_state_dict(self.model_value1.state_dict())
		self.model_value_next2.load_state_dict(self.model_value2.state_dict())

		# alpha 可学习参数
		self.alpha = torch.tensor(math.log(0.01))
		self.alpha.requires_grad = True
  
		self.optimizer_action = torch.optim.Adam(self.model_action.parameters(), lr=3e-4)
		self.optimizer_value1 = torch.optim.Adam(self.model_value1.parameters(), lr=3e-3)
		self.optimizer_value2 = torch.optim.Adam(self.model_value2.parameters(), lr=3e-3)
		self.optimizer_alpha = torch.optim.Adam([self.alpha], lr=3e-4)

		self.loss_fn = torch.nn.MSELoss()
  
	def get_action(self, state):
		state = torch.FloatTensor(state).reshape(1, 3)
		action, _ = self.model_action(state)
		return action.item()

	def _soft_update(self, model, model_next):
		for param, param_next in zip(model.parameters(), model_next.parameters()):
			# 小比例更新
			value = param_next.data * 0.995 + param.data * 0.005
			param_next.data.copy_(value)

	def _get_target(self, reward, next_state, over):
		# 首先使用model_action计算动作和动作的熵
		# [b, 3] -> [b, 1], [b, 1]
		action, entropy = self.model_action(next_state)


		# 评估next_state的价值
		# [b, 4], [b, 1] -> [b, 1]
		target1 = self.model_value_next1(next_state, action)
		target2 = self.model_value_next2(next_state, action)

		# 取价值小的，出于稳定性
		# [b, 1]
		target = torch.min(target1, target2)

		# 还原alpha
		# target 加上动作熵，alpha作为权重系数
		# [b, 1] - [b, 1] -> [b, 1]
		target += self.alpha.exp() * entropy

		# [b, 1] 
		target *= 0.99
		target *= (1 - over)
		target += reward

		return target

	def _get_loss_action(self, state):
		# 计算action和熵
		# [b, 3] -> [b, 1], [b, 1]
		action, entropy = self.model_action(state)

		# 使用两个value网络评估action的价值
		#[b, 3],[b, 1] -> [b, 1]
		value1 = self.model_value1(state, action)
		value2 = self.model_value2(state, action)

		# 取价值小的,出于稳定性考虑
		# [b, 1]
		value = torch.min(value1, value2)

		# alpha还原后乘以熵,这个值期望的是越大越好,但是这里是计算loss,所以符号取反
		# [1] - [b, 1] -> [b, 1]
		loss_action = -self.alpha.exp() * entropy
		
		# 减去value,所以value越大越好,这样loss就会越小
		loss_action -= value

		return loss_action.mean(), entropy

	def _get_loss_value(self, model_value, target, state, action, next_state):
		# 计算value
		value = model_value(state, action)
  
		# 计算loss, value, 贴近target
		loss_value = self.loss_fn(value, target)
		return loss_value

	def train(self, state, reward, action, next_state, over):
		# 对reward 偏移，便于训练
		reward = (reward + 8) / 8

		# 计算value和target, target已经考虑了动作和熵
		# [b, 1]
		target = self._get_target(reward, next_state, over)
		target = target.detach()
		
		# 计算两个value
		loss_value1 = self._get_loss_value(self.model_value1, target, state, action, next_state)
		loss_value2 = self._get_loss_value(self.model_value2, target, state, action, next_state)
		
		# 更新参数
		self.optimizer_value1.zero_grad()
		loss_value1.backward()
		self.optimizer_value1.step()
		
		self.optimizer_value2.zero_grad()
		loss_value2.backward()
		self.optimizer_value2.step()
		
		# 使用model_value计算model_action的loss, 更新参数
		loss_action, entropy = self._get_loss_action(state)
		self.optimizer_action.zero_grad()
		loss_action.backward()
		self.optimizer_action.step()
		
		# 熵乘以alpha就是alpha的loss
		# [b, 1] -> [1]
		loss_alpha = (entropy + 1).detach() * self.alpha.exp()
		loss_alpha = loss_alpha.mean()
		
		# 更新alpha值
		self.optimizer_alpha.zero_grad()
		loss_alpha.backward()
		self.optimizer_alpha.step()
		
		# 小比例更新
		self._soft_update(self.model_value1, self.model_value_next1)
		self._soft_update(self.model_value2, self.model_value_next2)
  
sac = SAC()

sac.train(
	torch.randn(5, 3),
	torch.randn(5, 1),
	torch.randn(5, 1),
	torch.randn(5, 3),
	torch.zeros(5, 1).long(),
)

sac.get_action([1, 2, 3])

-0.678047239780426

In [45]:
import numpy as np

class Pool:
  def __init__(self, limit):
    # 样本池
    self.datas = []
    self.limit = limit
  
	# 向样本池中添加数据
  def add(self, state, action, reward, next_state, over):
    if isinstance(state, np.ndarray) or isinstance(state, torch.Tensor):
      state = state.reshape(3).tolist()

    action = float(action)
    reward = float(reward)
    
    if isinstance(next_state, np.ndarray) or isinstance(next_state, torch.Tensor):
      next_state = next_state.reshape(3).tolist()

    over = bool(over)

    self.datas.append((state, action, reward, next_state, over))

    # 数据上限，超出时从最古老的开始删除
    while len(self.datas) > self.limit:
      self.datas.pop(0)

  # 获取一批数据样本
  def get_sample(self, size=None):
    if size is None:
      size = len(self)
      
    size = min(size, len(self))
    
    # 从样本池中采样
    samples = random.sample(self.datas, size)

    # [b, 3]
    state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)
    # [b, 1]
    action = torch.FloatTensor([i[1] for i in samples]).reshape(-1, 1)
    # [b, 1]
    reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)
    # [b, 4]
    next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 3)
    # [b, 1]
    over = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)

    return state, action, reward, next_state, over

  def __len__(self):
    return len(self.datas)

env_pool = Pool(100000)
model_pool = Pool(1000)

# 初始化一局游戏的数据
def _():
	# 初始化游戏
	state = env.reset()

	# 玩到游戏结束为止
	over = False
	while not over:
		# 根据当前状态得到一个动作
		action = sac.get_action(state)
  
		# 执行动作得到反馈
		next_state, reward, over, _ = env.step([action])
  
		# 记录数据样本
		env_pool.add(state, action, reward, next_state, over)
  
		# 更新游戏状态，开始下一个动作
		state = next_state
  

_()

len(env_pool), env_pool.datas[0], env_pool.get_sample(2)

(200,
 ([0.48654261231422424, 0.8736568689346313, 0.72088223695755],
  -1.7668370008468628,
  -1.1843527750986558,
  [0.4372809827327728, 0.8993249535560608, 1.1110992431640625],
  False),
 (tensor([[-0.1987,  0.9801, -7.1283],
          [ 0.9632,  0.2687, -2.3261]]),
  tensor([[-1.0721],
          [ 1.0418]]),
  tensor([[-8.2182],
          [-0.6162]]),
  tensor([[ 0.1273,  0.9919, -6.5541],
          [ 0.9850,  0.1728, -1.9683]]),
  tensor([[0],
          [0]])))

In [46]:
# 定义主模型
class Model(torch.nn.Module):
  
  # swish 激活函数
  class Swish(torch.nn.Module):
    def __init__(self):
      super().__init__()

    def forward(self, x):
      return x * torch.sigmoid(x)
    
  # 定义工具层
  class FCLayer(torch.nn.Module):
    def __init__(self, in_size, out_size):
      super().__init__()
      self.in_size = in_size
      
      # 初始化参数
      std = in_size ** 0.5
      std *= 2
      std = 1 / std
      
      weight = torch.empty(5, in_size, out_size)
      torch.nn.init.normal_(weight, mean=0.0, std=std)

      
      # [5, in, out]
      self.weight = torch.nn.Parameter(weight)
      
      # [5, 1, out]
      self.bias = torch.nn.Parameter(torch.zeros(5, 1, out_size))
    
    def forward(self, x):
      # x -> [5, b, in]
      
      # [5, b, in] * [5, in, out] -> [5, b, out]
      x = torch.bmm(x, self.weight)
      
      # [5, b, out] + [5, 1, out] -> [5, b, out]
      x = x + self.bias
      
      return x
    
    
  def __init__(self):
    super().__init__()
    
    self.sequential = torch.nn.Sequential(
      self.FCLayer(4, 200),
      self.Swish(),
      self.FCLayer(200, 200),
      self.Swish(),
      self.FCLayer(200, 200),
      self.Swish(),
      self.FCLayer(200, 200),
      self.Swish(),
      self.FCLayer(200, 8),
      torch.nn.Identity(),
    )
    
    self.softplus = torch.nn.Softplus()
    self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        
  def forward(self, x):
    # x -> [5, b, 4]
    
    # [5, b, 4] -> [5, b, 8]
    x = self.sequential(x)
    
    # [5, b, 8] -> [5, b, 4]
    mean = x[..., :4]
    
    # [5, b, 8] -> [5, b, 4]
    logvar = x[..., 4:]
    
    # [1, 1, 4] - [5, b, 4] -> [5, b, 4]
    logvar = 0.5 - logvar
    
    # [1, 1, 4] - [5, b, 4] -> [5, b, 4]
    logvar = 0.5 - self.softplus(logvar)
    
    # [5, b, 4] - [1, 1, 4] -> [5, b, 4]
    logvar = logvar + 10
    
    # [5, b, 4] + [1, 1, 4] -> [5, b, 4]
    logvar = self.softplus(logvar) - 10
    
    #[5, b, 4],[5, b, 4]
    return mean, logvar
  
  def train(self):
    state, action, reward, next_state, _ = env_pool.get_sample()
    # input -> [b, 4]
    # label -> [b, 4]
    input = torch.cat([state, action], dim=1)
    label = torch.cat([reward, next_state - state], dim=1)
    
    # 反复训练N次
    for _ in range(len(input) // 64 * 20):
      #从全量数据中抽样64个,反复抽5遍,形成5份数据
      #[5, 64]
      select = [torch.randperm(len(input))[:64] for _ in range(5)]
      select = torch.stack(select)
      # [5, b, 4], [5, b, 4]
      input_select = input[select]
      label_select = label[select]
      del select
      
      # 模型计算
      # [5, b, 4] -> [5, b, 4], [5, b, 4]
      mean, logvar = model(input_select)
      
      # 计算loss
      # [b, 4] - [b, 4] * [b, 4] -> [b, 4]
      mse_loss = (mean - label_select) ** 2 * (-logvar).exp()
      
      # [b, 4] -> [b] -> scala
      mse_loss = mse_loss.mean(dim=1).mean()
      
      # [b, 4] -> [b] -> scala
      var_loss = logvar.mean(dim=1).mean()
      
      loss = mse_loss + var_loss
      
      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()

      
model = Model()
a, b = model(torch.randn(5, 64, 4))
a.shape, b.shape

(torch.Size([5, 64, 4]), torch.Size([5, 64, 4]))

In [47]:
class MBPO():
  def _fake_step(self, state, action):
    state = torch.FloatTensor(state).reshape(-1, 3)
    action = torch.FloatTensor([action]).reshape(-1, 1)
    # state -> [b, 3]
    # action -> [b, 1]
    
    # [b, 4]
    input = torch.cat([state, action], dim=1)
    
    # 重复5遍
    # [b, 4] -> [1, b, 4] -> [5, b, 4]
    input = input.unsqueeze(dim=0).repeat([5, 1, 1])
    
    # 模型计算
    # [5, b, 4] -> [5, b, 4],[5, b, 4]
    with torch.no_grad():
      mean, std = model(input)
    std = std.exp().sqrt()
    del input
    
    # means的后三列加上环境数据
    mean[:, :, 1:] += state
    
    # 重采样
    # [5, b, 4]
    sample = torch.distributions.Normal(0, 1).sample(mean.shape)
    sample = mean + sample * std
    
    # 0-4采样b个元素
    # [4, 4, 2, 4, 3, 4, 1, 3, 3, 0, 2, ......]
    select = [random.choice(range(5)) for _ in range(mean.shape[1])]
    
    # 重采样结果，0d: 0-4随机选择，2d: 0-b随机选择
    # [5, b, 4] -> [b, 4]
    sample = sample[select, range(mean.shape[1])]
    
    # 切分
    reward, next_state = sample[:, :1], sample[:, 1:]

    return reward, next_state
    
  def rollout(self):
    states, _, _, _, _ = env_pool.get_sample(1000)
    for state in states:
      action = sac.get_action(state)
      reward, next_state = self._fake_step(state, action)
      
      model_pool.add(state, action, reward, next_state, False)
      state = next_state
  
  
mbpo = MBPO()
a, b, = mbpo._fake_step([1, 2, 3], 1)
print(a.shape, b.shape)
    

torch.Size([1, 1]) torch.Size([1, 3])


In [49]:
for i in range(20):
  reward_sum = 0
  state = env.reset()
  over = False
  
  step = 0
  while not over:
    # 每隔50个step, 训练一次模型
    if step % 50 == 0:
      model.train()
      mbpo.rollout()
      
    step += 1
  
    # 使用sac获取一个动作
    action = sac.get_action(state)

    # 执行动作得到反馈
    next_state, reward, over, _ = env.step([action])

    # 累加reward
    reward_sum += reward
    
    # 记录数据样本
    env_pool.add(state, action, reward, next_state, over)

    # 更新游戏状态，开始下一个动作
    state = next_state
    
    # 更新模型
    for _ in range(10):
      sample = []
      sample_env = env_pool.get_sample(32)
      sample_model = model_pool.get_sample(32)
      
      for (i1, i2) in zip(sample_env, sample_model):
        i3 = torch.cat([i1, i2], dim=0)
        sample.append(i3)
        
      sac.train(*sample)
      
  
  print(i, len(env_pool), len(model_pool), reward_sum)

0 4400 1000 -1501.6687599952068
1 4600 1000 -1648.8052691166113
2 4800 1000 -1658.6450614823418
3 5000 1000 -1423.7590339890028
4 5200 1000 -1601.9574772193978
5 5400 1000 -1497.145118584554
6 5600 1000 -1549.1217911468475
7 5800 1000 -1659.984712824094
8 6000 1000 -1493.9186326216739
9 6200 1000 -1664.9796748770852
10 6400 1000 -1439.3199169855939
11 6600 1000 -1494.0383525766676
12 6800 1000 -1680.307633952045
13 7000 1000 -1363.6668641944696
14 7200 1000 -1556.745940425559
15 7400 1000 -1634.7677005951002
16 7600 1000 -1622.1893034855896
17 7800 1000 -1598.9566343098154
18 8000 1000 -1624.661010674327
19 8200 1000 -1612.9941092380554
