## SARSA example of Cliff Walking

In [1]:
USE {
    repositories {
        mavenCentral()
        maven("https://central.sonatype.com/repository/maven-snapshots/")
    }
    dependencies {
        implementation("io.github.kotlinrl:integration:0.1.0-SNAPSHOT")
        implementation("io.github.kotlinrl:tabular:0.1.0-SNAPSHOT")
        implementation("io.github.kotlinrl:envs:0.1.0-SNAPSHOT")
        implementation("io.github.kotlinrl:rendering:0.1.0-SNAPSHOT")
    }
}

In [2]:
import io.github.kotlinrl.core.*
import io.github.kotlinrl.integration.gymnasium.*
import io.github.kotlinrl.integration.gymnasium.GymnasiumEnvs.*
import io.github.kotlinrl.rendering.*
import io.github.kotlinrl.tabular.*
import org.jetbrains.kotlinx.kandy.letsplot.export.*
import org.jetbrains.kotlinx.multik.api.*
import org.jetbrains.kotlinx.multik.api.io.*
import org.jetbrains.kotlinx.multik.ndarray.data.*
import java.io.*


In [3]:
val maxStepsPerEpisode = 200
val trainingEpisodes = 50_000
val testEpisodes = 50
val initialEpsilon = 0.6
val minEpsilon = 0.0
val epsilonDecayRate = (initialEpsilon - minEpsilon) / (trainingEpisodes * 0.9)
val alpha = ParameterSchedule.constant(0.2)
val minAlpha = ParameterSchedule.constant(0.02)
val gamma = 0.99
val fileName = "FrozenLakeSARSA.npy"
val actionSymbols = mapOf(
    0 to "←",
    1 to "↓",
    2 to "→",
    3 to "↑"
)


In [4]:
val env = gymnasium.make<FrozenLakeEnv>(FrozenLake_v1, render = true, options = mapOf(
    "is_slippery" to false,
    "map_name" to "8x8"
))

val trainingQtable: QTable = mk.rand<Double, D2>(from = 0.24, until = 0.26, dims = intArrayOf(64, 4))

val (epsilonSchedule, epsilonDecrement) = ParameterSchedule.linearDecay(
    initialValue = initialEpsilon,
    minValue = minEpsilon,
    decayRate = epsilonDecayRate,
    callback = { episode, parameter ->
        if (episode % 1000 == 0) {
            println("Episode: $episode, Epsilon: $parameter")
        }
    })
val phi = mk.d1array(64) { state ->
    (1.0 - gamma) * -(abs(7 - state / 8) + abs(7 - state % 8)).toDouble()
}


2025-09-19T22:44:14.284812Z Execution of code 'val env = gymnasium....' ERROR Log4j2 could not find a logging implementation. Please add log4j-core to the classpath. Using SimpleLogger to log to the console...


In [5]:
var currentState = 0
val trainer = episodicTrainer(
    env = TransformReward(env, transform = {
        val state = currentState
        val nextState = it.state
        currentState = if (it.terminated || it.truncated) 0 else nextState
        val isGoal = it.terminated && it.reward == 1.0
        val isHole = it.terminated && it.reward == 0.0
        val isSameState = state == nextState
        when {
            isGoal -> 1.0
            isHole -> -2.0
            isSameState -> 0.0
            else -> (50 * (gamma * phi[nextState] - phi[state])).coerceIn(-0.5, 0.5)
        }
    }
    ),
    agent = learningAgent(
        id = "training",
        algorithm = SARSA(
            Q = trainingQtable,
            epsilon = epsilonSchedule,
            alpha = { if (epsilonSchedule().decayStep > 8000) minAlpha() else alpha() },
            gamma = gamma,
        )
    ),
    maxStepsPerEpisode = maxStepsPerEpisode,
    warnOnTruncationOrMax = false,
    successfulTermination = { it.reward == 1.0 },
    callbacks = listOf(
        printEpisodeStart(1000),
        onEpisodeEnd {
            epsilonDecrement()
            if (it.totalEpisodes % 1_000 == 0) {
                val goalSuccessCount = TrainingResult(it.episodeStats.takeLast(1_000)).totalGoalSuccessCount
                println("Current goal success count: $goalSuccessCount, over the last 1000 episodes")
            }
        })
)
println("Starting training")
val training = trainer.train(maxEpisodes(trainingEpisodes).or {
    it.totalEpisodes >= 1000 && it.takeLast(1000).totalGoalSuccessCount == 1000
})
mk.writeNPY(fileName, trainingQtable)


