Skip to content

Commit

Permalink
+GroupBy
Browse files Browse the repository at this point in the history
  • Loading branch information
akarnokd committed Jul 27, 2019
1 parent 156105e commit 27c7415
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 4 deletions.
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -28,6 +28,7 @@ Table of contents
- `timer`
- Intermediate Flow operators (`FlowExtensions`)
- `Flow.concatWith`
- `Flow.groupBy`
- `Flow.parallel`
- `Flow.publish`
- `Flow.replay`
Expand Down
31 changes: 31 additions & 0 deletions src/main/kotlin/hu/akarnokd/kotlin/flow/FlowExtensions.kt
Expand Up @@ -127,6 +127,37 @@ fun timer(timeout: Long, unit: TimeUnit) : Flow<Long> =
delay(unit.toMillis(timeout))
emit(0L)
}

/**
* Groups the upstream values into their own Flows keyed by the value returned
* by the [keySelector] function.
*/
@FlowPreview
fun <T, K> Flow<T>.groupBy(keySelector: suspend (T) -> K) : Flow<GroupedFlow<K, T>> =
FlowGroupBy(this, keySelector, { it })

/**
* Groups the mapped upstream values into their own Flows keyed by the value returned
* by the [keySelector] function.
*/
@FlowPreview
fun <T, K, V> Flow<T>.groupBy(keySelector: suspend (T) -> K, valueSelector: suspend (T) -> V) : Flow<GroupedFlow<K, V>> =
FlowGroupBy(this, keySelector, valueSelector)

/**
* Collects all items of the upstream into a list.
*/
fun <T> Flow<T>.toList() : Flow<List<T>> {
val self = this
return flow {
val list = ArrayList<T>()
self.collect {
list.add(it)
}
emit(list)
}
}

// -----------------------------------------------------------------------------------------
// Parallel Extensions
// -----------------------------------------------------------------------------------------
Expand Down
29 changes: 29 additions & 0 deletions src/main/kotlin/hu/akarnokd/kotlin/flow/GroupedFlow.kt
@@ -0,0 +1,29 @@
/*
* Copyright 2019 David Karnok
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package hu.akarnokd.kotlin.flow

import kotlinx.coroutines.flow.Flow

/**
* Represents a Flow with a group key.
*/
interface GroupedFlow<K, V> : Flow<V> {
/**
* The key of the flow.
*/
val key : K
}
172 changes: 172 additions & 0 deletions src/main/kotlin/hu/akarnokd/kotlin/flow/impl/FlowGroupBy.kt
@@ -0,0 +1,172 @@
/*
* Copyright 2019 David Karnok
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package hu.akarnokd.kotlin.flow.impl

import hu.akarnokd.kotlin.flow.GroupedFlow
import hu.akarnokd.kotlin.flow.Resumable
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.FlowPreview
import kotlinx.coroutines.flow.AbstractFlow
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.flow.FlowCollector
import java.lang.IllegalStateException
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentMap
import java.util.concurrent.atomic.AtomicBoolean

/**
* Groups transformed values of the source flow based on a key selector
* function.
*/
@FlowPreview
internal class FlowGroupBy<T, K, V>(
private val source: Flow<T>,
private val keySelector: suspend (T) -> K,
private val valueSelector: suspend (T) -> V
) : AbstractFlow<GroupedFlow<K, V>>() {

override suspend fun collectSafely(collector: FlowCollector<GroupedFlow<K, V>>) {
val map = ConcurrentHashMap<K, FlowGroup<K, V>>()

val mainStopped = AtomicBoolean()

try {
source.collect {
val k = keySelector(it)

var group = map[k]

if (group != null) {
group.next(valueSelector(it))
} else {
if (!mainStopped.get()) {
group = FlowGroup(k, map)
map.put(k, group)

try {
collector.emit(group)
} catch (ex: CancellationException) {
mainStopped.set(true)
if (map.size == 0) {
throw CancellationException()
}
}

group.next(valueSelector(it))
} else {
if (map.size == 0) {
throw CancellationException()
}
}
}
}
for (group in map.values) {
group.complete()
}
} catch (ex: Throwable) {
for (group in map.values) {
group.error(ex)
}
}
}

class FlowGroup<K, V>(
override val key: K,
private val map : ConcurrentMap<K, FlowGroup<K, V>>
) : AbstractFlow<V>(), GroupedFlow<K, V> {

@Suppress("UNCHECKED_CAST")
private var value: V = null as V
@Volatile
private var hasValue: Boolean = false

private var error: Throwable? = null
@Volatile
private var done: Boolean = false

@Volatile
private var cancelled: Boolean = false

private val consumerReady = Resumable()

private val valueReady = Resumable()

private val once = AtomicBoolean()

override suspend fun collectSafely(collector: FlowCollector<V>) {
if (!once.compareAndSet(false, true)) {
throw IllegalStateException("A GroupedFlow can only be collected at most once.")
}

consumerReady.resume()

while (true) {
val d = done
val has = hasValue

if (d && !has) {
val ex = error
if (ex != null) {
throw ex
}
break
}

if (has) {
val v = value
@Suppress("UNCHECKED_CAST")
value = null as V
hasValue = false

try {
collector.emit(v)
} catch (ex: Throwable) {
map.remove(this.key)
cancelled = true
consumerReady.resume()
throw ex
}

consumerReady.resume()
continue
}

valueReady.await()
}
}

suspend fun next(value: V) {
if (!cancelled) {
consumerReady.await()
this.value = value
this.hasValue = true
valueReady.resume()
}
}

fun error(ex: Throwable) {
error = ex
done = true
valueReady.resume()
}

fun complete() {
done = true
valueReady.resume()
}
}
}
8 changes: 4 additions & 4 deletions src/test/kotlin/hu/akarnokd/kotlin/flow/TestSupport.kt
Expand Up @@ -57,15 +57,15 @@ suspend fun <T> Flow<T>.assertResult(vararg values: T) {
}

