Skip to content

Commit

Permalink
Refactor type checker (#38)
Browse files Browse the repository at this point in the history
- Fix generalization for the top level declarations
- Add regression test to check the fix
- Clean up the code
  • Loading branch information
AzimMuradov committed Jun 23, 2024
1 parent 3e2df35 commit 25ee47d
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 185 deletions.
5 changes: 3 additions & 2 deletions cabal.project.freeze
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ constraints: any.Cabal ==3.8.1.0,
distributive +semigroups +tagged,
any.dlist ==1.0,
dlist -werror,
any.exceptions ==0.10.5,
any.exceptions ==0.10.8,
exceptions +transformers-0-4,
any.extra ==1.7.16,
any.file-embed ==0.0.16.0,
any.filepath ==1.4.2.2,
Expand All @@ -54,7 +55,7 @@ constraints: any.Cabal ==3.8.1.0,
any.megaparsec ==9.6.1,
megaparsec -dev,
any.mmorph ==1.2.0,
any.mtl ==2.2.2,
any.mtl ==2.2.2 || ==2.3.1,
any.optparse-applicative ==0.18.1.0,
optparse-applicative +process,
any.os-string ==2.0.3,
Expand Down
181 changes: 87 additions & 94 deletions lib/TypeChecker/HindleyMilner.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,25 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# OPTIONS_GHC -Wno-orphans #-}

module TypeChecker.HindleyMilner where
module TypeChecker.HindleyMilner
( Infer,
TypeError (..),
UType,
Polytype,
applyBindings,
generalize,
toPolytype,
toUType,
withBinding,
fresh,
Poly (..),
UTerm (UTVar, UTUnit, UTBool, UTInt, UTFun),
(=:=),
lookup,
TypeF (..),
mkVarName,
)
where

import Control.Monad.Except
import Control.Monad.Reader
Expand All @@ -27,86 +45,78 @@ import Data.Set (Set, (\\))
import qualified Data.Set as S
import Data.Text (pack)
import GHC.Generics (Generic1)
import Trees.Common (Identifier, Type (..))
import qualified Trees.Common as L -- Lang
import Prelude hiding (lookup)

data HType a
= TyVarF Identifier
| TyUnitF
| TyBoolF
| TyIntF
| TyFunF a a
deriving (Show, Eq, Functor, Foldable, Traversable, Generic1, Unifiable)

type TypeF = Fix HType

type UType = UTerm HType IntVar
-- * Type

data Poly t = Forall [Identifier] t
deriving (Eq, Show, Functor)
type Type = Fix TypeF

type Polytype = Poly TypeF
data TypeF a
= TVarF L.Identifier
| TUnitF
| TBoolF
| TIntF
| TFunF a a
deriving (Show, Eq, Functor, Foldable, Traversable, Generic1, Unifiable)

type UPolytype = Poly UType
-- * UType

-- TypeF
type UType = UTerm TypeF IntVar

pattern TyVar :: Identifier -> TypeF
pattern TyVar v = Fix (TyVarF v)
pattern UTVar :: L.Identifier -> UType
pattern UTVar var = UTerm (TVarF var)

pattern TyUnit :: TypeF
pattern TyUnit = Fix TyUnitF
pattern UTUnit :: UType
pattern UTUnit = UTerm TUnitF

pattern TyBool :: TypeF
pattern TyBool = Fix TyBoolF
pattern UTBool :: UType
pattern UTBool = UTerm TBoolF

pattern TyInt :: TypeF
pattern TyInt = Fix TyIntF
pattern UTInt :: UType
pattern UTInt = UTerm TIntF

pattern TyFun :: TypeF -> TypeF -> TypeF
pattern TyFun t1 t2 = Fix (TyFunF t1 t2)
pattern UTFun :: UType -> UType -> UType
pattern UTFun funT argT = UTerm (TFunF funT argT)

-- UType
-- * Polytype

pattern UTyVar :: Identifier -> UType
pattern UTyVar v = UTerm (TyVarF v)
data Poly t = Forall [L.Identifier] t
deriving (Eq, Show, Functor)

pattern UTyUnit :: UType
pattern UTyUnit = UTerm TyUnitF
type Polytype = Poly Type

pattern UTyBool :: UType
pattern UTyBool = UTerm TyBoolF
type UPolytype = Poly UType

pattern UTyInt :: UType
pattern UTyInt = UTerm TyIntF
-- * Converters

pattern UTyFun :: UType -> UType -> UType
pattern UTyFun t1 t2 = UTerm (TyFunF t1 t2)
toUType :: L.Type -> UType
toUType = \case
L.TUnit -> UTUnit
L.TBool -> UTBool
L.TInt -> UTInt
L.TFun funT argT -> UTFun (toUType funT) (toUType argT)

toTypeF :: Type -> TypeF
toTypeF = \case
TUnit -> TyUnit
TBool -> TyBool
TInt -> TyInt
TFun t1 t2 -> TyFun (toTypeF t1) (toTypeF t2)
toPolytype :: UPolytype -> Polytype
toPolytype = fmap (fromJust . freeze)

fromTypeToUType :: Type -> UType
fromTypeToUType = \case
TUnit -> UTyUnit
TBool -> UTyBool
TInt -> UTyInt
TFun t1 t2 -> UTyFun (fromTypeToUType t1) (fromTypeToUType t2)
-- * Infer

type Infer = ReaderT Ctx (ExceptT TypeError (IntBindingT HType Identity))
type Infer = ReaderT Ctx (ExceptT TypeError (IntBindingT TypeF Identity))

type Ctx = Map Identifier UPolytype
type Ctx = Map L.Identifier UPolytype

lookup :: LookUpType -> Infer UType
lookup (Var v) = do
ctx <- ask
maybe (throwError $ UnboundVar v) instantiate (M.lookup v ctx)
lookup :: L.Identifier -> Infer UType
lookup var = do
varUPT <- asks $ M.lookup var
maybe (throwError $ UnboundVar var) instantiate varUPT
where
instantiate :: UPolytype -> Infer UType
instantiate (Forall xs uty) = do
xs' <- mapM (const fresh) xs
return $ substU (M.fromList (zip (map Left xs) xs')) uty

withBinding :: (MonadReader Ctx m) => Identifier -> UPolytype -> m a -> m a
withBinding :: (MonadReader Ctx m) => L.Identifier -> UPolytype -> m a -> m a
withBinding x ty = local (M.insert x ty)

ucata :: (Functor t) => (v -> a) -> (t a -> a) -> UTerm t v -> a
Expand All @@ -115,16 +125,18 @@ ucata f g (UTerm t) = g (fmap (ucata f g) t)

deriving instance Ord IntVar

-- * FreeVars

class FreeVars a where
freeVars :: a -> Infer (Set (Either Identifier IntVar))
freeVars :: a -> Infer (Set (Either L.Identifier IntVar))

instance FreeVars UType where
freeVars ut = do
fuvs <- fmap (S.fromList . map Right) . lift . lift $ getFreeVars ut
let ftvs =
ucata
(const S.empty)
(\case TyVarF x -> S.singleton (Left x); f -> fold f)
(\case TVarF x -> S.singleton (Left x); f -> fold f)
ut
return $ fuvs `S.union` ftvs

Expand All @@ -134,67 +146,48 @@ instance FreeVars UPolytype where
instance FreeVars Ctx where
freeVars = fmap S.unions . mapM freeVars . M.elems

newtype LookUpType = Var Identifier
fresh :: Infer UType
fresh = UVar <$> lift (lift freeVar)

-- * Errors

data TypeError where
Unreachable :: TypeError
UnboundVar :: Identifier -> TypeError
UnboundVar :: L.Identifier -> TypeError
Infinite :: IntVar -> UType -> TypeError
ImpossibleBinOpApplication :: UType -> UType -> TypeError
ImpossibleUnOpApplication :: UType -> TypeError
Mismatch :: HType UType -> HType UType -> TypeError
Mismatch :: TypeF UType -> TypeF UType -> TypeError
deriving (Show)

instance Fallible HType IntVar TypeError where
instance Fallible TypeF IntVar TypeError where
occursFailure = Infinite
mismatchFailure = Mismatch

fresh :: Infer UType
fresh = UVar <$> lift (lift freeVar)

(=:=) :: UType -> UType -> Infer UType
s =:= t = lift $ s U.=:= t

applyBindings :: UType -> Infer UType
applyBindings = lift . U.applyBindings

instantiate :: UPolytype -> Infer UType
instantiate (Forall xs uty) = do
xs' <- mapM (const fresh) xs
return $ substU (M.fromList (zip (map Left xs) xs')) uty

substU :: Map (Either Identifier IntVar) UType -> UType -> UType
substU :: Map (Either L.Identifier IntVar) UType -> UType -> UType
substU m =
ucata
(\v -> fromMaybe (UVar v) (M.lookup (Right v) m))
( \case
TyVarF v -> fromMaybe (UTyVar v) (M.lookup (Left v) m)
TVarF v -> fromMaybe (UTVar v) (M.lookup (Left v) m)
f -> UTerm f
)

skolemize :: UPolytype -> Infer UType
skolemize (Forall xs uty) = do
xs' <- mapM (const fresh) xs
return $ substU (M.fromList (zip (map Left xs) (map toSkolem xs'))) uty
where
toSkolem (UVar v) = UTyVar (mkVarName "s" v)
toSkolem _ = undefined -- We can't reach another situation, because we previously give `fresh` variable

mkVarName :: String -> IntVar -> Identifier
mkVarName nm (IntVar v) = pack (nm ++ show (v + (maxBound :: Int) + 1))
mkVarName :: String -> IntVar -> L.Identifier
mkVarName nm (IntVar v) = pack (nm <> show (v + (maxBound :: Int) + 1))

generalize :: UType -> Infer UPolytype
generalize uty = do
uty' <- applyBindings uty
ctx <- ask
tmfvs <- freeVars uty'
ctxfvs <- freeVars ctx
let fvs = S.toList $ tmfvs \\ ctxfvs
tmFreeVars <- freeVars uty'
ctxFreeVars <- freeVars ctx
let fvs = S.toList $ tmFreeVars \\ ctxFreeVars
xs = map (either id (mkVarName "a")) fvs
return $ Forall xs (substU (M.fromList (zip fvs (map UTyVar xs))) uty')

toUPolytype :: Polytype -> UPolytype
toUPolytype = fmap unfreeze

fromUPolytype :: UPolytype -> Polytype
fromUPolytype = fmap (fromJust . freeze)
return $ Forall xs (substU (M.fromList (zip fvs (map UTVar xs))) uty')
12 changes: 6 additions & 6 deletions lib/TypeChecker/PrettyPrinter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ type Prec = Int
instance (Pretty (t (Fix t))) => Pretty (Fix t) where
prettyPrec p = prettyPrec p . unFix

instance (Pretty t) => Pretty (HType t) where
prettyPrec _ (TyVarF x) = unpack x
prettyPrec _ TyUnitF = "unit"
prettyPrec _ TyBoolF = "bool"
prettyPrec _ TyIntF = "int"
prettyPrec p (TyFunF ty1 ty2) =
instance (Pretty t) => Pretty (TypeF t) where
prettyPrec _ (TVarF x) = unpack x
prettyPrec _ TUnitF = "unit"
prettyPrec _ TBoolF = "bool"
prettyPrec _ TIntF = "int"
prettyPrec p (TFunF ty1 ty2) =
mparens (p > 0) $ prettyPrec 1 ty1 ++ " -> " ++ prettyPrec 0 ty2

instance (Pretty (t (UTerm t v)), Pretty v) => Pretty (UTerm t v) where
Expand Down
Loading

0 comments on commit 25ee47d

Please sign in to comment.