Starting training
Starting episode 1000
Episode: 1000, Epsilon: Parameter(current=0.5866666666666903, previous=0.5866800000000236, minValue=0.0, decayStep=1000)
Current goal success count: 299, over the last 1000 episodes
Starting episode 2000
Episode: 2000, Epsilon: Parameter(current=0.5733333333333807, previous=0.573346666666714, minValue=0.0, decayStep=2000)
Current goal success count: 429, over the last 1000 episodes
Starting episode 3000
Episode: 3000, Epsilon: Parameter(current=0.560000000000071, previous=0.5600133333334043, minValue=0.0, decayStep=3000)
Current goal success count: 410, over the last 1000 episodes
Starting episode 4000
Episode: 4000, Epsilon: Parameter(current=0.5466666666667613, previous=0.5466800000000946, minValue=0.0, decayStep=4000)
Current goal success count: 477, over the last 1000 episodes
Starting episode 5000
Episode: 5000, Epsilon: Parameter(current=0.5333333333334517, previous=0.533346666666785, minValue=0.0, decayStep=5000)
Current goal success count

In [6]:
val testingQtable = mk.readNPY<Double, D2>(fileName).asD2Array()


In [7]:
val recordEnv = RecordVideo(env = env, folder = "videos/frozen_lake_sarsa", testEpisodes / 3)
val tester = episodicTrainer(
    env = recordEnv,
    agent = policyAgent(
        id = "testing",
        policy = testingQtable.greedy()
    ),
    maxStepsPerEpisode = maxStepsPerEpisode,
    successfulTermination = { it.done },
    callbacks = listOf(
        printEpisodeStart(10)
    )
)
println("Starting testing")
val test = tester.train(maxEpisodes(testEpisodes))


Starting testing
Starting episode 10
Starting episode 20
Starting episode 30
Starting episode 40
Starting episode 50
Max episodes reached: 50


In [8]:
println("Training average reward: ${training.totalAverageReward}")
println("Test average reward: ${test.totalAverageReward}")

printQTable(testingQtable, 8, 8, actionSymbols = actionSymbols)
displayVideos(recordEnv.folder)


Training average reward: 6.637258541537993
Test average reward: 1.0
Action Value Function:
  6.12   5.68   5.23   4.78   4.32   3.86   3.40   2.93 
  5.59   5.15   4.71   4.26   3.82   3.38   2.92   2.45 
  4.87   4.43   3.93   0.25   3.15   2.84   2.43   1.97 
  3.85   3.42   2.38   1.63   1.99   0.25   1.96   1.49 
  2.93   2.13   1.65   0.25   1.66   1.78   1.47   0.99 
  0.52   0.26   0.26   0.49   1.05   1.11   0.25   0.50 
  0.25   0.26   0.26   0.25   0.26   0.87   0.25   0.26 
  0.26   0.25   0.26   0.26   0.80   0.53   0.25   0.26 
Policy Table:
  →   →   →   →   →   →   →   ↓ 
  ↑   ↑   ↑   ↑   →   →   →   ↓ 
  ↑   ↑   ↑   ↓   ↑   →   →   ↓ 
  ↑   ↑   ←   →   ↑   ↓   →   ↓ 
  ↑   ↑   ↑   ↓   →   →   →   ↓ 
  ↓   ↑   →   ↓   →   ↓   ↑   ↓ 
  →   ↑   ↑   →   ←   ↓   →   ↓ 
  ↓   ↑   ↓   ↓   →   →   →   ↑ 


In [9]:
plotPolicyActionValueGrid(testingQtable, 8, 8, actionSymbols )