Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Templated input #7

Merged
merged 3 commits into from Mar 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.gradle.kts
Expand Up @@ -6,7 +6,7 @@ plugins {
}

group = "com.asdacap"
version = "0.0.5-alpha"
version = "0.0.6-alpha"

java {
sourceCompatibility = JavaVersion.VERSION_1_8
Expand Down
4 changes: 2 additions & 2 deletions src/main/kotlin/com/asdacap/ea4k/gp/FromFuncTreeNode.kt
Expand Up @@ -21,7 +21,7 @@ class FromFuncTreeNode <R>(
* Create a tree node factory from a kotlin KCallable. This will use reflection to detect the callable's
* argument and return type.
*/
inline fun <reified R> factoryFromFunction(
inline fun <reified R> fromFunction(
func: KCallable<R>,
type: NodeType = KotlinNodeType(typeOf<R>()),
parameters: List<NodeType> = func.parameters.map { KotlinNodeType(it.type) }
Expand All @@ -34,7 +34,7 @@ class FromFuncTreeNode <R>(
/**
* Create a tree node factory that always return the same result.
*/
inline fun <reified R> factoryFromConstant(
inline fun <reified R> fromConstant(
constant: R,
type: NodeType = KotlinNodeType(typeOf<R>())
): Factory<R> {
Expand Down
50 changes: 42 additions & 8 deletions src/main/kotlin/com/asdacap/ea4k/gp/Generator.kt
Expand Up @@ -22,23 +22,57 @@ object Generator {
var recurGen: ((Int, NodeType) -> TreeNode<*>)? = null
recurGen = { depth, ret ->
if (condition(height, depth)) {
val terminalOpts = pset.getTerminalAssignableTo(ret)
if (terminalOpts.isEmpty()) {
val terminalOpts = pset.selectTerminalAssignableTo(ret)
if (terminalOpts == null) {
throw Exception("The ea4k.gp.generate function tried to add " +
"a terminal of type $ret, but there is " +
"none available.")
}
val pickedTerminal = Utils.randomChoice(terminalOpts)
pickedTerminal.createNode(listOf())
terminalOpts.createNode(listOf())
} else {
val primitiveOpts = pset.getPrimitiveAssignableTo(ret)
if (primitiveOpts.isEmpty()) {
val primitiveOpts = pset.selectPrimitiveAssignableTo(ret)
if (primitiveOpts == null) {
throw Exception("The ea4k.gp.generate function tried to add " +
"a primitive of type '$ret', but there is " +
"none available.".format(ret))
}
val primitive = Utils.randomChoice(primitiveOpts)
primitive.createNode(primitive.args.map { recurGen!!.invoke(depth+1, it) }.toList())
primitiveOpts.createNode(primitiveOpts.args.map { recurGen!!.invoke(depth+1, it) }.toList())
}
}

return recurGen(0, type)
}

/**
* Like generate, but if either primitive or terminal is not found, it will use what is available and only
* throw an exception if both are not available
*/
fun <R> safeGenerate(pset: PSet<R>, min: Int, max: Int, condition: (Int, Int) -> Boolean, type: NodeType): TreeNode<*> {
val height = Random.nextInt(min, max)

var recurGen: ((Int, NodeType) -> TreeNode<*>)? = null
recurGen = { depth, ret ->
val terminalOpts = pset.selectTerminalAssignableTo(ret)
val primitiveOpts = pset.selectPrimitiveAssignableTo(ret)

if (terminalOpts == null && primitiveOpts == null) {
throw Exception("The ea4k.gp.generate function tried to add " +
"a node of type '$ret', but there is " +
"none available.".format(ret))
}

if (condition(height, depth)) {
if (terminalOpts == null) {
primitiveOpts!!.createNode(primitiveOpts.args.map { recurGen!!.invoke(depth+1, it) }.toList())
} else {
terminalOpts.createNode(listOf())
}
} else {
if (primitiveOpts == null) {
terminalOpts!!.createNode(listOf())
} else {
primitiveOpts.createNode(primitiveOpts.args.map { recurGen!!.invoke(depth+1, it) }.toList())
}
}
}

Expand Down
15 changes: 7 additions & 8 deletions src/main/kotlin/com/asdacap/ea4k/gp/GeneratorTreeNode.kt
@@ -1,6 +1,5 @@
package com.asdacap.ea4k.gp

import com.asdacap.ea4k.gp.functional.NodeFunction
import com.fasterxml.jackson.databind.JsonNode
import kotlin.reflect.jvm.jvmErasure
import kotlin.reflect.jvm.reflect
Expand All @@ -11,28 +10,28 @@ import kotlin.reflect.jvm.reflect
*/
class GeneratorTreeNode<R>(
val constant: R,
override val factory: TreeNodeFactory<NodeFunction<R>>,
): TreeNode<NodeFunction<R>>() {
override val factory: TreeNodeFactory<R>,
): TreeNode<R>() {
override val state: JsonNode by lazy {
val value = Utils.objectMapper.createObjectNode()
value.set<JsonNode>("constant", Utils.objectMapper.valueToTree(constant))
value
}

override fun evaluate(): NodeFunction<R> {
return NodeFunction { constant }
override fun evaluate(): R {
return constant
}

override fun replaceChildren(newChildren: List<TreeNode<*>>): TreeNode<NodeFunction<R>> {
override fun replaceChildren(newChildren: List<TreeNode<*>>): TreeNode<R> {
return GeneratorTreeNode(constant, factory)
}

class Factory<R: Any>(
val generator: () -> R,
val kotlinReturnType: Class<*> = generator.reflect()!!.returnType.jvmErasure.java,
override val returnType: NodeType = KotlinNodeType(generator.reflect()!!.returnType)
) : TreeNodeFactory<NodeFunction<R>> {
override fun createNode(children: List<TreeNode<*>>, state: JsonNode?): TreeNode<NodeFunction<R>> {
) : TreeNodeFactory<R> {
override fun createNode(children: List<TreeNode<*>>, state: JsonNode?): TreeNode<R> {
val constant = if (state == null) {
generator()
} else {
Expand Down
5 changes: 5 additions & 0 deletions src/main/kotlin/com/asdacap/ea4k/gp/NodeType.kt
Expand Up @@ -2,6 +2,7 @@ package com.asdacap.ea4k.gp

import kotlin.reflect.KType
import kotlin.reflect.full.isSupertypeOf
import kotlin.reflect.typeOf

/**
* A NodeType is a type that represent a type of a base tree node. Specifically, it is used by the generator to
Expand All @@ -14,6 +15,10 @@ interface NodeType {
fun fromKotlinNodeType(kType: KType): KotlinNodeType {
return KotlinNodeType(kType)
}

inline fun <reified T> fromKotlinNodeType(): KotlinNodeType {
return KotlinNodeType(typeOf<T>())
}
}
}

Expand Down
88 changes: 51 additions & 37 deletions src/main/kotlin/com/asdacap/ea4k/gp/PSet.kt
Expand Up @@ -3,48 +3,32 @@ package com.asdacap.ea4k.gp
import com.asdacap.ea4k.gp.Utils.objectMapper
import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.node.ObjectNode
import kotlin.random.Random.Default.nextDouble

/**
* Stores a set of terminal and primitives
*/
class PSet<R>(val returnType: NodeType) {

data class FactoryEntry(val name: String, val weight: Double, val factory: TreeNodeFactory<*>)

val terminalRatio: Double
get() {
val terminalCount = terminals.map { it.value.size }.sum()
val primitiveCount = primitives.map { it.value.size }.sum()
val terminalCount = terminals.size
val primitiveCount = primitives.size
return terminalCount.toDouble() / (primitiveCount + terminalCount).toDouble()
}
private val terminals: MutableMap<NodeType, MutableList<TreeNodeFactory<*>>> = mutableMapOf()
private val primitives: MutableMap<NodeType, MutableList<TreeNodeFactory<*>>> = mutableMapOf()
private val serializers: MutableList<Pair<String, TreeNodeFactory<*>>> = mutableListOf()

private fun <R> addTerminal(name: String, terminal: TreeNodeFactory<R>) {
if (terminals[terminal.returnType] == null) {
terminals[terminal.returnType] = mutableListOf()
}
terminals[terminal.returnType]?.add(terminal)
serializers.add(0, name to terminal)
}

private fun <R> addPrimitive(name: String, primitive: TreeNodeFactory<R>) {
if (primitives[primitive.returnType] == null) {
primitives[primitive.returnType] = mutableListOf()
}
primitives[primitive.returnType]?.add(primitive)
serializers.add(0, name to primitive)
}
private val factories: MutableList<FactoryEntry> = mutableListOf()
private val terminals: List<FactoryEntry> get() = factories.filter { it.factory.args.size == 0 }
private val primitives: List<FactoryEntry> get() = factories.filter { it.factory.args.size != 0 }

fun <R> addTreeNodeFactory(name: String, primitive: TreeNodeFactory<R>) {
if (primitive.args.size == 0) {
addTerminal(name, primitive)
} else {
addPrimitive(name, primitive)
}
fun <R> addTreeNodeFactory(name: String, primitive: TreeNodeFactory<R>, weight: Double = 1.0) {
factories.add(0, FactoryEntry(name, weight, primitive))
}

fun <R> serialize(tree: TreeNode<R>): JsonNode {
val factory = serializers.find {
it.second == tree.factory
val factory = factories.find {
it.factory == tree.factory
}

if (factory == null) {
Expand All @@ -59,7 +43,7 @@ class PSet<R>(val returnType: NodeType) {
childs.forEach {
childArray.add(it)
}
json.put("factory", factory.first)
json.put("factory", factory.name)
if (parent != objectMapper.nullNode()) {
json.set<ObjectNode>("node", parent)
}
Expand All @@ -72,8 +56,8 @@ class PSet<R>(val returnType: NodeType) {
fun deserialize(jsonNode: JsonNode): TreeNode<*> {
val factoryName = jsonNode.get("factory").asText()!!

val factory = serializers.find {
it.first == factoryName
val factory = factories.find {
it.name == factoryName
}
if (factory == null) {
throw Exception("Unknown factory $factoryName")
Expand All @@ -88,14 +72,44 @@ class PSet<R>(val returnType: NodeType) {
}

val nodeInfo = jsonNode.get("node") ?: null
return factory.second.createNode(children, nodeInfo)
return factory.factory.createNode(children, nodeInfo)
}

fun getTerminalAssignableTo(ret: NodeType): List<TreeNodeFactory<*>> {
return terminals.filter { it.key.isAssignableTo(ret) } .flatMap { it.value }
fun getTerminalsAssignableTo(ret: NodeType): List<TreeNodeFactory<*>> {
return terminals.filter { it.factory.returnType.isAssignableTo(ret) } .map { it.factory }
}

fun getPrimitiveAssignableTo(ret: NodeType): List<TreeNodeFactory<*>> {
return primitives.filter { it.key.isAssignableTo(ret) } .flatMap { it.value }
fun getPrimitivesAssignableTo(ret: NodeType): List<TreeNodeFactory<*>> {
return primitives.filter { it.factory.returnType.isAssignableTo(ret) } .map { it.factory }
}

fun selectTerminalAssignableTo(ret: NodeType): TreeNodeFactory<*>? {
val terminals = this.terminals.toList()
val totalWeight = terminals.map { it.weight }.sum()

var randomNumber = nextDouble() * totalWeight
var cumulative = 0.0;
terminals.forEach {
cumulative = cumulative + it.weight
if (cumulative > randomNumber) {
return it.factory
}
}
return null
}

fun selectPrimitiveAssignableTo(ret: NodeType): TreeNodeFactory<*>? {
val primitives = this.primitives.toList()
val totalWeight = primitives.map { it.weight }.sum()

var randomNumber = nextDouble() * totalWeight
var cumulative = 0.0;
primitives.forEach {
cumulative = cumulative + it.weight
if (cumulative > randomNumber) {
return it.factory
}
}
return null
}
}