diff --git a/src/Database/LSMTree/Internal/ByteString.hs b/src/Database/LSMTree/Internal/ByteString.hs index 72f203ac..c2edf2e2 100644 --- a/src/Database/LSMTree/Internal/ByteString.hs +++ b/src/Database/LSMTree/Internal/ByteString.hs @@ -8,6 +8,7 @@ module Database.LSMTree.Internal.ByteString ( tryGetByteArray, shortByteStringFromTo, byteArrayFromTo, + byteArrayToSBS, ) where import qualified Data.ByteString as BS @@ -94,3 +95,10 @@ shortByteStringCopyStepFromTo !ip0 !ipe0 !sbs k = where outRemaining = ope `minusPtr` op inpRemaining = ipe - ip + +byteArrayToSBS :: ByteArray -> ShortByteString +#if MIN_VERSION_bytestring(0,12,0) +byteArrayToSBS ba = SBS.ShortByteString ba +#else +byteArrayToSBS (ByteArray ba) = SBS.SBS ba +#endif diff --git a/src/Database/LSMTree/Internal/RawBytes.hs b/src/Database/LSMTree/Internal/RawBytes.hs index 598d56ac..686656cf 100644 --- a/src/Database/LSMTree/Internal/RawBytes.hs +++ b/src/Database/LSMTree/Internal/RawBytes.hs @@ -35,6 +35,7 @@ module Database.LSMTree.Internal.RawBytes ( -- | Use 'Semigroup' and 'Monoid' operations -- ** Restricting memory usage , copy + , force -- * Conversions , fromVector , fromByteArray @@ -58,6 +59,7 @@ import qualified Data.ByteString.Builder as BB import Data.ByteString.Short (ShortByteString (SBS)) import qualified Data.ByteString.Short as SBS import Data.Primitive.ByteArray (ByteArray (..), compareByteArrays) +import qualified Data.Primitive.ByteArray as BA import qualified Data.Vector.Primitive as PV import Database.LSMTree.Internal.ByteString (shortByteStringFromTo, tryGetByteArray) @@ -218,6 +220,19 @@ instance Monoid RawBytes where copy :: RawBytes -> RawBytes copy (RawBytes pvec) = RawBytes (PV.force pvec) +-- | Force 'RawBytes' to not retain any extra memory. This may copy the contents. +force :: RawBytes -> ByteArray +force (RawBytes (PV.Vector off len ba)) + | off == 0 + , BA.sizeofByteArray ba == len + = ba + + | otherwise + = BA.runByteArray $ do + mba <- BA.newByteArray len + BA.copyByteArray mba 0 ba off len + return mba + {------------------------------------------------------------------------------- Conversions -------------------------------------------------------------------------------} diff --git a/src/utils/Database/LSMTree/Orphans.hs b/src/utils/Database/LSMTree/Orphans.hs index 0220494b..28d9b511 100644 --- a/src/utils/Database/LSMTree/Orphans.hs +++ b/src/utils/Database/LSMTree/Orphans.hs @@ -18,6 +18,7 @@ import qualified Data.ByteString.Short.Internal as SBS import qualified Data.Primitive as P import Data.WideWord.Word256 (Word256 (..)) import Data.Word (Word64, byteSwap64) +import Database.LSMTree.Internal.ByteString (byteArrayToSBS) import Database.LSMTree.Internal.Entry (NumEntries (..)) import Database.LSMTree.Internal.IndexCompact (IndexCompact (..), PageNo (..), PageSpan (..)) @@ -106,3 +107,29 @@ instance SerialiseValue BS.ByteString where serialiseValue = RB.fromShortByteString . SBS.toShort deserialiseValue = deserialiseValueN . pure deserialiseValueN = LBS.toStrict . deserialiseValueN + +{------------------------------------------------------------------------------- + ShortByteString +-------------------------------------------------------------------------------} + +instance SerialiseKey SBS.ShortByteString where + serialiseKey = RB.fromShortByteString + deserialiseKey = byteArrayToSBS . RB.force + +instance SerialiseValue SBS.ShortByteString where + serialiseValue = RB.fromShortByteString + deserialiseValue = byteArrayToSBS . RB.force + deserialiseValueN = byteArrayToSBS . foldMap RB.force + +{------------------------------------------------------------------------------- + ByteArray +-------------------------------------------------------------------------------} + +instance SerialiseKey P.ByteArray where + serialiseKey ba = RB.fromByteArray 0 (P.sizeofByteArray ba) ba + deserialiseKey = RB.force + +instance SerialiseValue P.ByteArray where + serialiseValue ba = RB.fromByteArray 0 (P.sizeofByteArray ba) ba + deserialiseValue = RB.force + deserialiseValueN = foldMap RB.force