Skip to content

Commit

Permalink
Make MLockedSizedBytes type honest
Browse files Browse the repository at this point in the history
Previously, MLockedSizedBytes was defined in terms of PinnedSizedBytes,
but never actually used as such - the memory it handles doesn't actually
hold a PinnedSizedBytes object, it's just a blob of raw memory, not an
actual Haskell value in any capacity, and you're not supposed to use it
as such ever.

Hence, we introduce a `SizedVoid` type, which is similar to `Void`
(i.e., it has no inhabitants), but additionally carries a type-level
size, which allows us to express the notion of "a pointer to a block of
memory of a particular size that we will never use as a Haskell value,
only manipulate directly through the pointer".
  • Loading branch information
tdammers committed Nov 30, 2022
1 parent 96b5834 commit d30c878
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 40 deletions.
14 changes: 7 additions & 7 deletions cardano-crypto-class/memory-example/Main.hs
Expand Up @@ -20,7 +20,7 @@ import System.Posix.Process (getProcessID)
import qualified Data.ByteString as SB

import Cardano.Crypto.Libsodium
import Cardano.Crypto.Libsodium.MLockedBytes.Internal (MLockedSizedBytes (..))
import Cardano.Crypto.Libsodium.MLockedBytes (traceMLSB)
import Cardano.Crypto.Hash (SHA256, Blake2b_256, digest)

main :: IO ()
Expand All @@ -43,15 +43,15 @@ main = do
-- example SHA256 hash
do
let input = SB.pack [0..255]
MLSB hash <- digestMLockedBS (Proxy @SHA256) input
traceMLockedForeignPtr hash
hash <- digestMLockedBS (Proxy @SHA256) input
traceMLSB hash
print (digest (Proxy @SHA256) input)

-- example Blake2b_256 hash
do
let input = SB.pack [0..255]
MLSB hash <- digestMLockedBS (Proxy @Blake2b_256) input
traceMLockedForeignPtr hash
hash <- digestMLockedBS (Proxy @Blake2b_256) input
traceMLSB hash
print (digest (Proxy @Blake2b_256) input)

example
Expand All @@ -72,9 +72,9 @@ example args alloc = do
traceMLockedForeignPtr fptr

-- smoke test that hashing works
MLSB hash <- withMLockedForeignPtr fptr $ \ptr ->
hash <- withMLockedForeignPtr fptr $ \ptr ->
digestMLockedStorable (Proxy @SHA256) ptr
traceMLockedForeignPtr hash
traceMLSB hash

-- force finalizers
finalizeMLockedForeignPtr fptr
Expand Down
1 change: 0 additions & 1 deletion cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519ML.hs
Expand Up @@ -29,7 +29,6 @@ import Foreign.Ptr (castPtr, nullPtr)
import qualified Data.ByteString as BS
-- import qualified Data.ByteString.Unsafe as BS
import Data.Proxy
import Control.Exception (bracket)

import Cardano.Binary (FromCBOR (..), ToCBOR (..))

Expand Down
1 change: 0 additions & 1 deletion cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs
Expand Up @@ -32,7 +32,6 @@ import Cardano.Crypto.Hash
import Cardano.Crypto.Seed
import Cardano.Crypto.KES.Class
import Cardano.Crypto.Util
import Cardano.Crypto.DirectSerialise
import Cardano.Crypto.MonadSodium (mlsbAsByteString)

data MockKES (t :: Nat)
Expand Down
17 changes: 7 additions & 10 deletions cardano-crypto-class/src/Cardano/Crypto/Libsodium/Hash.hs
Expand Up @@ -28,7 +28,6 @@ import qualified Data.ByteString as BS

