Skip to content
Browse files

Math: Converted multimethods to protocols

  • Loading branch information...
1 parent c46df9d commit d4c298fd821258efe872861fe2dd5188267918d0 @Engelberg Engelberg committed May 21, 2011
Showing with 92 additions and 77 deletions.
  1. +92 −77 modules/math/src/main/clojure/clojure/contrib/math.clj
View
169 modules/math/src/main/clojure/clojure/contrib/math.clj
@@ -89,21 +89,6 @@ exact-integer-sqrt - Implements a math function from the R6RS Scheme
"}
clojure.contrib.math)
-(derive ::integer ::exact)
-(derive java.lang.Integer ::integer)
-(derive java.math.BigInteger ::integer)
-(derive clojure.lang.BigInt ::integer)
-(derive java.lang.Long ::integer)
-(derive java.math.BigDecimal ::exact)
-(derive clojure.lang.Ratio ::exact)
-(derive java.lang.Double ::inexact)
-(derive java.lang.Float ::inexact)
-
-(defmulti ^{:arglists '([base pow])
- :doc "(expt base pow) is base to the pow power.
-Returns an exact number if the base is an exact number and the power is an integer, otherwise returns a double."}
- expt (fn [x y] [(class x) (class y)]))
-
(defn- expt-int [base pow]
(loop [n pow, y (num 1), z base]
(let [t (even? n), n (quot n 2)]
@@ -112,54 +97,101 @@ Returns an exact number if the base is an exact number and the power is an integ
(zero? n) (*' z y)
:else (recur n (*' z y) (*' z z))))))
-(defmethod expt [::exact ::integer] [base pow]
- (cond
- (pos? pow) (expt-int base pow)
- (zero? pow) 1
- :else (/ 1 (expt-int base (-' pow)))))
-
-(defmethod expt :default [base pow] (Math/pow base pow))
-
+(defn expt
+ "(expt base pow) is base to the pow power.
+Returns an exact number if the base is an exact number and the power is an integer, otherwise returns a double."
+ [base pow]
+ (if (and (not (float? base)) (integer? pow))
+ (cond
+ (pos? pow) (expt-int base pow)
+ (zero? pow) 1
+ :else (/ 1 (expt-int base (-' pow))))
+ (Math/pow base pow)))
+
(defn abs "(abs n) is the absolute value of n" [n]
(cond
(not (number? n)) (throw (IllegalArgumentException.
"abs requires a number"))
(neg? n) (-' n)
:else n))
-(defmulti ^{:arglists '([n])
- :doc "(floor n) returns the greatest integer less than or equal to n.
-If n is an exact number, floor returns an integer, otherwise a double."}
- floor class)
-(defmethod floor ::integer [n] n)
-(defmethod floor java.math.BigDecimal [n] (bigint (.setScale n 0 BigDecimal/ROUND_FLOOR)))
-(defmethod floor clojure.lang.Ratio [n]
- (if (pos? n) (quot (. n numerator) (. n denominator))
- (dec' (quot (. n numerator) (. n denominator)))))
-(defmethod floor :default [n]
- (Math/floor n))
-
-(defmulti ^{:arglists '([n])
- :doc "(ceil n) returns the least integer greater than or equal to n.
-If n is an exact number, ceil returns an integer, otherwise a double."}
- ceil class)
-(defmethod ceil ::integer [n] n)
-(defmethod ceil java.math.BigDecimal [n] (bigint (.setScale n 0 BigDecimal/ROUND_CEILING)))
-(defmethod ceil clojure.lang.Ratio [n]
- (if (pos? n) (inc' (quot (. n numerator) (. n denominator)))
- (quot (. n numerator) (. n denominator))))
-(defmethod ceil :default [n]
- (Math/ceil n))
-
-(defmulti ^{:arglists '([n])
- :doc "(round n) rounds to the nearest integer.
-round always returns an integer. Rounds up for values exactly in between two integers."}
- round class)
-(defmethod round ::integer [n] n)
-(defmethod round java.math.BigDecimal [n] (floor (+ n 0.5M)))
-(defmethod round clojure.lang.Ratio [n] (floor (+ n 1/2)))
-;; Convert to bigdec in case float/double exceeds range of long
-(defmethod round :default [n] (round (bigdec n)))
+(defprotocol MathFunctions
+ (floor [n] "(floor n) returns the greatest integer less than or equal to n.
+If n is an exact number, floor returns an integer, otherwise a double.")
+ (ceil [n] "(ceil n) returns the least integer greater than or equal to n.
+If n is an exact number, ceil returns an integer, otherwise a double.")
+ (round [n] "(round n) rounds to the nearest integer.
+round always returns an integer. Rounds up for values exactly in between two integers.")
+ (integer-length [n] "Length of integer in binary")
+ (sqrt [n] "Square root, but returns exact number if possible."))
+
+(declare sqrt-integer)
+(declare sqrt-ratio)
+(declare sqrt-decimal)
+
+(extend-type
+ Integer MathFunctions
+ (floor [n] n)
+ (ceil [n] n)
+ (round [n] n)
+ (integer-length [n] (- 32 (Integer/numberOfLeadingZeros n)))
+ (sqrt [n] (sqrt-integer n)))
+
+(extend-type
+ Long MathFunctions
+ (floor [n] n)
+ (ceil [n] n)
+ (round [n] n)
+ (integer-length [n] (- 64 (Long/numberOfLeadingZeros n)))
+ (sqrt [n] (sqrt-integer n)))
+
+(extend-type
+ java.math.BigInteger MathFunctions
+ (floor [n] n)
+ (ceil [n] n)
+ (round [n] n)
+ (integer-length [n] (.bitLength n))
+ (sqrt [n] (sqrt-integer n)))
+
+(extend-type
+ clojure.lang.BigInt MathFunctions
+ (floor [n] n)
+ (ceil [n] n)
+ (round [n] n)
+ (integer-length [n] (.bitLength n))
+ (sqrt [n] (sqrt-integer n)))
+
+(extend-type
+ java.math.BigDecimal MathFunctions
+ (floor [n] (bigint (.setScale n 0 BigDecimal/ROUND_FLOOR)))
+ (ceil [n] (bigint (.setScale n 0 BigDecimal/ROUND_CEILING)))
+ (round [n] (floor (+ n 0.5M)))
+ (sqrt [n] (sqrt-decimal n)))
+
+(extend-type
+ clojure.lang.Ratio MathFunctions
+ (floor [n]
+ (if (pos? n) (quot (. n numerator) (. n denominator))
+ (dec' (quot (. n numerator) (. n denominator)))))
+ (ceil [n]
+ (if (pos? n) (inc' (quot (. n numerator) (. n denominator)))
+ (quot (. n numerator) (. n denominator))))
+ (round [n] (floor (+ n 1/2)))
+ (sqrt [n] (sqrt-ratio n)))
+
+(extend-type
+ Double MathFunctions
+ (floor [n] (Math/floor n))
+ (ceil [n] (Math/ceil n))
+ (round [n] (round (bigdec n)))
+ (sqrt [n] (Math/sqrt n)))
+
+(extend-type
+ Float MathFunctions
+ (floor [n] (Math/floor n))
+ (ceil [n] (Math/ceil n))
+ (round [n] (round (bigdec n)))
+ (sqrt [n] (Math/sqrt n)))
(defn gcd "(gcd a b) returns the greatest common divisor of a and b" [a b]
(if (or (not (integer? a)) (not (integer? b)))
@@ -177,17 +209,6 @@ round always returns an integer. Rounds up for values exactly in between two in
(zero? b) 0
:else (abs (*' b (quot a (gcd a b))))))
-; Length of integer in binary, used as helper function for sqrt.
-(defmulti ^{:private true} integer-length class)
-(defmethod integer-length java.lang.Integer [n]
- (count (Integer/toBinaryString n)))
-(defmethod integer-length java.lang.Long [n]
- (count (Long/toBinaryString n)))
-(defmethod integer-length java.math.BigInteger [n]
- (.bitLength n))
-(defmethod integer-length clojure.lang.BigInt [n]
- (.bitLength n))
-
;; Produces the largest integer less than or equal to the square root of n
;; Input n must be a non-negative integer
(defn- integer-sqrt [n]
@@ -216,17 +237,14 @@ For example, (exact-integer-sqrt 15) is [3 6] because 15 = 3^2+6."
error (-' n (*' isqrt isqrt))]
[isqrt error])))
-(defmulti ^{:arglists '([n])
- :doc "Square root, but returns exact number if possible."}
- sqrt class)
-(defmethod sqrt ::integer [n]
+(defn- sqrt-integer [n]
(if (neg? n) Double/NaN
(let [isqrt (integer-sqrt n),
error (-' n (*' isqrt isqrt))]
(if (zero? error) isqrt
(Math/sqrt n)))))
-(defmethod sqrt clojure.lang.Ratio [n]
+(defn- sqrt-ratio [n]
(if (neg? n) Double/NaN
(let [numerator (.numerator n),
denominator (.denominator n),
@@ -238,14 +256,11 @@ For example, (exact-integer-sqrt 15) is [3 6] because 15 = 3^2+6."
(Math/sqrt n)
(/ sqrtnum sqrtden)))))))
-(defmethod sqrt java.math.BigDecimal [n]
+(defn- sqrt-decimal [n]
(if (neg? n) Double/NaN
(let [frac (rationalize n),
sqrtfrac (sqrt frac)]
(if (ratio? sqrtfrac)
(/ (BigDecimal. (.numerator sqrtfrac))
(BigDecimal. (.denominator sqrtfrac)))
- sqrtfrac))))
-
-(defmethod sqrt :default [n]
- (Math/sqrt n))
+ sqrtfrac))))

0 comments on commit d4c298f

Please sign in to comment.
Something went wrong with that request. Please try again.