Skip to content
This repository has been archived by the owner on Dec 22, 2023. It is now read-only.

Commit

Permalink
Check whether auth provider exists
Browse files Browse the repository at this point in the history
The name of the auth provider is checked for existence rather panicking.

refs #3
  • Loading branch information
cheungpat authored and Ben Lei committed Apr 11, 2016
1 parent 1fa0550 commit 9756437
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 6 deletions.
12 changes: 10 additions & 2 deletions handler/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ func (h *SignupHandler) Handle(payload *router.Payload, response *router.Respons
} else if p.Provider != "" {
// Get AuthProvider and authenticates the user
log.Debugf(`Client requested auth provider: "%v".`, p.Provider)
authProvider := h.ProviderRegistry.GetAuthProvider(p.Provider)
authProvider, err := h.ProviderRegistry.GetAuthProvider(p.Provider)
if err != nil {
response.Err = skyerr.NewInvalidArgument(err.Error(), []string{"provider"})
return
}
principalID, authData, err := authProvider.Login(p.AuthData)
if err != nil {
response.Err = skyerr.NewError(skyerr.InvalidCredentials, "unable to login with the given credentials")
Expand Down Expand Up @@ -251,7 +255,11 @@ func (h *LoginHandler) Handle(payload *router.Payload, response *router.Response
if p.Provider != "" {
// Get AuthProvider and authenticates the user
log.Debugf(`Client requested auth provider: "%v".`, p.Provider)
authProvider := h.ProviderRegistry.GetAuthProvider(p.Provider)
authProvider, err := h.ProviderRegistry.GetAuthProvider(p.Provider)
if err != nil {
response.Err = skyerr.NewInvalidArgument(err.Error(), []string{"provider"})
return
}
principalID, authData, err := authProvider.Login(p.AuthData)
if err != nil {
response.Err = skyerr.NewError(skyerr.InvalidCredentials, "invalid authentication information")
Expand Down
27 changes: 27 additions & 0 deletions handler/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"io/ioutil"
"net/http"
"testing"
"time"

Expand Down Expand Up @@ -231,6 +232,19 @@ func TestLoginHandlerWithProvider(t *testing.T) {
p.Database = txdb
})

Convey("login in non-existent provider", func() {
resp := r.POST(`{"provider": "com.non-existent", "auth_data": {"name": "johndoe"}}`)
So(resp.Body.Bytes(), ShouldEqualJSON, `{
"error": {
"code": 108,
"name": "InvalidArgument",
"info": {"arguments": ["provider"]},
"message": "no auth provider of name \"com.non-existent\""
}
}`)
So(resp.Code, ShouldEqual, http.StatusBadRequest)
})

Convey("login in existing", func() {
userinfo := skydb.NewProvidedAuthUserInfo("com.example:johndoe", map[string]interface{}{"name": "boo"})
conn.userinfo = &userinfo
Expand Down Expand Up @@ -411,6 +425,19 @@ func TestSignupHandlerWithProvider(t *testing.T) {
p.Database = txdb
})

Convey("signs up with non-existent provider", func() {
resp := r.POST(`{"provider": "com.non-existent", "auth_data": {"name": "johndoe"}}`)
So(resp.Body.Bytes(), ShouldEqualJSON, `{
"error": {
"code": 108,
"name": "InvalidArgument",
"info": {"arguments": ["provider"]},
"message": "no auth provider of name \"com.non-existent\""
}
}`)
So(resp.Code, ShouldEqual, http.StatusBadRequest)
})

Convey("signs up with user", func() {
resp := r.POST(`{"provider": "com.example", "auth_data": {"name": "johndoe"}}`)

Expand Down
6 changes: 5 additions & 1 deletion handler/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,11 @@ func (h *UserLinkHandler) Handle(payload *router.Payload, response *router.Respo

// Get AuthProvider and authenticates the user
log.Debugf(`Client requested auth provider: "%v".`, p.Provider)
authProvider := h.ProviderRegistry.GetAuthProvider(p.Provider)
authProvider, err := h.ProviderRegistry.GetAuthProvider(p.Provider)
if err != nil {
response.Err = skyerr.NewInvalidArgument(err.Error(), []string{"provider"})
return
}
principalID, authData, err := authProvider.Login(p.AuthData)
if err != nil {
response.Err = skyerr.NewError(skyerr.InvalidCredentials, "unable to login with the given credentials")
Expand Down
14 changes: 14 additions & 0 deletions handler/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package handler

import (
"net/http"
"testing"

"github.com/skygeario/skygear-server/handler/handlertest"
Expand Down Expand Up @@ -358,6 +359,19 @@ func TestUserLinkHandler(t *testing.T) {
p.UserInfo = &userInfo
})

Convey("link account with non-existent provider", func() {
resp := r.POST(`{"provider": "com.non-existent", "auth_data": {"name": "johndoe"}}`)
So(resp.Body.Bytes(), ShouldEqualJSON, `{
"error": {
"code": 108,
"name": "InvalidArgument",
"info": {"arguments": ["provider"]},
"message": "no auth provider of name \"com.non-existent\""
}
}`)
So(resp.Code, ShouldEqual, http.StatusBadRequest)
})

Convey("link account", func() {
resp := r.POST(`{"provider": "com.example", "auth_data": {"name": "johndoe"}}`)
So(resp.Body.Bytes(), ShouldEqualJSON, `{}`)
Expand Down
10 changes: 7 additions & 3 deletions plugin/provider/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package provider

import (
"fmt"
"sync"
)

Expand All @@ -41,9 +42,12 @@ func (r *Registry) RegisterAuthProvider(name string, p AuthProvider) {
}

// GetAuthProvider gets an AuthProvider from the registry.
func (r *Registry) GetAuthProvider(name string) AuthProvider {
func (r *Registry) GetAuthProvider(name string) (AuthProvider, error) {
r.mutex.RLock()
defer r.mutex.RUnlock()
provider := r.authProviders[name]
return provider
provider, ok := r.authProviders[name]
if !ok {
return nil, fmt.Errorf(`no auth provider of name "%s"`, name)
}
return provider, nil
}

0 comments on commit 9756437

Please sign in to comment.