Skip to content

Commit

Permalink
Implement symmetric difference on PrimeIntSet
Browse files Browse the repository at this point in the history
  • Loading branch information
Bodigrim committed Jul 11, 2020
1 parent fa26ca9 commit 856a986
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 5 deletions.
6 changes: 4 additions & 2 deletions Math/NumberTheory/Primes/Factorisation/QuadraticSieve.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import Math.NumberTheory.Roots
import Math.NumberTheory.Primes
import Math.NumberTheory.Moduli.Sqrt
import Math.NumberTheory.Utils.FromIntegral
import Unsafe.Coerce (unsafeCoerce)

data SignedPrimeIntSet = SignedPrimeIntSet
{ sign :: !Bool
Expand All @@ -44,7 +45,7 @@ member value (SignedPrimeIntSet s ps) = case value of
PrimeInt p -> p `PS.member` ps

xor :: SignedPrimeIntSet -> SignedPrimeIntSet -> SignedPrimeIntSet
xor (SignedPrimeIntSet s1 ps1) (SignedPrimeIntSet s2 ps2) = SignedPrimeIntSet (s1 /= s2) ((ps1 PS.\\ PS.unPrimeIntSet ps2) <> (ps2 PS.\\ PS.unPrimeIntSet ps1))
xor (SignedPrimeIntSet s1 ps1) (SignedPrimeIntSet s2 ps2) = SignedPrimeIntSet (s1 /= s2) (ps1 `PS.symmetricDifference` ps2)

-- | Given an odd positive composite Integer @n@ and Int parameters @b@ and @t@,
-- the Quadratic Sieve attempts to output @factor@, a factor of @n@. If it fails,
Expand Down Expand Up @@ -143,7 +144,8 @@ gaussianElimination (p@(indices, pivotFact) : xs) = case nonZero pivotFact of
Just pivot -> gaussianElimination (map (\q@(_, fact) -> if pivot `member` fact then add p q else q) xs)
Nothing -> indices : gaussianElimination xs
where
add (a, u) (b, v) = ((a S.\\ b) <> (b S.\\ a), u `xor` v)
-- Temporary, until Data.IntSet.symmetricDifference is provided.
add (a, u) (b, v) = (unsafeCoerce PS.symmetricDifference a b, u `xor` v)

-- Given a solution, the value of @f(x)@ is computed again. By construction,
-- the solution IntSet consists of values which correspond to columns in the
Expand Down
77 changes: 75 additions & 2 deletions Math/NumberTheory/Primes/IntSet.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
-- > import qualified Math.NumberTheory.Primes.IntSet as PrimeIntSet
--

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
Expand Down Expand Up @@ -51,6 +52,7 @@ module Math.NumberTheory.Primes.IntSet
, difference
, (\\)
, intersection
, symmetricDifference
-- * Filter
, filter
, partition
Expand All @@ -73,16 +75,20 @@ module Math.NumberTheory.Primes.IntSet
, toDescList
) where

import Prelude (Eq, Ord, Show, Monoid, Bool, Maybe(..), Int, otherwise)
import Prelude ((>), (/=), (==), (-), Eq, Ord, Show, Monoid, Bool, Maybe(..), Int, Word, otherwise)
import Control.DeepSeq (NFData)
import Data.Coerce (coerce)
import Data.Data (Data)
import Data.Function (on)
import Data.IntSet (IntSet)
import qualified Data.IntSet as IS
import qualified Data.IntSet.Internal as IS
import Data.Semigroup (Semigroup)
import qualified GHC.Exts (IsList(..))

import Math.NumberTheory.Primes.Types (Prime(..))
import Math.NumberTheory.Utils.FromIntegral (wordToInt, intToWord)
import Data.Bits (Bits(..))
import Utils.Containers.Internal.BitUtil (highestBitMask)

