Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions internal_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package casbin

import (
"fmt"

Err "github.com/casbin/casbin/v2/errors"
"github.com/casbin/casbin/v2/model"
"github.com/casbin/casbin/v2/persist"
Expand Down Expand Up @@ -358,3 +360,16 @@ func (e *Enforcer) updateFilteredPolicies(sec string, ptype string, newRules [][

return ruleChanged, nil
}

func (e *Enforcer) getDomainIndex(ptype string) int {
p := e.model["p"][ptype]
pattern := fmt.Sprintf("%s_dom", ptype)
index := len(p.Tokens)
for i, token := range p.Tokens {
if token == pattern {
index = i
break
}
}
return index
}
74 changes: 46 additions & 28 deletions rbac_api_with_domains.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,52 +63,70 @@ func (e *Enforcer) DeleteRolesForUserInDomain(user string, domain string) (bool,
func (e *Enforcer) GetAllUsersByDomain(domain string) []string {
m := make(map[string]struct{})
g := e.model["g"]["g"]
if len(g.Tokens) != 3 {
return []string{}
}
p := e.model["p"]["p"]
users := make([]string, 0)
for _, policy := range g.Policy {
if _, ok := m[policy[2]]; policy[2] == domain && ok {
users = append(users, policy[0])
index := e.getDomainIndex("p")

getUser := func(index int, policies [][]string, domain string, m map[string]struct{}) []string {
if len(policies) == 0 || len(policies[0]) <= index {
return []string{}
}
res := make([]string, 0)
for _, policy := range policies {
if _, ok := m[policy[0]]; policy[index] == domain && !ok {
res = append(res, policy[0])
m[policy[0]] = struct{}{}
}
}
return res
}

users = append(users, getUser(2, g.Policy, domain, m)...)
users = append(users, getUser(index, p.Policy, domain, m)...)
return users
}

// DeleteAllUsersByDomain would delete all users associated with the domain.
func (e *Enforcer) DeleteAllUsersByDomain(domain string) (bool, error) {
g := e.model["g"]["g"]
if len(g.Tokens) != 3 {
return false, nil
}
policies := make([][]string, 0)
for _, policy := range g.Policy {
if policy[3] == domain {
policies = append(policies, policy)
p := e.model["p"]["p"]
index := e.getDomainIndex("p")

getUser := func(index int, policies [][]string, domain string) [][]string {
if len(policies) == 0 || len(policies[0]) <= index {
return [][]string{}
}
res := make([][]string, 0)
for _, policy := range policies {
if policy[index] == domain {
res = append(res, policy)
}
}
return res
}

users := getUser(2, g.Policy, domain)
if _, err := e.RemoveGroupingPolicies(users); err != nil {
return false, err
}
users = getUser(index, p.Policy, domain)
if _, err := e.RemovePolicies(users); err != nil {
return false, err
}
return e.RemoveGroupingPolicies(policies)
return true, nil
}

// DeleteDomains would delete all associated users and roles.
// It would delete all domains if parameter is not provided.
func (e *Enforcer) DeleteDomains(domains ...string) (bool, error) {
g := e.model["g"]["g"]
if len(g.Tokens) != 3 {
return false, nil
}
func (e *Enforcer) DeleteDomains(index int, domains ...string) (bool, error) {
if len(domains) == 0 {
return e.RemoveGroupingPolicies(g.Policy)
e.ClearPolicy()
return true, nil
}
m := make(map[string]struct{})
for _, domain := range domains {
m[domain] = struct{}{}
}
policies := make([][]string, 0)
for _, policy := range g.Policy {
if _, ok := m[policy[2]]; ok {
policies = append(policies, policy)
if _, err := e.DeleteAllUsersByDomain(domain); err != nil {
return false, err
}
}
return e.RemoveGroupingPolicies(policies)
return true, nil
}
40 changes: 40 additions & 0 deletions rbac_api_with_domains_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,43 @@ func TestGetDomainsForUser(t *testing.T) {
testGetDomainsForUser(t, e, []string{"domain2", "domain3"}, "bob")
testGetDomainsForUser(t, e, []string{"domain3"}, "user")
}

func testGetAllUsersByDomain(t *testing.T, e *Enforcer, domain string, expected []string) {
if !util.ArrayEquals(e.GetAllUsersByDomain(domain), expected) {
t.Errorf("users in %s: %v, supposed to be %v\n", domain, e.GetAllUsersByDomain(domain), expected)
}
}

func TestGetAllUsersByDomain(t *testing.T) {
e, _ := NewEnforcer("examples/rbac_with_domains_model.conf", "examples/rbac_with_domains_policy.csv")

testGetAllUsersByDomain(t, e, "domain1", []string{"alice", "admin"})
testGetAllUsersByDomain(t, e, "domain2", []string{"bob", "admin"})
}

func testDeleteAllUsersByDomain(t *testing.T, domain string, expectedPolicy, expectedGroupingPolicy [][]string) {
e, _ := NewEnforcer("examples/rbac_with_domains_model.conf", "examples/rbac_with_domains_policy.csv")

_, _ = e.DeleteAllUsersByDomain(domain)
if !util.Array2DEquals(e.GetPolicy(), expectedPolicy) {
t.Errorf("policy in %s: %v, supposed to be %v\n", domain, e.GetPolicy(), expectedPolicy)
}
if !util.Array2DEquals(e.GetGroupingPolicy(), expectedGroupingPolicy) {
t.Errorf("grouping policy in %s: %v, supposed to be %v\n", domain, e.GetGroupingPolicy(), expectedGroupingPolicy)
}
}

func TestDeleteAllUsersByDomain(t *testing.T) {
testDeleteAllUsersByDomain(t, "domain1", [][]string{
{"admin", "domain2", "data2", "read"},
{"admin", "domain2", "data2", "write"},
}, [][]string{
{"bob", "admin", "domain2"},
})
testDeleteAllUsersByDomain(t, "domain2", [][]string{
{"admin", "domain1", "data1", "read"},
{"admin", "domain1", "data1", "write"},
}, [][]string{
{"alice", "admin", "domain1"},
})
}