Skip to content

Commit

Permalink
finish impl of discrimination tree instance search
Browse files Browse the repository at this point in the history
  • Loading branch information
plt-amy committed Feb 23, 2024
1 parent a681f71 commit 2aecf36
Show file tree
Hide file tree
Showing 20 changed files with 428 additions and 145 deletions.
8 changes: 8 additions & 0 deletions src/full/Agda/Benchmarking.hs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ data Phase
-- ^ Subphase for 'Typing': generalizing over `variable`s
| InstanceSearch
-- ^ Subphase for 'Typing': solving instance goals
| Reflection
-- ^ Subphase for 'Typing': evaluating elaborator reflection
| InitialCandidates
-- ^ Subphase for 'InstanceSearch': collecting initial candidates
| FilterCandidates
-- ^ Subphase for 'InstanceSearch': checking candidates for validity
| OrderCandidates
-- ^ Subphase for 'InstanceSearch': ordering candidates for specificity
| UnifyIndices
-- ^ Subphase for 'CheckLHS': unification of the indices
| InverseScopeLookup
Expand Down
6 changes: 5 additions & 1 deletion src/full/Agda/Interaction/Imports.hs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ import Agda.Syntax.Scope.Base
import Agda.Syntax.TopLevelModuleName
import Agda.Syntax.Translation.ConcreteToAbstract as CToA

import Agda.TypeChecking.InstanceArguments
import Agda.TypeChecking.Errors
import Agda.TypeChecking.Warnings hiding (warnings)
import Agda.TypeChecking.Reduce
Expand Down Expand Up @@ -277,7 +278,6 @@ addImportedThings isig metas ibuiltin patsyns display userwarn
stTCWarnings `modifyTCLens` \ imp -> imp `List.union` warnings
stOpaqueBlocks `modifyTCLens` \ imp -> imp `Map.union` oblock
stOpaqueIds `modifyTCLens` \ imp -> imp `Map.union` oid
addImportedInstances isig

-- | Scope checks the given module. A proper version of the module
-- name (with correct definition sites) is returned.
Expand Down Expand Up @@ -1058,6 +1058,10 @@ createInterface mname file isMain msrc = do

unfreezeMetas

-- Remove any instances that now have visible arguments from the
-- instance tree before serialising.
pruneTemporaryInstances

-- Profiling: Count number of metas.
whenProfile Profile.Metas $ do
m <- fresh
Expand Down
74 changes: 60 additions & 14 deletions src/full/Agda/TypeChecking/DiscrimTree.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
module Agda.TypeChecking.DiscrimTree
( insertDT
, lookupDT
, deleteFromDT
)
where

Expand Down Expand Up @@ -33,6 +34,8 @@ import qualified Agda.Utils.ProfileOptions as Profile
import Agda.Utils.Impossible
import Agda.Utils.Trie (Trie(..))

-- | Dummy term to use as a stand-in for expanded eta-records while
-- building instance trees.
etaExpansionDummy :: Term
etaExpansionDummy = Dummy "eta-record argument in instance head" []

Expand All @@ -44,12 +47,16 @@ termKeyElims
-> TCM Type -- ^ Continuation to compute the type of the arguments in the spine.
-> Elims -- ^ The spine.
-> MaybeT TCM (Int, [Term])
termKeyElims precise _ elims | not precise = do
es <- hoistMaybe (allApplyElims elims)

-- Since the case tree was generated with wildcards everywhere an eta
-- record appeared, if we're *looking up* an instance, we don't have to
-- do the censorship again.
termKeyElims False _ elims = do
es <- MaybeT $ pure (allApplyElims elims)
pure (length es, map unArg es)

termKeyElims precise ty elims = do
args <- hoistMaybe (allApplyElims elims)
args <- MaybeT $ pure (allApplyElims elims)

let
go ty (Arg _ a:as) = flip (ifPiTypeB ty) (const mzero) \dom ty' -> do
Expand All @@ -73,6 +80,10 @@ termKeyElims precise ty elims = do
ty <- lift ty
go ty args

-- | Ticky profiling for the reason behind "inexactness" in instance
-- search. If at some point while narrowing the set of candidates we had
-- to go through all the possibilities, one of these counters is
-- incremented.
tickExplore :: Term -> TCM ()
tickExplore tm = whenProfile Profile.Instances do
tick "flex term blocking instance"
Expand All @@ -81,7 +92,13 @@ tickExplore tm = whenProfile Profile.Instances do
Def{} -> tick "explore: Def"
Var{} -> tick "explore: Var"
Lam _ v
-- These two are a hunch: just like FunK, it might be worth
-- optimising for the case where a lambda is constant (which is
-- easy to handle, by just pretending the term is something else).
-- These would come up in e.g. Dec (PathP (λ i → Nat) x y)
| NoAbs{} <- v -> tick "explore: constant function"
| Abs _ b <- v, not (0 `freeIn` b) -> tick "explore: constant function"

| otherwise -> tick "explore: Lam"
Lit{} -> tick "explore: Lit"
Sort{} -> tick "explore: Sort"
Expand Down Expand Up @@ -117,7 +134,7 @@ splitTermKey precise local tm = fmap (fromMaybe (FlexK, [])) . runMaybeT $ do
Var i as | i >= local -> do
let ty = unDom <$> domOfBV i
(arity, as) <- termKeyElims precise ty as
pure (LocalK i arity, as)
pure (LocalK (i - local) arity, as)

