In [1]:
import numpy as np
import scipy.signal
import gym, random, pickle, os.path, math, glob
from IPython.core.debugger import set_trace
from gym.wrappers import Monitor

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd as autograd
from torch.distributions import Categorical

import pdb

import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import clear_output

from atari_wrappers import make_atari, wrap_deepmind
from tensorboardX import SummaryWriter

USE_CUDA = torch.cuda.is_available()
if USE_CUDA:
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
torch.cuda.set_device(7)


In [2]:
class soft_DQN(nn.Module):
    def __init__(self, in_channels=4, num_actions=5, REWARD_SCALE = 1):
        """
        Initialize a deep Q-learning network as described in
        https://storage.googleapis.com/deepmind-data/assets/papers/DeepMindNature14236Paper.pdf
        Arguments:
            in_channels: number of channel of input.
                i.e The number of most recent frames stacked together as describe in the paper
            num_actions: number of action-value to output, one-to-one correspondence to action in game.
        """
        super(soft_DQN, self).__init__()        
        self.REWARD_SCALE = REWARD_SCALE
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.fc4 = nn.Linear(7 * 7 * 64, 512)
        self.fc5 = nn.Linear(512, num_actions)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.fc4(x.view(x.size(0), -1)))
        return self.fc5(x)
    
    def get_action(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.fc4(x.view(x.size(0), -1)))
        action_probs = F.softmax(self.fc5(x)/self.REWARD_SCALE,-1)
        action_dist = Categorical(action_probs)
        actions = action_dist.sample().view(-1, 1)
        return actions
    
class error_net(nn.Module):
    def __init__(self, in_channels=4, num_actions=5):
        """
        Initialize a deep Q-learning network as described in
        https://storage.googleapis.com/deepmind-data/assets/papers/DeepMindNature14236Paper.pdf
        Arguments:
            in_channels: number of channel of input.
                i.e The number of most recent frames stacked together as describe in the paper
            num_actions: number of action-value to output, one-to-one correspondence to action in game.
        """
        super(error_net, self).__init__()        
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.fc4 = nn.Linear(7 * 7 * 64, 512)
        self.fc5 = nn.Linear(512, num_actions)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.fc4(x.view(x.size(0), -1)))
        return self.fc5(x)
    
def disable_gradients(network):
    for param in network.parameters():
        param.requires_grad = False


In [3]:
class Memory_Buffer(object):
    def __init__(self, memory_size=1000):
        self.buffer = []
        self.memory_size = memory_size
        self.next_idx = 0
        
    def push(self, state, action, reward, next_state, done):
        data = (state, action, reward, next_state, done)
        if len(self.buffer) <= self.memory_size: # buffer not full
            self.buffer.append(data)
        else: # buffer is full
            self.buffer[self.next_idx] = data
        self.next_idx = (self.next_idx + 1) % self.memory_size

    def sample(self, batch_size):
        states, actions, rewards, next_states, dones = [], [], [], [], []
        for i in range(batch_size):
            idx = random.randint(0, self.size() - 1)
            data = self.buffer[idx]
            state, action, reward, next_state, done= data
            states.append(state)
            actions.append(action)
            rewards.append(reward)
            next_states.append(next_state)
            dones.append(done)
            
            
        return np.concatenate(states), actions, rewards, np.concatenate(next_states), dones
    
    def size(self):
        return len(self.buffer)