-- | A set of 'Prime' integers.
newtype PrimeIntSet = PrimeIntSet {
Expand Down Expand Up @@ -187,6 +193,10 @@ disjoint (PrimeIntSet x) (PrimeIntSet y) = IS.null (IS.intersection x y)
difference :: PrimeIntSet -> IntSet -> PrimeIntSet
difference = coerce IS.difference

-- | Symmetric difference of two sets of primes.
symmetricDifference :: PrimeIntSet -> PrimeIntSet -> PrimeIntSet
symmetricDifference = coerce symmDiff

-- | An alias to 'difference'.
(\\) :: PrimeIntSet -> IntSet -> PrimeIntSet
(\\) = coerce (IS.\\)
Expand Down Expand Up @@ -263,3 +273,66 @@ toAscList = coerce IS.toAscList
-- | Convert the set to a list of descending primes.
toDescList :: PrimeIntSet -> [Prime Int]
toDescList = coerce IS.toDescList

-------------------------------------------------------------------------------
-- IntSet helpers

-- | Symmetric difference of two sets.
-- Implementation is inspired by 'Data.IntSet.union'
-- and 'Data.IntSet.difference'.
symmDiff :: IntSet -> IntSet -> IntSet
symmDiff t1 t2 = case t1 of
IS.Bin p1 m1 l1 r1 -> case t2 of
IS.Bin p2 m2 l2 r2
| shorter m1 m2 -> symmDiff1
| shorter m2 m1 -> symmDiff2
| p1 == p2 -> bin p1 m1 (symmDiff l1 l2) (symmDiff r1 r2)
| otherwise -> link p1 t1 p2 t2
where
symmDiff1
| mask p2 m1 /= p1 = link p1 t1 p2 t2
| p2 .&. m1 == 0 = bin p1 m1 (symmDiff l1 t2) r1
| otherwise = bin p1 m1 l1 (symmDiff r1 t2)
symmDiff2
| mask p1 m2 /= p2 = link p1 t1 p2 t2
| p1 .&. m2 == 0 = bin p2 m2 (symmDiff t1 l2) r2
| otherwise = bin p2 m2 l2 (symmDiff t1 r2)
IS.Tip kx bm -> symmDiffBM kx bm t1
IS.Nil -> t1
IS.Tip kx bm -> symmDiffBM kx bm t2
IS.Nil -> t2

shorter :: Int -> Int -> Bool
shorter = (>) `on` intToWord

symmDiffBM :: Int -> Word -> IntSet -> IntSet
symmDiffBM !kx !bm t = case t of
IS.Bin p m l r
| mask kx m /= p -> link kx (IS.Tip kx bm) p t
| kx .&. m == 0 -> bin p m (symmDiffBM kx bm l) r
| otherwise -> bin p m l (symmDiffBM kx bm r)
IS.Tip kx' bm'
| kx' == kx -> if bm' == bm then IS.Nil else IS.Tip kx (bm' `xor` bm)
| otherwise -> link kx (IS.Tip kx bm) kx' t
IS.Nil -> IS.Tip kx bm

link :: Int -> IntSet -> Int -> IntSet -> IntSet
link p1 t1 p2 t2
| p1 .&. m == 0 = IS.Bin p m t1 t2
| otherwise = IS.Bin p m t2 t1
where
m = wordToInt (highestBitMask (intToWord p1 `xor` intToWord p2))
p = mask p1 m
{-# INLINE link #-}

bin :: Int -> Int -> IntSet -> IntSet -> IntSet
bin p m l r = case r of
IS.Nil -> l
_ -> case l of
IS.Nil -> r
_ -> IS.Bin p m l r
{-# INLINE bin #-}

mask :: Int -> Int -> Int
mask i m = i .&. (complement (m - 1) `xor` m)
{-# INLINE mask #-}
17 changes: 16 additions & 1 deletion test-suite/Math/NumberTheory/PrimesTests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
-- Tests for Math.NumberTheory.Primes
--

{-# LANGUAGE CPP #-}

{-# OPTIONS_GHC -fno-warn-type-defaults #-}

module Math.NumberTheory.PrimesTests
Expand All @@ -15,7 +17,12 @@ module Math.NumberTheory.PrimesTests

import Test.Tasty

import Math.NumberTheory.Primes (primes, unPrime, nextPrime, precPrime)
#if __GLASGOW_HASKELL__ < 803
import Data.Semigroup
#endif

import Math.NumberTheory.Primes
import qualified Math.NumberTheory.Primes.IntSet as PS
import Math.NumberTheory.TestUtils

primesSumWonk :: Int -> Int
Expand All @@ -27,8 +34,16 @@ primesSum upto = sum . takeWhile (<= upto) . map unPrime $ primes
primesSumProperty :: NonNegative Int -> Bool
primesSumProperty (NonNegative n) = n < 2 || primesSumWonk n == primesSum n

symmetricDifferenceProperty :: [Prime Int] -> [Prime Int] -> Bool
symmetricDifferenceProperty xs ys = z1 == z2
where
x = PS.fromList xs
y = PS.fromList ys
z1 = (x PS.\\ PS.unPrimeIntSet y) <> (y PS.\\ PS.unPrimeIntSet x)
z2 = PS.symmetricDifference x y

testSuite :: TestTree
testSuite = testGroup "Primes"
[ testSmallAndQuick "primesSum" primesSumProperty
, testSmallAndQuick "symmetricDifference" symmetricDifferenceProperty
]

0 comments on commit 856a986

Please sign in to comment.