Skip to content

Commit

Permalink
refactor: Begin work refactoring lambda generation - keep track of pa…
Browse files Browse the repository at this point in the history
…rameter names _eerywhere_
  • Loading branch information
bristermitten committed May 24, 2024
1 parent 4654676 commit 9f8ed70
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 29 deletions.
1 change: 1 addition & 0 deletions elara.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ library
Elara.Emit.Expr
Elara.Emit.Lambda
Elara.Emit.Method
Elara.Emit.Method.Descriptor
Elara.Emit.Monad
Elara.Emit.Operator
Elara.Emit.Params
Expand Down
6 changes: 3 additions & 3 deletions src/Elara/Emit.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import Elara.Emit.Error
import Elara.Emit.Expr
import Elara.Emit.Lambda (etaExpandN)
import Elara.Emit.Method (createMethod, createMethodWith, createMethodWithCodeBuilder, etaExpandNIntoMethod)
import Elara.Emit.Method.Descriptor
import Elara.Emit.Monad
import Elara.Emit.Operator (translateOperatorName)
import Elara.Emit.Params
Expand Down Expand Up @@ -123,18 +124,18 @@ addDeclaration declBody = case declBody of
let e' = transformTopLevelLambdas e
addStaticFieldInitialiser field e'
else do
descriptor <- generateNamedMethodDescriptor type' e
-- Whenever we have a function declaration we do 2 things
-- Turn it into a hidden method _name, doing the actual logic
-- Create a getter method name, which just returns _name wrapped into a Func
case stripForAll type' of
FuncTy{} -> do
let descriptor = generateMethodDescriptor type'
Log.debug $ "Creating method " <> showPretty declName <> " with signature " <> showPretty descriptor <> "..."
thisName <- ask @QualifiedClassName
y <- transformTopLevelJVMLambdas <$> etaExpandNIntoMethod e type' thisName
Log.debug $ "Transformed lambda expression: " <> showPretty y
createMethod thisName descriptor ("_" <> declName) y
let getterDescriptor = MethodDescriptor [] (TypeReturn (ObjectFieldType "Elara.Func"))
let getterDescriptor = NamedMethodDescriptor [] (TypeReturn (ObjectFieldType "Elara.Func"))
Log.debug $ "Creating getter method " <> showPretty declName <> " with signature " <> showPretty getterDescriptor <> "..."
createMethodWithCodeBuilder thisName getterDescriptor [MPublic, MStatic] declName $ do
Log.debug $ "Getting static field " <> showPretty declName <> "..."
Expand All @@ -143,7 +144,6 @@ addDeclaration declBody = case declBody of
Log.debug $ "Returning static field " <> showPretty declName <> "..."
Log.debug "=="
_ -> do
let descriptor = generateMethodDescriptor type'
Log.debug $ "Creating method " <> showPretty declName <> " with signature " <> showPretty descriptor <> "..."
let y = transformTopLevelLambdas e
Log.debug $ "Transformed lambda expression: " <> showPretty y
Expand Down
24 changes: 15 additions & 9 deletions src/Elara/Emit/ADT.hs
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@ module Elara.Emit.ADT where

import Elara.Core.Module (CoreTypeDecl (..), CoreTypeDeclBody (..))

