In [1]:
using Printf
using DataFrames
using CSV
using LinearAlgebra

### functions 

In [2]:
struct MDP 
    Œ≥   # discount rate 
    S   # state space
    A   # action space
    T   # transition function
    R   # reward function
    TR  # transition function 
end 

In [3]:
mutable struct IncrementalEstimate
    Œº  # mean estimate 
    Œ±  # learning rate (function) 
    m  # num of updates 
end 

function update!(model::IncrementalEstimate, x)
    model.m += 1 
    model.Œº += model.Œ±(model.m) * (x - model.Œº) 
    return model 
end 

update! (generic function with 1 method)

In [20]:
# for offline RL, sample a tuple (s,a,r,s_) to update a model
function sample_data(df)
    row = size(df,1)
    i = rand(1:row)
#     s, a, r, s_ = (df.s_i[i], df.s_j[i]), df.a[i], df.r[i], (df.sp_i[i], df.sp_j[i]) 
    s, a, r, s_ = df.s[i], df.a[i], df.r[i], df.sp[i] 
    return s, a, r, s_
end 

sample_data (generic function with 1 method)

In [21]:
# algorithm 15.9
function simulate(ùí´::MDP, model, œÄ, h, s) 
    for i in 1:h 
        a = œÄ(model, s) 
        s‚Ä≤, r = ùí´.TR(s, a) 
        update!(model, s, a, r, s‚Ä≤)  # update model from the sample (s,a,r,s_)
        s = s‚Ä≤ 
    end 
end

function train_offline(ùí´::MDP, model, df, h) 
    for i in 1:h 
        s, a, s‚Ä≤, r = sample_data(df)
        update!(model, s, a, r, s‚Ä≤)  # update model from the sample (s,a,r,s_)
        s = s‚Ä≤ 
    end 
end

train_offline (generic function with 2 methods)

In [22]:
mutable struct QLearning
    S 
    A
    Œ≥
    Q
    Œ±
end 

lookahead(model::QLearning, s,a) = model.Q[s,a]

function update!(model::QLearning, s,a,r,s_)
    Œ≥, Q, Œ± = model.Œ≥, model.Q, model.Œ±
    Q[s,a] += Œ±*(r + maximum(Q[s_, :]) - Q[s,a])  # update of Q-function
    return model 
end 

update! (generic function with 3 methods)

In [23]:
mutable struct Sarsa
    S
    A
    Œ≥
    Q  # action value function (initial)
    Œ±  # learning rate 
    l  # most recent experience tuble(s,a,r)
end

lookahead(model::Sarsa, s,a) = model.Q[s,a]

function update!(model::Sarsa, s,a,r,s_)
    if model.l != nothing 
        Œ≥, Q, Œ±, l = model.Œ≥, model.Q, model.Œ±, model.l 
        model.Q[l.s, l.a] += Œ± * (l.r + Œ≥*Q[s,a] - Q[l.s, l.a]) 
    end 
    model.l = (s=s, a=a, r=r) 
end 

update! (generic function with 3 methods)

In [24]:
function (œÄ::EpsilonGreedyExploration)(model, s)
    A, œµ = œÄ.A, œÄ.œµ
    if rand() < œµ
        return rand(A)
    end 
    Q(s,a) = lookahead(model, s, a)
    return argmax(a->Q(s,a), A) 
end

LoadError: UndefVarError: `EpsilonGreedyExploration` not defined

### implementation

In [25]:
infile = "data/small.csv"
df = CSV.File(infile) |> DataFrame
# data_mat = Matrix(df);

x = [mod(df.s[j], 10)!=0 ? mod(df.s[j], 10) : 10 for j in 1:size(df,1)] 
y = [mod(df.s[j], 10)!=0 ? df.s[j] √∑ 10 + 1 : df.s[j] √∑ 10  for j in 1:size(df,1)] 

df = insertcols!(df, 2, :s_i => [x[i] for i in 1:size(df,1)])
df = insertcols!(df, 3, :s_j => [y[j] for j in 1:size(df,1)]);