import Cardano.Foreign
import Cardano.Crypto.Hash (HashAlgorithm(SizeHash), SHA256, Blake2b_256)
import Cardano.Crypto.PinnedSizedBytes (ptrPsbToSizedPtr)
import Cardano.Crypto.Libsodium.C
import Cardano.Crypto.Libsodium.Memory.Internal
import Cardano.Crypto.Libsodium.MLockedBytes.Internal
Expand Down Expand Up @@ -98,14 +97,13 @@ expandHash h (MLSB sfptr) = do
instance SodiumHashAlgorithm SHA256 where
naclDigestPtr :: forall proxy a. proxy SHA256 -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash SHA256))
naclDigestPtr _ input inputlen = do
output <- allocMLockedForeignPtr
withMLockedForeignPtr output $ \output' -> do
res <- c_crypto_hash_sha256 (ptrPsbToSizedPtr output') (castPtr input) (fromIntegral inputlen)
output <- mlsbNew
mlsbUseAsSizedPtr output $ \output' -> do
res <- c_crypto_hash_sha256 output' (castPtr input) (fromIntegral inputlen)
unless (res == 0) $ do
errno <- getErrno
ioException $ errnoToIOError "digestMLocked @SHA256: c_crypto_hash_sha256" errno Nothing Nothing

return (MLSB output)
return output

-- Test that manually written numbers are the same as in libsodium
_testSHA256 :: SizeHash SHA256 :~: CRYPTO_SHA256_BYTES
Expand All @@ -114,17 +112,16 @@ _testSHA256 = Refl
instance SodiumHashAlgorithm Blake2b_256 where
naclDigestPtr :: forall proxy a. proxy Blake2b_256 -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash Blake2b_256))
naclDigestPtr _ input inputlen = do
output <- allocMLockedForeignPtr
withMLockedForeignPtr output $ \output' -> do
output <- mlsbNew
mlsbUseAsCPtr output $ \output' -> do
res <- c_crypto_generichash_blake2b
output' (fromInteger $ natVal (Proxy @CRYPTO_BLAKE2B_256_BYTES)) -- output
(castPtr input) (fromIntegral inputlen) -- input
nullPtr 0 -- key, unused
unless (res == 0) $ do
errno <- getErrno
ioException $ errnoToIOError "digestMLocked @Blake2b_256: c_crypto_hash_sha256" errno Nothing Nothing

return (MLSB output)
return output

_testBlake2b256 :: SizeHash Blake2b_256 :~: CRYPTO_BLAKE2B_256_BYTES
_testBlake2b256 = Refl
@@ -1,6 +1,8 @@
module Cardano.Crypto.Libsodium.MLockedBytes (
MLockedSizedBytes,
mlsbNew,
mlsbNewZero,
mlsbZero,
mlsbFromByteString,
mlsbFromByteStringCheck,
mlsbAsByteString,
Expand All @@ -9,6 +11,7 @@ module Cardano.Crypto.Libsodium.MLockedBytes (
mlsbUseAsSizedPtr,
mlsbFinalize,
mlsbCopy,
traceMLSB,
) where

import Cardano.Crypto.Libsodium.MLockedBytes.Internal
Expand Up @@ -4,9 +4,12 @@
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE KindSignatures #-}
module Cardano.Crypto.Libsodium.MLockedBytes.Internal (
MLockedSizedBytes (..),
mlsbNew,
mlsbNewZero,
mlsbZero,
mlsbFromByteString,
mlsbFromByteStringCheck,
mlsbAsByteString,
Expand All @@ -15,14 +18,15 @@ module Cardano.Crypto.Libsodium.MLockedBytes.Internal (
mlsbUseAsSizedPtr,
mlsbCopy,
mlsbFinalize,
traceMLSB,
) where

import Control.DeepSeq (NFData (..))
import Data.Proxy (Proxy (..))
import Foreign.C.Types (CSize (..))
import Foreign.ForeignPtr (castForeignPtr)
import Foreign.Ptr (Ptr, castPtr)
import GHC.TypeLits (KnownNat, natVal)
import GHC.TypeLits (KnownNat, Nat, natVal)
import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..))
import System.IO.Unsafe (unsafeDupablePerformIO)
import Data.Word (Word8)
Expand All @@ -32,12 +36,38 @@ import Text.Printf
import Cardano.Foreign
import Cardano.Crypto.Libsodium.Memory.Internal
import Cardano.Crypto.Libsodium.C
import Cardano.Crypto.PinnedSizedBytes

import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BSI
import Foreign.Storable (Storable (..))
import Data.Bits (shiftL)

-- | A void type with a type-level size attached to it. We need this in order
-- to express \"pointer to a block of memory of a particular size that can be
-- manipulated through the pointer, but not as a plain Haskell value\" as
-- @Ptr (SizedVoid n)@, or @ForeignPtr (SizedVoid n)@, or
-- @MLockedForeignPtr (SizedVoid n)@.
data SizedVoid (n :: Nat)