import Elara.AST.Name (unqualified)
import Data.Traversable (for)
import Elara.AST.Name (NameLike (nameText), unqualified)
import Elara.Core (DataCon (..), functionTypeArgs)
import Elara.Data.Unique (makeUnique)
import Elara.Emit.Lambda (lambdaTypeName)
import Elara.Emit.Method (createMethodWithCodeBuilder)
import Elara.Emit.Method.Descriptor
import Elara.Emit.Monad (InnerEmit, addClass, addInnerClass)
import Elara.Emit.Params (GenParams)
import Elara.Emit.State (MethodCreationState (maxLocalVariables), findLocalVariable)
Expand All @@ -87,26 +89,30 @@ generateADTClasses (CoreTypeDecl name kind tvs (CoreDataDecl ctors)) = do
addClass typeClassName $ do
addAccessFlag Public
addAccessFlag Abstract
let conArities = ctors <&> (\(DataCon _ ty _) -> length $ functionTypeArgs ty)
let matchSig =
MethodDescriptor (ObjectFieldType . lambdaTypeName <$> conArities) (TypeReturn $ ObjectFieldType "java/lang/Object")
let conArities = ctors <&> (\(DataCon conName ty _) -> (conName ^. unqualified, length $ functionTypeArgs ty))
matchSig <- do
conArities <- for conArities (\(conName, arity) -> (,ObjectFieldType $ lambdaTypeName arity) <$> makeUnique conName)
pure $ NamedMethodDescriptor conArities (TypeReturn $ ObjectFieldType "java/lang/Object")

-- add boring empty constructor
createMethodWithCodeBuilder typeClassName (MethodDescriptor [] VoidReturn) [MProtected] "<init>" $ do
createMethodWithCodeBuilder typeClassName (NamedMethodDescriptor [] VoidReturn) [MProtected] "<init>" $ do
emit $ ALoad 0
emit $ InvokeSpecial (ClassInfoType "java/lang/Object") "<init>" (MethodDescriptor [] VoidReturn)

