Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Node datatype #1501

Merged
merged 5 commits into from Sep 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 11 additions & 11 deletions src/Juvix/Compiler/Core/Data/InfoTable.hs
Expand Up @@ -9,7 +9,7 @@ data InfoTable = InfoTable
-- `_identMap` is needed only for REPL
_identMap :: HashMap Text (Either Symbol Tag),
_infoMain :: Maybe Symbol,
_infoIdents :: HashMap Symbol IdentInfo,
_infoIdentifiers :: HashMap Symbol IdentifierInfo,
_infoInductives :: HashMap Name InductiveInfo,
_infoConstructors :: HashMap Tag ConstructorInfo,
_infoAxioms :: HashMap Name AxiomInfo
Expand All @@ -21,20 +21,20 @@ emptyInfoTable =
{ _identContext = mempty,
_identMap = mempty,
_infoMain = Nothing,
_infoIdents = mempty,
_infoIdentifiers = mempty,
_infoInductives = mempty,
_infoConstructors = mempty,
_infoAxioms = mempty
}

data IdentInfo = IdentInfo
{ _identName :: Name,
_identSymbol :: Symbol,
_identType :: Type,
-- _identArgsNum will be used often enough to justify avoiding recomputation
_identArgsNum :: Int,
_identArgsInfo :: [ArgumentInfo],
_identIsExported :: Bool
data IdentifierInfo = IdentifierInfo
{ _identifierName :: Name,
_identifierSymbol :: Symbol,
_identifierType :: Type,
-- _identifierArgsNum will be used often enough to justify avoiding recomputation
_identifierArgsNum :: Int,
_identifierArgsInfo :: [ArgumentInfo],
_identifierIsExported :: Bool
}

data ArgumentInfo = ArgumentInfo
Expand Down Expand Up @@ -70,7 +70,7 @@ data AxiomInfo = AxiomInfo
}

