Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions application/auth/auth_application.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package auth

import (
"github.com/labbs/nexo/application/ports"
"github.com/labbs/nexo/domain"
"github.com/labbs/nexo/infrastructure/config"
"github.com/rs/zerolog"
)
Expand All @@ -13,6 +14,9 @@ type AuthApplication struct {
SessionApplication ports.SessionPort
SpaceApplication ports.SpacePort
DocumentApplication ports.DocumentPort
OAuthProviderPers domain.OAuthProviderPers

oidcUserinfoEndpoint string // cached from OIDC discovery
}

func NewAuthApplication(config config.Config, logger zerolog.Logger) *AuthApplication {
Expand Down
18 changes: 18 additions & 0 deletions application/auth/dto/sso.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package dto

import "github.com/gofiber/fiber/v2"

type SSORedirectOutput struct {
URL string
State string
}

type SSOCallbackInput struct {
Code string
State string
Context *fiber.Ctx
}

type SSOCallbackOutput struct {
Token string
}
218 changes: 218 additions & 0 deletions application/auth/sso_callback.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
package auth

import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"

"github.com/labbs/nexo/application/auth/dto"
d "github.com/labbs/nexo/application/document/dto"
s "github.com/labbs/nexo/application/session/dto"
spdto "github.com/labbs/nexo/application/space/dto"
u "github.com/labbs/nexo/application/user/dto"
"github.com/labbs/nexo/domain"
"github.com/labbs/nexo/infrastructure/helpers/tokenutil"
"golang.org/x/oauth2"
)

type oidcUserInfo struct {
Sub string `json:"sub"`
Email string `json:"email"`
PreferredUsername string `json:"preferred_username"`
Name string `json:"name"`
}

func (c *AuthApplication) SSOCallback(input dto.SSOCallbackInput) (*dto.SSOCallbackOutput, error) {
logger := c.Logger.With().Str("component", "application.auth.sso_callback").Logger()

if !c.Config.SSO.Enabled {
return nil, fmt.Errorf("SSO is not enabled")
}

if err := c.verifyState(input.State); err != nil {
logger.Warn().Err(err).Msg("invalid SSO state")
return nil, fmt.Errorf("invalid state parameter")
}

oauthCfg := c.buildOAuthConfig()
token, err := oauthCfg.Exchange(context.Background(), input.Code)
if err != nil {
logger.Error().Err(err).Msg("failed to exchange OAuth code")
return nil, fmt.Errorf("failed to exchange authorization code: %w", err)
}

userInfo, err := c.fetchUserInfo(oauthCfg, token)
if err != nil {
logger.Error().Err(err).Msg("failed to fetch userinfo")
return nil, fmt.Errorf("failed to fetch user info: %w", err)
}

if userInfo.Sub == "" {
return nil, fmt.Errorf("provider did not return a user identifier")
}

user, err := c.findOrCreateSSOUser(userInfo)
if err != nil {
logger.Error().Err(err).Str("sub", userInfo.Sub).Msg("failed to find or create SSO user")
return nil, fmt.Errorf("failed to resolve user: %w", err)
}

sessionResult, err := c.SessionApplication.Create(s.CreateSessionInput{
UserId: user.Id,
UserAgent: input.Context.Get("User-Agent"),
IpAddress: input.Context.IP(),
ExpiresAt: time.Now().Add(time.Minute * time.Duration(c.Config.Session.ExpirationMinutes)),
})
if err != nil {
logger.Error().Err(err).Str("user_id", user.Id).Msg("failed to create session")
return nil, fmt.Errorf("failed to create session: %w", err)
}

accessToken, err := tokenutil.CreateAccessToken(user.Id, sessionResult.SessionId, c.Config)
if err != nil {
logger.Error().Err(err).Str("user_id", user.Id).Msg("failed to create access token")
return nil, fmt.Errorf("failed to create access token: %w", err)
}

return &dto.SSOCallbackOutput{Token: accessToken}, nil
}

// verifyState validates the HMAC-signed state parameter.
func (c *AuthApplication) verifyState(state string) error {
parts := strings.SplitN(state, ".", 2)
if len(parts) != 2 {
return fmt.Errorf("malformed state")
}
nonce, sig := parts[0], parts[1]
mac := hmac.New(sha256.New, []byte(c.Config.Session.SecretKey))
mac.Write([]byte(nonce))
expected := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
if !hmac.Equal([]byte(sig), []byte(expected)) {
return fmt.Errorf("state signature mismatch")
}
return nil
}

