Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions components/ambient-api-server/plugins/common/project_scope.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Package common provides shared helpers for api-server plugin handlers.
package common

import (
"fmt"
"net/http"
"regexp"

"github.com/openshift-online/rh-trex-ai/pkg/errors"
"github.com/openshift-online/rh-trex-ai/pkg/services"
)

var safeProjectIDPattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)

// ApplyProjectScope reads the project ID from the query parameter or the
// X-Ambient-Project header (query param takes precedence) and injects a
// project_id filter into listArgs.Search. Returns a validation error if the
// project ID contains unsafe characters.
func ApplyProjectScope(r *http.Request, listArgs *services.ListArguments) *errors.ServiceError {
projectID := r.URL.Query().Get("project_id")
if projectID == "" {
projectID = r.Header.Get("X-Ambient-Project")
}
if projectID == "" {
return nil
}
if !safeProjectIDPattern.MatchString(projectID) {
return errors.Validation("invalid project_id format")
}
projectFilter := fmt.Sprintf("project_id = '%s'", projectID)
if listArgs.Search != "" {
listArgs.Search = fmt.Sprintf("%s and (%s)", projectFilter, listArgs.Search)
} else {
listArgs.Search = projectFilter
}
return nil
}
151 changes: 151 additions & 0 deletions components/ambient-api-server/plugins/common/project_scope_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package common

import (
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/openshift-online/rh-trex-ai/pkg/services"
)

func newRequest(queryParams, headerProject string) *http.Request {
reqURL := "/sessions"
if queryParams != "" {
reqURL += "?" + queryParams
}
r := httptest.NewRequest(http.MethodGet, reqURL, nil)
if headerProject != "" {
r.Header.Set("X-Ambient-Project", headerProject)
}
return r
}

func newRequestWithProjectParam(projectID, headerProject string) *http.Request {
reqURL := "/sessions?project_id=" + url.QueryEscape(projectID)
r := httptest.NewRequest(http.MethodGet, reqURL, nil)
if headerProject != "" {
r.Header.Set("X-Ambient-Project", headerProject)
}
return r
}

func TestApplyProjectScope_HeaderOnly(t *testing.T) {
r := newRequest("", "my-project")
listArgs := services.NewListArguments(r.URL.Query())

err := ApplyProjectScope(r, listArgs)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if listArgs.Search != "project_id = 'my-project'" {
t.Errorf("expected project filter in search, got %q", listArgs.Search)
}
}

func TestApplyProjectScope_QueryParamOnly(t *testing.T) {
r := newRequest("project_id=query-proj", "")
listArgs := services.NewListArguments(r.URL.Query())

err := ApplyProjectScope(r, listArgs)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if listArgs.Search != "project_id = 'query-proj'" {
t.Errorf("expected project filter in search, got %q", listArgs.Search)
}
}

func TestApplyProjectScope_QueryParamTakesPrecedence(t *testing.T) {
r := newRequest("project_id=from-param", "from-header")
listArgs := services.NewListArguments(r.URL.Query())

err := ApplyProjectScope(r, listArgs)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if listArgs.Search != "project_id = 'from-param'" {
t.Errorf("expected query param to take precedence, got %q", listArgs.Search)
}
}

func TestApplyProjectScope_NoProjectReturnsNoFilter(t *testing.T) {
r := newRequest("", "")
listArgs := services.NewListArguments(r.URL.Query())

err := ApplyProjectScope(r, listArgs)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if listArgs.Search != "" {
t.Errorf("expected empty search, got %q", listArgs.Search)
}
}

func TestApplyProjectScope_CombinesWithExistingSearch(t *testing.T) {
r := newRequest("search=name+%3D+%27test%27", "my-project")
listArgs := services.NewListArguments(r.URL.Query())

err := ApplyProjectScope(r, listArgs)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if listArgs.Search != "project_id = 'my-project' and (name = 'test')" {
t.Errorf("expected combined search, got %q", listArgs.Search)
}
}

