Skip to content

Commit

Permalink
Add transformations to make QLearning learn faster (possibly buggy)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zomis committed Jan 2, 2020
1 parent 20dd038 commit dda5173
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 24 deletions.
41 changes: 34 additions & 7 deletions games-core/src/main/kotlin/net/zomis/games/Map2D.kt
Expand Up @@ -19,6 +19,10 @@ data class Position(val x: Int, val y: Int, val sizeX: Int, val sizeY: Int) {
}
return Position(this.x + 1, this.y, this.sizeX, this.sizeY)
}

fun transform(transformation: Transformation): Position {
return transformation.transform(this)
}
}

enum class TransformationType(val transforming: (Position) -> Position, val reverse: (Position) -> Position) {
Expand Down Expand Up @@ -60,7 +64,17 @@ enum class Transformation(private val transformations: List<TransformationType>)

}

class Map2D<T>(val sizeX: Int, val sizeY: Int, val getter: (x: Int, y: Int) -> T, val setter: (x: Int, y: Int, value: T) -> Unit) {
class Map2D<T>(val sizeX: Int, val sizeY: Int, val getter: (x: Int, y: Int) -> T, val setter: (x: Int, y: Int, value: T) -> Unit = {_,_,_->}) {

private fun originalPossibleTransformations(): MutableSet<Transformation> {
val possibleTransformations = Transformation.values().toMutableSet()
if (sizeX != sizeY) {
// Rotating 90 or 270 degrees only works if both width or height is the same
possibleTransformations.remove(Transformation.ROTATE_90)
possibleTransformations.remove(Transformation.ROTATE_270)
}
return possibleTransformations
}

fun standardizedTransformation(valueFunction: (T) -> Int): Transformation {
// keep a Set<Transformation>, start with all of them
Expand All @@ -69,12 +83,7 @@ class Map2D<T>(val sizeX: Int, val sizeY: Int, val getter: (x: Int, y: Int) -> T

// start in the possible fields for the target map upper-left corner
// then continue, line by line, beginning with increasing X and then increase Y
val possibleTransformations = Transformation.values().toMutableSet()
if (sizeX != sizeY) {
// Rotating 90 or 270 degrees only works if both width or height is the same
possibleTransformations.remove(Transformation.ROTATE_90)
possibleTransformations.remove(Transformation.ROTATE_270)
}
val possibleTransformations = originalPossibleTransformations()

var position: Position? = Position(0, 0, sizeX, sizeY)
while (possibleTransformations.size > 1 && position != null) {
Expand All @@ -95,6 +104,16 @@ class Map2D<T>(val sizeX: Int, val sizeY: Int, val getter: (x: Int, y: Int) -> T
return transformation
}

fun symmetryTransformations(equalsFunction: (T, T) -> Boolean): Set<Transformation> {
val possibleTransformations = originalPossibleTransformations()
return possibleTransformations.filter { transformation ->
positions().all { pos ->
val other = transformation.transform(pos)
equalsFunction(getter(pos.x, pos.y), getter(other.x, other.y))
}
}.toSet()
}

fun transform(transformation: Transformation) {
val rotated = Map2DX(sizeX, sizeY) { x, y ->
val pos = Position(x, y, sizeX, sizeY)
Expand All @@ -109,6 +128,14 @@ class Map2D<T>(val sizeX: Int, val sizeY: Int, val getter: (x: Int, y: Int) -> T
}
}

fun positions(): Sequence<Position> {
return (0 until sizeY).asSequence().flatMap { y ->
(0 until sizeX).asSequence().map { x ->
Position(x, y, sizeX, sizeY)
}
}
}

fun standardize(valueFunction: (T) -> Int) {
this.transform(this.standardizedTransformation(valueFunction))
}
Expand Down
Expand Up @@ -132,7 +132,7 @@ class Server2(val events: EventSystem) {
val result = engine.eval(it.input.substring("kt ".length))
println(result)
})
events.with(TTTQLearn()::setup)
events.with(TTTQLearn(gameSystem)::setup)

events.listen("Stop Javalin", ShutdownEvent::class, {true}, {javalin.stop()})
events.listen("Start Javalin", StartupEvent::class, {true}, {javalin.start()})
Expand Down
Expand Up @@ -34,7 +34,7 @@ data class QAwaitingReward<S>(val state: S, val stateAction: S, val action: Int)
class MyQLearning<T, S>(val maxActions: Int,
private val stateFunction: (T) -> S,
private val actionPossible: ActionPossible<T>,
private val stateActionFunction: (S, Int) -> S,
private val stateActionFunction: (T, S, Int) -> S,
private val qTable: QStore<S>) {

private val logger = KLoggers.logger(this)
Expand Down Expand Up @@ -82,7 +82,7 @@ class MyQLearning<T, S>(val maxActions: Int,

fun prepareReward(environment: T, action: Int): QAwaitingReward<S> {
val state = stateFunction(environment)
val stateAction = stateActionFunction(state, action)
val stateAction = stateActionFunction(environment, state, action)
return QAwaitingReward(state, stateAction, action)
}

Expand All @@ -107,7 +107,7 @@ class MyQLearning<T, S>(val maxActions: Int,
val nextStateStr = stateFunction(nextState)
val estimateOfOptimalFutureValue = (0 until maxActions)
.filter { i -> actionPossible(nextState, i) }
.map { i -> stateActionFunction(nextStateStr, i) }
.map { i -> stateActionFunction(rewardedState.state, nextStateStr, i) }
.map { str -> qTable.getOrDefault(str, DEFAULT_QVALUE) }.max() ?: 0.0

val oldValue = qTable.getOrDefault(awaitReward.stateAction, DEFAULT_QVALUE)
Expand All @@ -123,7 +123,7 @@ class MyQLearning<T, S>(val maxActions: Int,
val result = DoubleArray(maxActions)
for (i in 0 until maxActions) {
if (actionPossible(environment, i)) {
val st = stateActionFunction(state, i)
val st = stateActionFunction(environment, state, i)
val value = qTable.getOrDefault(st, 0.0)
result[i] = value
}
Expand All @@ -146,7 +146,7 @@ class MyQLearning<T, S>(val maxActions: Int,
val scores = DoubleArray(possibleActions.size)
for (i in possibleActions.indices) {
val action = possibleActions[i]
val stateAction = stateActionFunction(state, action)
val stateAction = stateActionFunction(environment, state, action)
scores[i] = this.qTable.getOrDefault(stateAction, DEFAULT_QVALUE)
}
val min = scores.min() ?: 0.0
Expand Down Expand Up @@ -183,7 +183,7 @@ class MyQLearning<T, S>(val maxActions: Int,
return possibleActions[0]
}
for (i in possibleActions) {
val stateAction = stateActionFunction(state, i)
val stateAction = stateActionFunction(environment, state, i)
val value = qTable.getOrDefault(stateAction, DEFAULT_QVALUE)
val diff = Math.abs(value - bestValue)
val better = value > bestValue && diff >= EPSILON
Expand All @@ -199,7 +199,7 @@ class MyQLearning<T, S>(val maxActions: Int,
var pickedAction = random.nextInt(numBestActions)
logger.debug { "Pick best action chosed index $pickedAction of $possibleActions with value $bestValue" }
for (i in possibleActions) {
val stateAction = stateActionFunction(state, i)
val stateAction = stateActionFunction(environment, state, i)
val value = qTable.getOrDefault(stateAction, DEFAULT_QVALUE)
val diff = Math.abs(value - bestValue)

Expand Down
Expand Up @@ -5,25 +5,27 @@ import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import net.zomis.core.events.EventSystem
import net.zomis.games.Features
import net.zomis.games.Map2D
import net.zomis.games.Position
import net.zomis.games.Transformation
import net.zomis.games.dsl.PlayerIndex
import net.zomis.games.dsl.Point
import net.zomis.games.dsl.impl.GameImpl
import net.zomis.games.dsl.index
import net.zomis.games.server2.ClientJsonMessage
import net.zomis.games.server2.games.*
import net.zomis.games.server2.getTextOrDefault
import net.zomis.tttultimate.TTPlayer
import net.zomis.tttultimate.games.TTController

class TTTQLearn {
class TTTQLearn(val games: GameSystem) {
val gameType = "DSL-TTT"

val logger = KLoggers.logger(this)

val actionPossible: ActionPossible<TTController> = { tt, action ->
val columns = tt.game.sizeX
val x = action % columns
val y = action / columns
tt.isAllowedPlay(tt.game.getSub(x, y)!!)
val pos = actionToPosition(tt, action)
tt.isAllowedPlay(tt.game.getSub(pos.x, pos.y)!!)
}

fun newLearner(qStore: QStore<String>): MyQLearning<TTController, String> {
Expand All @@ -34,29 +36,59 @@ class TTTQLearn {
return newLearner(controller.game.sizeX * controller.game.sizeY, qStore)
}

private fun normalizeTransformation(controller: TTController): Transformation {
return Map2D(controller.game.sizeX, controller.game.sizeY, {x, y ->
controller.game.getSub(x, y)!!.wonBy
}).standardizedTransformation {
it.ordinal
}
}

private fun newLearner(maxActions: Int, qStore: QStore<String>): MyQLearning<TTController, String> {
val stateToString: (TTController) -> String = { g ->
val transformation = normalizeTransformation(g)
val sizeX = g.game.sizeX
val sizeY = g.game.sizeY
val str = StringBuilder()
for (y in 0 until sizeY) {
for (x in 0 until sizeX) {
val sub = g.game.getSub(x, y)!!
val p = Position(x, y, sizeX, sizeY).transform(transformation)
val sub = g.game.getSub(p.x, p.y)!!
str.append(if (sub.wonBy.isExactlyOnePlayer) sub.wonBy.name else "_")
}
str.append('-')
}
str.toString()
}
val learn = MyQLearning(maxActions, stateToString, actionPossible,
{ state, action -> state + action }, qStore)
val learn = MyQLearning(maxActions, stateToString, actionPossible, this::stateActionString, qStore)
// learn.setLearningRate(-0.01); // This leads to bad player moves. Like XOX-OXO-_X_ instead of XOX-OXO-X__
learn.discountFactor = -0.9
learn.learningRate = 1.0
learn.randomMoveProbability = 0.05
return learn
}

private fun stateActionString(environment: TTController, state: String, action: Int): String {
val transformation = normalizeTransformation(environment)
// Transform action
val point = actionToPosition(environment, action)
val resultingActionPoint = point.transform(transformation)
val resultingActionInt = positionToAction(environment, resultingActionPoint)
return state + resultingActionInt
}

private fun positionToAction(environment: TTController, position: Position): Int {
val columns = environment.game.sizeX
return position.y * columns + position.x
}

private fun actionToPosition(environment: TTController, action: Int): Position {
val columns = environment.game.sizeX
val x = action % columns
val y = action / columns
return Position(x, y, environment.game.sizeX, environment.game.sizeY)
}

fun isDraw(tt: TTController): Boolean {
for (yy in 0 until tt.game.sizeY) {
for (xx in 0 until tt.game.sizeX) {
Expand Down Expand Up @@ -108,6 +140,11 @@ class TTTQLearn {
return@ServerAI listOf()
}

// Always do actions based on the standardized state
// Find possible symmetry transformations
// Make move
// TODO: Learn the same value for all possible symmetries of action

val action = learn.pickWeightedBestAction(model)
val x = action % model.game.sizeX
val y = action / model.game.sizeX
Expand Down

0 comments on commit dda5173

Please sign in to comment.