Skip to content

Commit

Permalink
Store all meta-data related to a top-level variable name
Browse files Browse the repository at this point in the history
  • Loading branch information
nikolay-egorov committed Aug 20, 2021
1 parent 6c85ff4 commit db437cc
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 53 deletions.
4 changes: 1 addition & 3 deletions src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,7 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY, content = result))
}
} else {
repl.serializeVariables(content.cellId, content.descriptorsState) { result ->
sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY, content = result))
}
sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY, content = null))
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ interface ReplForJupyter {

suspend fun listErrors(code: Code, callback: (ListErrorsResult) -> Unit)

suspend fun serializeVariables(cellId: Int, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit)
suspend fun serializeVariables(cellId: Int, topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit)

suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, pathToDescriptor: List<String> = emptyList(),
callback: (SerializationReply) -> Unit)
Expand Down Expand Up @@ -535,8 +535,8 @@ class ReplForJupyterImpl(
}

private val serializationQueue = LockQueue<SerializationReply, SerializationArgs>()
override suspend fun serializeVariables(cellId: Int, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit) {
doWithLock(SerializationArgs(descriptorsState, cellId = cellId, callback = callback), serializationQueue, SerializationReply(cellId, descriptorsState), ::doSerializeVariables)
override suspend fun serializeVariables(cellId: Int, topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit) {
doWithLock(SerializationArgs(descriptorsState, cellId = cellId, topLevelVarName = topLevelVarName, callback = callback), serializationQueue, SerializationReply(cellId, descriptorsState), ::doSerializeVariables)
}

override suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, pathToDescriptor: List<String>,
Expand All @@ -552,7 +552,7 @@ class ReplForJupyterImpl(
finalAns
}
args.descriptorsState.forEach { (name, state) ->
resultMap[name] = variablesSerializer.doIncrementalSerialization(cellId - 1, name, state, args.pathToDescriptor)
resultMap[name] = variablesSerializer.doIncrementalSerialization(cellId - 1, args.topLevelVarName ,name, state, args.pathToDescriptor)
}
log.debug("Serialization cellID: $cellId")
log.debug("Serialization answer: ${resultMap.entries.first().value.fieldDescriptor}")
Expand Down
63 changes: 29 additions & 34 deletions src/main/kotlin/org/jetbrains/kotlinx/jupyter/serializationUtils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class ProcessedSerializedVarsState(
data class ProcessedDescriptorsState(
val processedSerializedVarsToJavaProperties: MutableMap<SerializedVariablesState, PropertiesData?> = mutableMapOf(),
val processedSerializedVarsToKTProperties: MutableMap<SerializedVariablesState, KPropertiesData?> = mutableMapOf(),
val instancesPerState: MutableMap<SerializedVariablesState, Any?> = mutableMapOf()
val instancesPerState: MutableMap<SerializedVariablesState, Any?> = mutableMapOf(),
val parent: ProcessedDescriptorsState? = null
)

data class RuntimeObjectWrapper(
Expand Down Expand Up @@ -276,16 +277,6 @@ class VariablesSerializer(
*/
if (descriptors.size == 1 && descriptors.entries.first().key == "size") {
descriptors.addDescriptor(value, "data")
/*
if (value is Collection<*>) {
value.forEach {
iterateThrough(descriptors, it)
}
} else if (value is Array<*>) {
value.forEach {
iterateThrough(descriptors, it)
}
}*/
}
}

Expand Down Expand Up @@ -319,9 +310,9 @@ class VariablesSerializer(
)

/**
* Stores info computed descriptors in a cell
* Stores info computed descriptors in a cell starting from the very variable as a root
*/
private val computedDescriptorsPerCell: MutableMap<Int, ProcessedDescriptorsState> = mutableMapOf()
private val computedDescriptorsPerCell: MutableMap<Int, MutableMap<String, ProcessedDescriptorsState>> = mutableMapOf()

private val isSerializationActive: Boolean = System.getProperty(serializationSystemProperty)?.toBooleanStrictOrNull() ?: true

Expand Down Expand Up @@ -409,7 +400,7 @@ class VariablesSerializer(
log.debug("Unchanged variables: ${unchangedVariables - neededEntries.keys}")

// remove previous data
computedDescriptorsPerCell[cellId]?.instancesPerState?.clear()
// computedDescriptorsPerCell[cellId]?.instancesPerState?.clear()
val serializedData = neededEntries.mapValues {
val actualCell = variablesCells[it.key] ?: cellId
serializeVariableState(actualCell, it.key, it.value)
Expand All @@ -424,14 +415,15 @@ class VariablesSerializer(

fun doIncrementalSerialization(
cellId: Int,
topLevelName: String,
propertyName: String,
serializedVariablesState: SerializedVariablesState,
pathToDescriptor: List<String> = emptyList()
): SerializedVariablesState {
if (!isSerializationActive) return serializedVariablesState

val cellDescriptors = computedDescriptorsPerCell[cellId] ?: return serializedVariablesState
return updateVariableState(cellId, propertyName, cellDescriptors, serializedVariablesState)
return updateVariableState(cellId, propertyName, cellDescriptors[topLevelName]!!, serializedVariablesState)
}

/**
Expand All @@ -456,38 +448,39 @@ class VariablesSerializer(
return serializeVariableState(cellId, propertyName, property, value, isRecursive = false, false)
}

private fun serializeVariableState(cellId: Int, name: String?, variableState: VariableState?, isOverride: Boolean = true): SerializedVariablesState {
if (!isSerializationActive || variableState == null || name == null) return SerializedVariablesState()
private fun serializeVariableState(cellId: Int, topLevelName: String?, variableState: VariableState?, isOverride: Boolean = true): SerializedVariablesState {
if (!isSerializationActive || variableState == null || topLevelName == null) return SerializedVariablesState()
// force recursive check
variableState.stringValue
return serializeVariableState(cellId, name, variableState.property, variableState.value.getOrNull(), variableState.isRecursive, isOverride)
return serializeVariableState(cellId, topLevelName, variableState.property, variableState.value.getOrNull(), variableState.isRecursive, isOverride)
}

private fun serializeVariableState(cellId: Int, name: String, property: Field?, value: Any?, isRecursive: Boolean, isOverride: Boolean = true): SerializedVariablesState {
private fun serializeVariableState(cellId: Int, topLevelName: String, property: Field?, value: Any?, isRecursive: Boolean, isOverride: Boolean = true): SerializedVariablesState {
val wrapper = value.toObjectWrapper(isRecursive)
val processedData = createSerializeVariableState(name, getSimpleTypeNameFrom(property, value), wrapper)
return doActualSerialization(cellId, processedData, wrapper, isRecursive, isOverride)
val processedData = createSerializeVariableState(topLevelName, getSimpleTypeNameFrom(property, value), wrapper)
return doActualSerialization(cellId, topLevelName, processedData, wrapper, isRecursive, isOverride)
}

private fun serializeVariableState(cellId: Int, name: String, property: KProperty<*>, value: Any?, isRecursive: Boolean, isOverride: Boolean = true): SerializedVariablesState {
private fun serializeVariableState(cellId: Int, topLevelName: String, property: KProperty<*>, value: Any?, isRecursive: Boolean, isOverride: Boolean = true): SerializedVariablesState {
val wrapper = value.toObjectWrapper(isRecursive)
val processedData = createSerializeVariableState(name, getSimpleTypeNameFrom(property, value), wrapper)
return doActualSerialization(cellId, processedData, wrapper, isRecursive, isOverride)
val processedData = createSerializeVariableState(topLevelName, getSimpleTypeNameFrom(property, value), wrapper)
return doActualSerialization(cellId, topLevelName, processedData, wrapper, isRecursive, isOverride)
}

private fun doActualSerialization(cellId: Int, processedData: ProcessedSerializedVarsState, value: RuntimeObjectWrapper, isRecursive: Boolean, isOverride: Boolean = true): SerializedVariablesState {
private fun doActualSerialization(cellId: Int, topLevelName:String, processedData: ProcessedSerializedVarsState, value: RuntimeObjectWrapper, isRecursive: Boolean, isOverride: Boolean = true): SerializedVariablesState {
val serializedVersion = processedData.serializedVariablesState

seenObjectsPerCell.putIfAbsent(cellId, mutableMapOf())
computedDescriptorsPerCell.putIfAbsent(cellId, mutableMapOf())

if (isOverride) {
val instances = computedDescriptorsPerCell[cellId]?.instancesPerState
computedDescriptorsPerCell[cellId] = ProcessedDescriptorsState()
val instances = computedDescriptorsPerCell[cellId]?.get(topLevelName)?.instancesPerState
computedDescriptorsPerCell[cellId]!![topLevelName] = ProcessedDescriptorsState()
if (instances != null) {
computedDescriptorsPerCell[cellId]!!.instancesPerState += instances
computedDescriptorsPerCell[cellId]!![topLevelName]!!.instancesPerState += instances
}
}
val currentCellDescriptors = computedDescriptorsPerCell[cellId]
val currentCellDescriptors = computedDescriptorsPerCell[cellId]?.get(topLevelName)
// TODO should we stack?
currentCellDescriptors!!.processedSerializedVarsToJavaProperties[serializedVersion] = processedData.propertiesData
currentCellDescriptors.processedSerializedVarsToKTProperties[serializedVersion] = processedData.kPropertiesData
Expand All @@ -507,9 +500,9 @@ class VariablesSerializer(
if (kProperties?.size == 1 && kProperties.first().name == "size") {
serializedVersion.fieldDescriptor.addDescriptor(value.objectInstance, "data")
}
iterateThroughContainerMembers(cellId, value.objectInstance, serializedVersion.fieldDescriptor, isRecursive = isRecursive, kProperties = currentCellDescriptors.processedSerializedVarsToKTProperties[serializedVersion])
iterateThroughContainerMembers(cellId, topLevelName, value.objectInstance, serializedVersion.fieldDescriptor, isRecursive = isRecursive, kProperties = currentCellDescriptors.processedSerializedVarsToKTProperties[serializedVersion])
} else {
iterateThroughContainerMembers(cellId, value.objectInstance, serializedVersion.fieldDescriptor, isRecursive = isRecursive, currentCellDescriptors.processedSerializedVarsToJavaProperties[serializedVersion])
iterateThroughContainerMembers(cellId, topLevelName, value.objectInstance, serializedVersion.fieldDescriptor, isRecursive = isRecursive, currentCellDescriptors.processedSerializedVarsToJavaProperties[serializedVersion])
}
}

Expand All @@ -518,6 +511,7 @@ class VariablesSerializer(

private fun iterateThroughContainerMembers(
cellId: Int,
topLevelName: String,
callInstance: Any?,
descriptor: MutableFieldDescriptor,
isRecursive: Boolean = false,
Expand All @@ -543,7 +537,7 @@ class VariablesSerializer(

seenObjectsPerCell.putIfAbsent(cellId, mutableMapOf())
val seenObjectsPerCell = seenObjectsPerCell[cellId]
val currentCellDescriptors = computedDescriptorsPerCell[cellId]!!
val currentCellDescriptors = computedDescriptorsPerCell[cellId]!![topLevelName]!!
// ok, it's a copy on the left for some reason
val instancesPerState = currentCellDescriptors.instancesPerState

Expand All @@ -570,7 +564,7 @@ class VariablesSerializer(
}

val isArrayType = checkForPossibleArray(callInstance)
computedDescriptorsPerCell[cellId]!!.instancesPerState += instancesPerState
computedDescriptorsPerCell[cellId]!![topLevelName]!!.instancesPerState += instancesPerState

if (descriptor.size == 2 && (descriptor.containsKey("data") || descriptor.containsKey("element"))) {
val singleElemMode = descriptor.containsKey("element")
Expand Down Expand Up @@ -606,9 +600,10 @@ class VariablesSerializer(
}
}.toObjectWrapper(isRecursive)

computedDescriptorsPerCell[cellId]!!.instancesPerState += instancesPerState
computedDescriptorsPerCell[cellId]!![topLevelName]!!.instancesPerState += instancesPerState
iterateThroughContainerMembers(
cellId,
topLevelName,
neededCallInstance.objectInstance,
serializedVariablesState.fieldDescriptor,
isRecursive = isRecursive,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ class ReplVarsTest : AbstractSingleReplTest() {
val serializer = repl.variablesSerializer
val descriptor = res.evaluatedVariablesState["l"]!!.fieldDescriptor
val innerList = descriptor["elementData"]!!.fieldDescriptor["data"]
val newData = serializer.doIncrementalSerialization(0, "data", innerList!!)
val newData = serializer.doIncrementalSerialization(0, "l", "data", innerList!!)
assertEquals(2, newData.fieldDescriptor.size)
}

Expand Down Expand Up @@ -808,7 +808,7 @@ class ReplVarsTest : AbstractSingleReplTest() {
jupyterId = 2
).metadata.evaluatedVariablesState
val innerList = res["l"]!!.fieldDescriptor["elementData"]!!.fieldDescriptor["data"]
val newData = serializer.doIncrementalSerialization(0, "data", innerList!!)
val newData = serializer.doIncrementalSerialization(0, "l","data", innerList!!)
assertTrue(newData.isContainer)
assertTrue(newData.fieldDescriptor.size > 4)
}
Expand Down Expand Up @@ -924,7 +924,7 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {
assertEquals(listOf(1, 2, 3, 4).toString().substring(1, actualContainer.value!!.length + 1), actualContainer.value)

val serializer = repl.variablesSerializer
val newData = serializer.doIncrementalSerialization(0, "data", actualContainer)
val newData = serializer.doIncrementalSerialization(0, "x","data", actualContainer)
}

@Test
Expand Down Expand Up @@ -1013,7 +1013,7 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {

val serializer = repl.variablesSerializer

val newData = serializer.doIncrementalSerialization(0, "i", descriptor["i"]!!)
val newData = serializer.doIncrementalSerialization(0, "c", "i", descriptor["i"]!!)
}

@Test
Expand All @@ -1033,7 +1033,7 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {
val actualContainer = listData.fieldDescriptor.entries.first().value!!
val serializer = repl.variablesSerializer

val newData = serializer.doIncrementalSerialization(0, listData.fieldDescriptor.entries.first().key, actualContainer)
val newData = serializer.doIncrementalSerialization(0, "x", listData.fieldDescriptor.entries.first().key, actualContainer)
val receivedDescriptor = newData.fieldDescriptor
assertEquals(4, receivedDescriptor.size)

Expand All @@ -1046,7 +1046,7 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {
}

val depthMostNode = actualContainer.fieldDescriptor.entries.first { it.value!!.isContainer }
val serializationAns = serializer.doIncrementalSerialization(0, depthMostNode.key, depthMostNode.value!!)
val serializationAns = serializer.doIncrementalSerialization(0, "x", depthMostNode.key, depthMostNode.value!!)
}

@Test
Expand All @@ -1064,7 +1064,7 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {
val serializer = repl.variablesSerializer
val path = listOf("x", "a")

val newData = serializer.doIncrementalSerialization(0, listData.fieldDescriptor.entries.first().key, actualContainer, path)
val newData = serializer.doIncrementalSerialization(0, "x", listData.fieldDescriptor.entries.first().key, actualContainer, path)
val receivedDescriptor = newData.fieldDescriptor
assertEquals(4, receivedDescriptor.size)

Expand Down Expand Up @@ -1105,7 +1105,7 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {

val serializer = repl.variablesSerializer

var newData = serializer.doIncrementalSerialization(0, "values", valuesDescriptor)
var newData = serializer.doIncrementalSerialization(0, "x", "values", valuesDescriptor)
var newDescriptor = newData.fieldDescriptor
assertEquals("4", newDescriptor["size"]!!.value)
assertEquals(3, newDescriptor["data"]!!.fieldDescriptor.size)
Expand All @@ -1120,7 +1120,7 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {
val entriesDescriptor = listDescriptors["entries"]!!
assertEquals("4", valuesDescriptor.fieldDescriptor["size"]!!.value)
assertTrue(valuesDescriptor.fieldDescriptor["data"]!!.isContainer)
newData = serializer.doIncrementalSerialization(0, "entries", entriesDescriptor)
newData = serializer.doIncrementalSerialization(0, "x", "entries", entriesDescriptor)
newDescriptor = newData.fieldDescriptor
assertEquals("4", newDescriptor["size"]!!.value)
assertEquals(4, newDescriptor["data"]!!.fieldDescriptor.size)
Expand Down Expand Up @@ -1200,7 +1200,7 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {
val propertyName = listData.fieldDescriptor.entries.first().key

runBlocking {
repl.serializeVariables(1, mapOf(propertyName to actualContainer)) { result ->
repl.serializeVariables(1, "x", mapOf(propertyName to actualContainer)) { result ->
val data = result.descriptorsState
assertTrue(data.isNotEmpty())

Expand Down Expand Up @@ -1261,7 +1261,7 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {
val propertyName = listData.fieldDescriptor.entries.first().key

runBlocking {
repl.serializeVariables(1, mapOf(propertyName to actualContainer)) { result ->
repl.serializeVariables(1, "c", mapOf(propertyName to actualContainer)) { result ->
val data = result.descriptorsState
assertTrue(data.isNotEmpty())

Expand All @@ -1276,7 +1276,7 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {

val anotherI = originalClass.fieldDescriptor["i"]!!
runBlocking {
repl.serializeVariables(1, mapOf(propertyName to anotherI)) { res ->
repl.serializeVariables(1, "c", mapOf(propertyName to anotherI)) { res ->
val data = res.descriptorsState
val innerList = data.entries.last().value
assertTrue(innerList.isContainer)
Expand Down Expand Up @@ -1368,4 +1368,22 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {
state = repl.notebook.unchangedVariables()
assertTrue(state.isEmpty())
}

@Test
fun testSerializationClearInfo() {
var res = eval(
"""
val x = listOf(1, 2, 3, 4)
""".trimIndent(),
jupyterId = 1
).metadata.evaluatedVariablesState
var state = repl.notebook.unchangedVariables()
res = eval(
"""
val x = listOf(1, 2, 3, 4)
""".trimIndent(),
jupyterId = 2
).metadata.evaluatedVariablesState
val a = 1
}
}

0 comments on commit db437cc

Please sign in to comment.