Skip to content

Commit

Permalink
Added support of match expressions with bindings.
Browse files Browse the repository at this point in the history
  • Loading branch information
simpletonDL committed Feb 9, 2021
1 parent 36a7d89 commit 0098feb
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 47 deletions.
6 changes: 5 additions & 1 deletion src/Brahma.FSharp.OpenCL.AST/Types.fs
Expand Up @@ -131,10 +131,14 @@ type DiscriminatedUnionType<'lang>(name: string, fields: List<int * Field<'lang>
member this.Tag = this.Fields.[0]
member this.Data = this.Fields.[1]

member this.GetCase (tag: int) =
member this.GetCaseByTag (tag: int) =
List.tryFind (fun (id, _) -> id = tag) fields
|> Option.map snd

member this.GetCaseByName (case: string) =
List.tryFind (fun (_, f) -> f.Name = case) fields
|> Option.map snd

type TupleType<'lang>(baseStruct: StructType<'lang>, number:int)=
inherit Type<'lang>()

Expand Down
136 changes: 90 additions & 46 deletions src/Brahma.FSharp.OpenCL.Translator/Body.fs
Expand Up @@ -15,6 +15,7 @@

module Brahma.FSharp.OpenCL.Translator.Body

open System.Reflection
open Microsoft.FSharp.Quotations
open Brahma.FSharp.OpenCL.AST
open Microsoft.FSharp.Collections
Expand Down Expand Up @@ -155,46 +156,62 @@ and private itemHelper exprs hostVar tContext =
match exprs with
| hd::_ -> TranslateAsExpr hd tContext
| [] -> failwith "Array index missed!"
let hVar =
match hostVar with
| Some(v,_) -> v
| None -> failwith "Host var missed!"
idx,tContext,hVar

and private transletaPropGet exprOpt (propInfo:System.Reflection.PropertyInfo) exprs targetContext =
let hostVar = exprOpt |> Option.map(fun e -> TranslateAsExpr e targetContext)
match propInfo.Name.ToLowerInvariant() with
| "globalid0i" | "globalid0" -> new FunCall<_>("get_global_id",[Const(PrimitiveType<_>(Int),"0")]) :> Expression<_>, targetContext
| "globalid1i" | "globalid1" -> new FunCall<_>("get_global_id",[Const(PrimitiveType<_>(Int),"1")]) :> Expression<_>, targetContext
| "globalid2i" | "globalid2" -> new FunCall<_>("get_global_id",[Const(PrimitiveType<_>(Int),"2")]) :> Expression<_>, targetContext
| "localid0" -> new FunCall<_>("get_local_id",[Const(PrimitiveType<_>(Int),"0")]) :> Expression<_>, targetContext
| "localid1" -> new FunCall<_>("get_local_id",[Const(PrimitiveType<_>(Int),"1")]) :> Expression<_>, targetContext
| "localid2" -> new FunCall<_>("get_local_id",[Const(PrimitiveType<_>(Int),"2")]) :> Expression<_>, targetContext
let (hVar, _) = hostVar

idx, tContext, hVar

and private translateSpecificPropGet expr propName exprs targetContext =
// TODO: Refactoring: Safe pattern matching by expr type.

let hostVar = TranslateAsExpr expr targetContext
match propName with
| "globalid0i" | "globalid0" -> FunCall<_>("get_global_id",[Const(PrimitiveType<_>(Int),"0")]) :> Expression<_>, targetContext
| "globalid1i" | "globalid1" -> FunCall<_>("get_global_id",[Const(PrimitiveType<_>(Int),"1")]) :> Expression<_>, targetContext
| "globalid2i" | "globalid2" -> FunCall<_>("get_global_id",[Const(PrimitiveType<_>(Int),"2")]) :> Expression<_>, targetContext
| "localid0" -> FunCall<_>("get_local_id",[Const(PrimitiveType<_>(Int),"0")]) :> Expression<_>, targetContext
| "localid1" -> FunCall<_>("get_local_id",[Const(PrimitiveType<_>(Int),"1")]) :> Expression<_>, targetContext
| "localid2" -> FunCall<_>("get_local_id",[Const(PrimitiveType<_>(Int),"2")]) :> Expression<_>, targetContext
| "item" ->
let idx,tContext,hVar = itemHelper exprs hostVar targetContext
new Item<_>(hVar,idx) :> Expression<_>, tContext
| x ->
match exprOpt with
| Some expr ->
let r,tContext = translateFieldGet expr propInfo.Name targetContext
r :> Expression<_>, tContext
| None -> failwithf "Unsupported property in kernel: %A" x

