diff --git a/lib/devices_test.go b/lib/devices_test.go index 0eb12e5..7f6e30a 100644 --- a/lib/devices_test.go +++ b/lib/devices_test.go @@ -19,6 +19,7 @@ package lib import ( "context" "fmt" + "github.com/SENERGY-Platform/permission-search/lib/auth" "github.com/SENERGY-Platform/permission-search/lib/model" "github.com/opensearch-project/opensearch-go/opensearchutil" "reflect" @@ -28,6 +29,55 @@ import ( "time" ) +func TestDeviceImmediatelySearchable(t *testing.T) { + if testing.Short() { + t.Skip("short") + } + wg := &sync.WaitGroup{} + defer wg.Wait() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, q, w, err := getTestEnv(ctx, wg, t) + if err != nil { + fmt.Println(err) + return + } + + type Attribute struct { + Key string `json:"key"` + Value string `json:"value"` + Origin string `json:"origin"` + } + + type Device struct { + Id string `json:"id"` + LocalId string `json:"local_id"` + Name string `json:"name"` + Attributes []Attribute `json:"attributes"` + DeviceTypeId string `json:"device_type_id"` + } + + deviceMsg, deviceCmd, err := getDeviceTestObj("device1", Device{ + Id: "device1", + Name: "device1", + Attributes: nil, + }) + if err != nil { + t.Error(err) + return + } + err = w.UpdateFeatures("devices", deviceMsg, deviceCmd) + if err != nil { + t.Error(err) + return + } + + err = q.CheckUserOrGroupFromAuthToken(auth.Token{Sub: "testOwner"}, "devices", "device1", "r") + if err != nil { + t.Error(err) + } +} + func TestDeviceComplexPathAndAttribute(t *testing.T) { if testing.Short() { t.Skip("short") diff --git a/lib/query/query.go b/lib/query/query.go index 4a559bf..7cadd64 100644 --- a/lib/query/query.go +++ b/lib/query/query.go @@ -30,6 +30,7 @@ import ( "log" "net/http" "runtime/debug" + "slices" "sort" "strconv" "strings" @@ -232,42 +233,50 @@ func (this *Query) CheckUserOrGroup(tokenStr string, kind string, resource strin func (this *Query) CheckUserOrGroupFromAuthToken(token auth.Token, kind string, resource string, rights string) (err error) { pureId, _ := modifier.SplitModifier(resource) - ctx := this.getTimeout() - filter := getRightsQuery(rights, token.GetUserId(), token.GetRoles()) - filter = append(filter, map[string]interface{}{ - "term": map[string]interface{}{ - "resource": pureId, - }, - }) - resp, err := this.opensearchClient.Search(this.opensearchClient.Search.WithIndex(kind), - this.opensearchClient.Search.WithContext(ctx), - this.opensearchClient.Search.WithVersion(true), - this.opensearchClient.Search.WithSize(1), - this.opensearchClient.Search.WithBody(opensearchutil.NewJSONReader(map[string]interface{}{ - "query": map[string]interface{}{ - "bool": map[string]interface{}{ - "filter": filter, - }, - }, - })), - ) + e, _, err := this.GetResourceEntry(kind, pureId) + if errors.Is(err, model.ErrNotFound) { + return model.ErrAccessDenied + } if err != nil { return err } - defer resp.Body.Close() - if resp.IsError() { - return errors.New(resp.String()) + + if rights == "" { + rights = "r" } - pl := model.SearchResult[model.Entry]{} - err = json.NewDecoder(resp.Body).Decode(&pl) - if err != nil { - return err + + user := token.GetUserId() + groups := token.GetRoles() + for _, right := range rights { + switch right { + case 'a': + if !slices.Contains(e.AdminUsers, user) && !containsAny(e.AdminGroups, groups) { + return model.ErrAccessDenied + } + case 'r': + if !slices.Contains(e.ReadUsers, user) && !containsAny(e.ReadGroups, groups) { + return model.ErrAccessDenied + } + case 'w': + if !slices.Contains(e.WriteUsers, user) && !containsAny(e.WriteGroups, groups) { + return model.ErrAccessDenied + } + case 'x': + if !slices.Contains(e.ExecuteUsers, user) && !containsAny(e.ExecuteGroups, groups) { + return model.ErrAccessDenied + } + } } + return nil +} - if pl.Hits.Total.Value == 0 { - err = model.ErrAccessDenied +func containsAny(list []string, any []string) bool { + for _, e := range any { + if slices.Contains(list, e) { + return true + } } - return + return false } func (this *Query) CheckListUserOrGroup(token auth.Token, kind string, ids []string, rights string) (allowed map[string]bool, err error) { diff --git a/lib/worker/command.go b/lib/worker/command.go index d2eaa72..885d3dd 100644 --- a/lib/worker/command.go +++ b/lib/worker/command.go @@ -221,6 +221,7 @@ func (this *Worker) UpdateFeatures(kind string, msg []byte, command model.Comman client.Index.WithIfPrimaryTerm(int(version.PrimaryTerm)), client.Index.WithIfSeqNo(int(version.SeqNo)), client.Index.WithContext(ctx), + //client.Index.WithRefresh("wait_for"), //to slow, don't use ) if err != nil { debug.PrintStack() @@ -239,6 +240,7 @@ func (this *Worker) UpdateFeatures(kind string, msg []byte, command model.Comman opensearchutil.NewJSONReader(entry), client.Index.WithDocumentID(command.Id), client.Index.WithContext(ctx), + //client.Index.WithRefresh("wait_for"), //to slow, don't use ) if err != nil { debug.PrintStack()