diff --git a/core/api/kotlinx-io-core.api b/core/api/kotlinx-io-core.api index f8018144c..dfb090038 100644 --- a/core/api/kotlinx-io-core.api +++ b/core/api/kotlinx-io-core.api @@ -54,7 +54,9 @@ public final class kotlinx/io/BuffersKt { } public final class kotlinx/io/ByteStringsKt { + public static final fun indexOf (Lkotlinx/io/Buffer;Lkotlinx/io/bytestring/ByteString;J)J public static final fun indexOf (Lkotlinx/io/Source;Lkotlinx/io/bytestring/ByteString;J)J + public static synthetic fun indexOf$default (Lkotlinx/io/Buffer;Lkotlinx/io/bytestring/ByteString;JILjava/lang/Object;)J public static synthetic fun indexOf$default (Lkotlinx/io/Source;Lkotlinx/io/bytestring/ByteString;JILjava/lang/Object;)J public static final fun readByteString (Lkotlinx/io/Source;)Lkotlinx/io/bytestring/ByteString; public static final fun readByteString (Lkotlinx/io/Source;I)Lkotlinx/io/bytestring/ByteString; diff --git a/core/common/src/ByteStrings.kt b/core/common/src/ByteStrings.kt index 9634d5ea1..b2d1a6fba 100644 --- a/core/common/src/ByteStrings.kt +++ b/core/common/src/ByteStrings.kt @@ -102,38 +102,63 @@ public fun Source.indexOf(byteString: ByteString, startIndex: Long = 0): Long { } var offset = startIndex - val peek = peek() - if (!request(startIndex)) { - return -1L + while (request(offset + byteString.size)) { + val idx = buffer.indexOf(byteString, offset) + if (idx < 0) { + // The buffer does not contain the pattern, let's try fetching at least one extra byte + // and start a new search attempt so that the pattern would fit in the suffix of + // the current buffer + 1 extra byte. + offset = buffer.size - byteString.size + 1 + } else { + return idx + } } - peek.skip(offset) - var resultingIndex = -1L - UnsafeByteStringOperations.withByteArrayUnsafe(byteString) { data -> - while (!peek.exhausted()) { - val index = peek.indexOf(data[0]) - if (index == -1L) { - return@withByteArrayUnsafe - } - offset += index - peek.skip(index) - if (!peek.request(byteString.size.toLong())) { - return@withByteArrayUnsafe - } + return -1 +} - var matches = true - for (idx in data.indices) { - if (data[idx] != peek.buffer[idx.toLong()]) { - matches = false - offset++ - peek.skip(1) - break - } - } - if (matches) { - resultingIndex = offset - return@withByteArrayUnsafe +@OptIn(UnsafeByteStringApi::class) +public fun Buffer.indexOf(byteString: ByteString, startIndex: Long = 0): Long { + require(startIndex <= size) { + "startIndex ($startIndex) should not exceed size ($size)" + } + if (byteString.isEmpty()) return 0 + if (startIndex > size - byteString.size) return -1L + + UnsafeByteStringOperations.withByteArrayUnsafe(byteString) { byteStringData -> + seek(startIndex) { seg, o -> + if (o == -1L) { + return -1L } + var segment = seg!! + var offset = o + do { + // If start index within this segment, the diff will be positive and + // we'll scan the segment starting from the corresponding offset. + // Otherwise, the diff will be negative and we'll scan the segment from the beginning. + val startOffset = maxOf((startIndex - offset).toInt(), 0) + // Try to search the pattern within the current segment. + val idx = segment.indexOfBytesInbound(byteStringData, startOffset) + if (idx != -1) { + // The offset corresponds to the segment's start, idx - to offset within the segment. + return offset + idx.toLong() + } + // firstOutboundOffset corresponds to a first byte starting reading the pattern from which + // will result in running out of the current segment bounds. + val firstOutboundOffset = maxOf(startOffset, segment.size - byteStringData.size + 1) + // Try to find a pattern in all suffixes shorter than the pattern. These suffixes start + // in the current segment, but ends in the following segments; thus we're using outbound function. + val idx1 = segment.indexOfBytesOutbound(byteStringData, firstOutboundOffset, head) + if (idx1 != -1) { + // Offset corresponds to the segment's start, idx - to offset within the segment. + return offset + idx1.toLong() + } + + // We scanned the whole segment, so let's go to the next one + offset += segment.size + segment = segment.next!! + } while (segment !== head && offset + byteString.size <= size) + return -1L } } - return resultingIndex + return -1 } diff --git a/core/common/src/Segment.kt b/core/common/src/Segment.kt index f463aa0d0..08d87908c 100644 --- a/core/common/src/Segment.kt +++ b/core/common/src/Segment.kt @@ -201,7 +201,9 @@ internal fun Segment.indexOf(byte: Byte, startOffset: Int, endOffset: Int): Int require(startOffset in 0 until size) { "$startOffset" } - require(endOffset in startOffset..size) { "$endOffset" } + require(endOffset in startOffset..size) { + "$endOffset" + } val p = pos for (idx in startOffset until endOffset) { if (data[p + idx] == byte) { @@ -210,3 +212,75 @@ internal fun Segment.indexOf(byte: Byte, startOffset: Int, endOffset: Int): Int } return -1 } + +/** + * Searches for a `bytes` pattern within this segment starting at the offset `startOffset`. + * `startOffset` is relative and should be within `[0, size)`. + */ +internal fun Segment.indexOfBytesInbound(bytes: ByteArray, startOffset: Int): Int { + // require(startOffset in 0 until size) + var offset = startOffset + val limit = size - bytes.size + 1 + val firstByte = bytes[0] + while (offset < limit) { + val idx = indexOf(firstByte, offset, limit) + if (idx < 0) { + return -1 + } + var found = true + for (innerIdx in 1 until bytes.size) { + if (data[pos + idx + innerIdx] != bytes[innerIdx]) { + found = false + break + } + } + if (found) { + return idx + } else { + offset++ + } + } + return -1 +} + +/** + * Searches for a `bytes` pattern starting in between offset `startOffset` and `size` within this segment + * and continued in the following segments. + * `startOffset` is relative and should be within `[0, size)`. + */ +internal fun Segment.indexOfBytesOutbound(bytes: ByteArray, startOffset: Int, head: Segment?): Int { + var offset = startOffset + val firstByte = bytes[0] + + while (offset in 0 until size) { + val idx = indexOf(firstByte, offset, size) + if (idx < 0) { + return -1 + } + // The pattern should start in this segment + var seg = this + var scanOffset = offset + + var found = true + for (element in bytes) { + // We ran out of bytes in this segment, + // so let's take the next one and continue the scan there. + if (scanOffset == seg.size) { + val next = seg.next + if (next === head) return -1 + seg = next!! + scanOffset = 0 // we're scanning the next segment right from the beginning + } + if (element != seg.data[seg.pos + scanOffset]) { + found = false + break + } + scanOffset++ + } + if (found) { + return offset + } + offset++ + } + return -1 +} diff --git a/core/common/test/AbstractSourceTest.kt b/core/common/test/AbstractSourceTest.kt index da9dbe53c..9e980e518 100644 --- a/core/common/test/AbstractSourceTest.kt +++ b/core/common/test/AbstractSourceTest.kt @@ -1769,4 +1769,17 @@ abstract class AbstractBufferedSourceTest internal constructor( assertEquals((Segment.SIZE * 2 + 1).toLong(), source.indexOf("fg".encodeToByteString())) assertEquals((Segment.SIZE * 2 + 2).toLong(), source.indexOf("g".encodeToByteString())) } + + @Test + fun indexOfByteStringSpanningAcrossMultipleSegments() { + sink.writeString("a".repeat(SEGMENT_SIZE)) + sink.emit() + sink.writeString("bbbb") + sink.emit() + sink.write(Buffer().also { it.writeString("c".repeat(SEGMENT_SIZE)) }, SEGMENT_SIZE.toLong()) + sink.emit() + + source.skip(SEGMENT_SIZE - 10L) + assertEquals(9, source.indexOf("abbbbc".encodeToByteString())) + } }