Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merging to release-4-lts: [TT-10189/TT-10467] Add OAuthPurgeLapsedTokens (#5766) #5901

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 43 additions & 17 deletions gateway/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,24 @@
// The code below describes the Tyk Gateway API
// Version: 2.8.0
//
// Schemes: https, http
// Host: localhost
// BasePath: /tyk/
// Schemes: https, http
// Host: localhost
// BasePath: /tyk/
//
// Consumes:
// - application/json
// Consumes:
// - application/json
//
// Produces:
// - application/json
// Produces:
// - application/json
//
// Security:
// - api_key:
// Security:
// - api_key:
//
// SecurityDefinitions:
// api_key:
// type: apiKey
// name: X-Tyk-Authorization
// in: header
// SecurityDefinitions:
// api_key:
// type: apiKey
// name: X-Tyk-Authorization
// in: header
//
// swagger:meta
package gateway
Expand All @@ -34,7 +34,6 @@ import (
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
Expand All @@ -51,6 +50,7 @@ import (
"github.com/TykTechnologies/tyk/apidef"
"github.com/TykTechnologies/tyk/ctx"
"github.com/TykTechnologies/tyk/headers"
"github.com/TykTechnologies/tyk/internal/url"
"github.com/TykTechnologies/tyk/storage"
"github.com/TykTechnologies/tyk/user"

Expand All @@ -59,6 +59,14 @@ import (
"github.com/TykTechnologies/tyk/internal/uuid"
)

const (
oAuthClientTokensKeyPattern = "oauth-data.*oauth-client-tokens.*"
)

var (
ErrRequestMalformed = errors.New("request malformed")
)

// apiModifyKeySuccess represents when a Key modification was successful
//
// swagger:model apiModifyKeySuccess
Expand Down Expand Up @@ -340,7 +348,6 @@ func (gw *Gateway) doAddOrUpdate(keyName string, newSession *user.SessionState,
// remove from all stores, update to all stores, stores handle quotas separately though because they are localised! Keys will
// need to be managed by API, but only for GetDetail, GetList, UpdateKey and DeleteKey

//
func (gw *Gateway) setBasicAuthSessionPassword(session *user.SessionState) {
basicAuthHashAlgo := gw.basicAuthHashAlgo()

Expand Down Expand Up @@ -1319,7 +1326,6 @@ func (gw *Gateway) groupResetHandler(w http.ResponseWriter, r *http.Request) {
// was in the URL parameters, it will block until the reload is done.
// Otherwise, it won't block and fn will be called once the reload is
// finished.
//
func (gw *Gateway) resetHandler(fn func()) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var wg sync.WaitGroup
Expand Down Expand Up @@ -2046,6 +2052,26 @@ func (gw *Gateway) getOauthClientDetails(keyName, apiID string) (interface{}, in
return reportableClientData, http.StatusOK
}

func (gw *Gateway) oAuthTokensHandler(w http.ResponseWriter, r *http.Request) {
if !url.QueryHas(r.URL.Query(), "scope") {
doJSONWrite(w, http.StatusUnprocessableEntity, apiError("scope parameter is required"))
return
}

if r.URL.Query().Get("scope") != "lapsed" {
doJSONWrite(w, http.StatusBadRequest, apiError("unknown scope"))
return
}

err := gw.purgeLapsedOAuthTokens()
if err != nil {
doJSONWrite(w, http.StatusInternalServerError, apiError("error purging lapsed tokens"))
return
}

doJSONWrite(w, http.StatusOK, apiOk("lapsed tokens purged"))
}

// Delete Client
func (gw *Gateway) handleDeleteOAuthClient(keyName, apiID string) (interface{}, int) {
storageID := oauthClientStorageID(keyName)
Expand Down
68 changes: 68 additions & 0 deletions gateway/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2042,3 +2042,71 @@ func TestOrgKeyHandler_LastUpdated(t *testing.T) {
}},
}...)
}

func TestPurgeOAuthClientTokens(t *testing.T) {
conf := func(globalConf *config.Config) {
// set tokens to be expired after 1 second
globalConf.OauthTokenExpire = 1
// cleanup tokens older than 2 seconds
globalConf.OauthTokenExpiredRetainPeriod = 2
}

ts := StartTest(conf)
defer ts.Close()

t.Run("scope validation", func(t *testing.T) {
ts.Run(t, []test.TestCase{
{
AdminAuth: true,
Path: "/tyk/oauth/tokens/",
Method: http.MethodDelete,
Code: http.StatusUnprocessableEntity,
},
{
AdminAuth: true,
Path: "/tyk/oauth/tokens/",
QueryParams: map[string]string{"scope": "expired"},
Method: http.MethodDelete,
Code: http.StatusBadRequest,
},
}...)
})

assertTokensLen := func(t *testing.T, storageManager storage.Handler, storageKey string, expectedTokensLen int) {
nowTs := time.Now().Unix()
startScore := strconv.FormatInt(nowTs, 10)
tokens, _, err := storageManager.GetSortedSetRange(storageKey, startScore, "+inf")
assert.NoError(t, err)
assert.Equal(t, expectedTokensLen, len(tokens))
}

t.Run("scope=lapsed", func(t *testing.T) {
spec := ts.LoadTestOAuthSpec()

clientID1, clientID2 := uuid.New(), uuid.New()

ts.createOAuthClientIDAndTokens(t, spec, clientID1)
ts.createOAuthClientIDAndTokens(t, spec, clientID2)
storageKey1, storageKey2 := fmt.Sprintf("%s%s", prefixClientTokens, clientID1),
fmt.Sprintf("%s%s", prefixClientTokens, clientID2)

storageManager := ts.Gw.getGlobalStorageHandler(generateOAuthPrefix(spec.APIID), false)
storageManager.Connect()

assertTokensLen(t, storageManager, storageKey1, 3)
assertTokensLen(t, storageManager, storageKey2, 3)

time.Sleep(time.Second * 3)
ts.Run(t, test.TestCase{
ControlRequest: true,
AdminAuth: true,
Path: "/tyk/oauth/tokens",
QueryParams: map[string]string{"scope": "lapsed"},
Method: http.MethodDelete,
Code: http.StatusOK,
})

assertTokensLen(t, storageManager, storageKey1, 0)
assertTokensLen(t, storageManager, storageKey2, 0)
})
}
55 changes: 51 additions & 4 deletions gateway/oauth_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,21 @@ import (
"net/http"
"net/url"
"strings"
"sync"
"time"

"github.com/TykTechnologies/tyk/request"
"github.com/sirupsen/logrus"

"github.com/hashicorp/go-multierror"
"github.com/lonelycode/osin"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"

"github.com/TykTechnologies/tyk/internal/uuid"

"strconv"

"github.com/TykTechnologies/tyk/internal/uuid"

"github.com/TykTechnologies/tyk/headers"
tykerrors "github.com/TykTechnologies/tyk/internal/errors"
"github.com/TykTechnologies/tyk/storage"
"github.com/TykTechnologies/tyk/user"
)
Expand Down Expand Up @@ -1186,3 +1188,48 @@ func (r *RedisOsinStorageInterface) SetUser(username string, session *user.Sessi
return nil

}

func (gw *Gateway) purgeLapsedOAuthTokens() error {
if gw.GetConfig().OauthTokenExpiredRetainPeriod <= 0 {
return nil
}

redisCluster := &storage.RedisCluster{KeyPrefix: "", HashKeys: false, RedisController: gw.RedisController}
keys, err := redisCluster.ScanKeys(oAuthClientTokensKeyPattern)

if err != nil {
log.WithError(err).Debug("error while scanning for tokens")
return err
}

nowTs := time.Now().Unix()
// clean up expired tokens in sorted set (remove all tokens with score up to current timestamp minus retention)
cleanupStartScore := strconv.FormatInt(nowTs-int64(gw.GetConfig().OauthTokenExpiredRetainPeriod), 10)

var wg sync.WaitGroup

errs := make(chan error, len(keys))
for _, key := range keys {
wg.Add(1)
go func(k string) {
defer wg.Done()
if err := redisCluster.RemoveSortedSetRange(k, "-inf", cleanupStartScore); err != nil {
errs <- err
}
}(key)
}

// Wait for all goroutines to finish
wg.Wait()
close(errs)

combinedErr := &multierror.Error{
ErrorFormat: tykerrors.Formatter,
}

for err := range errs {
combinedErr = multierror.Append(combinedErr, err)
}

return combinedErr.ErrorOrNil()
}
86 changes: 86 additions & 0 deletions gateway/oauth_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
"bytes"
"encoding/json"
"net/url"
"path"
"reflect"
"strconv"
"strings"
"testing"

Expand Down Expand Up @@ -150,6 +152,41 @@ func (ts *Test) createTestOAuthClient(spec *APISpec, clientID string) OAuthClien
return testClient
}

func (ts *Test) createOAuthClientIDAndTokens(t *testing.T, spec *APISpec, clientID string) {
t.Helper()
ts.createTestOAuthClient(spec, clientID)

param := make(url.Values)
param.Set("response_type", "token")
param.Set("redirect_uri", authRedirectUri)
param.Set("client_id", clientID)
param.Set("client_secret", authClientSecret)
param.Set("key_rules", keyRules)

headers := map[string]string{
"Content-Type": "application/x-www-form-urlencoded",
}

for i := 0; i < 3; i++ {
resp, err := ts.Run(t, test.TestCase{
Path: path.Join(spec.Proxy.ListenPath, "/tyk/oauth/authorize-client/"),
Data: param.Encode(),
AdminAuth: true,
Headers: headers,
Method: http.MethodPost,
Code: http.StatusOK,
})
if err != nil {
t.Error(err)
}

response := map[string]interface{}{}
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
t.Fatal(err)
}
}
}

func TestOauthMultipleAPIs(t *testing.T) {
ts := StartTest(nil)
defer ts.Close()
Expand Down Expand Up @@ -1269,3 +1306,52 @@ func TestJSONToFormValues(t *testing.T) {
}
})
}

