Skip to content

Commit

Permalink
chore: address comments\
Browse files Browse the repository at this point in the history
  • Loading branch information
olevski committed Apr 16, 2024
1 parent 2477c9f commit c280a57
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 35 deletions.
7 changes: 6 additions & 1 deletion cmd/gateway/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"os"
"os/signal"
"runtime/debug"
"time"

"github.com/SwissDataScienceCenter/renku-gateway/internal/config"
Expand Down Expand Up @@ -64,7 +65,11 @@ func main() {
return c.NoContent(http.StatusOK)
})
// Version endpoint
version := os.Getenv("VERSION")
buildInfo, ok := debug.ReadBuildInfo()
version := ""
if ok && buildInfo != nil {
version = buildInfo.Main.Version
}
e.GET("/version", func(c echo.Context) error {
return c.String(http.StatusOK, version)
})
Expand Down
2 changes: 1 addition & 1 deletion internal/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ login:
enabled: true
secretKey:
providers:
id1:
renku:
default: true
issuer: https://renkulab.io/auth/realms/Renku
clientID: renku
Expand Down
29 changes: 16 additions & 13 deletions internal/config/config_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ login:
tokenEncryption:
secretKey: secret-key-from-secret-file
providers:
id1:
renku:
clientSecret: client-secret-from-secret-file
cookieEncodingKey: enc-key-from-secret-file
cookieHashKey: hash-key-from-secret-file
Expand All @@ -30,25 +30,27 @@ func TestReadConfig(t *testing.T) {
t.Setenv("CONFIG_LOCATION", tmpDir)
err := createSecretFile(path.Join(tmpDir, "secret_config.yaml"))
require.NoError(t, err)
providerID := "renku"
ch := NewConfigHandler()
config, err := ch.Config()
require.NoError(t, err)
assert.NotEqual(t, config, Config{})
assert.Len(t, config.Login.Providers, 1)
assert.Equal(t, "https://renkulab.io", config.Revproxy.RenkuBaseURL.String())
assert.Equal(t, RedactedString("secret-key-from-secret-file"), config.Login.TokenEncryption.SecretKey)
assert.Equal(t, RedactedString("client-secret-from-secret-file"), config.Login.Providers["id1"].ClientSecret)
assert.Equal(t, RedactedString("enc-key-from-secret-file"), config.Login.Providers["id1"].CookieEncodingKey)
assert.Equal(t, RedactedString("hash-key-from-secret-file"), config.Login.Providers["id1"].CookieHashKey)
assert.Equal(t, true, config.Login.Providers["id1"].Default)
assert.Equal(t, RedactedString("client-secret-from-secret-file"), config.Login.Providers[providerID].ClientSecret)
assert.Equal(t, RedactedString("enc-key-from-secret-file"), config.Login.Providers[providerID].CookieEncodingKey)
assert.Equal(t, RedactedString("hash-key-from-secret-file"), config.Login.Providers[providerID].CookieHashKey)
assert.Equal(t, true, config.Login.Providers[providerID].Default)
}

