Permalink
Fetching contributors…
Cannot retrieve contributors at this time
301 lines (273 sloc) 8.08 KB
(* From the SML/NJ benchmark suite. *)
fun print _ = ()
signature BMARK =
sig
val doit : int -> unit
end;
structure Main: BMARK = struct
local
open Array Math
val printr = print o (Real.fmt (StringCvt.SCI(SOME 14)))
val printi = print o Int.toString
in
val PI = 3.14159265358979323846
val tpi = 2.0 * PI
(*
fun trace(name,f) x =
(print name ;
print "(" ;
printr x ;
print ") = " ;
let val y = f x
in printr y ;
print "\n" ;
y
end)
fun trace2(name,f) (x,y) =
(printr x ;
print " ";
print name ;
print " ";
printr y ;
print " = " ;
let val z = f(x,y)
in printr z ;
print "\n" ;
z
end)
fun trace2(_,f) = f
val op * = trace2("*", Real.* )
val op - = trace2("-", Real.-)
val op + = trace2("+", Real.+)
local
nonfix * + -
in
_overload * : ('a * 'a -> 'a)
as Int.*
and *
_overload + : ('a * 'a -> 'a)
as Int.+
and +
_overload - : ('a * 'a -> 'a)
as Int.-
and -
end
val sin = trace("sin", sin)
val cos = trace("cos", cos)
val sub =
fn (a,i) =>
let val x = sub(a,i)
in print "sub(_, " ;
printi i ;
print ") = ";
printr x ;
print "\n" ;
x
end
val update =
fn (a,i,x) =>
(update(a,i,x);
print "update(_, " ;
printi i ;
print ", " ;
printr x ;
print ")\n")
*)
fun fft px py np =
let
fun find_num_points i m =
if i < np
then find_num_points (i+i) (m+1)
else (i,m)
val (n,m) = find_num_points 2 1
(* val _ = (printi n ;
print "\n" ;
printi m ;
print "\n") *)
in
if n <> np then
let
fun loop i =
if i > n then ()
else (update(px, i, 0.0);
update(py, i, 0.0);
loop (i+1))
in
loop (np+1);
print "Use "; printi n; print " point fft\n"
end
else ();
let
fun loop_k k n2 =
if k >= m then ()
else
let
val n4 = n2 div 4
val e = tpi / (real n2)
fun loop_j j a =
if j > n4 then ()
else
let val a3 = 3.0 * a
val cc1 = cos(a)
val ss1 = sin(a)
val cc3 = cos(a3)
val ss3 = sin(a3)
fun loop_is is id =
if is >= n
then ()
else
let
fun loop_i0 i0 =
if i0 >= n
then ()
else
let val i1 = i0 + n4
val i2 = i1 + n4
val i3 = i2 + n4
val r1 = sub(px, i0) - sub(px, i2)
val _ = update(px, i0, sub(px, i0) + sub(px, i2))
val r2 = sub(px, i1) - sub(px, i3)
val _ = update(px, i1, sub(px, i1) + sub(px, i3))
val s1 = sub(py, i0) - sub(py, i2)
val _ = update(py, i0, sub(py, i0) + sub(py, i2))
val s2 = sub(py, i1) - sub(py, i3)
val _ = update(py, i1, sub(py, i1) + sub(py, i3))
val s3 = r1 - s2
val r1 = r1 + s2
val s2 = r2 - s1
val r2 = r2 + s1
val _ = update(px, i2, r1*cc1 - s2*ss1)
val _ = update(py, i2, ~s2*cc1 - r1*ss1)
val _ = update(px, i3, s3*cc3 + r2*ss3)
val _ = update(py, i3, r2*cc3 - s3*ss3)
in
loop_i0 (i0 + id)
end
in
loop_i0 is;
loop_is (2 * id - n2 + j) (4 * id)
end
in
loop_is j (2 * n2);
loop_j (j+1) (e * real j)
end
in
loop_j 1 0.0;
loop_k (k+1) (n2 div 2)
end
in
loop_k 1 n
end;
(************************************)
(* Last stage, length=2 butterfly *)
(************************************)
let fun loop_is is id = if is >= n then () else
let fun loop_i0 i0 = if i0 > n then () else
let val i1 = i0 + 1
val r1 = sub(px, i0)
val _ = update(px, i0, r1 + sub(px, i1))
val _ = update(px, i1, r1 - sub(px, i1))
val r1 = sub(py, i0)
val _ = update(py, i0, r1 + sub(py, i1))
val _ = update(py, i1, r1 - sub(py, i1))
in
loop_i0 (i0 + id)
end
in
loop_i0 is;
loop_is (2*id - 1) (4 * id)
end
in
loop_is 1 4
end;
(*************************)
(* Bit reverse counter *)
(*************************)
let
fun loop_i i j =
if i >= n
then ()
else
(if i < j
then (let val xt = sub(px, j)
in update(px, j, sub(px, i)); update(px, i, xt)
end;
let val xt = sub(py, j)
in update(py, j, sub(py, i)); update(py, i, xt)
end)
else ();
let
fun loop_k k j =
if k < j then loop_k (k div 2) (j-k) else j+k
val j' = loop_k (n div 2) j
in
loop_i (i+1) j'
end)
in
loop_i 1 1
end;
n
end
fun abs x = if x >= 0.0 then x else ~x
fun test np =
let val _ = (printi np; print "... ")
val enp = real np
val npm = (np div 2) - 1
val pxr = array (np+2, 0.0)
val pxi = array (np+2, 0.0)
val t = PI / enp
val _ = update(pxr, 1, (enp - 1.0) * 0.5)
val _ = update(pxi, 1, 0.0)
val n2 = np div 2
val _ = update(pxr, n2+1, ~0.5)
val _ = update(pxi, n2+1, 0.0)
fun loop_i i = if i > npm then () else
let val j = np - i
val _ = update(pxr, i+1, ~0.5)
val _ = update(pxr, j+1, ~0.5)
val z = t * real i
val y = ~0.5*(cos(z)/sin(z))
val _ = update(pxi, i+1, y)
val _ = update(pxi, j+1, ~y)
in
loop_i (i+1)
end
val _ = loop_i 1
(* val _ = print "\n"
fun loop_i i = if i > 15 then () else
(printi i; print "\t";
printr (sub(pxr, i+1)); print "\t";
printr (sub(pxi, i+1)); print "\n"; loop_i (i+1))
val _ = loop_i 0
*)
val _ = fft pxr pxi np
(*
fun loop_i i = if i > 15 then () else
(printi i; print "\t";
printr (sub(pxr, i+1)); print "\t";
printr (sub(pxi, i+1)); print "\n"; loop_i (i+1))
val _ = loop_i 0
*)
fun loop_i i zr zi kr ki = if i >= np then (zr,zi) else
let val a = abs(sub(pxr, i+1) - real i)
val (zr, kr) =
if zr < a then (a, i) else (zr, kr)
val a = abs(sub(pxi, i+1))
val (zi, ki) =
if zi < a then (a, i) else (zi, ki)
in
loop_i (i+1) zr zi kr ki
end
val (zr, zi) = loop_i 0 0.0 0.0 0 0
val zm = if abs zr < abs zi then zi else zr
in
printr zm; print "\n"
end
fun loop_np i np = if i > 15 then () else
(test np; loop_np (i+1) (np*2))
fun doit n =
if n = 0
then ()
else (loop_np 1 256; doit (n - 1))
end
end;