func TestPurgeOAuthClientTokensEvent(t *testing.T) {
conf := func(globalConf *config.Config) {
// set tokens to be expired after 1 second
globalConf.OauthTokenExpire = 1
// cleanup tokens older than 2 seconds
globalConf.OauthTokenExpiredRetainPeriod = 2
}

ts := StartTest(conf)
defer ts.Close()

assertTokensLen := func(t *testing.T, storageManager storage.Handler, storageKey string, expectedTokensLen int) {
nowTs := time.Now().Unix()
startScore := strconv.FormatInt(nowTs, 10)
tokens, _, err := storageManager.GetSortedSetRange(storageKey, startScore, "+inf")
assert.NoError(t, err)
assert.Equal(t, expectedTokensLen, len(tokens))
}

spec := ts.LoadTestOAuthSpec()

clientID1, clientID2 := uuid.New(), uuid.New()

ts.createOAuthClientIDAndTokens(t, spec, clientID1)
ts.createOAuthClientIDAndTokens(t, spec, clientID2)
storageKey1, storageKey2 := fmt.Sprintf("%s%s", prefixClientTokens, clientID1),
fmt.Sprintf("%s%s", prefixClientTokens, clientID2)

storageManager := ts.Gw.getGlobalStorageHandler(generateOAuthPrefix(spec.APIID), false)
storageManager.Connect()

assertTokensLen(t, storageManager, storageKey1, 3)
assertTokensLen(t, storageManager, storageKey2, 3)

time.Sleep(time.Second * 3)

// emit event

n := Notification{
Command: OAuthPurgeLapsedTokens,
Gw: ts.Gw,
}
ts.Gw.MainNotifier.Notify(n)

assertTokensLen(t, storageManager, storageKey1, 0)
assertTokensLen(t, storageManager, storageKey2, 0)

}
5 changes: 5 additions & 0 deletions gateway/redis_signals.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const (
NoticeGatewayDRLNotification NotificationCommand = "NoticeGatewayDRLNotification"
NoticeGatewayLENotification NotificationCommand = "NoticeGatewayLENotification"
KeySpaceUpdateNotification NotificationCommand = "KeySpaceUpdateNotification"
OAuthPurgeLapsedTokens NotificationCommand = "OAuthPurgeLapsedTokens"
)