and private transletaPropSet exprOpt (propInfo:System.Reflection.PropertyInfo) exprs newVal targetContext =
let hostVar = exprOpt |> Option.map(fun e -> TranslateAsExpr e targetContext)
let newVal,tContext = TranslateAsExpr newVal (match hostVar with Some(v,c) -> c | None -> targetContext)
match propInfo.Name.ToLowerInvariant() with
| "item" ->
let idx,tContext,hVar = itemHelper exprs hostVar tContext
let item = new Item<_>(hVar,idx)
new Assignment<_>(new Property<_>(PropertyType.Item(item)),newVal) :> Statement<_>
, tContext
| x ->
match exprOpt with
| Some e ->
let r,tContext = translateFieldSet e propInfo.Name exprs.[0] targetContext
Item<_>(hVar,idx) :> Expression<_>, tContext
| _ -> failwithf "Unsupported property in kernel: %A" propName

and private translatePropGet (exprOpt: Expr Option) (propInfo: PropertyInfo) exprs (targetContext: TargetContext<_, _>) =
let propName = propInfo.Name.ToLowerInvariant()

match exprOpt with
| Some expr ->
let exprType = expr.Type
if targetContext.UserDefinedTypes.Contains exprType
then
let exprTypeName = expr.Type.Name.ToLowerInvariant()
if targetContext.UserDefinedStructsOpenCLDeclaration.ContainsKey exprTypeName
then
translateStructFieldGet expr propInfo.Name targetContext
else
translateUnionFieldGet expr propInfo targetContext
else
translateSpecificPropGet expr propName exprs targetContext
| None -> failwithf "Unsupported static property get in kernel: %A" propName

and private translatePropSet exprOpt (propInfo:System.Reflection.PropertyInfo) exprs newVal targetContext =
// Todo: Safe pattern matching (item) by expr type
let propName = propInfo.Name.ToLowerInvariant()

match exprOpt with
| Some expr ->
let hostVar = TranslateAsExpr expr targetContext
let newVal,tContext = TranslateAsExpr newVal (match hostVar with (v,c) -> c)
match propInfo.Name.ToLowerInvariant() with
| "item" ->
let idx,tContext,hVar = itemHelper exprs hostVar tContext
let item = new Item<_>(hVar,idx)
new Assignment<_>(new Property<_>(PropertyType.Item(item)),newVal) :> Statement<_>
, tContext
| _ ->
let r,tContext = translateFieldSet expr propInfo.Name exprs.[0] targetContext
r :> Statement<_>,tContext
| None -> failwithf "Unsupported property in kernel: %A" x
| None -> failwithf "Unsupported static property set in kernel: %A" propName

and TranslateAsExpr expr (targetContext:TargetContext<_,_>) =
let (r:Node<_>),tc = Translate expr (targetContext:TargetContext<_,_>)
Expand Down Expand Up @@ -376,12 +393,39 @@ and translateFieldSet host (*fldInfo:System.Reflection.FieldInfo*) name _val con



and translateFieldGet host (*fldInfo:System.Reflection.FieldInfo*)name context =
and translateStructFieldGet host (*fldInfo:System.Reflection.FieldInfo*) name context =
let hostE, tc = TranslateAsExpr host context
let field = name//fldInfo.Name
let res = new FieldGet<_>(hostE,field)
let field = name //fldInfo.Name
let res = FieldGet<_>(hostE,field) :> Expression<_>
res, tc

and translateUnionFieldGet expr (propInfo: PropertyInfo) targetContext =
let exprTypeName = expr.Type.Name.ToLowerInvariant()
let unionType = targetContext.UserDefinedUnionsOpenCLDeclaration.[exprTypeName]

let unionValueExpr, targetContext = TranslateAsExpr expr targetContext

let caseName = propInfo.DeclaringType.Name
let unionCaseField = unionType.GetCaseByName caseName

match unionCaseField with
| None -> failwithf "Union field get translation error:
union %A doesn't have case %A" unionType.Name caseName
| Some unionCaseField ->
let r =
FieldGet<_> (
FieldGet<_> (
FieldGet<_> (
unionValueExpr,
unionType.Data.Name
),
unionCaseField.Name
),
propInfo.Name
)
:> Expression<_>
r, targetContext

and Translate expr (targetContext:TargetContext<_,_>) =
//printfn "%A" expr
match expr with
Expand All @@ -402,7 +446,7 @@ and Translate expr (targetContext:TargetContext<_,_>) =
| Patterns.FieldGet (exprOpt,fldInfo) ->
match exprOpt with
| Some expr ->
let r,tContext = translateFieldGet expr fldInfo.Name targetContext
let r,tContext = translateStructFieldGet expr fldInfo.Name targetContext
r :> Node<_>,tContext
| None -> failwithf "FieldGet for empty host is not suported. Field: %A" fldInfo.Name
| Patterns.FieldSet (exprOpt,fldInfo,expr) ->
Expand Down Expand Up @@ -469,7 +513,7 @@ and Translate expr (targetContext:TargetContext<_,_>) =

