## Two ways of interacting with (PO)MDPs: state objects and indices

In [1]:
using DMUStudent: HW2
using POMDPs

In [2]:
m = HW2.grid_world

POMDPModels.SimpleGridWorld
  size: Tuple{Int64, Int64}
  rewards: Dict{StaticArraysCore.SVector{2, Int64}, Float64}
  terminate_from: Set{StaticArraysCore.SVector{2, Int64}}
  tprob: Float64 0.7
  discount: Float64 0.95


In [3]:
states(m) # state objects

101-element Vector{StaticArraysCore.SVector{2, Int64}}:
 [1, 1]
 [2, 1]
 [3, 1]
 [4, 1]
 [5, 1]
 [6, 1]
 [7, 1]
 [8, 1]
 [9, 1]
 [10, 1]
 ⋮
 [3, 10]
 [4, 10]
 [5, 10]
 [6, 10]
 [7, 10]
 [8, 10]
 [9, 10]
 [10, 10]
 [-1, -1]

In [4]:
1:length(states(m)) # state indices

1:101

In [5]:
stateindex(m, [6,8])

76

In [6]:
transition(m, [1,1], :left) # state objects

                      [1mSparseCat distribution[22m           
            [90m┌                                        ┐[39m 
     [0m[1, 1] [90m┤[39m[38;5;2m■■■■■■■■■■■■■■■■■■■■[39m[0m 0.7999999999999999 [90m [39m 
     [0m[1, 2] [90m┤[39m[38;5;2m■■■[39m[0m 0.10000000000000002                 [90m [39m 
   [0m[-1, -1] [90m┤[39m[0m 0.0                                    [90m [39m 
   [0m[-1, -1] [90m┤[39m[0m 0.0                                    [90m [39m 
     [0m[2, 1] [90m┤[39m[38;5;2m■■■[39m[0m 0.10000000000000002                 [90m [39m 
            [90m└                                        ┘[39m 

In [7]:
T = HW2.transition_matrices(m) # state indices

Dict{Symbol, Matrix{Float64}} with 4 entries:
  :left  => [0.8 0.1 … 0.0 0.0; 0.7 0.1 … 0.0 0.0; … ; 0.0 0.0 … 0.2 0.0; 0.0 0…
  :right => [0.2 0.7 … 0.0 0.0; 0.1 0.1 … 0.0 0.0; … ; 0.0 0.0 … 0.8 0.0; 0.0 0…
  :up    => [0.2 0.1 … 0.0 0.0; 0.1 0.1 … 0.0 0.0; … ; 0.0 0.0 … 0.8 0.0; 0.0 0…
  :down  => [0.8 0.1 … 0.0 0.0; 0.1 0.7 … 0.0 0.0; … ; 0.0 0.0 … 0.2 0.0; 0.0 0…

In [8]:
T[:left][1, 2]

0.10000000000000002

In [9]:
T[:right][stateindex(m, [1,1]), stateindex(m, [2,1])]

0.7

In [10]:
R = HW2.reward_vectors(m) # state indices

Dict{Symbol, Vector{Float64}} with 4 entries:
  :left  => [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0…
  :right => [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0…
  :up    => [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0…
  :down  => [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0…

In [13]:
using POMDPTools: ordered_states

In [14]:
ordered_states(m)

101-element Vector{StaticArraysCore.SVector{2, Int64}}:
 [1, 1]
 [2, 1]
 [3, 1]
 [4, 1]
 [5, 1]
 [6, 1]
 [7, 1]
 [8, 1]
 [9, 1]
 [10, 1]
 ⋮
 [3, 10]
 [4, 10]
 [5, 10]
 [6, 10]
 [7, 10]
 [8, 10]
 [9, 10]
 [10, 10]
 [-1, -1]

## For Homework 2, you only need to use state indices
(if you want)

# Linear Algebra

In [15]:
x = [1,2]

2-element Vector{Int64}:
 1
 2

In [16]:
A = [1 2; 3 4]

2×2 Matrix{Int64}:
 1  2
 3  4

In [17]:
Vector{Int64}

Vector{Int64}[90m (alias for [39m[90mArray{Int64, 1}[39m[90m)[39m

In [18]:
Matrix{Int64}

Matrix{Int64}[90m (alias for [39m[90mArray{Int64, 2}[39m[90m)[39m

In [19]:
inv(A)

2×2 Matrix{Float64}:
 -2.0   1.0
  1.5  -0.5

In [20]:
norm(x)

UndefVarError: UndefVarError: norm not defined

In [21]:
using LinearAlgebra: norm

In [22]:
norm(x)

2.23606797749979

## Use the "dot syntax" to broadcast operations

In [23]:
x^2

MethodError: MethodError: no method matching ^(::Vector{Int64}, ::Int64)
Closest candidates are:
  ^(!Matched::Union{AbstractChar, AbstractString}, ::Integer) at strings/basic.jl:730
  ^(!Matched::LinearAlgebra.Symmetric{var"#s884", S} where {var"#s884"<:Real, S<:(AbstractMatrix{<:var"#s884"})}, ::Integer) at /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/symmetric.jl:674
  ^(!Matched::LinearAlgebra.Symmetric{var"#s884", S} where {var"#s884"<:Complex, S<:(AbstractMatrix{<:var"#s884"})}, ::Integer) at /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/symmetric.jl:675
  ...

In [24]:
x.^2

2-element Vector{Int64}:
 1
 4

## It automatically works for any function (including your own)

In [25]:
f(a) = sqrt(a) + 1

f (generic function with 1 method)

In [26]:
f(x)

MethodError: MethodError: no method matching sqrt(::Vector{Int64})
Closest candidates are:
  sqrt(!Matched::Union{Float32, Float64}) at math.jl:590
  sqrt(!Matched::StridedMatrix{T}) where T<:Union{Real, Complex} at /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/dense.jl:853
  sqrt(!Matched::PDMats.ScalMat) at ~/.julia/packages/PDMats/ZW0lN/src/scalmat.jl:68
  ...

In [27]:
f.(x)

2-element Vector{Float64}:
 2.0
 2.414213562373095

In [28]:
max(1,2)

2

In [29]:
max([1,2,3])

MethodError: MethodError: no method matching max(::Vector{Int64})
Closest candidates are:
  max(::Any, !Matched::Missing) at missing.jl:137
  max(::Any, !Matched::Any) at operators.jl:480
  max(::Any, !Matched::Any, !Matched::Any, !Matched::Any...) at operators.jl:591
  ...

In [30]:
maximum([1,2,3])

3

# Dictionaries

Dictionaries are hash maps that can store any object with any index

In [33]:
d = Dict("a"=>1, 2=>3)

Dict{Any, Int64} with 2 entries:
  2   => 3
  "a" => 1

In [34]:
d[2]

3

In [35]:
d[1]

KeyError: KeyError: key 1 not found

In [36]:
d[4] = 5

5

In [37]:
d

Dict{Any, Int64} with 3 entries:
  4   => 5
  2   => 3
  "a" => 1