makeLenses ''InfoTable
makeLenses ''IdentInfo
makeLenses ''IdentifierInfo
makeLenses ''ArgumentInfo
makeLenses ''InductiveInfo
makeLenses ''ConstructorInfo
Expand Down
12 changes: 6 additions & 6 deletions src/Juvix/Compiler/Core/Data/InfoTableBuilder.hs
Expand Up @@ -7,7 +7,7 @@ import Juvix.Compiler.Core.Language
data InfoTableBuilder m a where
FreshSymbol :: InfoTableBuilder m Symbol
FreshTag :: InfoTableBuilder m Tag
RegisterIdent :: IdentInfo -> InfoTableBuilder m ()
RegisterIdent :: IdentifierInfo -> InfoTableBuilder m ()
RegisterConstructor :: ConstructorInfo -> InfoTableBuilder m ()
RegisterIdentNode :: Symbol -> Node -> InfoTableBuilder m ()
SetIdentArgsInfo :: Symbol -> [ArgumentInfo] -> InfoTableBuilder m ()
Expand Down Expand Up @@ -44,7 +44,7 @@ makeLenses ''BuilderState
initBuilderState :: InfoTable -> BuilderState
initBuilderState tab =
BuilderState
{ _stateNextSymbol = fromIntegral $ HashMap.size (tab ^. infoIdents),
{ _stateNextSymbol = fromIntegral $ HashMap.size (tab ^. infoIdentifiers),
_stateNextUserTag = fromIntegral $ HashMap.size (tab ^. infoConstructors),
_stateInfoTable = tab
}
Expand All @@ -66,16 +66,16 @@ runInfoTableBuilder tab =
s <- get
return (UserTag (s ^. stateNextUserTag - 1))
RegisterIdent ii -> do
modify' (over stateInfoTable (over infoIdents (HashMap.insert (ii ^. identSymbol) ii)))
modify' (over stateInfoTable (over identMap (HashMap.insert (ii ^. (identName . nameText)) (Left (ii ^. identSymbol)))))
modify' (over stateInfoTable (over infoIdentifiers (HashMap.insert (ii ^. identifierSymbol) ii)))
modify' (over stateInfoTable (over identMap (HashMap.insert (ii ^. (identifierName . nameText)) (Left (ii ^. identifierSymbol)))))
RegisterConstructor ci -> do
modify' (over stateInfoTable (over infoConstructors (HashMap.insert (ci ^. constructorTag) ci)))
modify' (over stateInfoTable (over identMap (HashMap.insert (ci ^. (constructorName . nameText)) (Right (ci ^. constructorTag)))))
RegisterIdentNode sym node ->
modify' (over stateInfoTable (over identContext (HashMap.insert sym node)))
SetIdentArgsInfo sym argsInfo -> do
modify' (over stateInfoTable (over infoIdents (HashMap.adjust (set identArgsInfo argsInfo) sym)))
modify' (over stateInfoTable (over infoIdents (HashMap.adjust (set identArgsNum (length argsInfo)) sym)))
modify' (over stateInfoTable (over infoIdentifiers (HashMap.adjust (set identifierArgsInfo argsInfo) sym)))
modify' (over stateInfoTable (over infoIdentifiers (HashMap.adjust (set identifierArgsNum (length argsInfo)) sym)))
GetIdent txt -> do
s <- get
return $ HashMap.lookup txt (s ^. (stateInfoTable . identMap))
Expand Down
64 changes: 32 additions & 32 deletions src/Juvix/Compiler/Core/Evaluator.hs
Expand Up @@ -43,7 +43,7 @@ instance Show EvalError where

instance Exception.Exception EvalError

-- `eval ctx env n` evalues a node `n` whose all free variables point into
-- | `eval ctx env n` evalues a node `n` whose all free variables point into
-- `env`. All nodes in `ctx` must be closed. All nodes in `env` must be values.
-- Invariant for values v: eval ctx env v = v
eval :: IdentContext -> Env -> Node -> Node
Expand All @@ -54,24 +54,24 @@ eval !ctx !env0 = convertRuntimeNodes . eval' env0

eval' :: Env -> Node -> Node
eval' !env !n = case n of
Var _ idx -> env !! idx
Ident _ sym -> eval' [] (lookupContext n sym)
Constant {} -> n
App i l r ->
NVar (Var _ idx) -> env !! idx
NIdt (Ident _ sym) -> eval' [] (lookupContext n sym)
NCst {} -> n
NApp (App i l r) ->
case eval' env l of
Closure _ env' b -> let !v = eval' env r in eval' (v : env') b
v -> evalError "invalid application" (App i v (substEnv env r))
BuiltinApp _ op args -> applyBuiltin n env op args
Constr i tag args -> Constr i tag (map (eval' env) args)
Lambda i b -> Closure i env b
Let _ v b -> let !v' = eval' env v in eval' (v' : env) b
Case i v bs def ->
Closure env' (Lambda _ b) -> let !v = eval' env r in eval' (v : env') b
v -> evalError "invalid application" (mkApp i v (substEnv env r))
NBlt (BuiltinApp _ op args) -> applyBuiltin n env op args
NCtr (Constr i tag args) -> mkConstr i tag (map (eval' env) args)
NLam l@Lambda {} -> Closure env l
NLet (Let _ v b) -> let !v' = eval' env v in eval' (v' : env) b
NCase (Case i v bs def) ->
case eval' env v of
Constr _ tag args -> branch n env args tag def bs
v' -> evalError "matching on non-data" (substEnv env (Case i v' bs def))
Pi {} -> substEnv env n
Univ {} -> n
TypeConstr i sym args -> TypeConstr i sym (map (eval' env) args)
NCtr (Constr _ tag args) -> branch n env args tag def bs
v' -> evalError "matching on non-data" (substEnv env (mkCase i v' bs def))
NPi {} -> substEnv env n
NUniv {} -> n
NTyp (TypeConstr i sym args) -> mkTypeConstr i sym (map (eval' env) args)
Closure {} -> n

branch :: Node -> Env -> [Node] -> Tag -> Maybe Node -> [CaseBranch] -> Node
Expand All @@ -96,27 +96,27 @@ eval !ctx !env0 = convertRuntimeNodes . eval' env0
k -> nodeFromInteger (mod (integerFromNode (eval' env l)) k)
applyBuiltin _ env OpIntLt [l, r] = nodeFromBool (integerFromNode (eval' env l) < integerFromNode (eval' env r))
applyBuiltin _ env OpIntLe [l, r] = nodeFromBool (integerFromNode (eval' env l) <= integerFromNode (eval' env r))
applyBuiltin _ env OpEq [l, r] = nodeFromBool (eval' env l == eval' env r)
applyBuiltin _ env OpEq [l, r] = nodeFromBool (structEq (eval' env l) (eval' env r))
applyBuiltin _ env OpTrace [msg, x] = Debug.trace (printNode (eval' env msg)) (eval' env x)
applyBuiltin _ env OpFail [msg] =
Exception.throw (EvalError (fromString ("failure: " ++ printNode (eval' env msg))) Nothing)
applyBuiltin n env _ _ = evalError "invalid builtin application" (substEnv env n)

nodeFromInteger :: Integer -> Node
nodeFromInteger !int = Constant Info.empty (ConstInteger int)
nodeFromInteger !int = mkConstant' (ConstInteger int)

nodeFromBool :: Bool -> Node
nodeFromBool True = Constr Info.empty (BuiltinTag TagTrue) []
nodeFromBool False = Constr Info.empty (BuiltinTag TagFalse) []
nodeFromBool True = mkConstr' (BuiltinTag TagTrue) []
nodeFromBool False = mkConstr' (BuiltinTag TagFalse) []

integerFromNode :: Node -> Integer
integerFromNode = \case
Constant _ (ConstInteger int) -> int
NCst (Constant _ (ConstInteger int)) -> int
v -> evalError "not an integer" v

printNode :: Node -> String
printNode = \case
Constant _ (ConstString s) -> fromText s
NCst (Constant _ (ConstString s)) -> fromText s
v -> fromText $ ppPrint v

lookupContext :: Node -> Symbol -> Node
Expand All @@ -130,29 +130,29 @@ hEvalIO :: Handle -> Handle -> IdentContext -> Env -> Node -> IO Node
hEvalIO hin hout ctx env node =
let node' = eval ctx env node
in case node' of
Constr _ (BuiltinTag TagReturn) [x] ->
NCtr (Constr _ (BuiltinTag TagReturn) [x]) ->
return x
Constr _ (BuiltinTag TagBind) [x, f] -> do
NCtr (Constr _ (BuiltinTag TagBind) [x, f]) -> do
x' <- hEvalIO hin hout ctx env x
hEvalIO hin hout ctx env (App Info.empty f x')
Constr _ (BuiltinTag TagWrite) [Constant _ (ConstString s)] -> do
hEvalIO hin hout ctx env (mkApp Info.empty f x')
NCtr (Constr _ (BuiltinTag TagWrite) [NCst (Constant _ (ConstString s))]) -> do
hPutStr hout s
return unitNode
Constr _ (BuiltinTag TagWrite) [arg] -> do
NCtr (Constr _ (BuiltinTag TagWrite) [arg]) -> do
hPutStr hout (ppPrint arg)
return unitNode
Constr _ (BuiltinTag TagReadLn) [] -> do
NCtr (Constr _ (BuiltinTag TagReadLn) []) -> do
hFlush hout
Constant Info.empty . ConstString <$> hGetLine hin
mkConstant Info.empty . ConstString <$> hGetLine hin
_ ->
return node'
where
unitNode = Constr (Info.singleton (NoDisplayInfo ())) (BuiltinTag TagTrue) []
unitNode = mkConstr (Info.singleton (NoDisplayInfo ())) (BuiltinTag TagTrue) []

evalIO :: IdentContext -> Env -> Node -> IO Node
evalIO = hEvalIO stdin stdout

-- Catch EvalError and convert it to CoreError. Needs a default location in case
-- | Catch EvalError and convert it to CoreError. Needs a default location in case
-- no location is available in EvalError.
catchEvalError :: Location -> a -> IO (Either CoreError a)
catchEvalError loc a =
Expand Down
25 changes: 13 additions & 12 deletions src/Juvix/Compiler/Core/Extra.hs
Expand Up @@ -3,14 +3,15 @@ module Juvix.Compiler.Core.Extra
module Juvix.Compiler.Core.Extra.Base,
module Juvix.Compiler.Core.Extra.Recursors,
module Juvix.Compiler.Core.Extra.Info,
module Juvix.Compiler.Core.Extra.Equality,
)
where

import Data.HashSet qualified as HashSet
import Juvix.Compiler.Core.Extra.Base
import Juvix.Compiler.Core.Extra.Equality
import Juvix.Compiler.Core.Extra.Info
import Juvix.Compiler.Core.Extra.Recursors
import Juvix.Compiler.Core.Info qualified as Info
import Juvix.Compiler.Core.Language

isClosed :: Node -> Bool
Expand All @@ -23,8 +24,8 @@ freeVars :: SimpleFold Node Index
freeVars f = ufoldAN' reassemble go
where
go k = \case
Var i idx
| idx >= k -> Var i <$> f (idx - k)
NVar (Var i idx)
| idx >= k -> mkVar i <$> f (idx - k)
n -> pure n

getIdents :: Node -> HashSet Symbol
Expand All @@ -34,14 +35,14 @@ nodeIdents :: Traversal' Node Symbol
nodeIdents f = umapLeaves go
where
go = \case
Ident i d -> Ident i <$> f d
NIdt (Ident i d) -> mkIdent i <$> f d
n -> pure n

countFreeVarOccurrences :: Index -> Node -> Int
countFreeVarOccurrences idx = gatherN go 0
where
go k acc = \case
Var _ idx' | idx' == idx + k -> acc + 1
NVar (Var _ idx') | idx' == idx + k -> acc + 1
_ -> acc

-- | increase all free variable indices by a given value
Expand All @@ -50,7 +51,7 @@ shift 0 = id
shift m = umapN go
where
go k n = case n of
Var i idx | idx >= k -> Var i (idx + m)
NVar (Var i idx) | idx >= k -> mkVar i (idx + m)
_ -> n

-- | substitute a term t for the free variable with de Bruijn index 0, avoiding
Expand All @@ -60,8 +61,8 @@ subst :: Node -> Node -> Node
subst t = umapN go
where
go k n = case n of
Var _ idx | idx == k -> shift k t
Var i idx | idx > k -> Var i (idx - 1)
NVar (Var _ idx) | idx == k -> shift k t
NVar (Var i idx) | idx > k -> mkVar i (idx - 1)
_ -> n

-- | reduce all beta redexes present in a term and the ones created immediately
Expand All @@ -71,12 +72,12 @@ developBeta = umap go
where
go :: Node -> Node
go n = case n of
App _ (Lambda _ body) arg -> subst arg body
NApp (App _ (NLam (Lambda _ body)) arg) -> subst arg body
_ -> n

etaExpand :: Int -> Node -> Node
etaExpand 0 n = n
etaExpand k n = mkLambdas k (mkApp (shift k n) (map (Var Info.empty) (reverse [0 .. k - 1])))
etaExpand k n = mkLambdas' k (mkApps' (shift k n) (map mkVar' (reverse [0 .. k - 1])))

-- | substitution of all free variables for values in an environment
substEnv :: Env -> Node -> Node
Expand All @@ -85,15 +86,15 @@ substEnv env
| otherwise = umapN go
where
go k n = case n of
Var _ idx | idx >= k -> env !! (idx - k)
NVar (Var _ idx) | idx >= k -> env !! (idx - k)
_ -> n

convertClosures :: Node -> Node
convertClosures = umap go
where
go :: Node -> Node
go n = case n of
Closure i env b -> substEnv env (Lambda i b)
Closure env (Lambda i b) -> substEnv env (mkLambda i b)
_ -> n

convertRuntimeNodes :: Node -> Node
Expand Down