## Expected SARSA example of Cliff Walking

In [1]:
USE {
    repositories {
        mavenCentral()
        maven("https://central.sonatype.com/repository/maven-snapshots/")
    }
    dependencies {
        implementation("io.github.kotlinrl:integration:0.1.0-SNAPSHOT")
        implementation("io.github.kotlinrl:tabular:0.1.0-SNAPSHOT")
        implementation("io.github.kotlinrl:envs:0.1.0-SNAPSHOT")
        implementation("io.github.kotlinrl:rendering:0.1.0-SNAPSHOT")
    }
}

In [2]:
import io.github.kotlinrl.core.*
import io.github.kotlinrl.integration.gymnasium.*
import io.github.kotlinrl.integration.gymnasium.GymnasiumEnvs.*
import io.github.kotlinrl.rendering.*
import io.github.kotlinrl.tabular.*
import org.jetbrains.kotlinx.kandy.letsplot.export.*
import org.jetbrains.kotlinx.multik.api.*
import org.jetbrains.kotlinx.multik.api.io.*
import org.jetbrains.kotlinx.multik.ndarray.data.*
import java.io.*


In [3]:
val maxStepsPerEpisode = 9_000
val trainingEpisodes = 500
val testEpisodes = 50
val initialEpsilon = 1.0
val epsilonDecayRate = 0.98
val minEpsilon = 0.1
val alpha = 0.5
val gamma = 0.99
val fileName = "CliffWalkingExpectedSARSA.npy"
val actionSymbols = mapOf(
    3 to "←",
    2 to "↓",
    1 to "→",
    0 to "↑"
)


In [4]:
val env = gymnasium.make<CliffWalkingEnv>(CliffWalking_v0, render = true, options = mapOf(
    "is_slippery" to false
))

var trainingQtable: QTable = mk.d2array(48, 4) { 0.0 }

val (epsilonSchedule, epsilonDecrement) = ParameterSchedule.geometricDecay(
    initialValue = initialEpsilon,
    minValue = minEpsilon,
    decayRate = epsilonDecayRate
)


2025-09-22T16:17:51.607942Z Execution of code 'val env = gymnasium....' ERROR Log4j2 could not find a logging implementation. Please add log4j-core to the classpath. Using SimpleLogger to log to the console...


In [5]:
val trainer = episodicTrainer(
    env = env,
    agent = learningAgent(
        id = "training",
        algorithm = SARSA(
            Q = trainingQtable,
            epsilon = epsilonSchedule,
            alpha = ParameterSchedule.constant(alpha),
            gamma = gamma,
        )
    ),
    maxStepsPerEpisode = maxStepsPerEpisode,
    successfulTermination = { it.terminated },
    callbacks = listOf(
        printEpisodeStart(100),
        onEpisodeEnd {
            epsilonDecrement()
            if (it.totalEpisodes % 100 == 0) {
                val goalSuccessCount = TrainingResult(it.episodeStats.takeLast(100)).totalGoalSuccessCount
                println("Current goal success count: $goalSuccessCount, over the last 100 episodes")
            }
        }
    )
)
println("Starting training")
val training = trainer.train(maxEpisodes(trainingEpisodes))
mk.writeNPY(fileName, trainingQtable)


Starting training
Starting episode 100
Current goal success count: 98, over the last 100 episodes
Starting episode 200
Current goal success count: 100, over the last 100 episodes
Starting episode 300
Current goal success count: 100, over the last 100 episodes
Starting episode 400
Current goal success count: 100, over the last 100 episodes
Starting episode 500
Current goal success count: 100, over the last 100 episodes
Max episodes reached: 500


In [6]:
val testingQtable = mk.readNPY<Double, D2>(fileName).asD2Array()


In [7]:
val recordEnv = RecordVideo(env = env, folder = "videos/cliff_walking_expected_sarsa", testEpisodes / 3)
val tester = episodicTrainer(
    env = recordEnv,
    agent = policyAgent(
        id = "testing",
        policy = testingQtable.greedy()
    ),
    maxStepsPerEpisode = maxStepsPerEpisode,
    successfulTermination = { it.terminated },
    callbacks = listOf(
        printEpisodeStart(10)
    )
)
println("Starting testing")
val test = tester.train(maxEpisodes(testEpisodes))


Starting testing
Starting episode 10
Starting episode 20
Starting episode 30
Starting episode 40
Starting episode 50
Max episodes reached: 50


In [8]:
println("Training average reward: ${training.totalAverageReward}")
println("Test average reward: ${test.totalAverageReward}")

printQTable(testingQtable, 4, 12, actionSymbols = actionSymbols)
displayVideos(recordEnv.folder)


Training average reward: -672.376
Test average reward: -17.0
Action Value Function:
-65.45 -58.98 -41.17 -30.20 -17.85 -13.65 -10.27  -9.26 -11.87  -5.48  -4.35  -3.28 
-36.52 -37.80 -40.23 -54.44 -38.21 -47.73 -66.65 -22.52 -14.61 -10.78  -7.86  -2.08 
-39.54 -75.42 -139.21 -107.59 -103.10 -136.76 -114.37 -66.05 -55.56 -74.46  -2.00  -1.00 
-44.95   0.00   0.00   0.00   0.00   0.00   0.00   0.00   0.00   0.00   0.00   0.00 
Policy Table:
  →   →   →   →   →   →   →   →   →   →   →   ↓ 
  →   →   ↑   ↑   ↑   ↑   ↑   ↑   ↑   ↑   ↑   ↓ 
  ↑   ←   →   ↑   ↑   ←   ↑   ↑   ↑   ↑   →   ↓ 
  ↑   ↑   ↑   ↑   ↑   ↑   ↑   ↑   ↑   ↑   ↑   ↑ 


In [9]:
plotPolicyActionValueGrid(testingQtable, 4, 12, actionSymbols )