/
MillerRabin.hs
79 lines (60 loc) · 2.48 KB
/
MillerRabin.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
{-# LANGUAGE ConstraintKinds,
DataKinds,
KindSignatures,
TypeOperators,
GADTs,
TypeFamilies,
StandaloneDeriving,
RankNTypes,
ScopedTypeVariables #-}
module MillerRabin where
import GHC.TypeLits
import Data.Proxy(Proxy(..))
import Control.Arrow
import System.Random
newtype Mod i (n :: Nat) = Mod i
deriving (Eq, Ord, Show)
toMod :: forall n i. (Integral i, KnownNat n) => i -> Mod i n
toMod i = Mod $ i `mod` (fromInteger $ natVal (Proxy :: Proxy n))
unMod :: Mod i n -> i
unMod (Mod i) = i
modAdd :: forall i n. (Integral i, KnownNat n) => Mod i n -> Mod i n -> Mod i n
modAdd (Mod i1) (Mod i2) = toMod (i1 + i2)
modNegate :: forall i n. (Integral i, KnownNat n) => Mod i n -> Mod i n
modNegate (Mod i) = toMod ( (fromInteger $ natVal (Proxy :: Proxy n)) - i)
modMinus :: forall i n. (Integral i, KnownNat n) => Mod i n -> Mod i n -> Mod i n
modMinus = curry $ (second modNegate >>> uncurry modAdd)
modMult :: forall i n. (Integral i , KnownNat n) => Mod i n -> Mod i n -> Mod i n
modMult (Mod i1) (Mod i2) = toMod (i1*i2)
instance (Integral i, KnownNat n) => Num (Mod i n) where
fromInteger = toMod.fromInteger
(+) = modAdd
(-) = modMinus
(*) = modMult
negate = modNegate
abs = id
signum = const 1
millerRabin :: Integer -> Integer -> IO Bool
millerRabin 0 _ = return True
millerRabin n p = do gen <- getStdGen
let (a :: Integer, _) = randomR (1,p-1) gen
let Just someNat = someNatVal p
case someNat of
SomeNat (_ :: Proxy p) -> do
let aMod :: Mod Integer p
aMod = toMod a
let b1 = millerRabinCheck aMod
bRest <- millerRabin (n-1) p
return (b1 && bRest)
millerRabinCheck:: forall n. KnownNat n => (Mod Integer n) -> Bool
millerRabinCheck a = (firstCheck == 1) || ((-1) `elem` lsVals)
where nInteger = natVal (Proxy :: Proxy n)
(r,d) = decomp (nInteger - 1)
firstCheck = a^d
lsVals :: [Mod Integer n] = take (fromInteger r) $ iterate (\n -> n^2) firstCheck
twoAdicVal :: Integer -> Integer
twoAdicVal n = if even n then (1 + twoAdicVal (n `div` 2)) else 0
oddComponent :: Integer -> Integer
oddComponent n = n `div` (2^((twoAdicVal n)))
decomp :: Integer -> (Integer,Integer)
decomp = twoAdicVal &&& oddComponent