addMethod $
ClassFileMethod
[MPublic, MAbstract]
"match"
matchSig
(toMethodDescriptor matchSig)
mempty
for_ (zip ctors [1 ..]) $ \(DataCon ctorName ctorType conTy, i) -> do
let innerConClassName = createQualifiedInnerClassName (ctorName ^. unqualified) typeClassName
let fields = functionTypeArgs ctorType
-- Create static factory method
createMethodWithCodeBuilder typeClassName (MethodDescriptor (generateFieldType <$> fields) (TypeReturn $ ObjectFieldType typeClassName)) [MPublic, MStatic] ("_" <> ctorName ^. unqualified) $ do
fields' <- for fields $ \field -> do
field' <- makeUnique "param"
pure (field', generateFieldType field)
createMethodWithCodeBuilder typeClassName (NamedMethodDescriptor fields' (TypeReturn $ ObjectFieldType typeClassName)) [MPublic, MStatic] ("_" <> ctorName ^. unqualified) $ do
emit $ New (ClassInfoType innerConClassName)
emit Dup
for_ (zip fields [0 ..]) $ \(_, i) -> do
Expand All @@ -124,7 +130,7 @@ generateADTClasses (CoreTypeDecl name kind tvs (CoreDataDecl ctors)) = do
addField $ ClassFileField [FPrivate, FFinal] ("field" <> show i) (generateFieldType field) []

thisName <- getName
createMethodWithCodeBuilder thisName (MethodDescriptor (generateFieldType <$> fields) VoidReturn) [MPublic] "<init>" $ do
createMethodWithCodeBuilder thisName (NamedMethodDescriptor fields' VoidReturn) [MPublic] "<init>" $ do
-- call super constructor
emit $ ALoad 0
emit $ InvokeSpecial (ClassInfoType typeClassName) "<init>" (MethodDescriptor [] VoidReturn)
Expand All @@ -134,7 +140,7 @@ generateADTClasses (CoreTypeDecl name kind tvs (CoreDataDecl ctors)) = do
emit $ PutField (ClassInfoType thisName) ("field" <> show i) (generateFieldType field)

-- generate toString
createMethodWithCodeBuilder thisName (MethodDescriptor [] (TypeReturn $ ObjectFieldType "java/lang/String")) [MPublic] "toString" $ do
createMethodWithCodeBuilder thisName (NamedMethodDescriptor [] (TypeReturn $ ObjectFieldType "java/lang/String")) [MPublic] "toString" $ do
emit $ New (ClassInfoType "java/lang/StringBuilder")
emit Dup
emit $ LDC (LDCString $ ctorName ^. unqualified)
Expand Down
3 changes: 2 additions & 1 deletion src/Elara/Emit/Expr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ generateCaseInstructions scrutinee (Just bind) alts = do
JVMLocal i t -> error "Not a local variable"
Normal (Id (Local' v) t _) -> (v, generateFieldType t)
other -> error $ "Not a local variable: " <> showPretty other
inst <- createLambda binders' returnType cName altBody
x <- get @MethodCreationState
(s, inst) <- generateLambda x binders' returnType cName altBody
emit' inst
-- Emit the lambdas in order
generateInstructions scrutinee
Expand Down
29 changes: 27 additions & 2 deletions src/Elara/Emit/Lambda.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,14 @@ import Print (debugPretty, showPretty)
import Data.Generics.Sum (AsAny (_As))
import Data.Traversable (for)
import Elara.Core qualified as Core
import {-# SOURCE #-} Elara.Emit.Expr (generateInstructions)
import Elara.Emit.Method.Descriptor (NamedMethodDescriptor (..), toMethodDescriptor)
import Elara.Emit.Params
import Elara.Emit.State (MethodCreationState, createMethodCreationStateOf, initialMethodCreationState)
import JVM.Data.Abstract.Builder.Code (runCodeBuilder)
import Optics (filtered)
import Polysemy.Reader
import Polysemy.State

-- | etaExpand takes a function @f@, its type @a -> b@, and generates a lambda expression @\(x : a) -> f x@
etaExpand ::
Expand Down Expand Up @@ -86,6 +91,26 @@ elaraFuncDescriptor returnType baseParams = case baseParams of
[t1, t2, t3] -> ("Elara/Func3", "run", MethodDescriptor (replicate 3 (ObjectFieldType "java/lang/Object")) (TypeReturn (ObjectFieldType "java/lang/Object")), MethodDescriptor [t1, t2, t3] (TypeReturn returnType))
other -> error $ "createLambda: " <> show other <> " parameters not supported"

generateLambda ::
( Member (Error EmitError) r
, Member ClassBuilder r
, Member UniqueGen r
, Member (Reader GenParams) r
, Member Log r
) =>
MethodCreationState ->
[(Unique Text, FieldType)] ->
FieldType ->
QualifiedClassName ->
JVMExpr ->
Sem r (MethodCreationState, [Instruction])
generateLambda oldState explicitParams returnType thisClassName body = do
(state, (_, a, b)) <- runState (createMethodCreationStateOf oldState (fst <$> explicitParams)) $ do
debugPretty ("generateLambda" :: Text, body)
runCodeBuilder $ generateInstructions body

error (showPretty (state, body))

{- | Creates the bytecode for a lambda expression
This involves a few steps:
1. Create a method that implements the lambda's body
Expand All @@ -110,7 +135,7 @@ createLambda baseParams returnType thisClassName body = do
let lambdaMethodName = "lambda$" <> show lamSuffix
Log.debug $ "Creating lambda " <> showPretty lambdaMethodName <> " which captures: " <> showPretty captureParams

let lambdaMethodDescriptor = MethodDescriptor (toList $ snd <$> params) (TypeReturn returnType)
let lambdaMethodDescriptor = NamedMethodDescriptor (toList params) (TypeReturn returnType)
Log.debug $
"Creating lambda method "
<> showPretty lambdaMethodName
Expand Down Expand Up @@ -163,7 +188,7 @@ createLambda baseParams returnType thisClassName body = do
( MethodRef
(ClassInfoType thisClassName)
lambdaMethodName
lambdaMethodDescriptor
(toMethodDescriptor lambdaMethodDescriptor)
)
)
, BMMethodArg methodDescriptor
Expand Down
20 changes: 11 additions & 9 deletions src/Elara/Emit/Method.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ import JVM.Data.Abstract.Instruction
import Data.List.NonEmpty qualified as NE
import Elara.Core (CoreExpr, Expr (..), Type, functionTypeArgs)
import Elara.Core.Analysis (declaredLambdaArity, estimateArity)
import Elara.Data.Pretty (Pretty)
import Elara.Data.Unique
import Elara.Emit.Error
import Elara.Emit.Method.Descriptor
import Elara.Emit.Params
import JVM.Data.Abstract.Name
import JVM.Data.Abstract.Type (fieldTypeToClassInfoType)
import JVM.Data.Abstract.Type (FieldType, fieldTypeToClassInfoType)
import Polysemy
import Polysemy.Error
import Polysemy.Log (Log)
Expand All @@ -42,11 +44,11 @@ createMethod ::
, Member (Error EmitError) r
) =>
QualifiedClassName ->
MethodDescriptor ->
NamedMethodDescriptor ->
Text ->
JVMExpr ->
Sem r ()
createMethod thisClassName descriptor@(MethodDescriptor args _) name body = do
createMethod thisClassName descriptor@(NamedMethodDescriptor args _) name body = do
Log.debug $
"Creating method "
<> showPretty thisClassName
Expand All @@ -56,31 +58,31 @@ createMethod thisClassName descriptor@(MethodDescriptor args _) name body = do
<> showPretty descriptor
<> " and body "
<> showPretty body
let initialState = createMethodCreationState (length args) thisClassName
let initialState = createMethodCreationState (fst <$> args) thisClassName
((mcState, _), codeAttrs, instructions) <-
runCodeBuilder $
runState initialState $
generateInstructions body
createMethodWith descriptor [MPublic, MStatic] name codeAttrs mcState instructions
createMethodWith (toMethodDescriptor descriptor) [MPublic, MStatic] name codeAttrs mcState instructions

createMethodWithCodeBuilder ::
( Member ClassBuilder r
, Member Log r
, Member (Reader GenParams) r
) =>
QualifiedClassName ->
MethodDescriptor ->
NamedMethodDescriptor ->
_ ->
Text ->
Sem (CodeBuilder : r) () ->
Sem r ()
createMethodWithCodeBuilder thisClassName descriptor@(MethodDescriptor args _) methodAttrs name codeBuilder = do
createMethodWithCodeBuilder thisClassName descriptor@(NamedMethodDescriptor args _) methodAttrs name codeBuilder = do
Log.debug $ "Creating method " <> showPretty thisClassName <> "." <> showPretty name <> " with descriptor " <> showPretty descriptor
let initialState = createMethodCreationState (length args) thisClassName
let initialState = createMethodCreationState (fst <$> args) thisClassName
(_, codeAttrs, instructions) <-
subsume_ $
runCodeBuilder codeBuilder
createMethodWith descriptor methodAttrs name codeAttrs initialState instructions
createMethodWith (toMethodDescriptor descriptor) methodAttrs name codeAttrs initialState instructions

createMethodWith ::
(Member ClassBuilder r, Member (Reader GenParams) r) =>
Expand Down
12 changes: 12 additions & 0 deletions src/Elara/Emit/Method/Descriptor.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module Elara.Emit.Method.Descriptor where

import Elara.Data.Pretty
import Elara.Data.Unique
import JVM.Data.Abstract.Descriptor
import JVM.Data.Abstract.Type

data NamedMethodDescriptor = NamedMethodDescriptor [(Unique Text, FieldType)] ReturnDescriptor
deriving (Show)
instance Pretty NamedMethodDescriptor
toMethodDescriptor :: NamedMethodDescriptor -> MethodDescriptor
toMethodDescriptor (NamedMethodDescriptor args ret) = MethodDescriptor (snd <$> args) ret
16 changes: 12 additions & 4 deletions src/Elara/Emit/State.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import Data.Map qualified as Map
import Elara.Data.Pretty
import Elara.Data.Unique (Unique)
import JVM.Data.Abstract.Name
import JVM.Data.Abstract.Type (FieldType)
import JVM.Data.Raw.Types
import Polysemy (Member, Sem)
import Polysemy.State
Expand Down Expand Up @@ -41,11 +42,18 @@ instance Pretty LVKey where
initialMethodCreationState :: QualifiedClassName -> MethodCreationState
initialMethodCreationState = MethodCreationState Map.empty 0

createMethodCreationState :: Int -> QualifiedClassName -> MethodCreationState
createMethodCreationState argsCount =
-- | creates a "nested" method creation state, taking the existing state and appending the given arguments
createMethodCreationStateOf :: MethodCreationState -> [Unique Text] -> MethodCreationState
createMethodCreationStateOf copy args =
let newLvs = copy.localVariables <> Map.fromList (zip (KnownName <$> args) [maxLocalVariables copy ..])
newMax = maxLocalVariables copy + fromIntegral (length args)
in copy{localVariables = newLvs, maxLocalVariables = newMax}

createMethodCreationState :: [Unique Text] -> QualifiedClassName -> MethodCreationState
createMethodCreationState args =
MethodCreationState
(Map.fromList $ zip (UnknownName <$> [0 .. argsCount - 1]) [0 ..])
(fromIntegral argsCount)
(Map.fromList $ zip (KnownName <$> args) [0 ..])
(fromIntegral $ length args)

findLocalVariable :: Member (State MethodCreationState) r => Unique Text -> Sem r U1
findLocalVariable v = do
Expand Down
25 changes: 24 additions & 1 deletion src/Elara/Emit/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@ module Elara.Emit.Utils where

import Data.List.NonEmpty ((<|))
import Elara.AST.Name
import Elara.AST.VarRef
import Elara.Core
import Elara.Core.Analysis (findTyCon)
import Elara.Data.Unique
import Elara.Emit.Method.Descriptor
import Elara.Prim.Core
import JVM.Data.Abstract.Descriptor
import JVM.Data.Abstract.Name
import JVM.Data.Abstract.Type
import Polysemy

createModuleName :: ModuleName -> QualifiedClassName
createModuleName (ModuleName name) = QualifiedClassName (PackageName $ init name) (ClassName $ last name)
Expand All @@ -27,6 +31,26 @@ generateMethodDescriptor x = case generateMethodDescriptor' x of
Just y -> y
Nothing -> error $ "generateMethodDescriptor: " <> show x

generateNamedMethodDescriptor :: (HasCallStack, Member UniqueGen r) => Type -> CoreExpr -> Sem r NamedMethodDescriptor
generateNamedMethodDescriptor t e = do
-- collect as many known names as possible, looking at lambda params
-- note due to eta reduction this might not have the same length as the type signature
let collectLambdaBinders :: CoreExpr -> [(Unique Text, Type)]
collectLambdaBinders (Lam (Id (Local (Identity t')) ty _) e') = (t', ty) : collectLambdaBinders e'
collectLambdaBinders _ = []
let lambdaBinders :: [(Unique Text, FieldType)] = second generateFieldType <$> collectLambdaBinders e
let (MethodDescriptor params ret) = generateMethodDescriptor t
actualBinders <-
if length params == length lambdaBinders
then pure lambdaBinders
else do
-- pad out the lambda binders with fresh names, we'll just call it "paramN"

freshNames <- traverse (\p -> (,p) <$> makeUnique "param") (drop (length lambdaBinders) params)
pure $ lambdaBinders <> freshNames

pure $ NamedMethodDescriptor actualBinders ret

{- |
Attempts to generate a method descriptor for a given type, returning `Nothing` if the type should not compile to a method.
Expand All @@ -53,7 +77,6 @@ generateMethodDescriptor' t@(TyVarTy{}) = Just $ MethodDescriptor [] (TypeReturn
-- Should probably refactor ConTy to take a DataCon instead
generateMethodDescriptor' (ConTy (TyCon dc _)) = Just $ MethodDescriptor [] (TypeReturn $ ObjectFieldType $ createQualifiedClassName dc)
generateMethodDescriptor' (AppTy t _) = generateMethodDescriptor' t -- type erasure
generateMethodDescriptor' _ = Nothing

-- | Returns either the JVM type of the argument, or the JVM type of the return type, if it would compile to a method
generateReturnType :: HasCallStack => Type -> ReturnDescriptor
Expand Down

0 comments on commit 9f8ed70

Please sign in to comment.