## Two ways of interacting with (PO)MDPs: state objects and indices

In [32]:
using DMUStudent: HW2
using POMDPs

In [33]:
m = HW2.grid_world

POMDPModels.SimpleGridWorld
  size: Tuple{Int64, Int64}
  rewards: Dict{StaticArrays.SVector{2, Int64}, Float64}
  terminate_from: Set{StaticArrays.SVector{2, Int64}}
  tprob: Float64 0.7
  discount: Float64 0.95


In [34]:
states(m) # state objects

101-element Vector{StaticArrays.SVector{2, Int64}}:
 [1, 1]
 [2, 1]
 [3, 1]
 [4, 1]
 [5, 1]
 [6, 1]
 [7, 1]
 [8, 1]
 [9, 1]
 [10, 1]
 [1, 2]
 [2, 2]
 [3, 2]
 ⋮
 [10, 9]
 [1, 10]
 [2, 10]
 [3, 10]
 [4, 10]
 [5, 10]
 [6, 10]
 [7, 10]
 [8, 10]
 [9, 10]
 [10, 10]
 [-1, -1]

In [35]:
1:length(states(m)) # state indices

1:101

In [37]:
stateindex(m, [6,8])

76

In [38]:
transition(m, [1,1], :left) # state objects

                      [1mSparseCat distribution[22m           
            [90m┌                                        ┐[39m 
     [0m[1, 1] [90m┤[39m[38;5;2m■■■■■■■■■■■■■■■■■■■■[39m[0m 0.7999999999999999 [90m [39m 
     [0m[1, 2] [90m┤[39m[38;5;2m■■■[39m[0m 0.10000000000000002                 [90m [39m 
   [0m[-1, -1] [90m┤[39m[0m 0.0                                    [90m [39m 
   [0m[-1, -1] [90m┤[39m[0m 0.0                                    [90m [39m 
     [0m[2, 1] [90m┤[39m[38;5;2m■■■[39m[0m 0.10000000000000002                 [90m [39m 
            [90m└                                        ┘[39m 

In [39]:
T = HW2.transition_matrices(m) # state indices

Dict{Symbol, Matrix{Float64}} with 4 entries:
  :left  => [0.8 0.1 … 0.0 0.0; 0.7 0.1 … 0.0 0.0; … ; 0.0 0.0 … 0.2 0.0; 0.0 0…
  :right => [0.2 0.7 … 0.0 0.0; 0.1 0.1 … 0.0 0.0; … ; 0.0 0.0 … 0.8 0.0; 0.0 0…
  :up    => [0.2 0.1 … 0.0 0.0; 0.1 0.1 … 0.0 0.0; … ; 0.0 0.0 … 0.8 0.0; 0.0 0…
  :down  => [0.8 0.1 … 0.0 0.0; 0.1 0.7 … 0.0 0.0; … ; 0.0 0.0 … 0.2 0.0; 0.0 0…

In [40]:
T[:left][1, 2]

0.10000000000000002

In [41]:
T[:right][stateindex(m, [1,1]), stateindex(m, [2,1])]

0.7

In [42]:
R = HW2.reward_vectors(m) # state indices

Dict{Symbol, Vector{Float64}} with 4 entries:
  :left  => [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0…
  :right => [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0…
  :up    => [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0…
  :down  => [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0…

In [43]:
using POMDPModelTools: ordered_states

In [44]:
ordered_states(m)

101-element Vector{StaticArrays.SVector{2, Int64}}:
 [1, 1]
 [2, 1]
 [3, 1]
 [4, 1]
 [5, 1]
 [6, 1]
 [7, 1]
 [8, 1]
 [9, 1]
 [10, 1]
 [1, 2]
 [2, 2]
 [3, 2]
 ⋮
 [10, 9]
 [1, 10]
 [2, 10]
 [3, 10]
 [4, 10]
 [5, 10]
 [6, 10]
 [7, 10]
 [8, 10]
 [9, 10]
 [10, 10]
 [-1, -1]

## For Homework 2, you only need to use state indices
(if you want)

# Linear Algebra

In [45]:
x = [1,2]

2-element Vector{Int64}:
 1
 2

In [46]:
A = [1 2; 3 4]

2×2 Matrix{Int64}:
 1  2
 3  4

In [47]:
Vector{Int64}

Vector{Int64} (alias for Array{Int64, 1})

In [48]:
Matrix{Int64}

Matrix{Int64} (alias for Array{Int64, 2})

In [49]:
inv(A)

2×2 Matrix{Float64}:
 -2.0   1.0
  1.5  -0.5

In [50]:
norm(x)

2.23606797749979

In [22]:
using LinearAlgebra: norm

In [23]:
norm(x)

2.23606797749979

## Use the "dot syntax" to broadcast operations

In [51]:
x^2

LoadError: MethodError: no method matching ^(::Vector{Int64}, ::Int64)
[0mClosest candidates are:
[0m  ^([91m::Union{AbstractChar, AbstractString}[39m, ::Integer) at /opt/julia-1.7.1/share/julia/base/strings/basic.jl:721
[0m  ^([91m::LinearAlgebra.Diagonal[39m, ::Integer) at /opt/julia-1.7.1/share/julia/stdlib/v1.7/LinearAlgebra/src/diagonal.jl:196
[0m  ^([91m::LinearAlgebra.Diagonal[39m, ::Real) at /opt/julia-1.7.1/share/julia/stdlib/v1.7/LinearAlgebra/src/diagonal.jl:195
[0m  ...

In [52]:
x.^2

2-element Vector{Int64}:
 1
 4

## It automatically works for any function (including your own)

In [53]:
f(a) = sqrt(a) + 1

f (generic function with 1 method)

In [54]:
f(x)

LoadError: MethodError: no method matching sqrt(::Vector{Int64})
[0mClosest candidates are:
[0m  sqrt([91m::Union{Float32, Float64}[39m) at /opt/julia-1.7.1/share/julia/base/math.jl:566
[0m  sqrt([91m::StridedMatrix{T}[39m) where T<:Union{Real, Complex} at /opt/julia-1.7.1/share/julia/stdlib/v1.7/LinearAlgebra/src/dense.jl:836
[0m  sqrt([91m::LinearAlgebra.Diagonal[39m) at /opt/julia-1.7.1/share/julia/stdlib/v1.7/LinearAlgebra/src/diagonal.jl:592
[0m  ...

In [55]:
f.(x)

2-element Vector{Float64}:
 2.0
 2.414213562373095

In [56]:
max(1,2)

2

In [57]:
max([1,2,3])

LoadError: MethodError: no method matching max(::Vector{Int64})
[0mClosest candidates are:
[0m  max(::Any, [91m::Missing[39m) at /opt/julia-1.7.1/share/julia/base/missing.jl:137
[0m  max(::Any, [91m::Any[39m) at /opt/julia-1.7.1/share/julia/base/operators.jl:492
[0m  max(::Any, [91m::Any[39m, [91m::Any[39m, [91m::Any...[39m) at /opt/julia-1.7.1/share/julia/base/operators.jl:655
[0m  ...

In [58]:
maximum([1,2,3])

3

# Dictionaries

Dictionaries are hash maps that can store any object with any index

In [59]:
d = Dict("a"=>1, 2=>3)

Dict{Any, Int64} with 2 entries:
  2   => 3
  "a" => 1

In [60]:
d[2]

3

In [61]:
d[1]

LoadError: KeyError: key 1 not found

In [62]:
d[4] = 5

5

In [63]:
d

Dict{Any, Int64} with 3 entries:
  4   => 5
  2   => 3
  "a" => 1