let tag = Const(unionInfo.Tag.Type, string unionCaseInfo.Tag) :> Expression<_>
let args =
match unionInfo.GetCase unionCaseInfo.Tag with
match unionInfo.GetCaseByTag unionCaseInfo.Tag with
| None -> []
| Some field ->
let structArgs =
Expand All @@ -486,19 +530,19 @@ and Translate expr (targetContext:TargetContext<_,_>) =
[data :> Expression<_>]

NewStruct(unionInfo, tag :: args) :> Node<_>, targetContext
| Patterns.PropertyGet(exprOpt,propInfo,exprs) ->
let res, tContext = transletaPropGet exprOpt propInfo exprs targetContext
| Patterns.PropertyGet(exprOpt, propInfo, exprs) ->
let res, tContext = translatePropGet exprOpt propInfo exprs targetContext
(res :> Node<_>), tContext
| Patterns.PropertySet(exprOpt,propInfo,exprs,expr) ->
let res,tContext = transletaPropSet exprOpt propInfo exprs expr targetContext
let res,tContext = translatePropSet exprOpt propInfo exprs expr targetContext
res :> Node<_>,tContext
| Patterns.Sequential(expr1,expr2) ->
let res,tContext = translateSeq expr1 expr2 targetContext
res :> Node<_>,tContext
| Patterns.TryFinally(tryExpr,finallyExpr) -> "TryFinally is not suported:" + string expr|> failwith
| Patterns.TryWith(expr1,var1,expr2,var2,expr3) -> "TryWith is not suported:" + string expr|> failwith
| Patterns.TupleGet(expr,i) ->
let r,tContext = translateFieldGet expr ("_" + (string (i + 1))) targetContext
let r,tContext = translateStructFieldGet expr ("_" + (string (i + 1))) targetContext
r :> Node<_>,tContext
| Patterns.TypeTest(expr, sType) -> "TypeTest is not suported:" + string expr|> failwith
| Patterns.UnionCaseTest(expr, unionCaseInfo) ->
Expand Down
3 changes: 3 additions & 0 deletions tests/Brahma.FSharp.Tests/Brahma.FSharp.Tests.fsproj
Expand Up @@ -199,6 +199,9 @@
<Content Include="UnionExpected\Union.Compile.Test6.cl">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</Content>
<Content Include="UnionExpected\Union.Compile.Test7.cl">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</Content>
</ItemGroup>
<Import Project="..\..\.paket\Paket.Restore.targets" />
</Project>
17 changes: 17 additions & 0 deletions tests/Brahma.FSharp.Tests/Union.fs
Expand Up @@ -144,11 +144,28 @@ let unionTestCases =
@>
]

let unionPropertyGetTestLists =
testList "UnionPropertyGet" [
testGen testCase "Test 1: simple pattern matching bindings" "Union.Compile.Test7.gen" "Union.Compile.Test7.cl"
<@
fun (range: _1D) ->
let t = Case1
let mutable m = 5

match t with
| Case1 -> m <- 5
| Case2(x) -> m <- x
| Case3(y, z) -> m <- y + z
@>

]


testList "Union Compile tests"
[
newUnionTestList
testUnionCaseTestLists
unionPropertyGetTestLists
]

testList "Tests for translator"
Expand Down
25 changes: 25 additions & 0 deletions tests/Brahma.FSharp.Tests/UnionExpected/Union.Compile.Test7.cl
@@ -0,0 +1,25 @@
typedef struct TranslateMatchTestUnion {int tag ;
union TranslateMatchTestUnion_Data {struct Case2Type {int
Item ;}
Case2 ;
struct Case3Type {int
Item1
;
int
Item2
;}
Case3 ;} data ;} TranslateMatchTestUnion
;
__kernel void brahmaKernel ()
{TranslateMatchTestUnion t = { 0 } ;
int m = 5 ;
if (((t) . tag == 1))
{int x = (((t) . data) . Case2) . Item ;
m = x ;}
else
{if (((t) . tag == 2))
{int z = (((t) . data) . Case3) . Item2 ;
int y = (((t) . data) . Case3) . Item1 ;
m = (y + z) ;}
else
{m = 5 ;} ;} ;}

0 comments on commit 0098feb

Please sign in to comment.