In [4]:
class softDQN_DisCor_Agent: 
    def __init__(self, in_channels = 1, action_space = [], USE_CUDA = False, memory_size = 10000, lr = 1e-4, reward_scale = 1, tau_init = 10, gamma = 0.99):
        self.action_space = action_space
        self.memory_buffer = Memory_Buffer(memory_size)
        self.alpha = reward_scale
        self._tau = torch.tensor(tau_init, requires_grad=False)
        self._gamma = gamma
        self.DQN = soft_DQN(in_channels = in_channels, num_actions = action_space.n,REWARD_SCALE = reward_scale)
        self.DQN_target = soft_DQN(in_channels = in_channels, num_actions = action_space.n,REWARD_SCALE = reward_scale)
        self.DQN_target.load_state_dict(self.DQN.state_dict())
        # error net
        self.errnet = error_net(in_channels = in_channels, num_actions = action_space.n)
        self.errnet_target = error_net(in_channels = in_channels, num_actions = action_space.n)
        self.errnet_target.load_state_dict(self.errnet.state_dict())

        self.USE_CUDA = USE_CUDA
        if USE_CUDA:
            self.DQN = self.DQN.cuda()
            self.DQN_target = self.DQN_target.cuda()
            self.errnet = self.errnet.cuda()
            self.errnet_target = self.errnet_target.cuda()
            self._tau = self._tau.cuda()
        
        # disble target network gradient
        disable_gradients(self.errnet_target)        
        disable_gradients(self.DQN_target)
        # optimizer Adam
        self.optimizer = optim.Adam(list(self.DQN.parameters()) + list(self.errnet.parameters()),lr=lr)

    def observe(self, lazyframe):
        # from Lazy frame to tensor
        state =  torch.from_numpy(lazyframe._force().transpose(2,0,1)[None]/255).float()
        if self.USE_CUDA:
            state = state.cuda()
        return state

    def value(self, state):
        q_values = self.DQN(state)
        return q_values
    
    def act(self, state, t=0, explore_time=0):
        """
        random policy first, 
        then sample action according to softmax policy
        """
        if t < explore_time:
            action = self.action_space.sample()
        else:
            action = int(self.DQN.get_action(state).squeeze().cpu().detach().numpy())
        return action
        
    def compute_td_loss(self, states, actions, rewards, next_states, is_done, gamma=0.99):
        """ Compute td loss using torch operations only. Use the formula above. """
        actions = torch.tensor(actions).long()    # shape: [batch_size]
        rewards = torch.tensor(rewards, dtype =torch.float)  # shape: [batch_size]
        is_done = torch.tensor(is_done).type(torch.bool)  # shape: [batch_size]
        
        if self.USE_CUDA:
            actions = actions.cuda()
            rewards = rewards.cuda()
            is_done = is_done.cuda()
        
        imp_ws = self.calc_importance_weights(next_states, is_done)

        # get q-values for all actions in current states
        predicted_qvalues = self.DQN(states)

        # select q-values for chosen actions
        predicted_qvalues_for_actions = predicted_qvalues[
          range(states.shape[0]), actions
        ]

        # compute q-values for all actions in next states
        predicted_next_qvalues = self.DQN_target(next_states) # YOUR CODE
        # compute V*(next_states) using predicted next q-values
        next_state_values =  self.alpha*torch.logsumexp(predicted_next_qvalues/self.alpha, dim = -1) # YOUR CODE        

        # compute "target q-values" for loss - it's what's inside square parentheses in the above formula.
        target_qvalues_for_actions = rewards + gamma*next_state_values # YOUR CODE
        # at the last state we shall use simplified formula: Q(s,a) = r(s,a) since s' doesn't exist
        target_qvalues_for_actions = torch.where(
            is_done, rewards, target_qvalues_for_actions)
        # pdb.set_trace()
        # mean squared error loss to minimize with importance sampling from error model
        q_loss = (F.smooth_l1_loss(predicted_qvalues_for_actions, target_qvalues_for_actions.detach(), reduction='none')*imp_ws).sum()
        
        ## Calculate current and target errors, as well as importance weights.
        curr_errs_all = self.errnet(states)
        curr_errs = curr_errs_all[range(states.shape[0]), actions]
        
        target_errs = self.calc_target_errors(
            next_states, is_done, predicted_qvalues_for_actions, target_qvalues_for_actions)

        # calculate error loss
        err_loss = torch.mean((curr_errs - target_errs).pow(2))
        # total loss
        loss = q_loss + err_loss
        # update tau
        self._tau = (1.0 - 0.005)*self._tau + 0.005*curr_errs.detach().mean()
        logger.store(QVals = predicted_qvalues.detach().cpu().numpy(), LossQ = q_loss.item(), LossE = err_loss.item(), 
                     Current_errs = curr_errs.detach().cpu().numpy(), tau = self._tau.detach().cpu().numpy())
        return loss

    def calc_importance_weights(self, next_states, is_done):
        with torch.no_grad():
            next_actions = self.DQN_target.get_action(next_states)
            next_errs_all = self.errnet_target(next_states)
            next_errs = next_errs_all[range(next_states.shape[0]), next_actions.squeeze()]
        # Terms inside the exponent of importance weights.
        x = -(1.0 - is_done.float()) * self._gamma * next_errs / self._tau
        # Calculate self-normalized importance weights.
        imp_ws = F.softmax(x, dim=0)
        return imp_ws
    
    def calc_target_errors(self, next_states, dones, curr_qs, target_qs):
        # Calculate targets of the cumulative sum of discounted Bellman errors,
        # which is 'Delta' in the paper.
        with torch.no_grad():
            next_actions = self.DQN_target.get_action(next_states)
            next_errs_all = self.errnet_target(next_states)
            next_errs = next_errs_all[range(next_states.shape[0]), next_actions.squeeze()]

            target_errs = (curr_qs - target_qs).abs() + \
                (1.0 - dones.float()) * self._gamma * next_errs

        return target_errs

    def sample_from_buffer(self, batch_size):
        states, actions, rewards, next_states, dones = [], [], [], [], []
        for i in range(batch_size):
            idx = random.randint(0, self.memory_buffer.size() - 1)
            data = self.memory_buffer.buffer[idx]
            frame, action, reward, next_frame, done= data
            states.append(self.observe(frame))
            actions.append(action)
            rewards.append(reward)
            next_states.append(self.observe(next_frame))
            dones.append(done)
        return torch.cat(states), actions, rewards, torch.cat(next_states), dones
    
    def learn_from_experience(self, batch_size):
        if self.memory_buffer.size() > batch_size:
            states, actions, rewards, next_states, dones = self.sample_from_buffer(batch_size)
            td_loss = self.compute_td_loss(states, actions, rewards, next_states, dones)
            self.optimizer.zero_grad()
            td_loss.backward()
            for param in self.DQN.parameters():
                param.grad.data.clamp_(-1, 1)

            self.optimizer.step()
            return(td_loss.item())
        else:
            return(0)


