Skip to content

Commit

Permalink
Merge pull request #695 from zhelenskiy/add_jvm_bytecode
Browse files Browse the repository at this point in the history
Add JVM bytecode generation for Kotlin/JVM
  • Loading branch information
nikpachoo committed May 16, 2024
2 parents 46cafc0 + e02cead commit 93bc274
Show file tree
Hide file tree
Showing 12 changed files with 208 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.compiler.server.compiler.components

import com.compiler.server.executor.CommandLineArgument
import com.compiler.server.executor.JavaExecutor
import com.compiler.server.model.ExecutionResult
import com.compiler.server.model.JvmExecutionResult
import com.compiler.server.model.OutputDirectory
import com.compiler.server.model.bean.LibrariesFile
import com.compiler.server.model.toExceptionDescriptor
Expand All @@ -16,9 +16,12 @@ import org.jetbrains.org.objectweb.asm.ClassReader.*
import org.jetbrains.org.objectweb.asm.ClassVisitor
import org.jetbrains.org.objectweb.asm.MethodVisitor
import org.jetbrains.org.objectweb.asm.Opcodes.*
import org.jetbrains.org.objectweb.asm.util.TraceClassVisitor
import org.springframework.beans.factory.annotation.Value
import org.springframework.stereotype.Component
import java.io.File
import java.io.PrintWriter
import java.io.StringWriter
import java.nio.file.FileVisitResult
import java.nio.file.Files
import java.nio.file.Path
Expand All @@ -38,16 +41,34 @@ class KotlinCompiler(
val mainClasses: Set<String> = emptySet()
)

