Skip to content

Commit

Permalink
refactor: Slightly tidy up generation of lambdas
Browse files Browse the repository at this point in the history
  • Loading branch information
bristermitten committed May 24, 2024
1 parent 1508631 commit 4654676
Showing 1 changed file with 33 additions and 27 deletions.
60 changes: 33 additions & 27 deletions src/Elara/Emit/Lambda.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import Polysemy hiding (transform)
import Polysemy.Error
import Polysemy.Log (Log)
import Polysemy.Log qualified as Log
import Print (showPretty)
import Print (debugPretty, showPretty)

import Data.Generics.Sum (AsAny (_As))
import Data.Traversable (for)
Expand Down Expand Up @@ -79,7 +79,12 @@ etaExpandN funcCall exprType thisClassName = do
(NE.zip paramTypes [0 ..])
createLambda (toList paramTypes) (generateFieldType $ functionTypeResult exprType) thisClassName body

-- createGeneratedLambda :: [JVMBinder] -> JVMExpr ->
elaraFuncDescriptor returnType baseParams = case baseParams of
[] -> ("Elara/Func0", "run", MethodDescriptor [] (TypeReturn (ObjectFieldType "java/lang/Object")), MethodDescriptor [] (TypeReturn returnType))
[t] -> ("Elara/Func", "run", MethodDescriptor [ObjectFieldType "java/lang/Object"] (TypeReturn (ObjectFieldType "java/lang/Object")), MethodDescriptor [t] (TypeReturn returnType))
[t1, t2] -> ("Elara/Func2", "run", MethodDescriptor (replicate 2 (ObjectFieldType "java/lang/Object")) (TypeReturn (ObjectFieldType "java/lang/Object")), MethodDescriptor [t1, t2] (TypeReturn returnType))
[t1, t2, t3] -> ("Elara/Func3", "run", MethodDescriptor (replicate 3 (ObjectFieldType "java/lang/Object")) (TypeReturn (ObjectFieldType "java/lang/Object")), MethodDescriptor [t1, t2, t3] (TypeReturn returnType))
other -> error $ "createLambda: " <> show other <> " parameters not supported"

{- | Creates the bytecode for a lambda expression
This involves a few steps:
Expand Down Expand Up @@ -131,13 +136,7 @@ createLambda baseParams returnType thisClassName body = do
Log.debug $ "Body: " <> showPretty body <> " -> " <> showPretty body'

createMethod thisClassName lambdaMethodDescriptor lambdaMethodName body'
let (functionalInterface, invoke, baseMethodDescriptor, methodDescriptor) =
case baseParams of
[] -> ("Elara/Func0", "run", MethodDescriptor [] (TypeReturn (ObjectFieldType "java/lang/Object")), MethodDescriptor [] (TypeReturn returnType))
[(_, t)] -> ("Elara/Func", "run", MethodDescriptor [ObjectFieldType "java/lang/Object"] (TypeReturn (ObjectFieldType "java/lang/Object")), MethodDescriptor [t] (TypeReturn returnType))
[(_, t1), (_, t2)] -> ("Elara/Func2", "run", MethodDescriptor (replicate 2 (ObjectFieldType "java/lang/Object")) (TypeReturn (ObjectFieldType "java/lang/Object")), MethodDescriptor [t1, t2] (TypeReturn returnType))
[(_, t1), (_, t2), (_, t3)] -> ("Elara/Func3", "run", MethodDescriptor (replicate 3 (ObjectFieldType "java/lang/Object")) (TypeReturn (ObjectFieldType "java/lang/Object")), MethodDescriptor [t1, t2, t3] (TypeReturn returnType))
other -> error $ "createLambda: " <> show other <> " parameters not supported"
let (functionalInterface, invoke, baseMethodDescriptor, methodDescriptor) = elaraFuncDescriptor returnType (snd <$> baseParams)

let inst =
InvokeDynamic
Expand Down Expand Up @@ -183,22 +182,29 @@ lambdaTypeName 2 = "Elara/Func2"
lambdaTypeName 3 = "Elara/Func3"
lambdaTypeName n = error $ "lambdaTypeName: " <> show n <> " not supported"

{- | Inspects a lambda expression to determine which local variables are captured, returning a generated name for each, and its corresponding type
For example, if we have \x -> local_1, we know that local_1 must be captured from an outer scope (as it is not an argument to the lambda)
-}
getCapturedParams :: Member UniqueGen r => [(Unique Text, FieldType)] -> JVMExpr -> Sem r [(Unique Text, FieldType)]
getCapturedParams params expr =
do
let len = fromIntegral $ length params
-- Get all locals_<n> where n >= length params
let locals =
expr
^.. cosmos
% _As @"Var"
% _As @"JVMLocal"
% filtered (\(a, _) -> a >= len)

let locals' = locals ^.. traversed % _2 % _Just
for locals' $ \t -> do
n <- makeUnique "local"
let t' = case t of
JVMLFieldType ft -> ft
JVMLType t -> generateFieldType t
pure (n, t')
getCapturedParams params expr = do
debugPretty ("params" :: Text, params, expr)
let len = fromIntegral $ length params
-- Get all locals_<n> where n >= length params
let locals =
expr
^.. cosmos
% _As @"Var"
% _As @"JVMLocal"
% filtered (\(a, _) -> a >= len)

let namedLocals = expr ^.. cosmos % _As @"Var" % _As @"Normal"

debugPretty (locals, namedLocals)

let locals' = locals ^.. traversed % _2 % _Just
for locals' $ \t -> do
n <- makeUnique "local"
let t' = case t of
JVMLFieldType ft -> ft
JVMLType t -> generateFieldType t
pure (n, t')

0 comments on commit 4654676

Please sign in to comment.