In [None]:
from run_utils import setup_logger_kwargs
import itertools
import time
from logx import EpochLogger
import pdb

# Training DQN in PongNoFrameskip-v4 
# pdb.set_trace()
env_id = 'SeaquestNoFrameskip-v4'
experiment_name = "softDQN_DisCor_" + env_id
logger_kwargs = setup_logger_kwargs(experiment_name, 0)
logger = EpochLogger(**logger_kwargs)

experiment_dir = os.path.abspath(experiment_name)
monitor_path = os.path.join(experiment_dir, "monitor")
eval_monitor_path = os.path.join(experiment_dir, "eval_monitor")

log_dir = os.path.join(experiment_dir, "log")
model_path = os.path.join(experiment_dir, experiment_name+"_dict.pth.tar")
checkpoint_path = os.path.join(experiment_dir, "check_point")
env = make_atari(env_id)
env = wrap_deepmind(env, scale = False, frame_stack=True , clip_rewards= False, episode_life = True)
env_eval = make_atari(env_id)
env_eval = wrap_deepmind(env_eval, scale = False, frame_stack=True , clip_rewards= False, episode_life = False)

gamma = 0.99
steps_per_epoch = 100000
epochs = 100 # 1000
frames = steps_per_epoch * epochs# 10000000 timestamp/ 
USE_CUDA = True
learning_rate = 1e-4
max_buff = 1000000
prio_a = 0.6
prio_beta = 0.4
tau_init = 10 # initial parameter for error net

update_tar_interval = 10000
batch_size = 32
learning_start = 50000 # 50000
update_current_step = 4 # update current model every 4 steps
beta_increment_per_sampling = (1-prio_beta)/(frames/update_current_step)
record_video = True
record_video_every = 500 # video every 1000
eval_every = steps_per_epoch
save_freq = 1
num_test_episodes = 8

action_space = env.action_space
action_dim = env.action_space.n
state_dim = env.observation_space.shape[0]
state_channel = env.observation_space.shape[2]
reward_scale = 0.05
# logger.save_config(locals())
agent = softDQN_DisCor_Agent(in_channels = state_channel, action_space= action_space, USE_CUDA = USE_CUDA, lr = learning_rate, memory_size = max_buff, reward_scale = reward_scale, 
                     tau_init = tau_init, gamma = gamma)
