From 310a7c446b547d84b02c5da2161958e77ce69f0d Mon Sep 17 00:00:00 2001 From: oSumAtrIX Date: Mon, 21 Mar 2022 18:48:35 +0100 Subject: [PATCH] fix(Io): JAR loading and saving (#8) * refactor: Complete rewrite of `Io` * style: format code * style: rewrite todos * fix: use lateinit instead of nonnull assert for zipEntry * fix: use lateinit instead of nonnull assert for jarEntry & reuse zipEntry * docs: add docs to `Patcher` * test: match output of patcher * chore: add todo to `Io` for removing non-class files Co-authored-by: Sculas --- .../kotlin/net/revanced/patcher/Patcher.kt | 31 +++-- .../net/revanced/patcher/cache/Cache.kt | 2 +- .../net/revanced/patcher/cache/PatchData.kt | 3 +- .../kotlin/net/revanced/patcher/util/Io.kt | 110 ++++++++++++------ .../net/revanced/patcher/writer/ASMWriter.kt | 1 + .../net/revanced/patcher/PatcherTest.kt | 63 +++++----- .../kotlin/net/revanced/patcher/ReaderTest.kt | 4 +- .../net/revanced/patcher/util/TestUtil.kt | 2 +- 8 files changed, 136 insertions(+), 80 deletions(-) diff --git a/src/main/kotlin/net/revanced/patcher/Patcher.kt b/src/main/kotlin/net/revanced/patcher/Patcher.kt index 670a249f..71159617 100644 --- a/src/main/kotlin/net/revanced/patcher/Patcher.kt +++ b/src/main/kotlin/net/revanced/patcher/Patcher.kt @@ -5,28 +5,49 @@ import net.revanced.patcher.patch.Patch import net.revanced.patcher.resolver.MethodResolver import net.revanced.patcher.signature.Signature import net.revanced.patcher.util.Io +import org.objectweb.asm.tree.ClassNode +import java.io.IOException import java.io.InputStream import java.io.OutputStream /** - * The patcher. (docs WIP) + * The Patcher class. + * ***It is of utmost importance that the input and output streams are NEVER closed.*** * * @param input the input stream to read from, must be a JAR + * @param output the output stream to write to * @param signatures the signatures * @sample net.revanced.patcher.PatcherTest + * @throws IOException if one of the streams are closed */ class Patcher( private val input: InputStream, + private val output: OutputStream, signatures: Array, ) { var cache: Cache - private val patches: MutableList = mutableListOf() + + private var io: Io + private val patches = mutableListOf() init { - val classes = Io.readClassesFromJar(input) + val classes = mutableListOf() + io = Io(input, output, classes) + io.readFromJar() cache = Cache(classes, MethodResolver(classes, signatures).resolve()) } + /** + * Saves the output to the output stream. + * Calling this method will close the input and output streams, + * meaning this method should NEVER be called after. + * + * @throws IOException if one of the streams are closed + */ + fun save() { + io.saveAsJar() + } + fun addPatches(vararg patches: Patch) { this.patches.addAll(patches) } @@ -46,8 +67,4 @@ class Patcher( } } } - - fun saveTo(output: OutputStream) { - Io.writeClassesToJar(input, output, cache.classes) - } } \ No newline at end of file diff --git a/src/main/kotlin/net/revanced/patcher/cache/Cache.kt b/src/main/kotlin/net/revanced/patcher/cache/Cache.kt index 3b995ea2..050cec3d 100644 --- a/src/main/kotlin/net/revanced/patcher/cache/Cache.kt +++ b/src/main/kotlin/net/revanced/patcher/cache/Cache.kt @@ -2,7 +2,7 @@ package net.revanced.patcher.cache import org.objectweb.asm.tree.ClassNode -class Cache ( +class Cache( val classes: List, val methods: MethodMap ) diff --git a/src/main/kotlin/net/revanced/patcher/cache/PatchData.kt b/src/main/kotlin/net/revanced/patcher/cache/PatchData.kt index 3381f8e6..c2ddbb95 100644 --- a/src/main/kotlin/net/revanced/patcher/cache/PatchData.kt +++ b/src/main/kotlin/net/revanced/patcher/cache/PatchData.kt @@ -10,8 +10,9 @@ data class PatchData( val method: MethodNode, val scanData: PatternScanData ) { + @Suppress("Unused") // TODO(Sculas): remove this when we have coverage for this method. fun findParentMethod(signature: Signature): PatchData? { - return MethodResolver.resolveMethod(declaringClass, signature) + return MethodResolver.resolveMethod(declaringClass, signature) } } diff --git a/src/main/kotlin/net/revanced/patcher/util/Io.kt b/src/main/kotlin/net/revanced/patcher/util/Io.kt index e54ac68d..1fa4affb 100644 --- a/src/main/kotlin/net/revanced/patcher/util/Io.kt +++ b/src/main/kotlin/net/revanced/patcher/util/Io.kt @@ -3,47 +3,91 @@ package net.revanced.patcher.util import org.objectweb.asm.ClassReader import org.objectweb.asm.ClassWriter import org.objectweb.asm.tree.ClassNode +import java.io.BufferedInputStream import java.io.InputStream import java.io.OutputStream import java.util.jar.JarEntry import java.util.jar.JarInputStream -import java.util.jar.JarOutputStream - -object Io { - fun readClassesFromJar(input: InputStream) = mutableListOf().apply { - val jar = JarInputStream(input) - while (true) { - val e = jar.nextJarEntry ?: break - if (e.name.endsWith(".class")) { - val classNode = ClassNode() - ClassReader(jar.readBytes()).accept(classNode, ClassReader.EXPAND_FRAMES) - this.add(classNode) - } - jar.closeEntry() +import java.util.zip.ZipEntry +import java.util.zip.ZipInputStream +import java.util.zip.ZipOutputStream + +internal class Io( + private val input: InputStream, + private val output: OutputStream, + private val classes: MutableList +) { + private val bufferedInputStream = BufferedInputStream(input) + + fun readFromJar() { + bufferedInputStream.mark(0) + // create a BufferedInputStream in order to read the input stream again when calling saveAsJar(..) + val jis = JarInputStream(bufferedInputStream) + + // read all entries from the input stream + // we use JarEntry because we only read .class files + lateinit var jarEntry: JarEntry + while (jis.nextJarEntry.also { if (it != null) jarEntry = it } != null) { + // if the current entry ends with .class (indicating a java class file), add it to our list of classes to return + if (jarEntry.name.endsWith(".class")) { + // create a new ClassNode + val classNode = ClassNode() + // read the bytes with a ClassReader into the ClassNode + ClassReader(jis.readBytes()).accept(classNode, ClassReader.EXPAND_FRAMES) + // add it to our list + classes.add(classNode) } + + // finally, close the entry + jis.closeEntry() + } + + // at last reset the buffered input stream + bufferedInputStream.reset() } - fun writeClassesToJar(input: InputStream, output: OutputStream, classes: List) { - val jis = JarInputStream(input) - val jos = JarOutputStream(output) - - // TODO: Add support for adding new/custom classes - while (true) { - val next = jis.nextJarEntry ?: break - val e = JarEntry(next) // clone it, to not modify the input (if possible) - jos.putNextEntry(e) - - val clazz = classes.singleOrNull { - clazz -> clazz.name+".class" == e.name // clazz.name is the class name only while e.name is the full filename with extension - }; - if (clazz != null) { - val cw = ClassWriter(ClassWriter.COMPUTE_MAXS or ClassWriter.COMPUTE_FRAMES) - clazz.accept(cw) - jos.write(cw.toByteArray()) - } else { - jos.write(jis.readBytes()) - } + fun saveAsJar() { + val jis = ZipInputStream(bufferedInputStream) + val jos = ZipOutputStream(output) + + // first write all non .class zip entries from the original input stream to the output stream + // we read it first to close the input stream as fast as possible + // TODO(oSumAtrIX): There is currently no way to remove non .class files. + lateinit var zipEntry: ZipEntry + while (jis.nextEntry.also { if (it != null) zipEntry = it } != null) { + // skip all class files because we added them in the loop above + // TODO(oSumAtrIX): Check for zipEntry.isDirectory + if (zipEntry.name.endsWith(".class")) continue + + // create a new zipEntry and write the contents of the zipEntry to the output stream + jos.putNextEntry(ZipEntry(zipEntry)) + jos.write(jis.readBytes()) + + // close the newly created zipEntry + jos.closeEntry() + } + + // finally, close the input stream + jis.close() + bufferedInputStream.close() + input.close() + + // now write all the patched classes to the output stream + for (patchedClass in classes) { + // create a new entry of the patched class + jos.putNextEntry(JarEntry(patchedClass.name + ".class")) + + // parse the patched class to a byte array and write it to the output stream + val cw = ClassWriter(ClassWriter.COMPUTE_MAXS or ClassWriter.COMPUTE_FRAMES) + patchedClass.accept(cw) + jos.write(cw.toByteArray()) + + // close the newly created jar entry jos.closeEntry() } + + // finally, close the rest of the streams + jos.close() + output.close() } } \ No newline at end of file diff --git a/src/main/kotlin/net/revanced/patcher/writer/ASMWriter.kt b/src/main/kotlin/net/revanced/patcher/writer/ASMWriter.kt index df56a3cb..ce601c49 100644 --- a/src/main/kotlin/net/revanced/patcher/writer/ASMWriter.kt +++ b/src/main/kotlin/net/revanced/patcher/writer/ASMWriter.kt @@ -7,6 +7,7 @@ object ASMWriter { fun InsnList.setAt(index: Int, node: AbstractInsnNode) { this[this.get(index)] = node } + fun InsnList.insertAt(index: Int = 0, vararg nodes: AbstractInsnNode) { this.insert(this.get(index), nodes.toInsnList()) } diff --git a/src/test/kotlin/net/revanced/patcher/PatcherTest.kt b/src/test/kotlin/net/revanced/patcher/PatcherTest.kt index 9cecec5c..519cf309 100644 --- a/src/test/kotlin/net/revanced/patcher/PatcherTest.kt +++ b/src/test/kotlin/net/revanced/patcher/PatcherTest.kt @@ -12,13 +12,16 @@ import net.revanced.patcher.writer.ASMWriter.setAt import org.junit.jupiter.api.assertDoesNotThrow import org.objectweb.asm.Opcodes.* import org.objectweb.asm.Type -import org.objectweb.asm.tree.* +import org.objectweb.asm.tree.FieldInsnNode +import org.objectweb.asm.tree.LdcInsnNode +import org.objectweb.asm.tree.MethodInsnNode +import java.io.ByteArrayOutputStream import java.io.PrintStream import kotlin.test.Test internal class PatcherTest { companion object { - val testSigs: Array = arrayOf( + val testSignatures: Array = arrayOf( // Java: // public static void main(String[] args) { // System.out.println("Hello, world!"); @@ -45,8 +48,11 @@ internal class PatcherTest { @Test fun testPatcher() { - val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!! - val patcher = Patcher(testData, testSigs) + val patcher = Patcher( + PatcherTest::class.java.getResourceAsStream("/test1.jar")!!, + ByteArrayOutputStream(), + testSignatures + ) patcher.addPatches( object : Patch("TestPatch") { @@ -74,9 +80,9 @@ internal class PatcherTest { startIndex + 1, FieldInsnNode( GETSTATIC, - Type.getInternalName(System::class.java), // "java/io/System" + Type.getInternalName(System::class.java), // "java/lang/System" "out", - Type.getInternalName(PrintStream::class.java) // "java.io.PrintStream" + "L" + Type.getInternalName(PrintStream::class.java) // "Ljava/io/PrintStream" ), LdcInsnNode("Hello, ReVanced! Adding bytecode."), MethodInsnNode( @@ -111,41 +117,27 @@ internal class PatcherTest { ) // Apply all patches loaded in the patcher - val result = patcher.applyPatches() + val patchResult = patcher.applyPatches() // You can check if an error occurred - for ((s, r) in result) { - if (r.isFailure) { - throw Exception("Patch $s failed", r.exceptionOrNull()!!) + for ((patchName, result) in patchResult) { + if (result.isFailure) { + throw Exception("Patch $patchName failed", result.exceptionOrNull()!!) } } - // TODO Doesn't work, needs to be fixed. - //val out = ByteArrayOutputStream() - //patcher.saveTo(out) - //assertTrue( - // // 8 is a random value, it's just weird if it's any lower than that - // out.size() > 8, - // "Output must be at least 8 bytes" - //) - // - //out.close() - testData.close() + patcher.save() } - // TODO Doesn't work, needs to be fixed. - //@Test - //fun `test patcher with no changes`() { - // val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!! - // val available = testData.available() - // val patcher = Patcher(testData, testSigs) - // - // val out = ByteArrayOutputStream() - // patcher.saveTo(out) - // assertEquals(available, out.size()) - // - // out.close() - // testData.close() - //} + @Test + fun `test patcher with no changes`() { + val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!! + // val available = testData.available() + val out = ByteArrayOutputStream() + Patcher(testData, out, testSignatures).save() + // FIXME(Sculas): There seems to be a 1-byte difference, not sure what it is. + // assertEquals(available, out.size()) + out.close() + } @Test() fun `should not raise an exception if any signature member except the name is missing`() { @@ -154,6 +146,7 @@ internal class PatcherTest { assertDoesNotThrow("Should raise an exception because opcodes is empty") { Patcher( PatcherTest::class.java.getResourceAsStream("/test1.jar")!!, + ByteArrayOutputStream(), arrayOf( Signature( sigName, diff --git a/src/test/kotlin/net/revanced/patcher/ReaderTest.kt b/src/test/kotlin/net/revanced/patcher/ReaderTest.kt index e0afb2a4..6ecf11b7 100644 --- a/src/test/kotlin/net/revanced/patcher/ReaderTest.kt +++ b/src/test/kotlin/net/revanced/patcher/ReaderTest.kt @@ -1,12 +1,12 @@ package net.revanced.patcher +import java.io.ByteArrayOutputStream import kotlin.test.Test internal class ReaderTest { @Test fun `read jar containing multiple classes`() { val testData = PatcherTest::class.java.getResourceAsStream("/test2.jar")!! - Patcher(testData, PatcherTest.testSigs) // reusing test sigs from PatcherTest - testData.close() + Patcher(testData, ByteArrayOutputStream(), PatcherTest.testSignatures) // reusing test sigs from PatcherTest } } \ No newline at end of file diff --git a/src/test/kotlin/net/revanced/patcher/util/TestUtil.kt b/src/test/kotlin/net/revanced/patcher/util/TestUtil.kt index 55362dc9..a0a9fea7 100644 --- a/src/test/kotlin/net/revanced/patcher/util/TestUtil.kt +++ b/src/test/kotlin/net/revanced/patcher/util/TestUtil.kt @@ -17,7 +17,7 @@ object TestUtil { private fun AbstractInsnNode.nodeString(): String { val sb = NodeStringBuilder() when (this) { - // TODO: Add more types + // TODO(Sculas): Add more types is LdcInsnNode -> sb .addType("cst", cst) is FieldInsnNode -> sb