Skip to content

Commit

Permalink
Expand to d4rl-pybullet (#416)
Browse files Browse the repository at this point in the history
* Expand to d4rl-pybullet

Support d4rl-pybullet and make a few other changes.

* Update NEWS.md and ci.yml

Co-authored-by: Jun Tian <find_my_way@foxmail.com>
  • Loading branch information
Mobius1D and findmyway committed Jul 28, 2021
1 parent 625304f commit 6e60a19
Show file tree
Hide file tree
Showing 11 changed files with 128 additions and 54 deletions.
3 changes: 2 additions & 1 deletion .cspell/cspell.json
Expand Up @@ -97,7 +97,8 @@
"Thibaut",
"boxoban",
"DATADEPS",
"umaze"
"umaze",
"pybullet"
],
"ignoreWords": [],
"minWordLength": 5,
Expand Down
3 changes: 2 additions & 1 deletion NEWS.md
Expand Up @@ -19,7 +19,8 @@

#### v0.1.0

- Add functionality for fetching d4rl datasets as an iterable D4RLDataSet. Credits: https://arxiv.org/abs/2004.07219
- Add functionality for fetching d4rl datasets as an iterable DataSet. Credits: https://arxiv.org/abs/2004.07219
- This supports d4rl and d4rl-pybullet datasets.
- Uses DataDeps for data dependency management.

## ReinforcementLearning.jl@v0.9.0
Expand Down
4 changes: 3 additions & 1 deletion src/ReinforcementLearningDatasets/README.md
Expand Up @@ -2,6 +2,8 @@

A package to create, manage, store and retrieve datasets for Offline Reinforcement Learning using ReinforcementLearning.jl package.

- This package uses DataDeps.jl to fetch and manage packages.

### Note:

The package is under active development and for now it supports only d4rl datasets.
The package is under active development and for now it supports d4rl and d4rl-pybullet datasets.
Expand Up @@ -6,6 +6,8 @@ export RLDatasets
using DataDeps

include("d4rl/register.jl")
include("d4rl/d4rl_dataset.jl")
include("d4rl_pybullet/register.jl")
include("init.jl")
include("dataset.jl")

end
26 changes: 13 additions & 13 deletions src/ReinforcementLearningDatasets/src/d4rl/register.jl
@@ -1,13 +1,13 @@
export DATASET_URLS
export REF_MAX_SCORE
export REF_MIN_SCORE
export D4RL_DATASET_URLS
export D4RL_REF_MAX_SCORE
export D4RL_REF_MIN_SCORE

"""
This file holds the registration information for d4rl datasets.
It also registers the information in DataDeps for further use in this package.
"""