# Set up model saving
logger.setup_pytorch_saver(agent.DQN)

def eval_episode(agent, env_eval):
    with torch.no_grad():
        for j in range(num_test_episodes):       
            frame, done, ep_ret, ep_len  = env_eval.reset(), False,0,0
            while not done:
                state_tensor = agent.observe(frame)
                action = agent.act(state_tensor, 0,0)
                frame, reward, done, _ = env_eval.step(action)
                ep_ret += reward
                ep_len += 1
            logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)


if record_video:
    env = Monitor(env,
                 directory=monitor_path,
                 resume=True, mode = "training",
                 video_callable=lambda count: count % record_video_every == 0)
    env_eval = Monitor(env_eval,
                 directory=eval_monitor_path,
                 resume=True,
                 video_callable=lambda count: count % num_test_episodes == 0,
                      mode = "evaluation")

frame, ep_ret, ep_num,ep_len = env.reset(),0, 0,0
loss = 0

# tensorboard
summary_writer = SummaryWriter(log_dir = log_dir, comment= "good_makeatari")


start_time = time.time()
for i in range(frames):
    state_tensor = agent.observe(frame)
    action = agent.act(state_tensor, i, learning_start)
    
    next_frame, reward, done, _ = env.step(action)
    
    ep_ret += reward
    ep_len += 1
    agent.memory_buffer.push(frame, action, np.sign(reward), next_frame, done) # !! Clip reward by its sign
    frame = next_frame
    
    if agent.memory_buffer.size() >= learning_start:
        if i % update_current_step == 0:
            loss = agent.learn_from_experience(batch_size)
         
    if i % update_tar_interval == 0:
        agent.DQN_target.load_state_dict(agent.DQN.state_dict())
        agent.errnet_target.load_state_dict(agent.errnet.state_dict())
    
    if done:
        logger.store(EpRet=ep_ret, EpLen=ep_len)
        frame, ep_ret, ep_len = env.reset(), 0, 0
    
    if (i+1) % steps_per_epoch == 0:  
        epoch = (i+1) // steps_per_epoch

        # Save model
        if (epoch % save_freq == 0) or (epoch == epochs):
            logger.save_state({'env': env}, None)

        # Test the performance of the deterministic version of the agent.
        eval_episode(agent, env_eval)
        
        # Write to tensorboard
        summary_writer.add_scalar("Train Reward",logger.get_stats("EpRet")[0], i)
        summary_writer.add_scalar("Test Reward",logger.get_stats("TestEpRet")[0], i)
        summary_writer.add_scalar("Loss Q",logger.get_stats("LossQ")[0], i)
        summary_writer.add_scalar("Loss Error",logger.get_stats("LossE")[0], i)
        summary_writer.add_scalar("Train EpLen",logger.get_stats("EpLen")[0], i)
        summary_writer.add_scalar("Test EpLen",logger.get_stats("TestEpLen")[0], i)
        # Log info about epoch
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('TestEpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('TestEpLen', average_only=True)
        logger.log_tabular('TotalEnvInteracts', i)
        logger.log_tabular('QVals', with_min_and_max=True)
        logger.log_tabular('Current_errs', average_only=True)
        logger.log_tabular('LossQ', average_only=True)
        logger.log_tabular('LossE', average_only=True)
        logger.log_tabular('tau', average_only=True)
        logger.log_tabular('Time', time.time()-start_time)


        logger.dump_tabular()
        
summary_writer.close()
torch.save(agent.DQN.state_dict(), model_path)


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


[32;1mLogging data to /home/liang/RL/data/softDQN_DisCor_SeaquestNoFrameskip-v4/softDQN_DisCor_SeaquestNoFrameskip-v4_s0/progress.txt[0m
---------------------------------------
|             Epoch |               1 |
|      AverageEpRet |            22.5 |
|          StdEpRet |            34.7 |
|          MaxEpRet |             220 |
|          MinEpRet |               0 |
|  AverageTestEpRet |            52.5 |
|      StdTestEpRet |            28.2 |
|      MaxTestEpRet |              80 |
|      MinTestEpRet |               0 |
|             EpLen |             132 |
|         TestEpLen |             471 |
| TotalEnvInteracts |           1e+05 |
|      AverageQVals |            8.08 |
|          StdQVals |            32.4 |
|          MaxQVals |            1.34 |
|          MinQVals |         -0.0377 |
|      Current_errs |          0.0972 |
|             LossQ |         0.00401 |
|             LossE |         0.00732 |
|               tau |          0.0945 |
|              Time |

---------------------------------------
|             Epoch |              10 |
|      AverageEpRet |             113 |
|          StdEpRet |            62.5 |
|          MaxEpRet |             260 |
|          MinEpRet |               0 |
|  AverageTestEpRet |             395 |
|      StdTestEpRet |             143 |
|      MaxTestEpRet |             560 |
|      MinTestEpRet |             140 |
|             EpLen |             351 |
|         TestEpLen |        1.31e+03 |
| TotalEnvInteracts |           1e+06 |
|      AverageQVals |             166 |
|          StdQVals |             665 |
|          MaxQVals |            12.5 |
|          MinQVals |          -0.932 |
|      Current_errs |            5.52 |
|             LossQ |          0.0419 |
|             LossE |           0.154 |
|               tau |            5.51 |
|              Time |        2.73e+04 |
---------------------------------------
---------------------------------------
|             Epoch |              11 |


---------------------------------------
|             Epoch |              19 |
|      AverageEpRet |             201 |
|          StdEpRet |            83.1 |
|          MaxEpRet |             340 |
|          MinEpRet |              20 |
|  AverageTestEpRet |        1.06e+03 |
|      StdTestEpRet |             117 |
|      MaxTestEpRet |        1.26e+03 |
|      MinTestEpRet |             860 |
|             EpLen |             474 |
|         TestEpLen |        2.11e+03 |
| TotalEnvInteracts |         1.9e+06 |
|      AverageQVals |             231 |
|          StdQVals |             927 |
|          MaxQVals |              16 |
|          MinQVals |           0.185 |
|      Current_errs |            9.01 |
|             LossQ |          0.0342 |
|             LossE |           0.178 |
|               tau |            9.01 |
|              Time |        5.32e+04 |
---------------------------------------
---------------------------------------
|             Epoch |              20 |


---------------------------------------
|             Epoch |              28 |
|      AverageEpRet |             282 |
|          StdEpRet |            95.9 |
|          MaxEpRet |             400 |
|          MinEpRet |              20 |
|  AverageTestEpRet |        1.18e+03 |
|      StdTestEpRet |             167 |
|      MaxTestEpRet |         1.4e+03 |
|      MinTestEpRet |             960 |
|             EpLen |             490 |
|         TestEpLen |        1.96e+03 |
| TotalEnvInteracts |         2.8e+06 |
|      AverageQVals |             250 |
|          StdQVals |           1e+03 |
|          MaxQVals |            17.7 |
|          MinQVals |           0.131 |
|      Current_errs |            8.87 |
|             LossQ |          0.0216 |
|             LossE |          0.0787 |
|               tau |            8.87 |
|              Time |        7.86e+04 |
---------------------------------------
---------------------------------------
|             Epoch |              29 |


---------------------------------------
|             Epoch |              37 |
|      AverageEpRet |             326 |
|          StdEpRet |            95.4 |
|          MaxEpRet |             420 |
|          MinEpRet |              20 |
|  AverageTestEpRet |        1.06e+03 |
|      StdTestEpRet |             174 |
|      MaxTestEpRet |        1.34e+03 |
|      MinTestEpRet |             700 |
|             EpLen |             514 |
|         TestEpLen |        1.83e+03 |
| TotalEnvInteracts |         3.7e+06 |
|      AverageQVals |             256 |
|          StdQVals |        1.03e+03 |
|          MaxQVals |            18.5 |
|          MinQVals |           0.126 |
|      Current_errs |            8.61 |
|             LossQ |          0.0164 |
|             LossE |          0.0735 |
|               tau |            8.61 |
|              Time |        1.07e+05 |
---------------------------------------
---------------------------------------
|             Epoch |              38 |


In [31]:
agent.DQN(state_tensor).cpu().detach().numpy().argmax(1)[0]

2