func TestApplyProjectScope_RejectsInjection(t *testing.T) {
payloads := []struct {
name string
value string
}{
{"SQL injection single quote", "x' OR 1=1--"},
{"SQL injection drop", "x'; DROP TABLE sessions;--"},
{"space", "test project"},
{"quote", "test'quote"},
{"semicolon", "proj;evil"},
{"percent", "proj%20evil"},
}

for _, tt := range payloads {
t.Run(tt.name+" via header", func(t *testing.T) {
r := newRequest("", tt.value)
listArgs := services.NewListArguments(r.URL.Query())
err := ApplyProjectScope(r, listArgs)
if err == nil {
t.Errorf("expected validation error for %q, got nil", tt.value)
}
})

t.Run(tt.name+" via query param", func(t *testing.T) {
r := newRequestWithProjectParam(tt.value, "")
listArgs := services.NewListArguments(r.URL.Query())
err := ApplyProjectScope(r, listArgs)
if err == nil {
t.Errorf("expected validation error for %q, got nil", tt.value)
}
})
}
}

func TestApplyProjectScope_AcceptsValidPatterns(t *testing.T) {
valid := []string{
"my-project",
"project_123",
"ABC-DEF",
"a",
"test-cp-verify-2",
}

for _, v := range valid {
t.Run(v, func(t *testing.T) {
r := newRequest("", v)
listArgs := services.NewListArguments(r.URL.Query())
err := ApplyProjectScope(r, listArgs)
if err != nil {
t.Errorf("expected no error for %q, got %v", v, err)
}
})
}
}
17 changes: 3 additions & 14 deletions components/ambient-api-server/plugins/projectSettings/handler.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
package projectSettings

import (
"fmt"
"net/http"
"regexp"

"github.com/gorilla/mux"

"github.com/ambient-code/platform/components/ambient-api-server/pkg/api/openapi"
"github.com/ambient-code/platform/components/ambient-api-server/plugins/common"
"github.com/openshift-online/rh-trex-ai/pkg/api/presenters"
"github.com/openshift-online/rh-trex-ai/pkg/errors"
"github.com/openshift-online/rh-trex-ai/pkg/handlers"
"github.com/openshift-online/rh-trex-ai/pkg/services"
)

var safeProjectIDPattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)

var _ handlers.RestHandler = projectSettingsHandler{}

type projectSettingsHandler struct {
Expand Down Expand Up @@ -94,16 +91,8 @@ func (h projectSettingsHandler) List(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

listArgs := services.NewListArguments(r.URL.Query())
if projectID := r.URL.Query().Get("project_id"); projectID != "" {
if !safeProjectIDPattern.MatchString(projectID) {
return nil, errors.Validation("invalid project_id format")
}
projectFilter := fmt.Sprintf("project_id = '%s'", projectID)
if listArgs.Search != "" {
listArgs.Search = fmt.Sprintf("%s and (%s)", projectFilter, listArgs.Search)
} else {
listArgs.Search = projectFilter
}
if err := common.ApplyProjectScope(r, listArgs); err != nil {
return nil, err
}
var items []ProjectSettings
paging, err := h.generic.List(ctx, "id", listArgs, &items)
Expand Down
17 changes: 3 additions & 14 deletions components/ambient-api-server/plugins/sessions/handler.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,19 @@
package sessions

import (
"fmt"
"net/http"
"regexp"

"github.com/gorilla/mux"

"github.com/ambient-code/platform/components/ambient-api-server/pkg/api/openapi"
"github.com/ambient-code/platform/components/ambient-api-server/plugins/common"
"github.com/openshift-online/rh-trex-ai/pkg/api/presenters"
"github.com/openshift-online/rh-trex-ai/pkg/auth"
"github.com/openshift-online/rh-trex-ai/pkg/errors"
"github.com/openshift-online/rh-trex-ai/pkg/handlers"
"github.com/openshift-online/rh-trex-ai/pkg/services"
)

var safeProjectIDPattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)

var _ handlers.RestHandler = sessionHandler{}

type sessionHandler struct {
Expand Down Expand Up @@ -192,16 +189,8 @@ func (h sessionHandler) List(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

listArgs := services.NewListArguments(r.URL.Query())
if projectID := r.URL.Query().Get("project_id"); projectID != "" {
if !safeProjectIDPattern.MatchString(projectID) {
return nil, errors.Validation("invalid project_id format")
}
projectFilter := fmt.Sprintf("project_id = '%s'", projectID)
if listArgs.Search != "" {
listArgs.Search = fmt.Sprintf("%s and (%s)", projectFilter, listArgs.Search)
} else {
listArgs.Search = projectFilter
}
if err := common.ApplyProjectScope(r, listArgs); err != nil {
return nil, err
}
var sessions []Session
paging, err := h.generic.List(ctx, "id", listArgs, &sessions)
Expand Down
Loading