diff --git a/api/users.go b/api/users.go index d87d7501..9a437072 100644 --- a/api/users.go +++ b/api/users.go @@ -85,7 +85,8 @@ type CreateServiceUserResponseDto struct { } type AssociateServiceUserPublicKeyRequestDto struct { - PublicKey string `json:"publicKey" validate:"required"` + PublicKey string `json:"publicKey" validate:"required"` + Kid *string `json:"kid,omitempty"` } type AssociateServiceUserPublicKeyResponseDto struct { diff --git a/client/user.go b/client/user.go index 94c136bd..85f4bb9f 100644 --- a/client/user.go +++ b/client/user.go @@ -23,7 +23,8 @@ type UserClient interface { Get(ctx context.Context, id uuid.UUID) (api.GetUserByIdResponseDto, error) Patch(ctx context.Context, id uuid.UUID, dto api.PatchUserRequestDto) error CreateServiceUser(ctx context.Context, username string) (uuid.UUID, error) - AssociateServiceUserPublicKey(ctx context.Context, serviceUserID uuid.UUID, publicKeyPEM string) (string, error) + AssociateServiceUserPublicKey(ctx context.Context, serviceUserID uuid.UUID, dto api.AssociateServiceUserPublicKeyRequestDto) (api.AssociateServiceUserPublicKeyResponseDto, error) + RemoveServiceUserPublicKey(ctx context.Context, serviceUserID uuid.UUID, kid string) error } func NewUserClient(transport *Transport) UserClient { @@ -159,28 +160,43 @@ func (c *userClient) CreateServiceUser(ctx context.Context, username string) (uu return responseDto.Id, nil } -func (c *userClient) AssociateServiceUserPublicKey(ctx context.Context, serviceUserID uuid.UUID, publicKeyPEM string) (string, error) { - jsonBytes, err := json.Marshal(api.AssociateServiceUserPublicKeyRequestDto{PublicKey: publicKeyPEM}) +func (c *userClient) AssociateServiceUserPublicKey(ctx context.Context, serviceUserID uuid.UUID, dto api.AssociateServiceUserPublicKeyRequestDto) (api.AssociateServiceUserPublicKeyResponseDto, error) { + jsonBytes, err := json.Marshal(dto) if err != nil { - return "", fmt.Errorf("marshaling dto: %w", err) + return api.AssociateServiceUserPublicKeyResponseDto{}, fmt.Errorf("marshaling dto: %w", err) } endpoint := fmt.Sprintf("/users/service-users/%s/keys", serviceUserID) request, err := c.transport.NewTenantRequest(ctx, http.MethodPost, endpoint, bytes.NewBuffer(jsonBytes)) if err != nil { - return "", fmt.Errorf("creating request: %w", err) + return api.AssociateServiceUserPublicKeyResponseDto{}, fmt.Errorf("creating request: %w", err) } response, err := c.transport.Do(request) if err != nil { - return "", fmt.Errorf("doing request: %w", err) + return api.AssociateServiceUserPublicKeyResponseDto{}, fmt.Errorf("doing request: %w", err) } defer response.Body.Close() //nolint:errcheck var responseDto api.AssociateServiceUserPublicKeyResponseDto if err := json.NewDecoder(response.Body).Decode(&responseDto); err != nil { - return "", fmt.Errorf("decoding response: %w", err) + return api.AssociateServiceUserPublicKeyResponseDto{}, fmt.Errorf("decoding response: %w", err) + } + + return responseDto, nil +} + +func (c *userClient) RemoveServiceUserPublicKey(ctx context.Context, serviceUserID uuid.UUID, kid string) error { + endpoint := fmt.Sprintf("/users/service-users/%s/keys/%s", serviceUserID, kid) + request, err := c.transport.NewTenantRequest(ctx, http.MethodDelete, endpoint, nil) + if err != nil { + return fmt.Errorf("creating request: %w", err) } - return responseDto.Kid, nil + response, err := c.transport.Do(request) + if err != nil { + return fmt.Errorf("doing request: %w", err) + } + defer response.Body.Close() //nolint:errcheck + return nil } diff --git a/client/user_test.go b/client/user_test.go index ad5642b1..2a37e4a1 100644 --- a/client/user_test.go +++ b/client/user_test.go @@ -138,6 +138,88 @@ func (s *UserClientSuite) TestGetUser_HappyPath() { s.Equal(response, responseDto) } +func (s *UserClientSuite) TestAssociateServiceUserPublicKey_HappyPath() { + // arrange + serviceUserId := uuid.New() + request := api.AssociateServiceUserPublicKeyRequestDto{ + PublicKey: "-----BEGIN PUBLIC KEY-----\nabc\n-----END PUBLIC KEY-----", + Kid: utils.Ptr("my-kid"), + } + response := api.AssociateServiceUserPublicKeyResponseDto{Kid: "my-kid"} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.Equal(http.MethodPost, r.Method) + s.Equal(fmt.Sprintf("/api/virtual-servers/test/users/service-users/%s/keys", serviceUserId), r.URL.Path) + + var requestDto api.AssociateServiceUserPublicKeyRequestDto + err := json.NewDecoder(r.Body).Decode(&requestDto) + s.NoError(err) + s.Equal(request, requestDto) + + err = json.NewEncoder(w).Encode(response) + s.NoError(err) + })) + defer server.Close() + + testee := NewClient(server.URL, "test").User() + + // act + responseDto, err := testee.AssociateServiceUserPublicKey(s.T().Context(), serviceUserId, request) + + // assert + s.Require().NoError(err) + s.Equal(response, responseDto) +} + +func (s *UserClientSuite) TestAssociateServiceUserPublicKey_NoKid() { + // arrange + serviceUserId := uuid.New() + request := api.AssociateServiceUserPublicKeyRequestDto{ + PublicKey: "-----BEGIN PUBLIC KEY-----\nabc\n-----END PUBLIC KEY-----", + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var requestDto api.AssociateServiceUserPublicKeyRequestDto + err := json.NewDecoder(r.Body).Decode(&requestDto) + s.NoError(err) + s.Nil(requestDto.Kid) + + err = json.NewEncoder(w).Encode(api.AssociateServiceUserPublicKeyResponseDto{Kid: "server-generated-kid"}) + s.NoError(err) + })) + defer server.Close() + + testee := NewClient(server.URL, "test").User() + + // act + responseDto, err := testee.AssociateServiceUserPublicKey(s.T().Context(), serviceUserId, request) + + // assert + s.Require().NoError(err) + s.Equal("server-generated-kid", responseDto.Kid) +} + +func (s *UserClientSuite) TestRemoveServiceUserPublicKey_HappyPath() { + // arrange + serviceUserId := uuid.New() + kid := "my-kid" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.Equal(http.MethodDelete, r.Method) + s.Equal(fmt.Sprintf("/api/virtual-servers/test/users/service-users/%s/keys/%s", serviceUserId, kid), r.URL.Path) + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + testee := NewClient(server.URL, "test").User() + + // act + err := testee.RemoveServiceUserPublicKey(s.T().Context(), serviceUserId, kid) + + // assert + s.Require().NoError(err) +} + func (s *UserClientSuite) TestPatchUser_HappyPath() { // arrange requestId := uuid.New() diff --git a/internal/commands/RemoveServiceUserPublicKey.go b/internal/commands/RemoveServiceUserPublicKey.go index ae0efd5f..fb385cec 100644 --- a/internal/commands/RemoveServiceUserPublicKey.go +++ b/internal/commands/RemoveServiceUserPublicKey.go @@ -17,7 +17,7 @@ import ( type RemoveServiceUserPublicKey struct { VirtualServerName string ServiceUserId uuid.UUID - PublicKey string + Kid string } func (a RemoveServiceUserPublicKey) LogRequest() bool { @@ -60,7 +60,7 @@ func HandleRemoveServiceUserPublicKey(ctx context.Context, command RemoveService credentialFilter := repositories.NewCredentialFilter(). UserId(user.Id()). Type(repositories.CredentialTypeServiceUserKey). - DetailPublicKey(command.PublicKey) + DetailKid(command.Kid) credential, err := dbContext.Credentials().FirstOrErr(ctx, credentialFilter) if err != nil { return nil, fmt.Errorf("getting credential: %w", err) diff --git a/internal/commands/RemoveServiceUserPublicKey_test.go b/internal/commands/RemoveServiceUserPublicKey_test.go index 451de9a2..2c877b06 100644 --- a/internal/commands/RemoveServiceUserPublicKey_test.go +++ b/internal/commands/RemoveServiceUserPublicKey_test.go @@ -87,7 +87,7 @@ func (s *RemoveServiceUserPublicKeyCommandSuite) TestHappyPath() { credential.Mock(now) credentialRepository := mocks.NewMockCredentialRepository(ctrl) credentialRepository.EXPECT().FirstOrErr(gomock.Any(), gomock.Cond(func(x *repositories.CredentialFilter) bool { - return x.GetDetailPublicKey() == "publicKey" && + return x.GetDetailKid() == "test-kid" && x.GetUserId() == serviceUser.Id() && x.GetType() == repositories.CredentialTypeServiceUserKey })).Return(credential, nil) @@ -97,7 +97,7 @@ func (s *RemoveServiceUserPublicKeyCommandSuite) TestHappyPath() { cmd := RemoveServiceUserPublicKey{ VirtualServerName: "virtualServer", ServiceUserId: serviceUser.Id(), - PublicKey: "publicKey", + Kid: "test-kid", } // act diff --git a/internal/handlers/users.go b/internal/handlers/users.go index 3b4652bb..2ebe90e8 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -856,6 +856,7 @@ func AssociateServiceUserPublicKey(w http.ResponseWriter, r *http.Request) { VirtualServerName: vsName, ServiceUserId: serviceUserId, PublicKey: dto.PublicKey, + Kid: dto.Kid, }) if err != nil { utils.HandleHttpError(w, err) @@ -872,6 +873,55 @@ func AssociateServiceUserPublicKey(w http.ResponseWriter, r *http.Request) { } } +// RemoveServiceUserPublicKey removes a public key from a service user by kid. +// @Summary Remove a public key from a service user +// @Tags Users +// @Produce plain +// @Param virtualServerName path string true "Virtual server name" default(keyline) +// @Param serviceUserId path string true "Service user ID" +// @Param kid path string true "Key ID" +// @Success 204 {string} string "No Content" +// @Failure 400 {string} string +// @Failure 404 {string} string +// @Router /api/virtual-servers/{virtualServerName}/users/service-users/{serviceUserId}/keys/{kid} [delete] +func RemoveServiceUserPublicKey(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + vsName, err := middlewares.GetVirtualServerName(ctx) + if err != nil { + utils.HandleHttpError(w, err) + return + } + + vars := mux.Vars(r) + serviceUserId, err := uuid.Parse(vars["serviceUserId"]) + if err != nil { + utils.HandleHttpError(w, utils.ErrInvalidUuid) + return + } + + kid := vars["kid"] + if kid == "" { + utils.HandleHttpError(w, fmt.Errorf("missing kid: %w", utils.ErrHttpBadRequest)) + return + } + + scope := middlewares.GetScope(ctx) + m := ioc.GetDependency[mediatr.Mediator](scope) + + _, err = mediatr.Send[*commands.RemoveServiceUserPublicKeyResponse](ctx, m, commands.RemoveServiceUserPublicKey{ + VirtualServerName: vsName, + ServiceUserId: serviceUserId, + Kid: kid, + }) + if err != nil { + utils.HandleHttpError(w, err) + return + } + + w.WriteHeader(http.StatusNoContent) +} + func PasskeyCreateChallenge(w http.ResponseWriter, r *http.Request) { ctx := r.Context() scope := middlewares.GetScope(ctx) diff --git a/internal/server/server.go b/internal/server/server.go index 09decad5..23b71840 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -181,6 +181,7 @@ func mapApiRoutes(r *mux.Router) { vsApiRouter.HandleFunc("/users/{userId}", handlers.PatchUser).Methods(http.MethodPatch, http.MethodOptions) vsApiRouter.HandleFunc("/users/service-users", handlers.CreateServiceUser).Methods(http.MethodPost, http.MethodOptions) vsApiRouter.HandleFunc("/users/service-users/{serviceUserId}/keys", handlers.AssociateServiceUserPublicKey).Methods(http.MethodPost, http.MethodOptions) + vsApiRouter.HandleFunc("/users/service-users/{serviceUserId}/keys/{kid}", handlers.RemoveServiceUserPublicKey).Methods(http.MethodDelete, http.MethodOptions) vsApiRouter.HandleFunc("/users/{userId}/passkeys/register/start", handlers.PasskeyCreateChallenge).Methods(http.MethodPost, http.MethodOptions) vsApiRouter.HandleFunc("/users/{userId}/passkeys/register/finish", handlers.PasskeyValidateCreateChallengeResponse).Methods(http.MethodPost, http.MethodOptions) vsApiRouter.HandleFunc("/users/{userId}/passkeys", handlers.ListPasskeys).Methods(http.MethodGet, http.MethodOptions) diff --git a/tests/e2e/serviceuserkey_test.go b/tests/e2e/serviceuserkey_test.go new file mode 100644 index 00000000..1e216eb1 --- /dev/null +++ b/tests/e2e/serviceuserkey_test.go @@ -0,0 +1,66 @@ +//go:build e2e + +package e2e + +import ( + "github.com/The127/Keyline/api" + "github.com/The127/Keyline/config" + "github.com/The127/Keyline/utils" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func init() { + for _, backend := range testBackends { + backend := backend + Describe("Service user key endpoints ["+backend.name+"]", Ordered, func() { + var h *harness + var serviceUserId = serviceUserUsername + + BeforeAll(func() { + if backend.dbMode == config.DatabaseModePostgres && !postgresBackendAvailable() { + Skip("Postgres not available") + } + h = newE2eTestHarness(backend.dbMode, serviceUserTokenSource) + }) + + AfterAll(func() { + if h != nil { + h.Close() + } + }) + + It("associates a public key with a caller-supplied kid and then removes it", func() { + // Create a fresh service user to attach keys to, so we do not clobber the + // harness's default service user. + suId, err := h.Client().User().CreateServiceUser(h.Ctx(), "key-flow-user-"+backend.name) + Expect(err).ToNot(HaveOccurred()) + Expect(suId).ToNot(BeZero()) + + wantKid := "e2e-caller-kid-" + backend.name + resp, err := h.Client().User().AssociateServiceUserPublicKey(h.Ctx(), suId, api.AssociateServiceUserPublicKeyRequestDto{ + PublicKey: "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAX3J/Yilw4CTcsOVW0BBasQwY9wuYwcJZkJliqAhNa5s=\n-----END PUBLIC KEY-----\n", + Kid: utils.Ptr(wantKid), + }) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.Kid).To(Equal(wantKid)) + + Expect(h.Client().User().RemoveServiceUserPublicKey(h.Ctx(), suId, resp.Kid)).To(Succeed()) + + _ = serviceUserId + }) + + It("associates a public key without a kid and server generates one", func() { + suId, err := h.Client().User().CreateServiceUser(h.Ctx(), "key-flow-autokid-user-"+backend.name) + Expect(err).ToNot(HaveOccurred()) + + resp, err := h.Client().User().AssociateServiceUserPublicKey(h.Ctx(), suId, api.AssociateServiceUserPublicKeyRequestDto{ + PublicKey: "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAX3J/Yilw4CTcsOVW0BBasQwY9wuYwcJZkJliqAhNa5t=\n-----END PUBLIC KEY-----\n", + }) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.Kid).ToNot(BeEmpty()) + }) + }) + } +} diff --git a/tests/integration/serviceuser_flow_test.go b/tests/integration/serviceuser_flow_test.go index d61662bc..66de11da 100644 --- a/tests/integration/serviceuser_flow_test.go +++ b/tests/integration/serviceuser_flow_test.go @@ -23,6 +23,7 @@ func init() { Describe("ServiceUser flow ["+backend.name+"]", Ordered, func() { var h *harness var serviceUserId uuid.UUID + var associatedKid string BeforeAll(func() { if backend.dbMode == config.DatabaseModePostgres && !postgresBackendAvailable() { @@ -55,8 +56,10 @@ func init() { ServiceUserId: serviceUserId, PublicKey: ed25519PublicKey, } - _, err := mediatr.Send[*commands.AssociateServiceUserPublicKeyResponse](h.Ctx(), h.Mediator(), req) + resp, err := mediatr.Send[*commands.AssociateServiceUserPublicKeyResponse](h.Ctx(), h.Mediator(), req) Expect(err).ToNot(HaveOccurred()) + Expect(resp.Kid).ToNot(BeEmpty()) + associatedKid = resp.Kid Expect(h.dbContext.SaveChanges(h.ctx)).ToNot(HaveOccurred()) }) @@ -65,7 +68,7 @@ func init() { req := commands.RemoveServiceUserPublicKey{ VirtualServerName: h.VirtualServer(), ServiceUserId: serviceUserId, - PublicKey: ed25519PublicKey, + Kid: associatedKid, } _, err := mediatr.Send[*commands.RemoveServiceUserPublicKeyResponse](h.Ctx(), h.Mediator(), req) Expect(err).ToNot(HaveOccurred())