From c0ce03ef7af071201ca541d941824a3b24b7c4fe Mon Sep 17 00:00:00 2001 From: Levent Erkok Date: Wed, 21 Feb 2024 15:19:16 -0800 Subject: [PATCH] Further encoding of e4m3. WIP --- src/CrackNum/Main.hs | 56 +++++++++++++++++++++++++++----------------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/src/CrackNum/Main.hs b/src/CrackNum/Main.hs index 5e40543..aca9444 100644 --- a/src/CrackNum/Main.hs +++ b/src/CrackNum/Main.hs @@ -563,7 +563,7 @@ encodeLane debug lanes num rm inp ef E5M2 _ = ef (FP 5 3) True -- 3 is intentional; the format ignores the sign storage, but SBV doesn't, following SMTLib - ef E4M3 _ = encodeE4M3 debug inp + ef E4M3 _ = encodeE4M3 debug rm inp -- | Convert certain strings to more understandable format by read -- If first argument is True, then we're reading using reads, i.e., haskell syntax @@ -589,10 +589,10 @@ unrecognized inp = die [ "Input does not represent floating point number we reco ] -- Encoding E4M3 is tricky, because of deviation from IEEE. So, we do a case analysis, mostly -encodeE4M3 :: Bool -> String -> IO () -encodeE4M3 debug inp = case reads (fixup True inp) of - [(v :: Double, "")] -> analyze v - _ -> unrecognized inp +encodeE4M3 :: Bool -> RM -> String -> IO () +encodeE4M3 debug rm inp = case reads (fixup True inp) of + [(v :: Double, "")] -> analyze v + _ -> unrecognized inp where config = z3{ crackNum = True , verbose = debug } @@ -630,30 +630,45 @@ encodeE4M3 debug inp = case reads (fixup True inp) of | True = range v - extraVals :: [(Double, String)] - extraVals = [(-v, '1':s) | (v, s) <- reverse pos] - ++ [( v, '0':s) | (v, s) <- pos] - where pos = [ (256, "1111000") - , (288, "1111001") - , (320, "1111010") - , (352, "1111011") - , (384, "1111100") - , (416, "1111101") - , (448, "1111110") + -- This list is sorted on the first value. + -- Final bool is True if this value is considered "even" for rounding purposes + extraVals :: [(Double, String, Bool)] + extraVals = [(-v, '1':s, eo) | (v, s, eo) <- reverse pos] + ++ [( v, '0':s, eo) | (v, s, eo) <- pos] + where pos = [ (240, "1110111", False) + , (256, "1111000", True) + , (288, "1111001", False) + , (320, "1111010", True) + , (352, "1111011", False) + , (384, "1111100", True) + , (416, "1111101", False) + , (448, "1111110", True) ] -- Pick the value we land on pick v = case [p | (d, p) <- dists, d == minVal] of [x] -> x - [x, y] -> choose x y + [x, y] -> choose v x y -- The following two can't happen, but just in case: [] -> error $ "encodeE4M3: Empty list of candidates for " ++ show v -- Can't happen cands -> error $ "encodeE4M3: More than two candidates for " ++ show v ++ ": " ++ show cands - where dists = [(abs (v - ev), p) | p@(ev, _) <- extraVals] + where dists = [(abs (v - ev), p) | p@(ev, _, _) <- extraVals] minVal = minimum $ map fst dists - choose :: (Double, String) -> (Double, String) -> (Double, String) - choose p1 _ = p1 + -- choose is called if we're smack in between the two values given. Then, we pick + -- depending on the rounding mode. Note that p1 < p2 is guaranteed here. + choose :: Double -> (Double, String, Bool) -> (Double, String, Bool) -> (Double, String, Bool) + choose v p1@(_, _, eo1) p2@(_, _, eo2) = + let isNegative = v < 0 || isNegativeZero v + in case rm of + RNE -> case (eo1, eo2) of + (True, False) -> p1 + (False, True) -> p2 + _ -> error $ "encodeE4M3: RNE can't pick between values: " ++ show (v, p1, p2) + RNA -> if isNegative then p1 else p2 + RTP -> p2 + RTN -> p1 + RTZ -> if isNegative then p2 else p1 range v | v < -448 || v > 448 -- Out-of-bounds becomes NaN @@ -661,7 +676,7 @@ encodeE4M3 debug inp = case reads (fixup True inp) of putStrLn $ " Note: The input value " ++ show v ++ " is out of bounds, and hence becomes NaN" putStrLn " The representable range is [-448, 448]" - | v >= -240 || v <= 240 -- Fits into regular 4+4 format, so just decode + | v >= -240 && v <= 240 -- Fits into regular 4+4 format, so just decode = do res <- satWith config $ do x :: SFloatingPoint 4 4 <- sFloatingPoint "ENCODED" constrain $ x .== fromSDouble sRNE (literal v) putStrLn $ fixEncoded res @@ -670,4 +685,3 @@ encodeE4M3 debug inp = case reads (fixup True inp) of -- Pick the nearest and display that | True = print $ pick v -