Skip to content

Commit

Permalink
[SampleCountMeasurement improvement #52][introduced Measured interfac…
Browse files Browse the repository at this point in the history
…e for measured objects]
  • Loading branch information
asubb committed May 6, 2020
1 parent fc62c67 commit 68b498f
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 45 deletions.
1 change: 1 addition & 0 deletions .release/sample-count-measurement-improvements.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* [ [#52](https://github.com/WaveBeans/wavebeans/issues/52) ] Custom class that requires measurement needs to implement `Measured` interface. See [updated section of documentation](/docs/user/api/operations/projection-operation.md#working-with-different-types)
27 changes: 22 additions & 5 deletions docs/user/api/operations/projection-operation.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,36 @@ Working with different types

Projection operation is defined for `Sample` and `Window<Sample>` types out of the box, but it's not limited to them. Only thing you need to keep in mind, that projection calculates time when the stream is being executed and the sample rate is provided, so it needs a way to convert the size of your type to samples to correctly calculate time markers, i.e. for `Sample` the size is always 1, for windowed samples the size is th size of the window step.

To define your own type you need to register it before it's being executed:
To use your own type you need to define how to measure it

One way is to implement the `io.wavebeans.lib.stream.Measured` interface for you class:

```kotlin
data class DoubleSample(val one: Sample, val two: Sample) : Measured {
override fun measure(): Int = 2
}
```

data class DoubleSample(val one: Sample, val two: Sample)
Another way is to register it before it's being executed, preferrably to be used for the classes you can't extend like SDK classes:

```kotlin
data class DoubleSample(val one: Sample, val two: Sample)
SampleCountMeasurement.registerType(DoubleSample::class) { 2 }
```

And now you can use it:

```kotlin
440.sine().window(2)
.map { DoubleSample(it.elements.first(), it.elements.drop(1).first()) }
.map { DoubleSample(it.elements[0], it.elements[1]) }
.rangeProjection(100, 200)
```

If you won't register the type, during execution you'll have an exception like `class my.wavebeans.DoubleSample is not registered within SampleCountMeasurement, use registerType() function`
If you won't register the type, during execution you'll have an exception like `class my.wavebeans.DoubleSample is not registered within SampleCountMeasurement, use registerType() function or extend your class with Measured interface`

If you use windowed type, you just need to define the calculation of your type, the Window will be calculated automatically, by applying the formula `sizeOfTheSample * window.step`.
The following types are built-in:
* `Number` -- always return 1
* `Sample` -- always return 1
* `FftSample` -- measured as the `window.step` it is built on top of.
* `List<T>` -- measured as a sum of length of all corresponding elements of type `T`. Doesn't support nullable elements, will throw an exception.
* `Window<T>` -- measured as `sizeOfTheSample * window.step`, where `sizeOfTheSample` is measure of the first element.
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ internal fun FftSample.encode(buf: ByteArray, at: Int): Int {
writeLong(this.index, buf, pointer); pointer += 8
writeInt(this.binCount, buf, pointer); pointer += 4
writeInt(this.samplesCount, buf, pointer); pointer += 4
writeInt(this.samplesLength, buf, pointer); pointer += 4
writeInt(this.sampleRate.toBits(), buf, pointer); pointer += 4
writeInt(this.fft.size, buf, pointer); pointer += 4
this.fft.forEach {
Expand All @@ -221,6 +222,7 @@ internal fun decodeFftSample(buf: ByteArray, from: Int): Pair<FftSample, Int> {
val index = readLong(buf, pointer); pointer += 8
val binCount = readInt(buf, pointer); pointer += 4
val samplesCount = readInt(buf, pointer); pointer += 4
val samplesLength = readInt(buf, pointer); pointer += 4
val sampleRate = Float.fromBits(readInt(buf, pointer)); pointer += 4
val fftSize = readInt(buf, pointer); pointer += 4
val fft = ArrayList<ComplexNumber>(fftSize)
Expand All @@ -231,7 +233,7 @@ internal fun decodeFftSample(buf: ByteArray, from: Int): Pair<FftSample, Int> {
}

return Pair(
FftSample(index, binCount, samplesCount, sampleRate, fft),
FftSample(index, binCount, samplesCount, samplesLength, sampleRate, fft),
pointer
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ object PodCallResultSpec : Spek({

describe("Wrapping List<FftSampleArray>") {
val obj = listOf(
createFftSampleArray(2) { i -> FftSample(i.toLong(), i, i, i * 123.0f, listOf(i.r, i.i)) },
createFftSampleArray(2) { i -> FftSample(i.toLong(), i, i, i * 123.0f, listOf(i.r, i.i)) }
createFftSampleArray(2) { i -> FftSample(i.toLong(), i, i, i, i * 123.0f, listOf(i.r, i.i)) },
createFftSampleArray(2) { i -> FftSample(i.toLong(), i, i, i, i * 123.0f, listOf(i.r, i.i)) }
)
val result = result(obj)

Expand Down
12 changes: 12 additions & 0 deletions lib/src/main/kotlin/io/wavebeans/lib/stream/Measured.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package io.wavebeans.lib.stream

/**
* Classes implementing that interface are marked as measurable by time-to-sample functions like projection() and trim().
* Should implement the method that measures internal state in samples.
*/
interface Measured {
/**
* Returns number of samples in the object, to be used in conjunction with [SampleCountMeasurement]
*/
fun measure(): Int
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,8 @@
package io.wavebeans.lib.stream

import com.sun.media.sound.FFT
import io.wavebeans.lib.*
import io.wavebeans.lib.stream.fft.FftSample
import io.wavebeans.lib.stream.window.Window
import io.wavebeans.lib.stream.window.WindowStream
import io.wavebeans.lib.stream.window.WindowStreamParams
import kotlinx.serialization.Serializable
import java.util.concurrent.TimeUnit
import kotlin.reflect.KClass
import kotlin.reflect.KType
import kotlin.reflect.full.isSubclassOf
import kotlin.reflect.full.isSupertypeOf

@Serializable
data class ProjectionBeanStreamParams(
Expand Down Expand Up @@ -74,28 +65,3 @@ fun <T : Any> BeanStream<T>.rangeProjection(start: Long, end: Long? = null, time
return ProjectionBeanStream(this, ProjectionBeanStreamParams(start, end, timeUnit))
}

@Suppress("UNCHECKED_CAST")
object SampleCountMeasurement {

private val types = mutableMapOf<KClass<*>, (Any) -> Int>()

init {
registerType(Number::class) { 1 }
registerType(Sample::class) { 1 }
registerType(Window::class) { window -> window.step * samplesInObject(window.elements.first()) }
registerType(FftSample::class) { it.samplesCount }
registerType(List::class) { it.size * samplesInObject(it.first() ?: throw IllegalStateException("List of nullable types is not supported")) }
}

fun <T : Any> registerType(clazz: KClass<T>, measurer: (T) -> Int) {
check(types.put(clazz, measurer as (Any) -> Int) == null) { "$clazz is already registered" }
}

fun samplesInObject(obj: Any): Int {
return types.filterKeys { obj::class.isSubclassOf(it) }
.map { it.value }
.firstOrNull()
?.invoke(obj)
?: throw IllegalStateException("${obj::class} is not registered within ${SampleCountMeasurement::class.simpleName}, use registerType() function")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package io.wavebeans.lib.stream

import io.wavebeans.lib.Sample
import kotlin.reflect.KClass
import kotlin.reflect.full.isSubclassOf

@Suppress("UNCHECKED_CAST")
object SampleCountMeasurement {

private val types = mutableMapOf<KClass<*>, (Any) -> Int>()

init {
registerType(Number::class) { 1 }
registerType(Sample::class) { 1 }
registerType(List::class) { l ->
l.map { samplesInObject(it ?: throw IllegalStateException("List of nullable types is not supported")) }
.sum()
}
}

fun <T : Any> registerType(clazz: KClass<T>, measurer: (T) -> Int) {
check(types.put(clazz, measurer as (Any) -> Int) == null) { "$clazz is already registered" }
}

fun samplesInObject(obj: Any): Int {
return if (obj is Measured)
obj.measure()
else types.filterKeys { obj::class.isSubclassOf(it) }
.map { it.value }
.firstOrNull()
?.invoke(obj)
?: throw IllegalStateException("${obj::class} is not registered within ${SampleCountMeasurement::class.simpleName}, use registerType() function " +
"or extend your class with ${Measured::class.simpleName} interface")
}
}
10 changes: 9 additions & 1 deletion lib/src/main/kotlin/io/wavebeans/lib/stream/fft/FftStream.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package io.wavebeans.lib.stream.fft
import io.wavebeans.lib.*
import io.wavebeans.lib.math.ComplexNumber
import io.wavebeans.lib.math.r
import io.wavebeans.lib.stream.Measured
import io.wavebeans.lib.stream.window.Window
import kotlinx.serialization.Serializable
import kotlin.math.PI
Expand All @@ -25,6 +26,10 @@ data class FftSample(
* Number of samples the FFT is calculated based on.
*/
val samplesCount: Int,
/**
* The actual length of the sample it is built on. Calculated based on the underlying window.
*/
val samplesLength: Int,
/**
* Sample rate which was used to calculate the FFT
*/
Expand All @@ -33,7 +38,9 @@ data class FftSample(
* The list of [ComplexNumber]s which is calculated FFT. Use [magnitude] and [phase] methods to extract magnitude and phase respectively.
*/
val fft: List<ComplexNumber>
) {
) : Measured {

override fun measure(): Int = samplesLength

/**
* Gets the magnitude values for this FFT calculation. It is returned in logarithmic scale, using only first half of the FFT.
Expand Down Expand Up @@ -107,6 +114,7 @@ class FftStream(
index = idx++,
binCount = parameters.n,
samplesCount = m,
samplesLength = window.step,
fft = fft.toList(),
sampleRate = sampleRate
)
Expand Down
8 changes: 6 additions & 2 deletions lib/src/main/kotlin/io/wavebeans/lib/stream/window/Window.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package io.wavebeans.lib.stream.window

import io.wavebeans.lib.Sample
import io.wavebeans.lib.ZeroSample
import io.wavebeans.lib.stream.Measured
import io.wavebeans.lib.stream.SampleCountMeasurement

data class Window<T: Any>(
data class Window<T : Any>(
/**
* The size of the window it was created with.
*/
Expand All @@ -25,7 +27,9 @@ data class Window<T: Any>(
* by this function.
*/
val zeroEl: () -> T
) {
) : Measured {

override fun measure(): Int = step * SampleCountMeasurement.samplesInObject(elements.first())

init {
require(size >= 1) { "Size should be more than 0" }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package io.wavebeans.lib.stream

import assertk.assertThat
import assertk.assertions.isEqualTo
import io.wavebeans.lib.Sample
import io.wavebeans.lib.math.i
import io.wavebeans.lib.sampleOf
import io.wavebeans.lib.stream.SampleCountMeasurement.samplesInObject
import io.wavebeans.lib.stream.fft.FftSample
import io.wavebeans.lib.stream.window.Window
import org.spekframework.spek2.Spek
import org.spekframework.spek2.style.specification.describe

object SampleCountMeasurementSpec : Spek({
describe("Measuring builtin types") {
it("should measure samples") {
val obj = sampleOf(1)
assertThat(samplesInObject(obj)).isEqualTo(1)
}
it("should measure list of samples") {
val obj = listOf(sampleOf(1), sampleOf(1))
assertThat(samplesInObject(obj)).isEqualTo(2)
}
it("should measure window of samples") {
val obj = Window.ofSamples(3, 2, listOf(sampleOf(1), sampleOf(1), sampleOf(1), sampleOf(1)))
assertThat(samplesInObject(obj)).isEqualTo(2)
}
it("should measure numbers: ints") {
val obj = 1
assertThat(samplesInObject(obj)).isEqualTo(1)
}
it("should measure list of numbers: ints") {
val obj = listOf(1, 2, 3, 4)
assertThat(samplesInObject(obj)).isEqualTo(4)
}
it("should measure window of ints") {
val obj = Window(3, 2, listOf(1, 2, 3, 4)) { 0 }
assertThat(samplesInObject(obj)).isEqualTo(2)
}
it("should measure numbers: doubles") {
val obj = 1.0
assertThat(samplesInObject(obj)).isEqualTo(1)
}
it("should measure list of numbers: doubles") {
val obj = listOf(1.0, 2.0)
assertThat(samplesInObject(obj)).isEqualTo(2)
}
it("should measure window of doubles") {
val obj = Window(3, 2, listOf(1.0, 2.0, 3.0, 4.0)) { 0.0 }
assertThat(samplesInObject(obj)).isEqualTo(2)
}
it("should measure numbers: floats") {
val obj = 1.0f
assertThat(samplesInObject(obj)).isEqualTo(1)
}
it("should measure list of numbers: floats") {
val obj = listOf(1.0f, 2.0f, 3.0f)
assertThat(samplesInObject(obj)).isEqualTo(3)
}
it("should measure window of floats") {
val obj = Window(3, 2, listOf(1.0f, 2.0f, 3.0f, 4.0f)) { 0.0f }
assertThat(samplesInObject(obj)).isEqualTo(2)
}
it("should measure FFT samples") {
val obj = FftSample(0, 4, 4, 2, 1.0f, listOf(0.i, 0.i, 0.i, 0.i))
assertThat(samplesInObject(obj)).isEqualTo(2)
}
it("should measure list of FFT samples") {
val obj = listOf(
FftSample(0, 4, 4, 2, 1.0f, listOf(0.i, 0.i, 0.i, 0.i)),
FftSample(0, 4, 4, 4, 1.0f, listOf(0.i, 0.i, 0.i, 0.i))
)
assertThat(samplesInObject(obj)).isEqualTo(6)
}
}

describe("Measuring custom types via interface") {

data class MySample(val v: Sample) : Measured {
override fun measure(): Int = samplesInObject(this.v)
}

it("should measure samples") {
val obj = MySample(sampleOf(1))
assertThat(samplesInObject(obj)).isEqualTo(1)
}
it("should measure list of samples") {
val obj = listOf(MySample(sampleOf(1)), MySample(sampleOf(1)))
assertThat(samplesInObject(obj)).isEqualTo(2)
}
}

describe("Measuring custom types via registerType") {

data class MySample(val v: Sample)

SampleCountMeasurement.registerType(MySample::class) { samplesInObject(it.v) }

it("should measure samples") {
val obj = MySample(sampleOf(1))
assertThat(samplesInObject(obj)).isEqualTo(1)
}
it("should measure list of samples") {
val obj = listOf(MySample(sampleOf(1)), MySample(sampleOf(1)))
assertThat(samplesInObject(obj)).isEqualTo(2)
}
}
})

0 comments on commit 68b498f

Please sign in to comment.