Skip to content

Commit

Permalink
refactor: use checkDirectCall for recursive invokes
Browse files Browse the repository at this point in the history
  • Loading branch information
tolgaOzen committed Sep 16, 2023
1 parent d7c7030 commit c94646b
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 67 deletions.
85 changes: 42 additions & 43 deletions internal/engines/check.go
Expand Up @@ -123,11 +123,7 @@ func (engine *CheckEngine) check(
var fn CheckFunction

// Determine the type of the reference by name in the given entity definition.
tor, err := schema.GetTypeOfReferenceByNameInEntityDefinition(en, request.GetPermission())
if err != nil {
// If an error is encountered while determining the type, a CheckFunction is returned that always fails with this error.
return checkFail(err)
}
tor, _ := schema.GetTypeOfReferenceByNameInEntityDefinition(en, request.GetPermission())

// Based on the type of the reference, define the CheckFunction in different ways.
switch tor {
Expand Down Expand Up @@ -155,11 +151,7 @@ func (engine *CheckEngine) check(
// If the reference is a relation, check the direct relation.
fn = engine.checkDirectRelation(request)
default:
// If the reference is not a permission, attribute or relation, check the call.
fn = engine.checkCall(request, &base.Call{
RuleName: request.GetPermission(),
Arguments: request.GetArguments(),
})
fn = engine.checkDirectCall(request)
}

// If the CheckFunction is still undefined after the switch, return a CheckFunction that always fails with an error indicating an undefined child kind.
Expand Down Expand Up @@ -412,7 +404,7 @@ func (engine *CheckEngine) checkTupleToUserSet(
// ComputedUserSet data structure. It returns a CheckFunction closure that performs the check.
func (engine *CheckEngine) checkComputedUserSet(
request *base.PermissionCheckRequest, // The request containing details about the permission to be checked
cu *base.ComputedUserSet, // The computed user set containing user set information
cu *base.ComputedUserSet, // The computed user set containing user set information
) CheckFunction {
// The returned CheckFunction invokes a permission check with a new request that is almost the same
// as the incoming request, but changes the Permission to be the relation defined in the computed user set.
Expand Down Expand Up @@ -496,7 +488,7 @@ func (engine *CheckEngine) checkDirectAttribute(
}

// Unmarshal the attribute value into a BoolValue message.
var msg base.BoolValue
var msg base.BooleanValue
if err := val.GetValue().UnmarshalTo(&msg); err != nil {
// If there was an error unmarshaling, return a denied response and the error.
return denied(&base.PermissionCheckResponseMetadata{}), err
Expand All @@ -512,13 +504,30 @@ func (engine *CheckEngine) checkDirectAttribute(
}
}

// checkCall is a function that validates a call using the CheckEngine.
// It returns a function (CheckFunction) that when called, performs the permission check.
// checkCall creates and returns a CheckFunction based on the provided request and call details.
// It essentially constructs a new PermissionCheckRequest based on the call details and then invokes
// the permission check using the engine's invoke method.
func (engine *CheckEngine) checkCall(
request *base.PermissionCheckRequest, // The request containing the details for the permission check
call *base.Call, // The specific call to be checked
request *base.PermissionCheckRequest,
call *base.Call,
) CheckFunction {
// Construct a new permission check request based on the input request and call details.
return engine.invoke(&base.PermissionCheckRequest{
TenantId: request.GetTenantId(),
Entity: request.GetEntity(),
Permission: call.GetRuleName(),
Subject: request.GetSubject(),
Metadata: request.GetMetadata(),
Context: request.GetContext(),
Arguments: call.GetArguments(),
})
}

// checkDirectCall creates and returns a CheckFunction that performs direct permission checking.
// The function evaluates permissions based on rule definitions, arguments, and attributes.
func (engine *CheckEngine) checkDirectCall(
request *base.PermissionCheckRequest,
) CheckFunction {
// The function returned by checkCall
return func(ctx context.Context) (*base.PermissionCheckResponse, error) {
var err error

Expand All @@ -529,52 +538,44 @@ func (engine *CheckEngine) checkCall(

// Read the rule definition from the schema. If an error occurs, return the default denied response.
var ru *base.RuleDefinition
ru, _, err = engine.schemaReader.ReadRuleDefinition(ctx, request.GetTenantId(), call.GetRuleName(), request.GetMetadata().GetSchemaVersion())
ru, _, err = engine.schemaReader.ReadRuleDefinition(ctx, request.GetTenantId(), request.GetPermission(), request.GetMetadata().GetSchemaVersion())
if err != nil {
return emptyResp, err
}

// Prepare the arguments map to be used in the CEL evaluation
// Initialize an arguments map to hold argument values.
arguments := make(map[string]interface{})

// Prepare a slice for attributes
// List to store computed attributes.
attributes := make([]string, 0)

// Populate the arguments map based on the type of argument in the call
for _, arg := range call.GetArguments() {
// Iterate over request arguments to classify and process them.
for _, arg := range request.GetArguments() {
switch actualArg := arg.Type.(type) {
case *base.Argument_ComputedAttribute:
// Get the name of the computed attribute
// Handle computed attributes: Set them to a default empty value.
attrName := actualArg.ComputedAttribute.GetName()

// Get the empty value for this attribute type
emptyValue := getEmptyValueForType(ru.GetArguments()[attrName])

// Add the attribute with its empty value to the arguments map
arguments[attrName] = emptyValue

// Append the attribute to the attributes slice
attributes = append(attributes, attrName)

case *base.Argument_ContextAttribute:
// Get the name of the context attribute
// Handle context attributes: Use the value from context or default to an empty value.
attrName := actualArg.ContextAttribute.GetName()

// Get the value of the context attribute if exists, else get an empty value
value, exists := request.GetContext().GetData().AsMap()[attrName]
if !exists {
value = getEmptyValueForType(ru.GetArguments()[attrName])
}

// Add the attribute with its value to the arguments map
arguments[attrName] = value

default:
// Return an error for unhandled types
// Return an error for any unsupported argument types.
return denied(&base.PermissionCheckResponseMetadata{}), fmt.Errorf(base.ErrorCode_ERROR_CODE_INTERNAL.String())
}
}

// If there are computed attributes, fetch them from the data source.
if len(attributes) > 0 {
// Prepare the filter for querying attributes from the database
filter := &base.AttributeFilter{
Entity: &base.EntityFilter{
Type: request.GetEntity().GetType(),
Expand All @@ -583,7 +584,6 @@ func (engine *CheckEngine) checkCall(
Attributes: attributes,
}

// Query the database for the attributes and add them to the arguments map
ait, err := engine.dataReader.QueryAttributes(ctx, request.GetTenantId(), filter, request.GetMetadata().GetSnapToken())
if err != nil {
return denied(&base.PermissionCheckResponseMetadata{}), err
Expand All @@ -594,38 +594,37 @@ func (engine *CheckEngine) checkCall(
return denied(&base.PermissionCheckResponseMetadata{}), err
}

// Combine attributes from different sources ensuring uniqueness.
it := database.NewUniqueAttributeIterator(ait, cta)

for it.HasNext() {
next, ok := it.GetNext()
if !ok {
break
}

arguments[next.GetAttribute()] = utils.ConvertProtoAnyToInterface(next.GetValue())
}
}

// Prepare the CEL environment using the arguments in the rule definition
// Prepare the CEL environment with the argument values.
env, err := utils.ArgumentsAsCelEnv(ru.Arguments)
if err != nil {
return nil, err
}

// Prepare the CEL program using the rule's expression
// Compile the rule expression into an executable form.
exp := cel.CheckedExprToAst(ru.Expression)
prg, err := env.Program(exp)
if err != nil {
return nil, err
}

// Evaluate the CEL program with the arguments. If an error occurs, return a "denied" response.
// Evaluate the rule expression with the provided arguments.
out, _, err := prg.Eval(arguments)
if err != nil {
return denied(&base.PermissionCheckResponseMetadata{}), fmt.Errorf("failed to evaluate expression: %w", err)
}

// Check if the result of the CEL evaluation is a boolean
// Ensure the result of evaluation is boolean and decide on permission.
result, ok := out.Value().(bool)
if !ok {
return denied(&base.PermissionCheckResponseMetadata{}), fmt.Errorf("expected boolean result, but got %T", out.Value())
Expand Down
58 changes: 36 additions & 22 deletions internal/engines/expand.go
Expand Up @@ -81,11 +81,7 @@ func (engine *ExpandEngine) expand(ctx context.Context, request *base.Permission

var tor base.EntityDefinition_Reference
// Get the type of reference by name in the entity definition.
tor, err = schema.GetTypeOfReferenceByNameInEntityDefinition(en, request.GetPermission())
if err != nil {
// If an error occurred while getting the type of reference, return an ExpandResponse with the error.
return ExpandResponse{Err: err}
}
tor, _ = schema.GetTypeOfReferenceByNameInEntityDefinition(en, request.GetPermission())

// Depending on the type of reference, execute different branches of code.
switch tor {
Expand Down Expand Up @@ -114,10 +110,7 @@ func (engine *ExpandEngine) expand(ctx context.Context, request *base.Permission
fn = engine.expandDirectRelation(request)
default:
// If the reference is neither permission, attribute, nor relation, use the 'expandCall' method.
fn = engine.expandCall(request, &base.Call{
RuleName: request.GetPermission(),
Arguments: request.GetArguments(),
})
fn = engine.expandDirectCall(request)
}

if fn == nil {
Expand Down Expand Up @@ -519,7 +512,7 @@ func (engine *ExpandEngine) expandDirectAttribute(
},
Attribute: request.GetPermission(),
}
val.Value, err = anypb.New(&base.BoolValue{Data: false})
val.Value, err = anypb.New(&base.BooleanValue{Data: false})
if err != nil {
expandChan <- expandFailResponse(err)
return
Expand All @@ -546,11 +539,32 @@ func (engine *ExpandEngine) expandDirectAttribute(
}
}

// expandCall returns an ExpandFunction for the given request and call.
// The returned function, when executed, sends the expanded permission result
// to the provided result channel.
func (engine *ExpandEngine) expandCall(
request *base.PermissionExpandRequest,
call *base.Call,
) ExpandFunction {
return func(ctx context.Context, resultChan chan<- ExpandResponse) {
resultChan <- engine.expand(ctx, &base.PermissionExpandRequest{
TenantId: request.GetTenantId(),
Entity: &base.Entity{
Type: request.GetEntity().GetType(),
Id: request.GetEntity().GetId(),
},
Permission: call.GetRuleName(),
Metadata: request.GetMetadata(),
Context: request.GetContext(),
Arguments: call.GetArguments(),
})
}
}

// The function 'expandCall' is a method on the ExpandEngine struct.
// It takes a PermissionExpandRequest and a Call as parameters and returns an ExpandFunction.
func (engine *ExpandEngine) expandCall(
func (engine *ExpandEngine) expandDirectCall(
request *base.PermissionExpandRequest, // The request object containing information necessary for the expansion.
call *base.Call, // The call object that defines the rule name and its arguments.
) ExpandFunction { // The function returns an ExpandFunction.
return func(ctx context.Context, expandChan chan<- ExpandResponse) { // defining the returned function.

Expand All @@ -559,7 +573,7 @@ func (engine *ExpandEngine) expandCall(
var ru *base.RuleDefinition // variable to hold the rule definition.

// Read the rule definition based on the rule name in the call.
ru, _, err = engine.schemaReader.ReadRuleDefinition(ctx, request.GetTenantId(), call.GetRuleName(), request.GetMetadata().GetSchemaVersion())
ru, _, err = engine.schemaReader.ReadRuleDefinition(ctx, request.GetTenantId(), request.GetPermission(), request.GetMetadata().GetSchemaVersion())
if err != nil {
// If there's an error in reading the rule definition, send a failure response through the channel and return from the function.
expandChan <- expandFailResponse(err)
Expand All @@ -573,7 +587,7 @@ func (engine *ExpandEngine) expandCall(
attributes := make([]string, 0)

// For each argument in the call...
for _, arg := range call.GetArguments() {
for _, arg := range request.GetArguments() {
switch actualArg := arg.Type.(type) { // Switch on the type of the argument.
case *base.Argument_ComputedAttribute: // If the argument is a ComputedAttribute...
attrName := actualArg.ComputedAttribute.GetName() // get the name of the attribute.
Expand Down Expand Up @@ -667,8 +681,8 @@ func (engine *ExpandEngine) expandCall(
Response: &base.PermissionExpandResponse{
Tree: &base.Expand{
Entity: request.GetEntity(),
Permission: call.GetRuleName(),
Arguments: call.GetArguments(),
Permission: request.GetPermission(),
Arguments: request.GetArguments(),
Node: &base.Expand_Leaf{
Leaf: &base.ExpandLeaf{
Type: &base.ExpandLeaf_Values{
Expand All @@ -688,7 +702,7 @@ func (engine *ExpandEngine) expandCall(
// It takes a PermissionExpandRequest and a ComputedAttribute as parameters and returns an ExpandFunction.
func (engine *ExpandEngine) expandComputedAttribute(
request *base.PermissionExpandRequest, // The request object containing necessary information for the expansion.
ca *base.ComputedAttribute, // The computed attribute object that has the name of the attribute to be computed.
ca *base.ComputedAttribute, // The computed attribute object that has the name of the attribute to be computed.
) ExpandFunction { // The function returns an ExpandFunction.
return func(ctx context.Context, resultChan chan<- ExpandResponse) { // defining the returned function.

Expand All @@ -712,11 +726,11 @@ func (engine *ExpandEngine) expandComputedAttribute(
// a slice of arguments, slice of ExpandFunctions, and an operation of type base.ExpandTreeNode_Operation.
// It returns an ExpandResponse.
func expandOperation(
ctx context.Context, // The context of this operation, which may carry deadlines, cancellation signals, etc.
entity *base.Entity, // The entity on which the operation will be performed.
permission string, // The permission string required for the operation.
arguments []*base.Argument, // A slice of arguments required for the operation.
functions []ExpandFunction, // A slice of functions that will be used to expand the operation.
ctx context.Context, // The context of this operation, which may carry deadlines, cancellation signals, etc.
entity *base.Entity, // The entity on which the operation will be performed.
permission string, // The permission string required for the operation.
arguments []*base.Argument, // A slice of arguments required for the operation.
functions []ExpandFunction, // A slice of functions that will be used to expand the operation.
op base.ExpandTreeNode_Operation, // The operation to be performed.
) ExpandResponse { // The function returns an ExpandResponse.

Expand Down
4 changes: 2 additions & 2 deletions internal/engines/expand_test.go
Expand Up @@ -1320,8 +1320,8 @@ var _ = Describe("expand-engine", func() {
assertions map[string]*base.Expand
}

anyVal, _ := anypb.New(&base.Boolean{Value: true})
dow, _ := anypb.New(&base.String{Value: "monday"})
anyVal, _ := anypb.New(&base.BooleanValue{Data: true})
dow, _ := anypb.New(&base.StringValue{Data: "monday"})

tests := struct {
relationships []string
Expand Down

0 comments on commit c94646b

Please sign in to comment.