<a href="https://colab.research.google.com/github/JFHwang/deepmind-research/blob/master/rl_unplugged/atari_dqn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Copyright 2020 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License. You may obtain a copy of the
License at

[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)

Unless required by applicable law or agreed to in writing, software distributed
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

# RL Unplugged: Offline DQN - Atari
## Guide to  training an Acme DQN agent on Atari data.
# <a href="https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/rl_unplugged/atari_dqn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>




## Installation

In [None]:
!pip install dm-acme==0.2.0
!pip install dm-acme[reverb]==0.2.0
!pip install dm-acme[tf]==0.2.0
!pip install dm-sonnet
!pip install dopamine-rl==3.1.2
!pip install atari-py
!git clone https://github.com/deepmind/deepmind-research.git
%cd deepmind-research

In [None]:
!wget http://www.atarimania.com/roms/Roms.rar
!mkdir roms
!unrar e Roms.rar roms
!python -m atari_py.import_roms roms/

## Imports

In [None]:
%cd /content/deepmind-research
import copy

import acme
import rl_unplugged
from acme.agents.tf import actors
from acme.agents.tf.dqn import learning as dqn
from acme.tf import utils as acme_utils
from acme.utils import loggers
from rl_unplugged import atari
import sonnet as snt
import tensorflow as tf

/content/deepmind-research




##Setting Parameters




In [None]:
game = 'BattleZone' #@param
shardcount = 100#@param
run =   1#@param
tmp_path = '/tmp/dataset'
gs_path = 'gs://rl_unplugged/atari'
!mkdir -p {tmp_path}/{game}
environment = atari.environment(game=game)

## Data

In [None]:
if shardcount < 100:

  # If not enough disk space, only copy subset of samples
  for i in range(shardcount/10):
    !gsutil -m cp -R gs://rl_unplugged/atari/{game}/run_{run}-00{i}* /tmp/dataset/{game}

  for shard in range(shardcount):
    fshard = "{:05d}".format(shard)
    fshardcount = "{:05d}".format(shardcount)
    !mv /tmp/dataset/{game}/run_{run}-{fshard}-of-00100 /tmp/dataset/{game}/run_{run}-{fshard}-of-{fshardcount}

else:
  # Copy all 100 shards
  !gsutil -m cp -R gs://rl_unplugged/atari/{game}/run_{run}* /tmp/dataset/{game}

##Create Agent

In [None]:
# Get total number of actions.
num_actions = environment.action_spec().num_values

# Create the Q network.
network = snt.Sequential([
    lambda x: tf.image.convert_image_dtype(x, tf.float32),
    snt.Conv2D(32, [8, 8], [4, 4]),
    tf.nn.relu,
    snt.Conv2D(64, [4, 4], [2, 2]),
    tf.nn.relu,
    snt.Conv2D(64, [3, 3], [1, 1]),
    tf.nn.relu,
    snt.Flatten(),
    snt.nets.MLP([512, num_actions])
])
acme_utils.create_variables(network, [environment.observation_spec()])

TensorSpec(shape=(18,), dtype=tf.float32, name=None)

## DQN learner

In [None]:
batch_size = 32#@param

def discard_extras(sample):
  return sample._replace(data=sample.data[:5])

#Organize files into dataset. tuples of length 5 and batches of 32 tuples.
dataset = atari.dataset(path=tmp_path, game=game, run=run, num_shards=shardcount)
dataset = dataset.map(discard_extras).batch(batch_size)

##DQN

In [None]:
# Create a logger.
logger = loggers.TerminalLogger(label='learner', time_delta=60.)

# Create the DQN learner.
learner = dqn.DQNLearner(
    network=network,
    target_network=copy.deepcopy(network),
    discount=0.99,
    learning_rate=25e-5,
    importance_sampling_exponent=0.2,
    target_update_period=8000,
    dataset=dataset,
    logger=logger)

## Training and Eval Loop

In [None]:
iterations = 20#@param
# Create a logger.
logger = loggers.TerminalLogger(label='evaluation', time_delta=1.)

for iteration in range(iterations):
  # Training Loop
  for _ in range(250000):
    learner.step()


  # Create an environment loop.
  policy_network = snt.Sequential([
      network,
      lambda q: tf.argmax(q, axis=-1),
  ])
  loop = acme.EnvironmentLoop(
      environment=environment,
      actor=actors.FeedForwardActor(policy_network=policy_network),
      logger=logger)

  loop.run(20)