xp = [mod(df.sp[j], 10)!=0 ? mod(df.sp[j], 10) : 10 for j in 1:size(df,1)] 
yp = [mod(df.sp[j], 10)!=0 ? df.sp[j] √∑ 10 + 1 : df.sp[j] √∑ 10  for j in 1:size(df,1)] 

df = insertcols!(df, 7, :sp_i => [xp[i] for i in 1:size(df,1)])
df = insertcols!(df, 8, :sp_j => [yp[j] for j in 1:size(df,1)])

Row,s,s_i,s_j,a,r,sp,sp_i,sp_j
Unnamed: 0_level_1,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64
1,85,5,9,3,0,86,6,9
2,86,6,9,2,0,87,7,9
3,87,7,9,3,0,97,7,10
4,97,7,10,2,0,87,7,9
5,87,7,9,1,0,86,6,9
6,86,6,9,3,0,76,6,8
7,76,6,8,4,0,66,6,7
8,66,6,7,1,0,65,5,7
9,65,5,7,2,0,66,6,7
10,66,6,7,3,0,76,6,8


In [49]:
# S = [[[x,y] for x in 1:10] for y in 1:10] 
S = [i for i in 1:100]
A = [1,2,3,4]
Œ≥ = 0.95
T = NaN
R = NaN
TR = NaN

prob = MDP(Œ≥,S,A,T,R,TR)

Q = zeros(length(prob.S), length(prob.A))
Œ± = 0.0005
l = (s=(df.s_i[1], df.s_i[2]), a=df.a[1], r=df.r[1])
l = (s=df.s[1], a=df.a[1], r=df.r[1])

sarsa = Sarsa(S,A,Œ≥,Q,Œ±,l)

(100, 4)

In [46]:
train_offline(prob, sarsa, df, 10000) 

In [47]:
sarsa.Q

100√ó4 Matrix{Float64}:
 0.0160077  0.0901616  0.0692234  0.0162046
 0.0466351  0.0640014  0.0733448  0.0264176
 0.0627234  0.044337   0.0769889  0.027605
 0.075845   0.0737455  0.141346   0.0702024
 0.0597101  0.0997327  0.170632   0.0660614
 0.093392   0.104624   0.212059   0.0664765
 0.124725   0.0920299  0.150576   0.0756119
 0.138391   0.136785   0.137362   0.0827373
 0.0987519  0.192301   0.17399    0.125147
 0.103768   0.0932607  0.169519   0.111674
 0.112827   0.147472   0.243815   0.0860841
 0.0884431  0.145849   0.185457   0.0782137
 0.133094   0.134249   0.265673   0.144993
 ‚ãÆ                                
 1.2397     1.39094    1.46355    1.17586
 1.19376    1.03278    1.41384    0.962089
 1.22279    1.22874    1.26887    1.10418
 0.765642   1.17634    1.55719    0.866994
 1.23502    1.14515    1.43547    0.877541
 0.741831   1.18999    1.17106    0.983015
 1.51254    0.886861   0.712154   1.01992
 1.26471    1.5121     1.29245    1.09109
 1.31736    1.14352    1.16056 

In [48]:
maxQ, maxQid = findmax(sarsa.Q[1,:])
actions = [findmax(sarsa.Q[i,:])[2] for i in 1:size(sarsa.Q, 1)]

100-element Vector{Int64}:
 2
 3
 3
 3
 3
 3
 3
 1
 2
 3
 3
 3
 3
 ‚ãÆ
 3
 3
 3
 3
 3
 2
 1
 2
 4
 1
 2
 2

In [None]:
Œ≤(s,a) = [s,s^2,a,a^2,1]   # basis funciton 
Q(Œ∏,s,a) = dot(Œ∏,Œ≤(s,a)) 
‚àáQ(Œ∏,s,a) = Œ≤(s,a) 
Œ∏ = [0.1,0.2,0.3,0.4,0.5] # initial parameter vector