-- | Storable instance is necessary for 'allocMLockedForeignPtr'; 'peek' and
-- 'poke' error out, but cannot actually be used due to 'SizedVoid' not having
-- any inhabitants.
instance KnownNat n => Storable (SizedVoid n) where
sizeOf _ = fromIntegral (natVal (Proxy @n))
alignment _ = nextPowerOf2 (fromIntegral (natVal (Proxy @n)))
peek _ = error "Do not peek SizedVoid"
poke _ _ = error "Do not poke SizedVoid"

nextPowerOf2 :: Int -> Int
nextPowerOf2 i =
go 1
where
go :: Int -> Int
go c =
let c' = c `shiftL` 1
in if c' > i then c else go c'

newtype MLockedSizedBytes n = MLSB (MLockedForeignPtr (PinnedSizedBytes n))
newtype MLockedSizedBytes (n :: Nat) = MLSB (MLockedForeignPtr (SizedVoid n))
deriving NoThunks via OnlyCheckWhnfNamed "MLockedSizedBytes" (MLockedSizedBytes n)
deriving newtype NFData

Expand All @@ -54,29 +84,32 @@ instance KnownNat n => Ord (MLockedSizedBytes n) where
size = natVal (Proxy @n)

instance KnownNat n => Show (MLockedSizedBytes n) where
-- showsPrec d _ = showParen (d > 10)
-- $ showString "_ :: MLockedSizedBytes "
-- . showsPrec 11 (natVal (Proxy @n))
show mlsb =
let bytes = BS.unpack $ mlsbAsByteString mlsb
hexstr = concatMap (printf "%02x") bytes
in "MLSB " ++ hexstr

withMLSB :: forall a b n. MLockedSizedBytes n -> (Ptr a -> IO b) -> IO b
withMLSB (MLSB fptr) action = withMLockedForeignPtr fptr (action . castPtr)
traceMLSB :: KnownNat n => MLockedSizedBytes n -> IO ()
traceMLSB = print
{-# DEPRECATED traceMLSB "Don't leave traceMLockedForeignPtr in production" #-}

-- | Note: this doesn't need to allocate mlocked memory,
-- but we do that for consistency
-- mlsbZero :: forall n. KnownNat n => MLockedSizedBytes n
-- mlsbZero = unsafeDupablePerformIO mlsbNew
withMLSB :: forall b n. MLockedSizedBytes n -> (Ptr (SizedVoid n) -> IO b) -> IO b
withMLSB (MLSB fptr) action = withMLockedForeignPtr fptr action

mlsbNew :: forall n. KnownNat n => IO (MLockedSizedBytes n)
mlsbNew = do
fptr <- allocMLockedForeignPtr
withMLockedForeignPtr fptr $ \ptr -> do
_ <- c_memset (castPtr ptr) 0 size
return ()
return (MLSB fptr)
mlsbNew = MLSB <$> allocMLockedForeignPtr

mlsbNewZero :: forall n. KnownNat n => IO (MLockedSizedBytes n)
mlsbNewZero = do
mlsb <- mlsbNew
mlsbZero mlsb
return mlsb

mlsbZero :: forall n. KnownNat n => MLockedSizedBytes n -> IO ()
mlsbZero mlsb = do
withMLSB mlsb $ \ptr -> do
_ <- c_memset (castPtr ptr) 0 size
return ()
where
size :: CSize
size = fromInteger (natVal (Proxy @n))
Expand Down Expand Up @@ -138,10 +171,12 @@ mlsbToByteString mlsb =
size = fromInteger (natVal (Proxy @n))

mlsbUseAsCPtr :: MLockedSizedBytes n -> (Ptr Word8 -> IO r) -> IO r
mlsbUseAsCPtr (MLSB x) k = withMLockedForeignPtr x (k . castPtr)
mlsbUseAsCPtr (MLSB x) k =
withMLockedForeignPtr x (k . castPtr)

mlsbUseAsSizedPtr :: MLockedSizedBytes n -> (SizedPtr n -> IO r) -> IO r
mlsbUseAsSizedPtr (MLSB x) k = withMLockedForeignPtr x (k . ptrPsbToSizedPtr)
mlsbUseAsSizedPtr :: forall n r. MLockedSizedBytes n -> (SizedPtr n -> IO r) -> IO r
mlsbUseAsSizedPtr (MLSB x) k =
withMLockedForeignPtr x (k . SizedPtr . castPtr)

-- | Calls 'finalizeMLockedForeignPtr' on underlying pointer.
-- This function invalidates argument.
Expand Down

0 comments on commit d30c878

Please sign in to comment.