const DATASET_URLS = Dict{String, String}(
const D4RL_DATASET_URLS = Dict{String, String}(
"maze2d-open-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-sparse.hdf5",
"maze2d-umaze-v1" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse-v1.hdf5",
"maze2d-medium-v1" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse-v1.hdf5",
Expand Down Expand Up @@ -100,7 +100,7 @@ const DATASET_URLS = Dict{String, String}(
)


const REF_MIN_SCORE = Dict{String, Float32}(
const D4RL_REF_MIN_SCORE = Dict{String, Float32}(
"maze2d-open-v0" => 0.01 ,
"maze2d-umaze-v1" => 23.85 ,
"maze2d-medium-v1" => 13.13 ,
Expand Down Expand Up @@ -184,7 +184,7 @@ const REF_MIN_SCORE = Dict{String, Float32}(
"bullet-maze2d-large-v0"=> 1.820000,
)

const REF_MAX_SCORE = Dict{String, Float32}(
const D4RL_REF_MAX_SCORE = Dict{String, Float32}(
"maze2d-open-v0" => 20.66 ,
"maze2d-umaze-v1" => 161.86 ,
"maze2d-medium-v1" => 277.39 ,
Expand Down Expand Up @@ -269,24 +269,24 @@ const REF_MAX_SCORE = Dict{String, Float32}(
)

# give a prompt for flow and carla tasks
# add checksums

function __init__()
for ds in keys(DATASET_URLS)
function d4rl_init()
repo = "d4rl"
for ds in keys(D4RL_DATASET_URLS)
register(
DataDep(
"d4rl-" * ds,
repo*"-"* ds,
"""
Credits: https://arxiv.org/abs/2004.07219
The following dataset is fetched from the d4rl.
The dataset is fetched and modified in a form that is useful for RL.jl package.
Dataset information:
Name: $(ds)
$(if ds in keys(REF_MAX_SCORE) "MAXIMUM_SCORE: " * string(REF_MAX_SCORE[ds]) end)
$(if ds in keys(REF_MIN_SCORE) "MINIMUM_SCORE: " * string(REF_MIN_SCORE[ds]) end)
$(if ds in keys(D4RL_REF_MAX_SCORE) "MAXIMUM_SCORE: " * string(D4RL_REF_MAX_SCORE[ds]) end)
$(if ds in keys(D4RL_REF_MIN_SCORE) "MINIMUM_SCORE: " * string(D4RL_REF_MIN_SCORE[ds]) end)
""", #check if the MAX and MIN score part is even necessary and make the log file prettier
DATASET_URLS[ds],
D4RL_DATASET_URLS[ds],
)
)
end
Expand Down
33 changes: 33 additions & 0 deletions src/ReinforcementLearningDatasets/src/d4rl_pybullet/register.jl
@@ -0,0 +1,33 @@
export D4RL_PYBULLET_URLS

const D4RL_PYBULLET_URLS = Dict(
"hopper-bullet-mixed-v0" => "https://www.dropbox.com/s/xv3p0h7dzgxt8xb/hopper-bullet-mixed-v0.hdf5?dl=1",
"walker2d-bullet-random-v0" => "https://www.dropbox.com/s/1gwcfl2nmx6878m/walker2d-bullet-random-v0.hdf5?dl=1",
"hopper-bullet-medium-v0" => "https://www.dropbox.com/s/w22kgzldn6eng7j/hopper-bullet-medium-v0.hdf5?dl=1",
"walker2d-bullet-mixed-v0" => "https://www.dropbox.com/s/i4u2ii0d85iblou/walker2d-bullet-mixed-v0.hdf5?dl=1",
"halfcheetah-bullet-mixed-v0" => "https://www.dropbox.com/s/scj1rqun963aw90/halfcheetah-bullet-mixed-v0.hdf5?dl=1",
"halfcheetah-bullet-random-v0" => "https://www.dropbox.com/s/jnvpb1hp60zt2ak/halfcheetah-bullet-random-v0.hdf5?dl=1",
"walker2d-bullet-medium-v0" => "https://www.dropbox.com/s/v0f2kz48b1hw6or/walker2d-bullet-medium-v0.hdf5?dl=1",
"hopper-bullet-random-v0" => "https://www.dropbox.com/s/bino8ojd7iq4p4d/hopper-bullet-random-v0.hdf5?dl=1",
"ant-bullet-random-v0" => "https://www.dropbox.com/s/2xpmh4wk2m7i8xh/ant-bullet-random-v0.hdf5?dl=1",
"halfcheetah-bullet-medium-v0" => "https://www.dropbox.com/s/v4xgssp1w968a9l/halfcheetah-bullet-medium-v0.hdf5?dl=1",
"ant-bullet-medium-v0" => "https://www.dropbox.com/s/6n79kwd94xthr1t/ant-bullet-medium-v0.hdf5?dl=1",
"ant-bullet-mixed-v0" => "https://www.dropbox.com/s/pmy3dzab35g4whk/ant-bullet-mixed-v0.hdf5?dl=1"
)

function d4rl_pybullet_init()
repo = "d4rl-pybullet"
for ds in keys(D4RL_PYBULLET_URLS)
register(
DataDep(
repo* "-" * ds,
"""
Credits: https://github.com/takuseno/d4rl-pybullet
The following dataset is fetched from the d4rl-pybullet.
""",
D4RL_PYBULLET_URLS[ds],
)
)
end
nothing
end
Expand Up @@ -7,24 +7,26 @@ import Base: iterate, length, IteratorEltype
export dataset
export SARTS
export SART
export D4RLDataSet
export DataSet

const SARTS = (:state, :action, :reward, :terminals, :next_state)
const SART = (:state, :action, :reward, :terminals)
const SARTS = (:state, :action, :reward, :terminal, :next_state)
const SART = (:state, :action, :reward, :terminal)

"""
Represents a iterable dataset from d4rl with the following fields:
Represents a iterable dataset with the following fields:
`dataset`: Dict{Symbol, Any}, representation of the dataset as a Dictionary with style as `style`
`repo`: String, the repository from which the dataset is taken
`size`: Integer, the size of the dataset
`batch_size`: Integer, the size of the batches returned by `iterate`.
`style`: Tuple, the type of the NamedTuple, for now SARTS and SART is supported.
`rng`<: AbstractRNG.
`meta`: Dict, the metadata provided along with the dataset
`is_shuffle`: Bool, determines if the batches returned by `iterate` are shuffled.
"""
struct D4RLDataSet{T<:AbstractRNG}
struct DataSet{T<:AbstractRNG}
dataset::Dict{Symbol, Any}
repo::String
size::Integer
batch_size::Integer
style::Tuple
Expand All @@ -39,11 +41,12 @@ end
"""
dataset(dataset::String; style::Tuple, rng<:AbstractRNG, is_shuffle::Bool, max_iters::Int64, batch_size::Int64)
Creates a dataset of enclosed in a D4RLDataSet type and other related metadata for the `dataset` that is passed.
The dataset type is an iterable that fetches batches when used in a for loop for convenience during offline training.
Creates a dataset of enclosed in a DataSet type and other related metadata for the `dataset` that is passed.
The `DataSet` type is an iterable that fetches batches when used in a for loop for convenience during offline training.
`dataset`: Name of the D4RLDataSet dataset.
`style`: the style of the iterator and the Dict inside D4RLDataSet that is returned.
`dataset`: Dict{Symbol, Any}, Name of the datset.
`repo`: Name of the repository of the dataset.
`style`: the style of the iterator and the Dict inside DataSet that is returned.
`rng`: StableRNG
`max_iters`: maximum number of iterations for the iterator.
`is_shuffle`: whether the dataset is shuffled or not. `true` by default.
Expand All @@ -52,19 +55,20 @@ The dataset type is an iterable that fetches batches when used in a for loop for
The returned type is an infinite iterator which can be called using `iterate` and will return batches as specified in the dataset.
"""
function dataset(dataset::String;
style=SARTS,
style=SARTS,
repo = "d4rl",
rng = StableRNG(123),
is_shuffle = true,
batch_size=256
)

try
@datadep_str "d4rl-"*dataset
@datadep_str repo*"-"*dataset
catch
throw("The provided dataset is not available")
end

path = @datadep_str "d4rl-"*dataset
path = @datadep_str repo*"-"*dataset

@assert length(readdir(path)) == 1
file_name = readdir(path)[1]
Expand All @@ -79,7 +83,7 @@ function dataset(dataset::String;
dataset = Dict{Symbol, Any}()
meta = Dict{String, Any}()

N_samples = size(data["terminals"])[1]
N_samples = size(data["observations"])[2]

for (key, d_key) in zip(["observations", "actions", "rewards", "terminals"], Symbol.(["state", "action", "reward", "terminal"]))
dataset[d_key] = data[key]
Expand All @@ -91,11 +95,11 @@ function dataset(dataset::String;
end
end

return D4RLDataSet(dataset, N_samples, batch_size, style, rng, meta, is_shuffle)
return DataSet(dataset, repo, N_samples, batch_size, style, rng, meta, is_shuffle)

end

function iterate(ds::D4RLDataSet, state = 0)
function iterate(ds::DataSet, state = 0)
rng = ds.rng
batch_size = ds.batch_size
size = ds.size
Expand Down Expand Up @@ -127,16 +131,16 @@ function iterate(ds::D4RLDataSet, state = 0)
end


take(ds::D4RLDataSet, n::Integer) = take(ds.dataset, n)
length(ds::D4RLDataSet) = ds.size
IteratorEltype(::Type{D4RLDataSet}) = EltypeUnknown() # see if eltype can be known (not sure about carla and adroit)
take(ds::DataSet, n::Integer) = take(ds.dataset, n)
length(ds::DataSet) = ds.size
IteratorEltype(::Type{DataSet}) = EltypeUnknown() # see if eltype can be known (not sure about carla and adroit)


function verify(data::Dict{String, Any})
for key in ["observations", "actions", "rewards", "terminals"]
@assert (key in keys(data)) "Expected keys not present in data"
end
N_samples = size(data["observations"])[2]
@assert size(data["rewards"]) == (N_samples,)
@assert size(data["terminals"]) == (N_samples,)
@assert size(data["rewards"]) == (N_samples,) || size(data["rewards"]) == (1, N_samples)
@assert size(data["terminals"]) == (N_samples,) || size(data["terminals"]) == (1, N_samples)
end
4 changes: 4 additions & 0 deletions src/ReinforcementLearningDatasets/src/init.jl
@@ -0,0 +1,4 @@
function __init__()
RLDatasets.d4rl_init()
RLDatasets.d4rl_pybullet_init()
end
32 changes: 32 additions & 0 deletions src/ReinforcementLearningDatasets/test/d4rl_pybullet.jl
@@ -0,0 +1,32 @@
using Base: batch_size_err_str
@testset "d4rl_pybullet" begin
ds = dataset(
"hopper-bullet-mixed-v0";
repo="d4rl-pybullet",
style = style,
rng = rng,
is_shuffle = true,
batch_size = batch_size
)

n_s = 15
n_a = 3

N_samples = 59345

data_dict = ds.dataset

@test size(data_dict[:state]) == (n_s, N_samples)
@test size(data_dict[:action]) == (n_a, N_samples)
@test size(data_dict[:reward]) == (1, N_samples)
@test size(data_dict[:terminal]) == (1, N_samples)

for sample in Iterators.take(ds, 3)
@test typeof(sample) <: NamedTuple{SARTS}
@test size(sample[:state]) == (n_s, batch_size)
@test size(sample[:action]) == (n_a, batch_size)
@test size(sample[:reward]) == (1, batch_size) || size(sample[:reward]) == (batch_size,)
@test size(sample[:terminal]) == (1, batch_size) || size(sample[:terminal]) == (batch_size,)
end

end
@@ -1,32 +1,30 @@
n_s = 11
n_a = 3
N_samples = 200919
batch_size = 256
style = SARTS
rng = StableRNG(123)

@testset "dataset_d4rl_shuffle" begin
# TO-DO make functions to make tests modular and more widely applicable
@testset "dataset_shuffle" begin
ds = dataset(
"hopper-medium-replay-v0";
repo="d4rl",
style = style,
rng = rng,
is_shuffle = true,
batch_size = batch_size
)

data_dict = ds.dataset
N_samples = size(data_dict[:state])[2]

@test size(data_dict[:state]) == (n_s, N_samples)
@test size(data_dict[:action]) == (n_a, N_samples)
@test size(data_dict[:reward]) == (N_samples,)
@test size(data_dict[:terminal]) == (N_samples,)

i = 1

for sample in ds
if i > 5 break end
for sample in Iterators.take(ds, 3)
@test typeof(sample) <: NamedTuple
i += 1
end

sample1 = iterate(ds)
Expand All @@ -42,7 +40,7 @@ rng = StableRNG(123)
@test length(iters) == 2

for iter in iters
@test typeof(iter) <: NamedTuple
@test typeof(iter) <: NamedTuple{SARTS}
end

@test iter1 != iter2
Expand All @@ -54,7 +52,7 @@ rng = StableRNG(123)

end

@testset "dataset_d4rl" begin
@testset "dataset" begin
ds = dataset(
"hopper-medium-replay-v0";
style = style,
Expand All @@ -63,20 +61,16 @@ end
batch_size = batch_size
)


data_dict = ds.dataset
N_samples = size(data_dict[:state])[2]

@test size(data_dict[:state]) == (n_s, N_samples)
@test size(data_dict[:action]) == (n_a, N_samples)
@test size(data_dict[:reward]) == (N_samples,)
@test size(data_dict[:terminal]) == (N_samples,)

i = 1

for sample in ds
if i > 5 break end
@test typeof(sample) <: NamedTuple
i += 1
for sample in Iterators.take(ds, 3)
@test typeof(sample) <: NamedTuple{SARTS}
end

sample1 = iterate(ds)
Expand Down
3 changes: 2 additions & 1 deletion src/ReinforcementLearningDatasets/test/runtests.jl
Expand Up @@ -6,5 +6,6 @@ using Test
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"

@testset "ReinforcementLearningDatasets.jl" begin
include("d4rl/d4rl_dataset.jl")
include("dataset.jl")
include("d4rl_pybullet.jl")
end

0 comments on commit 6e60a19

Please sign in to comment.