// Notification is a type that encodes a message published to a pub sub channel (shared between implementations)
Expand Down Expand Up @@ -119,6 +120,10 @@ func (gw *Gateway) handleRedisEvent(v interface{}, handled func(NotificationComm
gw.reloadURLStructure(reloaded)
case KeySpaceUpdateNotification:
gw.handleKeySpaceEventCacheFlush(notif.Payload)
case OAuthPurgeLapsedTokens:
if err := gw.purgeLapsedOAuthTokens(); err != nil {
log.WithError(err).Errorf("error while purging tokens for event %s", OAuthPurgeLapsedTokens)
}
default:
pubSubLog.Warnf("Unknown notification command: %q", notif.Command)
return
Expand Down
1 change: 1 addition & 0 deletions gateway/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,7 @@ func (gw *Gateway) loadControlAPIEndpoints(muxer *mux.Router) {
r.HandleFunc("/oauth/clients/{apiID}", gw.oAuthClientHandler).Methods("GET", "DELETE")
r.HandleFunc("/oauth/clients/{apiID}/{keyName:[^/]*}", gw.oAuthClientHandler).Methods("GET", "DELETE")
r.HandleFunc("/oauth/clients/{apiID}/{keyName}/tokens", gw.oAuthClientTokensHandler).Methods("GET")
r.HandleFunc("/oauth/tokens", gw.oAuthTokensHandler).Methods(http.MethodDelete)

mainLog.Debug("Loaded API Endpoints")
}
Expand Down
Loading
Loading