suspend fun <T> Flow<T>.assertResultSet(vararg values: T) {
val list = HashSet<T>()
val set = HashSet<T>()

this.collect {
list.add(it)
set.add(it)
}

assertEquals(values.size, list.size)
assertEquals("Number of values differ", values.size, set.size)

values.forEach { assertTrue("" + it, list.contains(it)) }
values.forEach { assertTrue("Missing: " + it, set.contains(it)) }
}

suspend fun <T, E : Throwable> Flow<T>.assertFailure(errorClazz: Class<E>, vararg values: T) {
Expand Down
85 changes: 85 additions & 0 deletions src/test/kotlin/hu/akarnokd/kotlin/flow/impl/FlowGroupByTest.kt
@@ -0,0 +1,85 @@
/*
* Copyright 2019 David Karnok
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package hu.akarnokd.kotlin.flow.impl

import hu.akarnokd.kotlin.flow.*
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.FlowPreview
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.runBlocking
import org.junit.Ignore
import org.junit.Test

@ExperimentalCoroutinesApi
@FlowPreview
class FlowGroupByTest {

@Test
fun basic() = runBlocking {
range(1, 10)
.groupBy { it % 2 }
.flatMapMerge { it.toList() }
.assertResultSet(listOf(1, 3, 5, 7, 9), listOf(2, 4, 6, 8, 10))
}

@Test
fun basicValueSelector() = runBlocking {
range(1, 10)
.groupBy({ it % 2 }) { it + 1}
.flatMapMerge { it.toList() }
.assertResultSet(listOf(2, 4, 6, 8, 10), listOf(3, 5, 7, 9, 11))
}

@Test
fun oneOfEach() = runBlocking {
range(1, 10)
.groupBy { it % 2 }
.flatMapMerge { it.take(1) }
.assertResultSet(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
}

@Test
fun maxGroups() = runBlocking {
range(1, 10)
.groupBy { it % 3 }
.take(2)
.flatMapMerge { it.toList() }
.assertResultSet(listOf(1, 4, 7, 10), listOf(2, 5, 8))
}

@Test
@Ignore("Hangs for some reason")
fun takeItems() = runBlocking {
range(1, 10)
.groupBy { it % 2 }
.flatMapMerge { it }
.take(2)
.assertResultSet(1, 2)
}

@Test
fun takeGroupsAndItems() = runBlocking {
range(1, 10)
.groupBy { it % 3 }
.take(2)
.flatMapMerge { it }
.take(2)
.assertResultSet(1, 2)
}

}

0 comments on commit 27c7415

Please sign in to comment.