// fetchUserInfo calls the provider's userinfo endpoint using the access token.
func (c *AuthApplication) fetchUserInfo(oauthCfg *oauth2.Config, token *oauth2.Token) (*oidcUserInfo, error) {
// Prefer the endpoint discovered via OIDC; fall back to /userinfo.
endpoints := []string{strings.TrimRight(c.Config.SSO.IssuerURL, "/") + "/userinfo"}
if c.oidcUserinfoEndpoint != "" && c.oidcUserinfoEndpoint != endpoints[0] {
endpoints = append([]string{c.oidcUserinfoEndpoint}, endpoints...)
}

client := oauthCfg.Client(context.Background(), token)
var lastErr error
for _, url := range endpoints {
resp, err := client.Get(url)
if err != nil {
lastErr = err
continue
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("userinfo endpoint returned %d", resp.StatusCode)
continue
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var info oidcUserInfo
if err := json.Unmarshal(body, &info); err != nil {
return nil, err
}
return &info, nil
}
return nil, lastErr
}

// findOrCreateSSOUser finds an existing user linked to the SSO provider, or creates a new one.
func (c *AuthApplication) findOrCreateSSOUser(info *oidcUserInfo) (domain.User, error) {
// 1. Check if provider link already exists
op, err := c.OAuthProviderPers.FindByProviderAndSubject("oidc", info.Sub)
if err == nil {
// Link exists — fetch the user
resp, err := c.UserApplication.GetByUserId(u.GetByUserIdInput{UserId: op.UserId})
if err != nil {
return domain.User{}, err
}
return *resp.User, nil
}

// 2. Try to link an existing user by email
var user domain.User
if info.Email != "" {
resp, err := c.UserApplication.GetByEmail(u.GetByEmailInput{Email: info.Email})
if err == nil {
user = *resp.User
}
}

// 3. No existing user — auto-create one
if user.Id == "" {
username := info.PreferredUsername
if username == "" {
username = strings.Split(info.Email, "@")[0]
}
if username == "" {
username = "user-" + info.Sub[:8]
}

created, err := c.UserApplication.Create(u.CreateUserInput{
User: domain.User{
Username: username,
Email: info.Email,
Password: "", // no password for SSO users
Active: true,
},
})
if err != nil {
return domain.User{}, fmt.Errorf("failed to create SSO user: %w", err)
}
user = *created.User

// Create private space + welcome document (mirrors Register use case)
space, err := c.SpaceApplication.CreatePrivateSpaceForUser(spdto.CreatePrivateSpaceForUserInput{UserId: user.Id})
if err == nil {
welcomeContent := []d.Block{{
ID: "welcome-1",
Type: d.BlockTypeParagraph,
Props: map[string]any{
"textColor": "default", "backgroundColor": "default", "textAlignment": "left",
},
Content: []d.InlineContent{{
Type: "text", Text: "This is your private space. Start adding your notes and documents here!",
Styles: map[string]bool{},
}},
Children: []d.Block{},
}}
_, _ = c.DocumentApplication.CreateDocument(d.CreateDocumentInput{
Name: "Welcome to Your Private Space",
UserId: user.Id,
SpaceId: space.Space.Id,
Content: welcomeContent,
})
}
}

// 4. Create the provider link
_, err = c.OAuthProviderPers.Create(domain.OAuthProvider{
UserId: user.Id,
Provider: "oidc",
ProviderUserId: info.Sub,
Email: info.Email,
})
if err != nil {
return domain.User{}, fmt.Errorf("failed to store SSO provider link: %w", err)
}

return user, nil
}
80 changes: 80 additions & 0 deletions application/auth/sso_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package auth

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"

"golang.org/x/oauth2"
)

type oidcDiscovery struct {
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserinfoEndpoint string `json:"userinfo_endpoint"`
}

// discoverOIDC fetches the provider's discovery document and caches the endpoints.
func (c *AuthApplication) discoverOIDC(ctx context.Context) (*oidcDiscovery, error) {
discoveryURL := strings.TrimRight(c.Config.SSO.IssuerURL, "/") + "/.well-known/openid-configuration"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryURL, nil)
if err != nil {
return nil, err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("OIDC discovery request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("OIDC discovery returned %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var d oidcDiscovery
if err := json.Unmarshal(body, &d); err != nil {
return nil, err
}
return &d, nil
}

func (c *AuthApplication) buildOAuthConfig() *oauth2.Config {
authURL := c.Config.SSO.AuthURL
tokenURL := c.Config.SSO.TokenURL

// If explicit URLs not provided, attempt discovery at startup.
// Errors here are non-fatal; the callback will fail gracefully.
if authURL == "" || tokenURL == "" {
if disc, err := c.discoverOIDC(context.Background()); err == nil {
if authURL == "" {
authURL = disc.AuthorizationEndpoint
}
if tokenURL == "" {
tokenURL = disc.TokenEndpoint
}
// Cache the userinfo endpoint for use in fetchUserInfo.
c.oidcUserinfoEndpoint = disc.UserinfoEndpoint
} else {
c.Logger.Warn().Err(err).Msg("OIDC discovery failed — check SSO config")
}
}

scopes := []string{"openid", "email", "profile"}
scopes = append(scopes, c.Config.SSO.Scopes...)

return &oauth2.Config{
ClientID: c.Config.SSO.ClientID,
ClientSecret: c.Config.SSO.ClientSecret,
RedirectURL: c.Config.SSO.RedirectURL,
Scopes: scopes,
Endpoint: oauth2.Endpoint{
AuthURL: authURL,
TokenURL: tokenURL,
},
}
}
38 changes: 38 additions & 0 deletions application/auth/sso_redirect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package auth

import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"

"github.com/labbs/nexo/application/auth/dto"
"golang.org/x/oauth2"
)

func (c *AuthApplication) SSORedirect() (*dto.SSORedirectOutput, error) {
if !c.Config.SSO.Enabled {
return nil, fmt.Errorf("SSO is not enabled")
}

raw := make([]byte, 32)
if _, err := rand.Read(raw); err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
}

// state = base64(nonce) + "." + base64(HMAC(nonce))
nonce := base64.RawURLEncoding.EncodeToString(raw)
mac := hmac.New(sha256.New, []byte(c.Config.Session.SecretKey))
mac.Write([]byte(nonce))
sig := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
state := nonce + "." + sig

oauthCfg := c.buildOAuthConfig()
url := oauthCfg.AuthCodeURL(state, oauth2.AccessTypeOnline)

return &dto.SSORedirectOutput{
URL: url,
State: state,
}, nil
}
8 changes: 8 additions & 0 deletions application/ports/sso.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package ports

import "github.com/labbs/nexo/application/auth/dto"

type SSOPort interface {
GetRedirectURL() (*dto.SSORedirectOutput, error)
HandleCallback(input dto.SSOCallbackInput) (*dto.SSOCallbackOutput, error)
}
Loading
Loading