Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
tree: c340888358
Fetching contributors…

Cannot retrieve contributors at this time

117 lines (90 sloc) 2.717 kB
package stream
object SpecializedIterators2 {
// Specialized iterator
trait SIterator[@specialized T] {
def hasNext : Boolean
def next() : T
def filter(pred : T => Boolean, dummy : T) = new FilterIterator[T](this, pred)
def map[@specialized U](fn : T => U, dummy : T, dummy2 : U) = new MapIterator[T, U](this, fn)
}
final class ArrayIterator[@specialized T](a : Array[T], var index : Int, endIndex : Int) extends SIterator[T] {
def next() = {val r = a(index); index += 1; r}
def hasNext = index < endIndex
}
final class FilterIterator[@specialized T](iter : SIterator[T], pred : T => Boolean) extends SIterator[T] {
private var hasElem = false
private var elem : T = findNext()
def hasNext = hasElem
def next() = {
val r = elem
findNext()
r
}
def findNext() : T = {
while (iter.hasNext) {
elem = iter.next()
if (pred(elem)) {
hasElem = true
return elem
}
}
hasElem = false
elem
}
}
final class MapIterator[@specialized T, @specialized U](iter : SIterator[T], fn : T => U) extends SIterator[U] {
def next() = fn(iter.next())
def hasNext = iter.hasNext
}
def fold[@specialized T, @specialized U](iter : SIterator[T], fn : (U, T) => U, v : U, dummy : T) = {
var r = v
while (iter.hasNext) {
r = fn(r, iter.next())
}
r
}
def mapFilterSum(a : Array[Int]) = {
val ai = new ArrayIterator(a, 0, a.length)
val s = new FilterIterator[Int](new MapIterator[Int, Int](ai, _ * 3 + 7), _ % 10 == 0)
fold[Int, Int](s, _ + _, 0, 0)
}
def mapFilterSum2(a : Array[Int]) = {
val ai = new ArrayIterator(a, 0, a.length)
val s = ai.map(_ * 3 + 7, 0, 0).filter(_ % 10 == 0, 0) // Doesn't specialize properly
fold[Int, Int](s, _ + _, 0, 0)
}
def mapFilterSumLoop(a : Array[Int]) = {
var i = 0
var r = 0
while (i < a.length) {
val v = a(i) * 3 + 7
if ((v % 10) == 0)
r += v
i += 1
}
r
}
def createArray(count : Int) = {
var a = new Array[Int](count)
for (i <- 1 until a.length)
a(i) = a(i - 1) * 13 + 1947
a
}
// Benchmark function
def bench(proc : Array[Int] => Int) = {
var m = java.lang.Long.MAX_VALUE
var a = createArray(1000000)
var r = 0
for (i <- 0 to 200) {
var t = System.nanoTime
r = proc(a)
m = Math.min(m, System.nanoTime - t)
}
(m / 1000, r)
}
def main(args : Array[String]) {
println("Java version: " + System.getProperty("java.version"))
println("Loop: " + bench(mapFilterSumLoop))
println("Specialized iterators: " + bench(mapFilterSum))
}
}
Jump to Line
Something went wrong with that request. Please try again.