Con ch _ as -> do
let
Expand All @@ -130,11 +147,11 @@ splitTermKey precise local tm = fmap (fromMaybe (FlexK, [])) . runMaybeT $ do
-- For slightly more accurate matching, we decompose non-dependent
-- 'Pi's into a distinguished key.
| NoAbs _ b <- ret -> do
whenProfile Profile.Conversion $ tick "funk: non-dependent function"
pure (FunK, [unEl (unDom dom), raise 1 (unEl b)])
whenProfile Profile.Instances $ tick "funk: non-dependent function"
pure (FunK, [unEl (unDom dom), unEl b])

| otherwise -> do
whenProfile Profile.Conversion $ tick "funk: genuine pi"
whenProfile Profile.Instances $ tick "funk: genuine pi"
pure (PiK, [])

_ -> pure (FlexK, [])
Expand All @@ -149,7 +166,15 @@ termPath local acc (tm:todo) = do
]
termPath local (k:acc) (as <> todo)

insertDT :: (Ord a, PrettyTCM a) => Int -> Term -> a -> DiscrimTree a -> TCM (DiscrimTree a)
-- | Insert a value into the discrimination tree, turning variables into
-- rigid locals or wildcards depending on the given scope.
insertDT
:: (Ord a, PrettyTCM a)
=> Int -- ^ Number of variables to consider wildcards, e.g. the number of leading invisible pis in an instance type.
-> Term -- ^ The term to use as a key
-> a
-> DiscrimTree a
-> TCM (DiscrimTree a)
insertDT local key val tree = ignoreAbstractMode do
path <- termPath local [] [key]
let it = singletonDT path val
Expand Down Expand Up @@ -270,10 +295,6 @@ lookupDT term tree = ignoreAbstractMode (match [term] tree) where
let sp' = sp0 ++ args ++ sp1

-- Actually take the branch corresponding to our rigid head.
--
-- TODO (Amy, 2024-02-12): Need to handle eta equality. I guess
-- singletonDT can be made type-directed and we can add an EtaDT
-- to to the discrimination tree type??
branch <- visit k sp'

-- Function values get unpacked to their components on the
Expand All @@ -297,6 +318,31 @@ lookupDT term tree = ignoreAbstractMode (match [term] tree) where
[ "IMPOSSIBLE match" <+> prettyTCM ts
, prettyTCM tree
]
-- This really is impossible: since each branch is annotated with
-- its arity, we only take branches corresponding to neutrals which
-- exploded into enough arguments.
__IMPOSSIBLE__
-- TODO (Amy, 2024-02-12): Is it really? Padding the argument list
-- when exploring might not be enough.

-- | Smart constructor for a leaf node.
doneDT :: Set.Set a -> DiscrimTree a
doneDT s | Set.null s = EmptyDT
doneDT s = DoneDT s

-- | Remove a set of values from the discrimination tree. The tree is
-- rebuilt so that cases with no leaves are removed.
deleteFromDT :: Ord a => DiscrimTree a -> Set.Set a -> DiscrimTree a
deleteFromDT dt gone = case dt of
EmptyDT -> EmptyDT
DoneDT s ->
let s' = Set.difference s gone
in doneDT s'
CaseDT i s k ->
let
del x = case deleteFromDT x gone of
EmptyDT -> Nothing
dt' -> Just dt'

s' = Map.mapMaybe del s
k' = deleteFromDT k gone
in if | Map.null s' -> k'
| otherwise -> CaseDT i s' k'
16 changes: 14 additions & 2 deletions src/full/Agda/TypeChecking/DiscrimTree/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import Data.Set (Set)
import GHC.Generics (Generic)

import Agda.Syntax.Internal
import Agda.Syntax.Position

import Agda.Utils.Trie
import Agda.Utils.Impossible
import Agda.Utils.Null

Expand Down Expand Up @@ -57,7 +57,7 @@ data DiscrimTree a
{-# UNPACK #-} !Int -- ^ The variable to case on.
(Map Key (DiscrimTree a)) -- ^ The proper branches.
(DiscrimTree a) -- ^ A further tree, which should always be explored.
deriving (Generic, Eq)
deriving (Generic, Eq, Show)

{-
The extra continuation to CaseDT is used to represent instance tables
Expand Down Expand Up @@ -85,6 +85,12 @@ and the extra continuation would be empty.

instance NFData a => NFData (DiscrimTree a)

instance (KillRange a, Ord a) => KillRange (DiscrimTree a) where
killRange = \case
EmptyDT -> EmptyDT
DoneDT s -> killRangeN DoneDT s
CaseDT i k o -> killRangeN CaseDT i k o

instance Null (DiscrimTree a) where
empty = EmptyDT
null = \case
Expand All @@ -108,6 +114,12 @@ mergeDT (CaseDT i bs els) x = case x of
| j < i -> CaseDT j bs' (mergeDT els' (CaseDT i bs els))
| otherwise -> __IMPOSSIBLE__

instance Ord a => Semigroup (DiscrimTree a) where
(<>) = mergeDT

instance Ord a => Monoid (DiscrimTree a) where
mempty = EmptyDT

-- | Construct the case tree corresponding to only performing proper
-- matches on the given key. In this context, a "proper match" is any
-- 'Key' that is not 'FlexK'.
Expand Down

0 comments on commit 2aecf36

Please sign in to comment.