func TestReadConfigWithEnvVars(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("CONFIG_LOCATION", tmpDir)
err := createSecretFile(path.Join(tmpDir, "secret_config.yaml"))
require.NoError(t, err)
t.Setenv("GATEWAY_LOGIN_PROVIDERS_ID1_CLIENTSECRET", "env-var-secret")
providerID := "renku"
t.Setenv("GATEWAY_LOGIN_PROVIDERS_RENKU_CLIENTSECRET", "env-var-secret")
t.Setenv("GATEWAY_REVPROXY_RENKUBASEURL", "https://dev.renku.ch")
t.Setenv("GATEWAY_LOGIN_TOKENENCRYPTION_SECRETKEY", "token-encryption-key-12345678910")
ch := NewConfigHandler()
Expand All @@ -57,25 +59,26 @@ func TestReadConfigWithEnvVars(t *testing.T) {
assert.NotEqual(t, config, Config{})
assert.Len(t, config.Login.Providers, 1)
assert.Equal(t, "https://dev.renku.ch", config.Revproxy.RenkuBaseURL.String())
assert.Equal(t, RedactedString("env-var-secret"), config.Login.Providers["id1"].ClientSecret)
assert.Equal(t, RedactedString("enc-key-from-secret-file"), config.Login.Providers["id1"].CookieEncodingKey)
assert.Equal(t, RedactedString("hash-key-from-secret-file"), config.Login.Providers["id1"].CookieHashKey)
assert.Equal(t, RedactedString("env-var-secret"), config.Login.Providers[providerID].ClientSecret)
assert.Equal(t, RedactedString("enc-key-from-secret-file"), config.Login.Providers[providerID].CookieEncodingKey)
assert.Equal(t, RedactedString("hash-key-from-secret-file"), config.Login.Providers[providerID].CookieHashKey)
assert.Equal(t, RedactedString("token-encryption-key-12345678910"), config.Login.TokenEncryption.SecretKey)
assert.Equal(t, true, config.Login.Providers["id1"].Default)
assert.Equal(t, true, config.Login.Providers[providerID].Default)
}

func TestReadConfigWithEnvVarsNoSecretFile(t *testing.T) {
t.Setenv("GATEWAY_LOGIN_PROVIDERS_ID1_CLIENTSECRET", "env-var-secret")
t.Setenv("GATEWAY_LOGIN_PROVIDERS_RENKU_CLIENTSECRET", "env-var-secret")
t.Setenv("GATEWAY_LOGIN_TOKENENCRYPTION_SECRETKEY", "token-encryption-key-12345678910")
providerID := "renku"
ch := NewConfigHandler()
config, err := ch.Config()
require.NoError(t, err)
slog.Info("configuration data", "config", config)
assert.NotEqual(t, config, Config{})
assert.Len(t, config.Login.Providers, 1)
assert.Equal(t, "https://renkulab.io", config.Revproxy.RenkuBaseURL.String())
assert.Equal(t, RedactedString("env-var-secret"), config.Login.Providers["id1"].ClientSecret)
assert.Equal(t, RedactedString("env-var-secret"), config.Login.Providers[providerID].ClientSecret)
assert.Equal(t, RedactedString("token-encryption-key-12345678910"), config.Login.TokenEncryption.SecretKey)
assert.Equal(t, true, config.Login.Providers["id1"].Default)
assert.Equal(t, true, config.Login.Providers[providerID].Default)
}

10 changes: 6 additions & 4 deletions internal/models/serializable_int.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
package models

import "strconv"
import (
"encoding/json"
"strconv"
)

type SerializableInt int

func (s SerializableInt) MarshalBinary() (data []byte, err error) {
return s.MarshalText()
return json.Marshal(s)
}

func (s *SerializableInt) UnmarshalBinary(data []byte) error {
return s.UnmarshalText(data)
return json.Unmarshal(data, s)
}

func (s SerializableInt) MarshalText() (data []byte, err error) {
Expand All @@ -24,4 +27,3 @@ func (s *SerializableInt) UnmarshalText(data []byte) error {
*s = SerializableInt(val)
return nil
}

2 changes: 1 addition & 1 deletion internal/models/serializable_int_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func TestSerializableIntText(t *testing.T) {

func TestSerializableIntBinary(t *testing.T) {
var a SerializableInt = 10
data, err := a.MarshalBinary()
data, err := SerializableInt.MarshalBinary(a)
require.NoError(t, err)
var b SerializableInt
err = b.UnmarshalBinary(data)
Expand Down
2 changes: 1 addition & 1 deletion internal/models/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (s *Session) SaveTokens(ctx context.Context, accessToken OauthToken, refres
if state != "" {
_, found := s.ProviderIDs.Delete(state)
if !found {
return fmt.Errorf("could not find a matching state parameter in the session")
return fmt.Errorf("the session does not contain the state parameter that was provided")
}
}
if accessToken.ID != refreshToken.ID || accessToken.ID != idToken.ID {
Expand Down
11 changes: 5 additions & 6 deletions internal/models/session_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/redis/go-redis/v9"
)


type SessionHandlerOption func(*SessionHandler)

type SessionHandler struct {
Expand All @@ -23,7 +22,7 @@ type SessionHandler struct {
recreateSessionIfExpired bool
contextKey string
headerKey string
idQueryKey string
idQueryKey string
}

func (s *SessionHandler) Cookie(session *Session) *http.Cookie {
Expand All @@ -43,7 +42,7 @@ func (s *SessionHandler) Remove(c echo.Context) error {
if s.sessionStore == nil {
return fmt.Errorf("cannot remove a session when the session store is not defined")
}
sessionIDs := mapset.NewSet[string]()
sessionIDs := mapset.NewSet[string]()
sessionID := c.Request().Header.Get(s.headerKey)
// remove the request header if set
if sessionID != "" {
Expand Down Expand Up @@ -76,7 +75,7 @@ func (s *SessionHandler) Remove(c echo.Context) error {
err = nil
}
}
return err
return err
}

func (s *SessionHandler) Load(c echo.Context) (Session, error) {
Expand All @@ -97,7 +96,7 @@ func (s *SessionHandler) Load(c echo.Context) (Session, error) {
}
var sessionID string = ""
// the CLI will pass in the session ID as Basic Auth to access Gitlab, try to see if that is the case
basicAuthUser, basicAuthPwd, ok := c.Request().BasicAuth()
basicAuthUser, basicAuthPwd, ok := c.Request().BasicAuth()
if ok {
switch basicAuthUser {
case "renku":
Expand Down Expand Up @@ -152,7 +151,7 @@ func (s *SessionHandler) Create(c echo.Context) (Session, error) {
c.Set(s.contextKey, session)
c.SetCookie(s.Cookie(&session))
c.Request().AddCookie(s.Cookie(&session))
return session, nil
return session, nil
}

func (s *SessionHandler) Middleware() echo.MiddlewareFunc {
Expand Down
5 changes: 5 additions & 0 deletions internal/revproxy/proxies.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ func proxyFromURL(url *url.URL) echo.MiddlewareFunc {
return middleware.ProxyWithConfig(mwconfig)
}

// registerCoreSvcProxies creates and registers all proxies for the core service. The core service is special
// because it runs multiple API versions of itself at the same time and the gateway has to route between them.
// In addition, and even more importantly, the core service requires sticky sessions between different pods of
// a deployment that runs the same version of the API. So we have to implement our own custom load balancer that
// can distinguish between different pods that sit behind a K8s service and consistently send requests to the same pod.
func registerCoreSvcProxies(ctx context.Context, e *echo.Echo, revproxyConfig *config.RevproxyConfig, mwFuncs ...echo.MiddlewareFunc) {
if len(revproxyConfig.RenkuServices.Core.ServicePaths) != len(revproxyConfig.RenkuServices.Core.ServiceNames) {
e.Logger.Fatalf("Failed proxy setup for core service, number of paths (%d) and services (%d) provided does not match", len(revproxyConfig.RenkuServices.Core.ServicePaths), len(revproxyConfig.RenkuServices.Core.ServiceNames))
Expand Down
2 changes: 2 additions & 0 deletions internal/stickysessions/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ func NewEndpointStoreFromEndpointItems(input []EndpointStoreItem, includeNonRead
func NewEndpointStoreFromEndpointSlices(input []*discoveryV1.EndpointSlice, containerPortName string) *EndpointStore {
items := []EndpointStoreItem{}
for _, endpointSlice := range input {
// the loop variable is a pointer in Go, so as the loop progresses the same pointer is used
// and it points to different things every iteration. That is why we have to make a copy here.
es := endpointSlice
items = append(items, NewEndpointStoreItems(es, containerPortName)...)
}
Expand Down
35 changes: 27 additions & 8 deletions internal/tokenrefresher/tokenrefresher.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,23 @@ func refreshExpiringTokens(
slog.Error("GetExpiringAccessTokenIDs failed", "error", err)
return err
}

errorTokenIDs := []string{}
// For each token id expiring in the next minsToExpiration minutes
for _, expiringTokenID := range expiringTokenIDs {

// Get the refresh and access tokens associated with the token ID
myRefreshToken, err := tokenStore.GetRefreshToken(ctx, expiringTokenID)
if err != nil {
slog.Error("GetRefreshToken failed", "error", err)
return err
errorTokenIDs = append(errorTokenIDs, expiringTokenID)
continue
}

myAccessToken, err := tokenStore.GetAccessToken(ctx, expiringTokenID)
if err != nil {
slog.Error("GetAccessToken failed", "error", err)
return err
errorTokenIDs = append(errorTokenIDs, expiringTokenID)
continue
}

// Set the parameters required to refresh the tokens
Expand All @@ -109,18 +111,22 @@ func refreshExpiringTokens(

// Send the POST request to refresh the tokens
resp, err := http.PostForm(myAccessToken.TokenURL, params)
if resp != nil {
defer resp.Body.Close()
}
if err != nil {
slog.Error("Request Failed", "error", err)
return err
errorTokenIDs = append(errorTokenIDs, expiringTokenID)
continue
}
defer resp.Body.Close()

// Decode JSON returned from the POST refresh request into a tokenResponse
token := tokenResponse{}
err = json.NewDecoder(resp.Body).Decode(&token)
if err != nil {
slog.Error("Decoding body failed", "error", err)
return err
errorTokenIDs = append(errorTokenIDs, expiringTokenID)
continue
}

slog.Info("New token received")
Expand Down Expand Up @@ -151,20 +157,33 @@ func refreshExpiringTokens(
TokenURL: myAccessToken.TokenURL,
Type: myAccessToken.Type,
})
if err != nil {
errorTokenIDs = append(errorTokenIDs, expiringTokenID)
continue
}

err = tokenStore.SetRefreshToken(ctx, models.OauthToken{
ID: myRefreshToken.ID,
Value: token.RefreshToken,
ExpiresAt: refreshTokenExpiration,
})
if err != nil {
errorTokenIDs = append(errorTokenIDs, expiringTokenID)
continue
}
}

slog.Info(
fmt.Sprintf(
"%v expiring access tokens refreshed, evaluating again in %v minutes",
"%v/%v expiring access tokens refreshed, evaluating again in %v minutes",
len(expiringTokenIDs) - len(errorTokenIDs),
len(expiringTokenIDs),
minsToExpiration,
),
)
return err

if len(errorTokenIDs) != 0 {
return fmt.Errorf("some token IDs could not be refreshed %v", errorTokenIDs)
}
return nil
}

0 comments on commit c280a57

Please sign in to comment.