fun run(files: List<KtFile>, args: String): ExecutionResult {
return execute(files) { output, compiled ->
private fun ByteArray.asHumanReadable(): String {
val classReader = ClassReader(this)
val stringWriter = StringWriter()
val printWriter = PrintWriter(stringWriter)
val traceClassVisitor = TraceClassVisitor(printWriter)

classReader.accept(traceClassVisitor, 0)

return stringWriter.toString()
}

private fun JvmExecutionResult.addByteCode(compiled: JvmClasses) {
jvmByteCode = compiled.files
.mapNotNull { (_, bytes) -> runCatching { bytes.asHumanReadable() }.getOrNull() }
.takeUnless { it.isEmpty() }
?.joinToString("\n\n")
}

fun run(files: List<KtFile>, addByteCode: Boolean, args: String): JvmExecutionResult {
return execute(files, addByteCode) { output, compiled ->
val mainClass = JavaRunnerExecutor::class.java.name
val compiledMainClass = when (compiled.mainClasses.size) {
0 -> return@execute ExecutionResult(
0 -> return@execute JvmExecutionResult(
exception = IllegalArgumentException("No main method found in project").toExceptionDescriptor()
)

1 -> compiled.mainClasses.single()
else -> return@execute ExecutionResult(
else -> return@execute JvmExecutionResult(
exception = IllegalArgumentException(
"Multiple classes in project contain main methods found: ${compiled.mainClasses.joinToString()}"
).toExceptionDescriptor()
Expand All @@ -59,8 +80,8 @@ class KotlinCompiler(
}
}

fun test(files: List<KtFile>): ExecutionResult {
return execute(files) { output, _ ->
fun test(files: List<KtFile>, addByteCode: Boolean): JvmExecutionResult {
return execute(files, addByteCode) { output, _ ->
val mainClass = JUnitExecutors::class.java.name
javaExecutor.execute(argsFrom(mainClass, output, listOf(output.path.toString())))
.asJUnitExecutionResult()
Expand Down Expand Up @@ -117,22 +138,26 @@ class KotlinCompiler(

private fun execute(
files: List<KtFile>,
block: (output: OutputDirectory, compilation: JvmClasses) -> ExecutionResult
): ExecutionResult = try {
addByteCode: Boolean,
block: (output: OutputDirectory, compilation: JvmClasses) -> JvmExecutionResult
): JvmExecutionResult = try {
when (val compilationResult = compile(files)) {
is Compiled<JvmClasses> -> {
usingTempDirectory { outputDir ->
val output = write(compilationResult.result, outputDir)
block(output, compilationResult.result).also {
it.addWarnings(compilationResult.compilerDiagnostics)
if (addByteCode) {
it.addByteCode(compilationResult.result)
}
}
}
}

is NotCompiled -> ExecutionResult(compilerDiagnostics = compilationResult.compilerDiagnostics)
is NotCompiled -> JvmExecutionResult(compilerDiagnostics = compilationResult.compilerDiagnostics)
}
} catch (e: Exception) {
ExecutionResult(exception = e.toExceptionDescriptor())
JvmExecutionResult(exception = e.toExceptionDescriptor())
}

private fun write(classes: JvmClasses, outputDir: Path): OutputDirectory {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@ import org.springframework.web.bind.annotation.*
@RequestMapping(value = ["/api/compiler", "/api/**/compiler"])
class CompilerRestController(private val kotlinProjectExecutor: KotlinProjectExecutor) {
@PostMapping("/run")
fun executeKotlinProjectEndpoint(@RequestBody project: Project): ExecutionResult {
return kotlinProjectExecutor.run(project)
fun executeKotlinProjectEndpoint(
@RequestBody project: Project,
@RequestParam(defaultValue = "false") addByteCode: Boolean,
): ExecutionResult {
return kotlinProjectExecutor.run(project, addByteCode)
}

@PostMapping("/test")
fun testKotlinProjectEndpoint(@RequestBody project: Project): ExecutionResult {
return kotlinProjectExecutor.test(project)
fun testKotlinProjectEndpoint(
@RequestBody project: Project,
@RequestParam(defaultValue = "false") addByteCode: Boolean,
): ExecutionResult {
return kotlinProjectExecutor.test(project, addByteCode)
}

@PostMapping("/translate")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class KotlinPlaygroundRestController(private val kotlinProjectExecutor: KotlinPr
@RequestParam type: String,
@RequestParam(required = false) line: Int?,
@RequestParam(required = false) ch: Int?,
@RequestParam(required = false) project: Project?
@RequestParam(required = false) project: Project?,
@RequestParam(defaultValue = "false") addByteCode: Boolean,
): ResponseEntity<*> {
val result = when (type) {
"getKotlinVersions" -> listOf(kotlinProjectExecutor.getVersion())
Expand All @@ -39,7 +40,7 @@ class KotlinPlaygroundRestController(private val kotlinProjectExecutor: KotlinPr
when (type) {
"run" -> {
when (project.confType) {
ProjectType.JAVA -> kotlinProjectExecutor.run(project)
ProjectType.JAVA -> kotlinProjectExecutor.run(project, addByteCode)
ProjectType.JS -> throw LegacyJsException()
ProjectType.JS_IR, ProjectType.CANVAS ->
kotlinProjectExecutor.convertToJsIr(
Expand All @@ -49,7 +50,7 @@ class KotlinPlaygroundRestController(private val kotlinProjectExecutor: KotlinPr
project,
debugInfo = false,
)
ProjectType.JUNIT -> kotlinProjectExecutor.test(project)
ProjectType.JUNIT -> kotlinProjectExecutor.test(project, addByteCode)
}
}

Expand Down
13 changes: 10 additions & 3 deletions src/main/kotlin/com/compiler/server/model/ExecutionResult.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import com.fasterxml.jackson.databind.SerializerProvider
import com.fasterxml.jackson.databind.annotation.JsonDeserialize
import com.fasterxml.jackson.databind.annotation.JsonSerialize

open class ExecutionResult(
sealed class ExecutionResult(
@field:JsonProperty("errors")
open var compilerDiagnostics: CompilerDiagnostics = CompilerDiagnostics(),
open var exception: ExceptionDescriptor? = null
Expand Down Expand Up @@ -51,6 +51,12 @@ data class CompilerDiagnostics(
val map: Map<String, List<ErrorDescriptor>> = mapOf()
): List<ErrorDescriptor> by map.values.flatten()

open class JvmExecutionResult(
compilerDiagnostics: CompilerDiagnostics = CompilerDiagnostics(),
exception: ExceptionDescriptor? = null,
var jvmByteCode: String? = null,
): ExecutionResult(compilerDiagnostics, exception)

abstract class TranslationResultWithJsCode(
open val jsCode: String?,
compilerDiagnostics: CompilerDiagnostics,
Expand Down Expand Up @@ -79,8 +85,9 @@ class JunitExecutionResult(
val testResults: Map<String, List<TestDescription>> = emptyMap(),
override var exception: ExceptionDescriptor? = null,
@field:JsonProperty("errors")
override var compilerDiagnostics: CompilerDiagnostics = CompilerDiagnostics()
) : ExecutionResult(compilerDiagnostics, exception)
override var compilerDiagnostics: CompilerDiagnostics = CompilerDiagnostics(),
jvmBytecode: String? = null,
) : JvmExecutionResult(compilerDiagnostics, exception, jvmBytecode)

private fun unEscapeOutput(value: String) = value.replace("&amp;lt;".toRegex(), "<")
.replace("&amp;gt;".toRegex(), ">")
Expand Down
13 changes: 7 additions & 6 deletions src/main/kotlin/com/compiler/server/model/ProgramOutput.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,20 @@ const val ERROR_STREAM_END = "</errStream>"

data class ProgramOutput(
val standardOutput: String = "",
val jvmByteCode: String? = null,
val exception: Exception? = null,
val restriction: String? = null
) {
fun asExecutionResult(): ExecutionResult {
fun asExecutionResult(): JvmExecutionResult {
return when {
restriction != null -> ExecutionResult().apply { text = buildRestriction(restriction) }
exception != null -> ExecutionResult(exception = exception.toExceptionDescriptor())
standardOutput.isBlank() -> ExecutionResult()
restriction != null -> JvmExecutionResult().apply { text = buildRestriction(restriction) }
exception != null -> JvmExecutionResult(exception = exception.toExceptionDescriptor())
standardOutput.isBlank() -> JvmExecutionResult()
else -> {
try {
outputMapper.readValue(standardOutput, ExecutionResult::class.java)
outputMapper.readValue(standardOutput, JvmExecutionResult::class.java)
} catch (e: Exception) {
ExecutionResult(exception = e.toExceptionDescriptor())
JvmExecutionResult(exception = e.toExceptionDescriptor())
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ class KotlinProjectExecutor(

private val log = LoggerFactory.getLogger(KotlinProjectExecutor::class.java)

fun run(project: Project): ExecutionResult {
fun run(project: Project, addByteCode: Boolean): ExecutionResult {
return kotlinEnvironment.environment { environment ->
val files = getFilesFrom(project, environment).map { it.kotlinFile }
kotlinCompiler.run(files, project.args)
kotlinCompiler.run(files, addByteCode, project.args)
}.also { logExecutionResult(project, it) }
}

fun test(project: Project): ExecutionResult {
fun test(project: Project, addByteCode: Boolean): ExecutionResult {
return kotlinEnvironment.environment { environment ->
val files = getFilesFrom(project, environment).map { it.kotlinFile }
kotlinCompiler.test(files)
kotlinCompiler.test(files, addByteCode)
}.also { logExecutionResult(project, it) }
}

Expand Down
4 changes: 2 additions & 2 deletions src/test/kotlin/com/compiler/server/CompilerAPITest.kt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.compiler.server

import com.compiler.server.generator.generateSingleProject
import com.compiler.server.model.ExecutionResult
import com.compiler.server.model.JvmExecutionResult
import com.compiler.server.model.bean.VersionInfo
import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
import org.junit.jupiter.api.Test
Expand Down Expand Up @@ -51,7 +51,7 @@ class CompilerAPITest {
),
headers
),
ExecutionResult::class.java
JvmExecutionResult::class.java
)
assertNotNull(response, "Empty response!")
assertContains(
Expand Down
59 changes: 53 additions & 6 deletions src/test/kotlin/com/compiler/server/JUnitTestsRunnerTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ package com.compiler.server
import com.compiler.server.base.BaseJUnitTest
import com.compiler.server.executor.ExecutorMessages
import com.compiler.server.model.TestStatus
import org.intellij.lang.annotations.Language
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test
import kotlin.test.assertContains
import kotlin.test.assertNull

class JUnitTestsRunnerTest : BaseJUnitTest() {

Expand All @@ -25,11 +28,20 @@ class JUnitTestsRunnerTest : BaseJUnitTest() {

@Test
fun `base fail junit test`() {
val test = test(
"fun start(): String = \"OP\"",
"import org.junit.Assert\nimport org.junit.Test\n\nclass TestStart {\n @Test fun testOk() {\n Assert.assertEquals(\"OK\", start())\n }\n}",
koansUtilsFile
)
@Language("kotlin")
val testCode = """
import org.junit.Assert
import org.junit.Test
class TestStart {
@Test fun testOk() {
Assert.assertEquals("OK", start())
}
}
""".trimIndent()
val sourceCode = """fun start(): String = "OP""""

val test = test(sourceCode, testCode, koansUtilsFile)
val fail = test.first()
Assertions.assertTrue(fail.status == TestStatus.FAIL)
Assertions.assertNotNull(fail.comparisonFailure, "comparisonFailure should not be a null")
Expand All @@ -38,4 +50,39 @@ class JUnitTestsRunnerTest : BaseJUnitTest() {
Assertions.assertTrue(it.expected == "OK")
}
}
}

@Test
fun `no bytecode`() {
@Language("kotlin")
val testCode = """
import org.junit.Assert
import org.junit.Test
class TestStart {
@Test fun testOk() {
Assert.assertEquals("OK", "OK")
}
}
""".trimIndent()
val testResults = testRaw(testCode, addByteCode = false)
assertNull(testResults!!.jvmByteCode, "Bytecode should not be generated")
}

@Test
fun `with bytecode`() {
@Language("kotlin")
val testCode = """
import org.junit.Assert
import org.junit.Test
class TestStart {
@Test fun testOk() {
Assert.assertEquals("OK", "OK")
}
}
""".trimIndent()
val testResults = testRaw(testCode, addByteCode = true)
val byteCode = testResults!!.jvmByteCode!!
assertContains(byteCode, "public final testOk()V")
}
}
24 changes: 22 additions & 2 deletions src/test/kotlin/com/compiler/server/JvmRunnerTest.kt
Original file line number Diff line number Diff line change
@@ -1,17 +1,37 @@
package com.compiler.server

import com.compiler.server.base.BaseExecutorTest
import com.compiler.server.model.JvmExecutionResult
import org.junit.jupiter.api.Test
import kotlin.test.assertContains
import kotlin.test.assertEquals
import kotlin.test.assertNull

class JvmRunnerTest : BaseExecutorTest() {

@Test
fun `base execute test JVM`() {
run(
val executionResult = run(
code = "fun main() {\n println(\"Hello, world!!!\")\n}",
contains = "Hello, world!!!"
contains = "Hello, world!!!",
addByteCode = false,
)
assertNull((executionResult as JvmExecutionResult).jvmByteCode, "Bytecode should not be generated")
}

@Test
fun `jvm bytecode`() {
val executionResult = run(
code = "fun main() {\n println(\"Hello, world!!!\")\n}",
contains = "Hello, world!!!",
addByteCode = true,
)

val byteCode = (executionResult as JvmExecutionResult).jvmByteCode!!
assertContains(byteCode, "public static synthetic main([Ljava/lang/String;)V", message = byteCode)
assertContains(byteCode, "public final static main()V", message = byteCode)
assertContains(byteCode, "LDC \"Hello, world!!!\"", message = byteCode)
assertContains(byteCode, "INVOKEVIRTUAL java/io/PrintStream.println (Ljava/lang/Object;)V", message = byteCode)
}

@Test
Expand Down
Loading

0 comments on commit 93bc274

Please sign in to comment.