From 1f15c209c6800778b898d1fc8e16dfa0160f4790 Mon Sep 17 00:00:00 2001 From: Jeremy Mouton Date: Mon, 20 Apr 2026 02:03:36 +0200 Subject: [PATCH] feat: implement SSO functionality with OAuth2 integration, including redirect and callback handling --- application/auth/auth_application.go | 4 + application/auth/dto/sso.go | 18 ++ application/auth/sso_callback.go | 218 ++++++++++++++++++ application/auth/sso_config.go | 80 +++++++ application/auth/sso_redirect.go | 38 +++ application/ports/sso.go | 8 + config-example.yaml | 16 ++ domain/sso_provider.go | 24 ++ go.mod | 1 + go.sum | 2 + infrastructure/config/config.go | 15 ++ infrastructure/config/sso_flags.go | 76 ++++++ infrastructure/deps.go | 1 + .../files/20260420000000_sso_provider.go | 60 +++++ .../persistence/sso_provider_pers.go | 37 +++ interfaces/cli/server/server.go | 4 + interfaces/http/v1/auth/dtos/sso_dtos.go | 15 ++ interfaces/http/v1/auth/handlers.go | 32 +++ interfaces/http/v1/auth/router.go | 16 ++ interfaces/http/v1/router.go | 1 + .../http/v1/user/dtos/profile_request.go | 13 +- interfaces/http/v1/user/handlers.go | 11 + interfaces/http/v1/user/router.go | 2 + 23 files changed, 686 insertions(+), 6 deletions(-) create mode 100644 application/auth/dto/sso.go create mode 100644 application/auth/sso_callback.go create mode 100644 application/auth/sso_config.go create mode 100644 application/auth/sso_redirect.go create mode 100644 application/ports/sso.go create mode 100644 domain/sso_provider.go create mode 100644 infrastructure/config/sso_flags.go create mode 100644 infrastructure/migration/files/20260420000000_sso_provider.go create mode 100644 infrastructure/persistence/sso_provider_pers.go create mode 100644 interfaces/http/v1/auth/dtos/sso_dtos.go diff --git a/application/auth/auth_application.go b/application/auth/auth_application.go index 7c6285f..2007477 100644 --- a/application/auth/auth_application.go +++ b/application/auth/auth_application.go @@ -2,6 +2,7 @@ package auth import ( "github.com/labbs/nexo/application/ports" + "github.com/labbs/nexo/domain" "github.com/labbs/nexo/infrastructure/config" "github.com/rs/zerolog" ) @@ -13,6 +14,9 @@ type AuthApplication struct { SessionApplication ports.SessionPort SpaceApplication ports.SpacePort DocumentApplication ports.DocumentPort + OAuthProviderPers domain.OAuthProviderPers + + oidcUserinfoEndpoint string // cached from OIDC discovery } func NewAuthApplication(config config.Config, logger zerolog.Logger) *AuthApplication { diff --git a/application/auth/dto/sso.go b/application/auth/dto/sso.go new file mode 100644 index 0000000..70eb1c4 --- /dev/null +++ b/application/auth/dto/sso.go @@ -0,0 +1,18 @@ +package dto + +import "github.com/gofiber/fiber/v2" + +type SSORedirectOutput struct { + URL string + State string +} + +type SSOCallbackInput struct { + Code string + State string + Context *fiber.Ctx +} + +type SSOCallbackOutput struct { + Token string +} diff --git a/application/auth/sso_callback.go b/application/auth/sso_callback.go new file mode 100644 index 0000000..c6cf8c6 --- /dev/null +++ b/application/auth/sso_callback.go @@ -0,0 +1,218 @@ +package auth + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/labbs/nexo/application/auth/dto" + d "github.com/labbs/nexo/application/document/dto" + s "github.com/labbs/nexo/application/session/dto" + spdto "github.com/labbs/nexo/application/space/dto" + u "github.com/labbs/nexo/application/user/dto" + "github.com/labbs/nexo/domain" + "github.com/labbs/nexo/infrastructure/helpers/tokenutil" + "golang.org/x/oauth2" +) + +type oidcUserInfo struct { + Sub string `json:"sub"` + Email string `json:"email"` + PreferredUsername string `json:"preferred_username"` + Name string `json:"name"` +} + +func (c *AuthApplication) SSOCallback(input dto.SSOCallbackInput) (*dto.SSOCallbackOutput, error) { + logger := c.Logger.With().Str("component", "application.auth.sso_callback").Logger() + + if !c.Config.SSO.Enabled { + return nil, fmt.Errorf("SSO is not enabled") + } + + if err := c.verifyState(input.State); err != nil { + logger.Warn().Err(err).Msg("invalid SSO state") + return nil, fmt.Errorf("invalid state parameter") + } + + oauthCfg := c.buildOAuthConfig() + token, err := oauthCfg.Exchange(context.Background(), input.Code) + if err != nil { + logger.Error().Err(err).Msg("failed to exchange OAuth code") + return nil, fmt.Errorf("failed to exchange authorization code: %w", err) + } + + userInfo, err := c.fetchUserInfo(oauthCfg, token) + if err != nil { + logger.Error().Err(err).Msg("failed to fetch userinfo") + return nil, fmt.Errorf("failed to fetch user info: %w", err) + } + + if userInfo.Sub == "" { + return nil, fmt.Errorf("provider did not return a user identifier") + } + + user, err := c.findOrCreateSSOUser(userInfo) + if err != nil { + logger.Error().Err(err).Str("sub", userInfo.Sub).Msg("failed to find or create SSO user") + return nil, fmt.Errorf("failed to resolve user: %w", err) + } + + sessionResult, err := c.SessionApplication.Create(s.CreateSessionInput{ + UserId: user.Id, + UserAgent: input.Context.Get("User-Agent"), + IpAddress: input.Context.IP(), + ExpiresAt: time.Now().Add(time.Minute * time.Duration(c.Config.Session.ExpirationMinutes)), + }) + if err != nil { + logger.Error().Err(err).Str("user_id", user.Id).Msg("failed to create session") + return nil, fmt.Errorf("failed to create session: %w", err) + } + + accessToken, err := tokenutil.CreateAccessToken(user.Id, sessionResult.SessionId, c.Config) + if err != nil { + logger.Error().Err(err).Str("user_id", user.Id).Msg("failed to create access token") + return nil, fmt.Errorf("failed to create access token: %w", err) + } + + return &dto.SSOCallbackOutput{Token: accessToken}, nil +} + +// verifyState validates the HMAC-signed state parameter. +func (c *AuthApplication) verifyState(state string) error { + parts := strings.SplitN(state, ".", 2) + if len(parts) != 2 { + return fmt.Errorf("malformed state") + } + nonce, sig := parts[0], parts[1] + mac := hmac.New(sha256.New, []byte(c.Config.Session.SecretKey)) + mac.Write([]byte(nonce)) + expected := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) + if !hmac.Equal([]byte(sig), []byte(expected)) { + return fmt.Errorf("state signature mismatch") + } + return nil +} + +// fetchUserInfo calls the provider's userinfo endpoint using the access token. +func (c *AuthApplication) fetchUserInfo(oauthCfg *oauth2.Config, token *oauth2.Token) (*oidcUserInfo, error) { + // Prefer the endpoint discovered via OIDC; fall back to /userinfo. + endpoints := []string{strings.TrimRight(c.Config.SSO.IssuerURL, "/") + "/userinfo"} + if c.oidcUserinfoEndpoint != "" && c.oidcUserinfoEndpoint != endpoints[0] { + endpoints = append([]string{c.oidcUserinfoEndpoint}, endpoints...) + } + + client := oauthCfg.Client(context.Background(), token) + var lastErr error + for _, url := range endpoints { + resp, err := client.Get(url) + if err != nil { + lastErr = err + continue + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + lastErr = fmt.Errorf("userinfo endpoint returned %d", resp.StatusCode) + continue + } + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + var info oidcUserInfo + if err := json.Unmarshal(body, &info); err != nil { + return nil, err + } + return &info, nil + } + return nil, lastErr +} + +// findOrCreateSSOUser finds an existing user linked to the SSO provider, or creates a new one. +func (c *AuthApplication) findOrCreateSSOUser(info *oidcUserInfo) (domain.User, error) { + // 1. Check if provider link already exists + op, err := c.OAuthProviderPers.FindByProviderAndSubject("oidc", info.Sub) + if err == nil { + // Link exists — fetch the user + resp, err := c.UserApplication.GetByUserId(u.GetByUserIdInput{UserId: op.UserId}) + if err != nil { + return domain.User{}, err + } + return *resp.User, nil + } + + // 2. Try to link an existing user by email + var user domain.User + if info.Email != "" { + resp, err := c.UserApplication.GetByEmail(u.GetByEmailInput{Email: info.Email}) + if err == nil { + user = *resp.User + } + } + + // 3. No existing user — auto-create one + if user.Id == "" { + username := info.PreferredUsername + if username == "" { + username = strings.Split(info.Email, "@")[0] + } + if username == "" { + username = "user-" + info.Sub[:8] + } + + created, err := c.UserApplication.Create(u.CreateUserInput{ + User: domain.User{ + Username: username, + Email: info.Email, + Password: "", // no password for SSO users + Active: true, + }, + }) + if err != nil { + return domain.User{}, fmt.Errorf("failed to create SSO user: %w", err) + } + user = *created.User + + // Create private space + welcome document (mirrors Register use case) + space, err := c.SpaceApplication.CreatePrivateSpaceForUser(spdto.CreatePrivateSpaceForUserInput{UserId: user.Id}) + if err == nil { + welcomeContent := []d.Block{{ + ID: "welcome-1", + Type: d.BlockTypeParagraph, + Props: map[string]any{ + "textColor": "default", "backgroundColor": "default", "textAlignment": "left", + }, + Content: []d.InlineContent{{ + Type: "text", Text: "This is your private space. Start adding your notes and documents here!", + Styles: map[string]bool{}, + }}, + Children: []d.Block{}, + }} + _, _ = c.DocumentApplication.CreateDocument(d.CreateDocumentInput{ + Name: "Welcome to Your Private Space", + UserId: user.Id, + SpaceId: space.Space.Id, + Content: welcomeContent, + }) + } + } + + // 4. Create the provider link + _, err = c.OAuthProviderPers.Create(domain.OAuthProvider{ + UserId: user.Id, + Provider: "oidc", + ProviderUserId: info.Sub, + Email: info.Email, + }) + if err != nil { + return domain.User{}, fmt.Errorf("failed to store SSO provider link: %w", err) + } + + return user, nil +} diff --git a/application/auth/sso_config.go b/application/auth/sso_config.go new file mode 100644 index 0000000..acda2a2 --- /dev/null +++ b/application/auth/sso_config.go @@ -0,0 +1,80 @@ +package auth + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "golang.org/x/oauth2" +) + +type oidcDiscovery struct { + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserinfoEndpoint string `json:"userinfo_endpoint"` +} + +// discoverOIDC fetches the provider's discovery document and caches the endpoints. +func (c *AuthApplication) discoverOIDC(ctx context.Context) (*oidcDiscovery, error) { + discoveryURL := strings.TrimRight(c.Config.SSO.IssuerURL, "/") + "/.well-known/openid-configuration" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryURL, nil) + if err != nil { + return nil, err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("OIDC discovery request failed: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("OIDC discovery returned %d", resp.StatusCode) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + var d oidcDiscovery + if err := json.Unmarshal(body, &d); err != nil { + return nil, err + } + return &d, nil +} + +func (c *AuthApplication) buildOAuthConfig() *oauth2.Config { + authURL := c.Config.SSO.AuthURL + tokenURL := c.Config.SSO.TokenURL + + // If explicit URLs not provided, attempt discovery at startup. + // Errors here are non-fatal; the callback will fail gracefully. + if authURL == "" || tokenURL == "" { + if disc, err := c.discoverOIDC(context.Background()); err == nil { + if authURL == "" { + authURL = disc.AuthorizationEndpoint + } + if tokenURL == "" { + tokenURL = disc.TokenEndpoint + } + // Cache the userinfo endpoint for use in fetchUserInfo. + c.oidcUserinfoEndpoint = disc.UserinfoEndpoint + } else { + c.Logger.Warn().Err(err).Msg("OIDC discovery failed — check SSO config") + } + } + + scopes := []string{"openid", "email", "profile"} + scopes = append(scopes, c.Config.SSO.Scopes...) + + return &oauth2.Config{ + ClientID: c.Config.SSO.ClientID, + ClientSecret: c.Config.SSO.ClientSecret, + RedirectURL: c.Config.SSO.RedirectURL, + Scopes: scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: authURL, + TokenURL: tokenURL, + }, + } +} diff --git a/application/auth/sso_redirect.go b/application/auth/sso_redirect.go new file mode 100644 index 0000000..4dd97ec --- /dev/null +++ b/application/auth/sso_redirect.go @@ -0,0 +1,38 @@ +package auth + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" + + "github.com/labbs/nexo/application/auth/dto" + "golang.org/x/oauth2" +) + +func (c *AuthApplication) SSORedirect() (*dto.SSORedirectOutput, error) { + if !c.Config.SSO.Enabled { + return nil, fmt.Errorf("SSO is not enabled") + } + + raw := make([]byte, 32) + if _, err := rand.Read(raw); err != nil { + return nil, fmt.Errorf("failed to generate state: %w", err) + } + + // state = base64(nonce) + "." + base64(HMAC(nonce)) + nonce := base64.RawURLEncoding.EncodeToString(raw) + mac := hmac.New(sha256.New, []byte(c.Config.Session.SecretKey)) + mac.Write([]byte(nonce)) + sig := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) + state := nonce + "." + sig + + oauthCfg := c.buildOAuthConfig() + url := oauthCfg.AuthCodeURL(state, oauth2.AccessTypeOnline) + + return &dto.SSORedirectOutput{ + URL: url, + State: state, + }, nil +} diff --git a/application/ports/sso.go b/application/ports/sso.go new file mode 100644 index 0000000..944f9e9 --- /dev/null +++ b/application/ports/sso.go @@ -0,0 +1,8 @@ +package ports + +import "github.com/labbs/nexo/application/auth/dto" + +type SSOPort interface { + GetRedirectURL() (*dto.SSORedirectOutput, error) + HandleCallback(input dto.SSOCallbackInput) (*dto.SSOCallbackOutput, error) +} diff --git a/config-example.yaml b/config-example.yaml index 60f9900..b4181e7 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -23,3 +23,19 @@ session: secret_key: "CHANGE_ME_USE_openssl_rand_base64_48" expiration_minutes: 43200 # 30 days issuer: nexo + +sso: + enabled: false + # client_id and client_secret from your OIDC provider + client_id: "" + client_secret: "" + # issuer_url: base URL of the OIDC provider (must expose /.well-known/openid-configuration) + # Examples: + # Keycloak: https://keycloak.example.com/realms/myrealm + # Google: https://accounts.google.com + # Okta: https://dev-xxxxx.okta.com + issuer_url: "" + # redirect_url: must match what is registered in the provider + redirect_url: "http://localhost:5173/auth/callback" + # scopes: additional scopes beyond openid (email and profile are always requested) + scopes: [] diff --git a/domain/sso_provider.go b/domain/sso_provider.go new file mode 100644 index 0000000..ad8a6b9 --- /dev/null +++ b/domain/sso_provider.go @@ -0,0 +1,24 @@ +package domain + +import "time" + +// OAuthProvider stores the link between a user and an external SSO identity. +type OAuthProvider struct { + Id string + UserId string + Provider string // e.g. "oidc" + ProviderUserId string // subject (sub) from the provider + Email string + CreatedAt time.Time + UpdatedAt time.Time +} + +func (o *OAuthProvider) TableName() string { + return "oauth_provider" +} + +type OAuthProviderPers interface { + FindByProviderAndSubject(provider, subject string) (OAuthProvider, error) + FindByUserId(userId string) ([]OAuthProvider, error) + Create(op OAuthProvider) (OAuthProvider, error) +} diff --git a/go.mod b/go.mod index de1def1..ef1b56a 100644 --- a/go.mod +++ b/go.mod @@ -56,6 +56,7 @@ require ( github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.69.0 // indirect go.uber.org/multierr v1.11.0 // indirect + golang.org/x/oauth2 v0.36.0 // indirect golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.35.0 // indirect diff --git a/go.sum b/go.sum index 3a942bd..00080b0 100644 --- a/go.sum +++ b/go.sum @@ -187,6 +187,8 @@ golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAf golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa h1:Zt3DZoOFFYkKhDT3v7Lm9FDMEV06GpzjG2jrqW+QTE0= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= diff --git a/infrastructure/config/config.go b/infrastructure/config/config.go index 81a0ca5..d01b283 100644 --- a/infrastructure/config/config.go +++ b/infrastructure/config/config.go @@ -50,6 +50,21 @@ type Config struct { PasswordComplexity bool // Require complex passwords (uppercase, lowercase, numbers, symbols) } + SSO struct { + Enabled bool + ClientID string + ClientSecret string + // IssuerURL is the base URL of the OIDC provider. Used to build the userinfo URL + // and as a fallback when AuthURL/TokenURL are not set. + IssuerURL string + // AuthURL and TokenURL can be set explicitly to override OIDC discovery. + // If left empty, they are auto-discovered from IssuerURL + /.well-known/openid-configuration. + AuthURL string + TokenURL string + RedirectURL string + Scopes []string + } + ExportOapi struct { FileName string } diff --git a/infrastructure/config/sso_flags.go b/infrastructure/config/sso_flags.go new file mode 100644 index 0000000..a181050 --- /dev/null +++ b/infrastructure/config/sso_flags.go @@ -0,0 +1,76 @@ +package config + +import ( + altsrc "github.com/urfave/cli-altsrc/v3" + altsrcyaml "github.com/urfave/cli-altsrc/v3/yaml" + "github.com/urfave/cli/v3" +) + +func SSOFlags(cfg *Config) []cli.Flag { + return []cli.Flag{ + &cli.BoolFlag{ + Name: "sso.enabled", + Destination: &cfg.SSO.Enabled, + Sources: cli.NewValueSourceChain( + cli.EnvVar("SSO_ENABLED"), + altsrcyaml.YAML("sso.enabled", altsrc.NewStringPtrSourcer(&cfg.ConfigFile)), + ), + }, + &cli.StringFlag{ + Name: "sso.client_id", + Destination: &cfg.SSO.ClientID, + Sources: cli.NewValueSourceChain( + cli.EnvVar("SSO_CLIENT_ID"), + altsrcyaml.YAML("sso.client_id", altsrc.NewStringPtrSourcer(&cfg.ConfigFile)), + ), + }, + &cli.StringFlag{ + Name: "sso.client_secret", + Destination: &cfg.SSO.ClientSecret, + Sources: cli.NewValueSourceChain( + cli.EnvVar("SSO_CLIENT_SECRET"), + altsrcyaml.YAML("sso.client_secret", altsrc.NewStringPtrSourcer(&cfg.ConfigFile)), + ), + }, + &cli.StringFlag{ + Name: "sso.issuer_url", + Destination: &cfg.SSO.IssuerURL, + Sources: cli.NewValueSourceChain( + cli.EnvVar("SSO_ISSUER_URL"), + altsrcyaml.YAML("sso.issuer_url", altsrc.NewStringPtrSourcer(&cfg.ConfigFile)), + ), + }, + &cli.StringFlag{ + Name: "sso.auth_url", + Destination: &cfg.SSO.AuthURL, + Sources: cli.NewValueSourceChain( + cli.EnvVar("SSO_AUTH_URL"), + altsrcyaml.YAML("sso.auth_url", altsrc.NewStringPtrSourcer(&cfg.ConfigFile)), + ), + }, + &cli.StringFlag{ + Name: "sso.token_url", + Destination: &cfg.SSO.TokenURL, + Sources: cli.NewValueSourceChain( + cli.EnvVar("SSO_TOKEN_URL"), + altsrcyaml.YAML("sso.token_url", altsrc.NewStringPtrSourcer(&cfg.ConfigFile)), + ), + }, + &cli.StringFlag{ + Name: "sso.redirect_url", + Destination: &cfg.SSO.RedirectURL, + Sources: cli.NewValueSourceChain( + cli.EnvVar("SSO_REDIRECT_URL"), + altsrcyaml.YAML("sso.redirect_url", altsrc.NewStringPtrSourcer(&cfg.ConfigFile)), + ), + }, + &cli.StringSliceFlag{ + Name: "sso.scopes", + Destination: &cfg.SSO.Scopes, + Sources: cli.NewValueSourceChain( + cli.EnvVar("SSO_SCOPES"), + altsrcyaml.YAML("sso.scopes", altsrc.NewStringPtrSourcer(&cfg.ConfigFile)), + ), + }, + } +} diff --git a/infrastructure/deps.go b/infrastructure/deps.go index 905fa00..207d99d 100644 --- a/infrastructure/deps.go +++ b/infrastructure/deps.go @@ -44,6 +44,7 @@ type Deps struct { FavoriteApplication *favorite.FavoriteApplication PermissionApplication *permission.PermissionApplication PermissionPers domain.PermissionPers + OAuthProviderPers domain.OAuthProviderPers CollaborationHub *collaboration.Hub } diff --git a/infrastructure/migration/files/20260420000000_sso_provider.go b/infrastructure/migration/files/20260420000000_sso_provider.go new file mode 100644 index 0000000..cdd3486 --- /dev/null +++ b/infrastructure/migration/files/20260420000000_sso_provider.go @@ -0,0 +1,60 @@ +package migrations + +import ( + "context" + "database/sql" + "fmt" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationContext(upOAuthProvider, downOAuthProvider) +} + +func upOAuthProvider(ctx context.Context, tx *sql.Tx) error { + var query string + dialect, _ := ctx.Value("dbDialect").(string) + switch dialect { + case "sqlite": + query = ` + CREATE TABLE IF NOT EXISTS oauth_provider ( + id TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + provider TEXT NOT NULL, + provider_user_id TEXT NOT NULL, + email TEXT, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES user(id) ON DELETE CASCADE, + UNIQUE(provider, provider_user_id) + ); + CREATE INDEX IF NOT EXISTS idx_oauth_provider_user_id ON oauth_provider(user_id); + ` + case "postgres": + query = ` + CREATE TABLE IF NOT EXISTS oauth_provider ( + id UUID PRIMARY KEY, + user_id UUID NOT NULL, + provider TEXT NOT NULL, + provider_user_id TEXT NOT NULL, + email TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + UNIQUE(provider, provider_user_id) + ); + CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_oauth_provider_user_id ON oauth_provider(user_id); + ` + default: + return fmt.Errorf("unsupported dialect: %s", dialect) + } + + _, err := tx.ExecContext(ctx, query) + return err +} + +func downOAuthProvider(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, `DROP TABLE IF EXISTS oauth_provider;`) + return err +} diff --git a/infrastructure/persistence/sso_provider_pers.go b/infrastructure/persistence/sso_provider_pers.go new file mode 100644 index 0000000..e7c0fc0 --- /dev/null +++ b/infrastructure/persistence/sso_provider_pers.go @@ -0,0 +1,37 @@ +package persistence + +import ( + "errors" + + "github.com/labbs/nexo/domain" + "github.com/labbs/nexo/infrastructure/helpers/apperrors" + "gorm.io/gorm" +) + +type oauthProviderPers struct { + db *gorm.DB +} + +func NewOAuthProviderPers(db *gorm.DB) *oauthProviderPers { + return &oauthProviderPers{db: db} +} + +func (o *oauthProviderPers) FindByProviderAndSubject(provider, subject string) (domain.OAuthProvider, error) { + var op domain.OAuthProvider + err := o.db.Where("provider = ? AND provider_user_id = ?", provider, subject).First(&op).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + return op, apperrors.ErrNotFound + } + return op, err +} + +func (o *oauthProviderPers) FindByUserId(userId string) ([]domain.OAuthProvider, error) { + var ops []domain.OAuthProvider + err := o.db.Where("user_id = ?", userId).Find(&ops).Error + return ops, err +} + +func (o *oauthProviderPers) Create(op domain.OAuthProvider) (domain.OAuthProvider, error) { + err := o.db.Create(&op).Error + return op, err +} diff --git a/interfaces/cli/server/server.go b/interfaces/cli/server/server.go index bb6493e..e56d1c9 100644 --- a/interfaces/cli/server/server.go +++ b/interfaces/cli/server/server.go @@ -57,6 +57,7 @@ func getFlags(cfg *config.Config) (list []cli.Flag) { list = append(list, config.DatabaseFlags(cfg)...) list = append(list, config.SessionFlags(cfg)...) list = append(list, config.RegistrationFlags(cfg)...) + list = append(list, config.SSOFlags(cfg)...) return } @@ -95,6 +96,7 @@ func runServer(cfg config.Config) error { // Initialize application services userPers := persistence.NewUserPers(deps.Database.Db) + oauthProviderPers := persistence.NewOAuthProviderPers(deps.Database.Db) groupPers := persistence.NewGroupPers(deps.Database.Db) sessionPers := persistence.NewSessionPers(deps.Database.Db) spacePers := persistence.NewSpacePers(deps.Database.Db) @@ -127,12 +129,14 @@ func runServer(cfg config.Config) error { deps.FavoriteApplication = favorite.NewFavoriteApplication(deps.Config, deps.Logger, favoritePers) deps.PermissionApplication = permission.NewPermissionApplication(deps.Config, deps.Logger, permissionPers) deps.PermissionPers = permissionPers + deps.OAuthProviderPers = oauthProviderPers // Inject port dependencies (after construction to avoid circular dependencies) deps.AuthApplication.UserApplication = deps.UserApplication deps.AuthApplication.SessionApplication = deps.SessionApplication deps.AuthApplication.SpaceApplication = deps.SpaceApplication deps.AuthApplication.DocumentApplication = deps.DocumentApplication + deps.AuthApplication.OAuthProviderPers = oauthProviderPers deps.UserApplication.GroupApplication = deps.GroupApplication deps.FavoriteApplication.DocumentApplication = deps.DocumentApplication deps.SpaceApplication.DocumentApplication = deps.DocumentApplication diff --git a/interfaces/http/v1/auth/dtos/sso_dtos.go b/interfaces/http/v1/auth/dtos/sso_dtos.go new file mode 100644 index 0000000..76375b6 --- /dev/null +++ b/interfaces/http/v1/auth/dtos/sso_dtos.go @@ -0,0 +1,15 @@ +package dtos + +type SSORedirectResponse struct { + URL string `json:"url"` + State string `json:"state"` +} + +type SSOCallbackRequest struct { + Code string `json:"code" validate:"required"` + State string `json:"state" validate:"required"` +} + +type SSOCallbackResponse struct { + Token string `json:"token"` +} diff --git a/interfaces/http/v1/auth/handlers.go b/interfaces/http/v1/auth/handlers.go index 6096dbf..0b4f1d2 100644 --- a/interfaces/http/v1/auth/handlers.go +++ b/interfaces/http/v1/auth/handlers.go @@ -90,4 +90,36 @@ func (ctrl Controller) Register(ctx *fiber.Ctx, req dtos.RegisterRequest) (*dtos }, nil } +func (ctrl Controller) SSORedirect(ctx *fiber.Ctx, input struct{}) (*dtos.SSORedirectResponse, *fiberoapi.ErrorResponse) { + out, err := ctrl.AuthApplication.SSORedirect() + if err != nil { + return nil, &fiberoapi.ErrorResponse{ + Code: fiber.StatusBadRequest, + Details: err.Error(), + Type: "SSO_DISABLED", + } + } + return &dtos.SSORedirectResponse{URL: out.URL, State: out.State}, nil +} + +func (ctrl Controller) SSOCallback(ctx *fiber.Ctx, req dtos.SSOCallbackRequest) (*dtos.SSOCallbackResponse, *fiberoapi.ErrorResponse) { + requestId := ctx.Locals("requestid").(string) + logger := ctrl.Logger.With().Str("request_id", requestId).Str("component", "http.api.v1.auth.sso_callback").Logger() + + out, err := ctrl.AuthApplication.SSOCallback(authDto.SSOCallbackInput{ + Code: req.Code, + State: req.State, + Context: ctx, + }) + if err != nil { + logger.Error().Err(err).Msg("SSO callback failed") + return nil, &fiberoapi.ErrorResponse{ + Code: fiber.StatusUnauthorized, + Details: err.Error(), + Type: "SSO_CALLBACK_FAILED", + } + } + return &dtos.SSOCallbackResponse{Token: out.Token}, nil +} + //TODO: implement password reset, email verification, ... diff --git a/interfaces/http/v1/auth/router.go b/interfaces/http/v1/auth/router.go index c29e269..0f6b0b3 100644 --- a/interfaces/http/v1/auth/router.go +++ b/interfaces/http/v1/auth/router.go @@ -37,4 +37,20 @@ func SetupAuthRouter(controller Controller) { Tags: []string{"Auth"}, Security: "disabled", }) + + fiberoapi.Get(controller.FiberOapi, "/sso/redirect", controller.SSORedirect, fiberoapi.OpenAPIOptions{ + Summary: "SSO redirect URL", + Description: "Returns the provider authorization URL for SSO login", + OperationID: "auth.sso.redirect", + Tags: []string{"Auth"}, + Security: "disabled", + }) + + fiberoapi.Post(controller.FiberOapi, "/sso/callback", controller.SSOCallback, fiberoapi.OpenAPIOptions{ + Summary: "SSO callback", + Description: "Exchange OAuth2 code for a Nexo session token", + OperationID: "auth.sso.callback", + Tags: []string{"Auth"}, + Security: "disabled", + }) } diff --git a/interfaces/http/v1/router.go b/interfaces/http/v1/router.go index c353170..b23250c 100644 --- a/interfaces/http/v1/router.go +++ b/interfaces/http/v1/router.go @@ -34,6 +34,7 @@ func SetupRouterV1(deps infrastructure.Deps) { UserApplication: deps.UserApplication, SpaceApplication: deps.SpaceApplication, FavoriteApplication: deps.FavoriteApplication, + OAuthProviderPers: deps.OAuthProviderPers, } user.SetupUserRouter(userCtrl) diff --git a/interfaces/http/v1/user/dtos/profile_request.go b/interfaces/http/v1/user/dtos/profile_request.go index 0076f22..2464339 100644 --- a/interfaces/http/v1/user/dtos/profile_request.go +++ b/interfaces/http/v1/user/dtos/profile_request.go @@ -1,10 +1,11 @@ package dtos type ProfileResponse struct { - Id string `json:"id"` - Username string `json:"username"` - Email string `json:"email"` - Avatar string `json:"avatar"` - Role string `json:"role"` - Preferences map[string]any `json:"preferences,omitempty"` + Id string `json:"id"` + Username string `json:"username"` + Email string `json:"email"` + Avatar string `json:"avatar"` + Role string `json:"role"` + Preferences map[string]any `json:"preferences,omitempty"` + SsoProviders []string `json:"sso_providers,omitempty"` } diff --git a/interfaces/http/v1/user/handlers.go b/interfaces/http/v1/user/handlers.go index 273ad63..2fb7df1 100644 --- a/interfaces/http/v1/user/handlers.go +++ b/interfaces/http/v1/user/handlers.go @@ -60,6 +60,17 @@ func (ctrl *Controller) GetProfile(ctx *fiber.Ctx, input struct{}) (*dtos.Profil profile.Preferences = map[string]any(result.User.Preferences) } + // Add linked SSO providers if the repository is wired + if ctrl.OAuthProviderPers != nil { + if linked, err := ctrl.OAuthProviderPers.FindByUserId(authCtx.UserID); err == nil && len(linked) > 0 { + providers := make([]string, len(linked)) + for i, op := range linked { + providers[i] = op.Provider + } + profile.SsoProviders = providers + } + } + return &profile, nil } diff --git a/interfaces/http/v1/user/router.go b/interfaces/http/v1/user/router.go index d9f3f5e..de8022a 100644 --- a/interfaces/http/v1/user/router.go +++ b/interfaces/http/v1/user/router.go @@ -5,6 +5,7 @@ import ( "github.com/labbs/nexo/application/favorite" "github.com/labbs/nexo/application/space" "github.com/labbs/nexo/application/user" + "github.com/labbs/nexo/domain" "github.com/labbs/nexo/infrastructure/config" "github.com/rs/zerolog" ) @@ -16,6 +17,7 @@ type Controller struct { UserApplication *user.UserApplication FavoriteApplication *favorite.FavoriteApplication SpaceApplication *space.SpaceApplication + OAuthProviderPers domain.OAuthProviderPers } func SetupUserRouter(controller Controller) {