diff --git a/src/Brahma.FSharp.OpenCL.AST/Types.fs b/src/Brahma.FSharp.OpenCL.AST/Types.fs index 7b63a686..da86c86d 100644 --- a/src/Brahma.FSharp.OpenCL.AST/Types.fs +++ b/src/Brahma.FSharp.OpenCL.AST/Types.fs @@ -131,10 +131,14 @@ type DiscriminatedUnionType<'lang>(name: string, fields: List member this.Tag = this.Fields.[0] member this.Data = this.Fields.[1] - member this.GetCase (tag: int) = + member this.GetCaseByTag (tag: int) = List.tryFind (fun (id, _) -> id = tag) fields |> Option.map snd + member this.GetCaseByName (case: string) = + List.tryFind (fun (_, f) -> f.Name = case) fields + |> Option.map snd + type TupleType<'lang>(baseStruct: StructType<'lang>, number:int)= inherit Type<'lang>() diff --git a/src/Brahma.FSharp.OpenCL.Translator/Body.fs b/src/Brahma.FSharp.OpenCL.Translator/Body.fs index dde5086e..f28efce6 100644 --- a/src/Brahma.FSharp.OpenCL.Translator/Body.fs +++ b/src/Brahma.FSharp.OpenCL.Translator/Body.fs @@ -15,6 +15,7 @@ module Brahma.FSharp.OpenCL.Translator.Body +open System.Reflection open Microsoft.FSharp.Quotations open Brahma.FSharp.OpenCL.AST open Microsoft.FSharp.Collections @@ -155,46 +156,62 @@ and private itemHelper exprs hostVar tContext = match exprs with | hd::_ -> TranslateAsExpr hd tContext | [] -> failwith "Array index missed!" - let hVar = - match hostVar with - | Some(v,_) -> v - | None -> failwith "Host var missed!" - idx,tContext,hVar - -and private transletaPropGet exprOpt (propInfo:System.Reflection.PropertyInfo) exprs targetContext = - let hostVar = exprOpt |> Option.map(fun e -> TranslateAsExpr e targetContext) - match propInfo.Name.ToLowerInvariant() with - | "globalid0i" | "globalid0" -> new FunCall<_>("get_global_id",[Const(PrimitiveType<_>(Int),"0")]) :> Expression<_>, targetContext - | "globalid1i" | "globalid1" -> new FunCall<_>("get_global_id",[Const(PrimitiveType<_>(Int),"1")]) :> Expression<_>, targetContext - | "globalid2i" | "globalid2" -> new FunCall<_>("get_global_id",[Const(PrimitiveType<_>(Int),"2")]) :> Expression<_>, targetContext - | "localid0" -> new FunCall<_>("get_local_id",[Const(PrimitiveType<_>(Int),"0")]) :> Expression<_>, targetContext - | "localid1" -> new FunCall<_>("get_local_id",[Const(PrimitiveType<_>(Int),"1")]) :> Expression<_>, targetContext - | "localid2" -> new FunCall<_>("get_local_id",[Const(PrimitiveType<_>(Int),"2")]) :> Expression<_>, targetContext + let (hVar, _) = hostVar + + idx, tContext, hVar + +and private translateSpecificPropGet expr propName exprs targetContext = + // TODO: Refactoring: Safe pattern matching by expr type. + + let hostVar = TranslateAsExpr expr targetContext + match propName with + | "globalid0i" | "globalid0" -> FunCall<_>("get_global_id",[Const(PrimitiveType<_>(Int),"0")]) :> Expression<_>, targetContext + | "globalid1i" | "globalid1" -> FunCall<_>("get_global_id",[Const(PrimitiveType<_>(Int),"1")]) :> Expression<_>, targetContext + | "globalid2i" | "globalid2" -> FunCall<_>("get_global_id",[Const(PrimitiveType<_>(Int),"2")]) :> Expression<_>, targetContext + | "localid0" -> FunCall<_>("get_local_id",[Const(PrimitiveType<_>(Int),"0")]) :> Expression<_>, targetContext + | "localid1" -> FunCall<_>("get_local_id",[Const(PrimitiveType<_>(Int),"1")]) :> Expression<_>, targetContext + | "localid2" -> FunCall<_>("get_local_id",[Const(PrimitiveType<_>(Int),"2")]) :> Expression<_>, targetContext | "item" -> let idx,tContext,hVar = itemHelper exprs hostVar targetContext - new Item<_>(hVar,idx) :> Expression<_>, tContext - | x -> - match exprOpt with - | Some expr -> - let r,tContext = translateFieldGet expr propInfo.Name targetContext - r :> Expression<_>, tContext - | None -> failwithf "Unsupported property in kernel: %A" x - -and private transletaPropSet exprOpt (propInfo:System.Reflection.PropertyInfo) exprs newVal targetContext = - let hostVar = exprOpt |> Option.map(fun e -> TranslateAsExpr e targetContext) - let newVal,tContext = TranslateAsExpr newVal (match hostVar with Some(v,c) -> c | None -> targetContext) - match propInfo.Name.ToLowerInvariant() with - | "item" -> - let idx,tContext,hVar = itemHelper exprs hostVar tContext - let item = new Item<_>(hVar,idx) - new Assignment<_>(new Property<_>(PropertyType.Item(item)),newVal) :> Statement<_> - , tContext - | x -> - match exprOpt with - | Some e -> - let r,tContext = translateFieldSet e propInfo.Name exprs.[0] targetContext + Item<_>(hVar,idx) :> Expression<_>, tContext + | _ -> failwithf "Unsupported property in kernel: %A" propName + +and private translatePropGet (exprOpt: Expr Option) (propInfo: PropertyInfo) exprs (targetContext: TargetContext<_, _>) = + let propName = propInfo.Name.ToLowerInvariant() + + match exprOpt with + | Some expr -> + let exprType = expr.Type + if targetContext.UserDefinedTypes.Contains exprType + then + let exprTypeName = expr.Type.Name.ToLowerInvariant() + if targetContext.UserDefinedStructsOpenCLDeclaration.ContainsKey exprTypeName + then + translateStructFieldGet expr propInfo.Name targetContext + else + translateUnionFieldGet expr propInfo targetContext + else + translateSpecificPropGet expr propName exprs targetContext + | None -> failwithf "Unsupported static property get in kernel: %A" propName + +and private translatePropSet exprOpt (propInfo:System.Reflection.PropertyInfo) exprs newVal targetContext = + // Todo: Safe pattern matching (item) by expr type + let propName = propInfo.Name.ToLowerInvariant() + + match exprOpt with + | Some expr -> + let hostVar = TranslateAsExpr expr targetContext + let newVal,tContext = TranslateAsExpr newVal (match hostVar with (v,c) -> c) + match propInfo.Name.ToLowerInvariant() with + | "item" -> + let idx,tContext,hVar = itemHelper exprs hostVar tContext + let item = new Item<_>(hVar,idx) + new Assignment<_>(new Property<_>(PropertyType.Item(item)),newVal) :> Statement<_> + , tContext + | _ -> + let r,tContext = translateFieldSet expr propInfo.Name exprs.[0] targetContext r :> Statement<_>,tContext - | None -> failwithf "Unsupported property in kernel: %A" x + | None -> failwithf "Unsupported static property set in kernel: %A" propName and TranslateAsExpr expr (targetContext:TargetContext<_,_>) = let (r:Node<_>),tc = Translate expr (targetContext:TargetContext<_,_>) @@ -376,12 +393,39 @@ and translateFieldSet host (*fldInfo:System.Reflection.FieldInfo*) name _val con -and translateFieldGet host (*fldInfo:System.Reflection.FieldInfo*)name context = +and translateStructFieldGet host (*fldInfo:System.Reflection.FieldInfo*) name context = let hostE, tc = TranslateAsExpr host context - let field = name//fldInfo.Name - let res = new FieldGet<_>(hostE,field) + let field = name //fldInfo.Name + let res = FieldGet<_>(hostE,field) :> Expression<_> res, tc +and translateUnionFieldGet expr (propInfo: PropertyInfo) targetContext = + let exprTypeName = expr.Type.Name.ToLowerInvariant() + let unionType = targetContext.UserDefinedUnionsOpenCLDeclaration.[exprTypeName] + + let unionValueExpr, targetContext = TranslateAsExpr expr targetContext + + let caseName = propInfo.DeclaringType.Name + let unionCaseField = unionType.GetCaseByName caseName + + match unionCaseField with + | None -> failwithf "Union field get translation error: + union %A doesn't have case %A" unionType.Name caseName + | Some unionCaseField -> + let r = + FieldGet<_> ( + FieldGet<_> ( + FieldGet<_> ( + unionValueExpr, + unionType.Data.Name + ), + unionCaseField.Name + ), + propInfo.Name + ) + :> Expression<_> + r, targetContext + and Translate expr (targetContext:TargetContext<_,_>) = //printfn "%A" expr match expr with @@ -402,7 +446,7 @@ and Translate expr (targetContext:TargetContext<_,_>) = | Patterns.FieldGet (exprOpt,fldInfo) -> match exprOpt with | Some expr -> - let r,tContext = translateFieldGet expr fldInfo.Name targetContext + let r,tContext = translateStructFieldGet expr fldInfo.Name targetContext r :> Node<_>,tContext | None -> failwithf "FieldGet for empty host is not suported. Field: %A" fldInfo.Name | Patterns.FieldSet (exprOpt,fldInfo,expr) -> @@ -469,7 +513,7 @@ and Translate expr (targetContext:TargetContext<_,_>) = let tag = Const(unionInfo.Tag.Type, string unionCaseInfo.Tag) :> Expression<_> let args = - match unionInfo.GetCase unionCaseInfo.Tag with + match unionInfo.GetCaseByTag unionCaseInfo.Tag with | None -> [] | Some field -> let structArgs = @@ -486,11 +530,11 @@ and Translate expr (targetContext:TargetContext<_,_>) = [data :> Expression<_>] NewStruct(unionInfo, tag :: args) :> Node<_>, targetContext - | Patterns.PropertyGet(exprOpt,propInfo,exprs) -> - let res, tContext = transletaPropGet exprOpt propInfo exprs targetContext + | Patterns.PropertyGet(exprOpt, propInfo, exprs) -> + let res, tContext = translatePropGet exprOpt propInfo exprs targetContext (res :> Node<_>), tContext | Patterns.PropertySet(exprOpt,propInfo,exprs,expr) -> - let res,tContext = transletaPropSet exprOpt propInfo exprs expr targetContext + let res,tContext = translatePropSet exprOpt propInfo exprs expr targetContext res :> Node<_>,tContext | Patterns.Sequential(expr1,expr2) -> let res,tContext = translateSeq expr1 expr2 targetContext @@ -498,7 +542,7 @@ and Translate expr (targetContext:TargetContext<_,_>) = | Patterns.TryFinally(tryExpr,finallyExpr) -> "TryFinally is not suported:" + string expr|> failwith | Patterns.TryWith(expr1,var1,expr2,var2,expr3) -> "TryWith is not suported:" + string expr|> failwith | Patterns.TupleGet(expr,i) -> - let r,tContext = translateFieldGet expr ("_" + (string (i + 1))) targetContext + let r,tContext = translateStructFieldGet expr ("_" + (string (i + 1))) targetContext r :> Node<_>,tContext | Patterns.TypeTest(expr, sType) -> "TypeTest is not suported:" + string expr|> failwith | Patterns.UnionCaseTest(expr, unionCaseInfo) -> diff --git a/tests/Brahma.FSharp.Tests/Brahma.FSharp.Tests.fsproj b/tests/Brahma.FSharp.Tests/Brahma.FSharp.Tests.fsproj index 80918df2..d0439cdd 100644 --- a/tests/Brahma.FSharp.Tests/Brahma.FSharp.Tests.fsproj +++ b/tests/Brahma.FSharp.Tests/Brahma.FSharp.Tests.fsproj @@ -199,6 +199,9 @@ Always + + Always + diff --git a/tests/Brahma.FSharp.Tests/Union.fs b/tests/Brahma.FSharp.Tests/Union.fs index 50074f47..8d1ac620 100644 --- a/tests/Brahma.FSharp.Tests/Union.fs +++ b/tests/Brahma.FSharp.Tests/Union.fs @@ -144,11 +144,28 @@ let unionTestCases = @> ] + let unionPropertyGetTestLists = + testList "UnionPropertyGet" [ + testGen testCase "Test 1: simple pattern matching bindings" "Union.Compile.Test7.gen" "Union.Compile.Test7.cl" + <@ + fun (range: _1D) -> + let t = Case1 + let mutable m = 5 + + match t with + | Case1 -> m <- 5 + | Case2(x) -> m <- x + | Case3(y, z) -> m <- y + z + @> + + ] + testList "Union Compile tests" [ newUnionTestList testUnionCaseTestLists + unionPropertyGetTestLists ] testList "Tests for translator" diff --git a/tests/Brahma.FSharp.Tests/UnionExpected/Union.Compile.Test7.cl b/tests/Brahma.FSharp.Tests/UnionExpected/Union.Compile.Test7.cl new file mode 100644 index 00000000..d6f5e121 --- /dev/null +++ b/tests/Brahma.FSharp.Tests/UnionExpected/Union.Compile.Test7.cl @@ -0,0 +1,25 @@ +typedef struct TranslateMatchTestUnion {int tag ; + union TranslateMatchTestUnion_Data {struct Case2Type {int + Item ;} + Case2 ; + struct Case3Type {int + Item1 + ; + int + Item2 + ;} + Case3 ;} data ;} TranslateMatchTestUnion + ; +__kernel void brahmaKernel () +{TranslateMatchTestUnion t = { 0 } ; + int m = 5 ; + if (((t) . tag == 1)) + {int x = (((t) . data) . Case2) . Item ; + m = x ;} + else + {if (((t) . tag == 2)) + {int z = (((t) . data) . Case3) . Item2 ; + int y = (((t) . data) . Case3) . Item1 ; + m = (y + z) ;} + else + {m = 5 ;} ;} ;}