From 16e508e32e82e2f86295f21bdda128da4e602b75 Mon Sep 17 00:00:00 2001 From: closetool <4closetool3@gmail.com> Date: Fri, 14 May 2021 15:29:52 +0800 Subject: [PATCH] fix: bug of GetAllUsersByDomain Signed-off-by: closetool <4closetool3@gmail.com> --- internal_api.go | 15 +++++++ rbac_api_with_domains.go | 74 ++++++++++++++++++++++------------- rbac_api_with_domains_test.go | 40 +++++++++++++++++++ 3 files changed, 101 insertions(+), 28 deletions(-) diff --git a/internal_api.go b/internal_api.go index c4cae7856..3659e18bb 100644 --- a/internal_api.go +++ b/internal_api.go @@ -15,6 +15,8 @@ package casbin import ( + "fmt" + Err "github.com/casbin/casbin/v2/errors" "github.com/casbin/casbin/v2/model" "github.com/casbin/casbin/v2/persist" @@ -358,3 +360,16 @@ func (e *Enforcer) updateFilteredPolicies(sec string, ptype string, newRules [][ return ruleChanged, nil } + +func (e *Enforcer) getDomainIndex(ptype string) int { + p := e.model["p"][ptype] + pattern := fmt.Sprintf("%s_dom", ptype) + index := len(p.Tokens) + for i, token := range p.Tokens { + if token == pattern { + index = i + break + } + } + return index +} diff --git a/rbac_api_with_domains.go b/rbac_api_with_domains.go index 404909874..e62649d6d 100644 --- a/rbac_api_with_domains.go +++ b/rbac_api_with_domains.go @@ -63,52 +63,70 @@ func (e *Enforcer) DeleteRolesForUserInDomain(user string, domain string) (bool, func (e *Enforcer) GetAllUsersByDomain(domain string) []string { m := make(map[string]struct{}) g := e.model["g"]["g"] - if len(g.Tokens) != 3 { - return []string{} - } + p := e.model["p"]["p"] users := make([]string, 0) - for _, policy := range g.Policy { - if _, ok := m[policy[2]]; policy[2] == domain && ok { - users = append(users, policy[0]) + index := e.getDomainIndex("p") + + getUser := func(index int, policies [][]string, domain string, m map[string]struct{}) []string { + if len(policies) == 0 || len(policies[0]) <= index { + return []string{} } + res := make([]string, 0) + for _, policy := range policies { + if _, ok := m[policy[0]]; policy[index] == domain && !ok { + res = append(res, policy[0]) + m[policy[0]] = struct{}{} + } + } + return res } + + users = append(users, getUser(2, g.Policy, domain, m)...) + users = append(users, getUser(index, p.Policy, domain, m)...) return users } // DeleteAllUsersByDomain would delete all users associated with the domain. func (e *Enforcer) DeleteAllUsersByDomain(domain string) (bool, error) { g := e.model["g"]["g"] - if len(g.Tokens) != 3 { - return false, nil - } - policies := make([][]string, 0) - for _, policy := range g.Policy { - if policy[3] == domain { - policies = append(policies, policy) + p := e.model["p"]["p"] + index := e.getDomainIndex("p") + + getUser := func(index int, policies [][]string, domain string) [][]string { + if len(policies) == 0 || len(policies[0]) <= index { + return [][]string{} } + res := make([][]string, 0) + for _, policy := range policies { + if policy[index] == domain { + res = append(res, policy) + } + } + return res + } + + users := getUser(2, g.Policy, domain) + if _, err := e.RemoveGroupingPolicies(users); err != nil { + return false, err + } + users = getUser(index, p.Policy, domain) + if _, err := e.RemovePolicies(users); err != nil { + return false, err } - return e.RemoveGroupingPolicies(policies) + return true, nil } // DeleteDomains would delete all associated users and roles. // It would delete all domains if parameter is not provided. -func (e *Enforcer) DeleteDomains(domains ...string) (bool, error) { - g := e.model["g"]["g"] - if len(g.Tokens) != 3 { - return false, nil - } +func (e *Enforcer) DeleteDomains(index int, domains ...string) (bool, error) { if len(domains) == 0 { - return e.RemoveGroupingPolicies(g.Policy) + e.ClearPolicy() + return true, nil } - m := make(map[string]struct{}) for _, domain := range domains { - m[domain] = struct{}{} - } - policies := make([][]string, 0) - for _, policy := range g.Policy { - if _, ok := m[policy[2]]; ok { - policies = append(policies, policy) + if _, err := e.DeleteAllUsersByDomain(domain); err != nil { + return false, err } } - return e.RemoveGroupingPolicies(policies) + return true, nil } diff --git a/rbac_api_with_domains_test.go b/rbac_api_with_domains_test.go index ad31befc2..65b78368e 100644 --- a/rbac_api_with_domains_test.go +++ b/rbac_api_with_domains_test.go @@ -210,3 +210,43 @@ func TestGetDomainsForUser(t *testing.T) { testGetDomainsForUser(t, e, []string{"domain2", "domain3"}, "bob") testGetDomainsForUser(t, e, []string{"domain3"}, "user") } + +func testGetAllUsersByDomain(t *testing.T, e *Enforcer, domain string, expected []string) { + if !util.ArrayEquals(e.GetAllUsersByDomain(domain), expected) { + t.Errorf("users in %s: %v, supposed to be %v\n", domain, e.GetAllUsersByDomain(domain), expected) + } +} + +func TestGetAllUsersByDomain(t *testing.T) { + e, _ := NewEnforcer("examples/rbac_with_domains_model.conf", "examples/rbac_with_domains_policy.csv") + + testGetAllUsersByDomain(t, e, "domain1", []string{"alice", "admin"}) + testGetAllUsersByDomain(t, e, "domain2", []string{"bob", "admin"}) +} + +func testDeleteAllUsersByDomain(t *testing.T, domain string, expectedPolicy, expectedGroupingPolicy [][]string) { + e, _ := NewEnforcer("examples/rbac_with_domains_model.conf", "examples/rbac_with_domains_policy.csv") + + _, _ = e.DeleteAllUsersByDomain(domain) + if !util.Array2DEquals(e.GetPolicy(), expectedPolicy) { + t.Errorf("policy in %s: %v, supposed to be %v\n", domain, e.GetPolicy(), expectedPolicy) + } + if !util.Array2DEquals(e.GetGroupingPolicy(), expectedGroupingPolicy) { + t.Errorf("grouping policy in %s: %v, supposed to be %v\n", domain, e.GetGroupingPolicy(), expectedGroupingPolicy) + } +} + +func TestDeleteAllUsersByDomain(t *testing.T) { + testDeleteAllUsersByDomain(t, "domain1", [][]string{ + {"admin", "domain2", "data2", "read"}, + {"admin", "domain2", "data2", "write"}, + }, [][]string{ + {"bob", "admin", "domain2"}, + }) + testDeleteAllUsersByDomain(t, "domain2", [][]string{ + {"admin", "domain1", "data1", "read"}, + {"admin", "domain1", "data1", "write"}, + }, [][]string{ + {"alice", "admin", "domain1"}, + }) +}