Skip to content

Commit

Permalink
Merge b57db95 into fae116f
Browse files Browse the repository at this point in the history
  • Loading branch information
rgooch committed Jan 6, 2020
2 parents fae116f + b57db95 commit 67596f2
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 44 deletions.
4 changes: 3 additions & 1 deletion cmd/keymasterd/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ import (
"sync"
"time"

"github.com/Cloud-Foundations/Dominator/lib/log"
"github.com/Cloud-Foundations/Dominator/lib/log/serverlogger"
"github.com/Cloud-Foundations/Dominator/lib/logbuf"
"github.com/Cloud-Foundations/Dominator/lib/srpc"
"github.com/Cloud-Foundations/golib/pkg/auth/userinfo/gitdb"
"github.com/Cloud-Foundations/golib/pkg/log"
"github.com/Cloud-Foundations/keymaster/keymasterd/admincache"
"github.com/Cloud-Foundations/keymaster/keymasterd/eventnotifier"
"github.com/Cloud-Foundations/keymaster/lib/authutil"
Expand Down Expand Up @@ -155,6 +156,7 @@ type RuntimeState struct {
SignerIsReady chan bool
oktaUsernameFilterRE *regexp.Regexp
Mutex sync.Mutex
gitDB *gitdb.UserInfo
pendingOauth2 map[string]pendingAuth2Request
storageRWMutex sync.RWMutex
db *sql.DB
Expand Down
49 changes: 41 additions & 8 deletions cmd/keymasterd/certgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ import (

const certgenPath = "/certgen/"

func prependGroups(groups []string, prefix string) []string {
if prefix == "" {
return groups
}
newGroups := make([]string, 0, len(groups))
for _, group := range groups {
newGroups = append(newGroups, prefix+group)
}
return newGroups
}

func (state *RuntimeState) certGenHandler(w http.ResponseWriter, r *http.Request) {
var signerIsNull bool
var keySigner crypto.Signer
Expand Down Expand Up @@ -216,11 +227,28 @@ func (state *RuntimeState) postAuthSSHCertHandler(
}(targetUser, "ssh")
}

func (state *RuntimeState) getUserGroups(username string) ([]string, error) {
func (state *RuntimeState) getGitDbUserGroups(username string) (
bool, []string, error) {
if state.gitDB == nil {
return false, nil, nil
}
groups, err := state.gitDB.GetUserGroups(username)
if err != nil {
return true, nil, err
}
return true,
prependGroups(groups, state.Config.UserInfo.GitDB.GroupPrepend),
nil
}

func (state *RuntimeState) getLdapUserGroups(username string) (
bool, []string, error) {
ldapConfig := state.Config.UserInfo.Ldap
var timeoutSecs uint
timeoutSecs = 2
//for _, ldapUrl := range ldapConfig.LDAPTargetURLs {
if ldapConfig.LDAPTargetURLs == "" {
return false, nil, nil
}
for _, ldapUrl := range strings.Split(ldapConfig.LDAPTargetURLs, ",") {
if len(ldapUrl) < 1 {
continue
Expand All @@ -238,15 +266,20 @@ func (state *RuntimeState) getUserGroups(username string) ([]string, error) {
if err != nil {
continue
}
return groups, nil
return true, prependGroups(groups, ldapConfig.GroupPrepend), nil

}
if ldapConfig.LDAPTargetURLs == "" {
var emptyGroup []string
return emptyGroup, nil
return true, nil, errors.New("error getting the groups")
}

func (state *RuntimeState) getUserGroups(username string) ([]string, error) {
if config, groups, err := state.getLdapUserGroups(username); config {
return groups, err
}
if config, groups, err := state.getGitDbUserGroups(username); config {
return groups, err
}
err := errors.New("error getting the groups")
return nil, err
return nil, nil
}

func (state *RuntimeState) postAuthX509CertHandler(
Expand Down
24 changes: 23 additions & 1 deletion cmd/keymasterd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"strings"
"time"

"github.com/Cloud-Foundations/golib/pkg/auth/userinfo/gitdb"
"github.com/Cloud-Foundations/keymaster/keymasterd/admincache"
"github.com/Cloud-Foundations/keymaster/lib/pwauth/command"
"github.com/Cloud-Foundations/keymaster/lib/pwauth/ldap"
Expand Down Expand Up @@ -60,6 +61,14 @@ type baseConfig struct {
EnableLocalTOTP bool `yaml:"enable_local_totp"`
}

type GitDatabaseConfig struct {
Branch string `yaml:"branch"`
CheckInterval time.Duration `yaml:"check_interval"`
GroupPrepend string `yaml:"group_prepend"`
LocalRepositoryDirectory string `yaml:"local_repository_directory"`
RepositoryURL string `yaml:"repository_url"`
}

type LdapConfig struct {
BindPattern string `yaml:"bind_pattern"`
LDAPTargetURLs string `yaml:"ldap_target_urls"`
Expand All @@ -74,6 +83,7 @@ type OktaConfig struct {
type UserInfoLDAPSource struct {
BindUsername string `yaml:"bind_username"`
BindPassword string `yaml:"bind_password"`
GroupPrepend string `yaml:"group_prepend"`
LDAPTargetURLs string `yaml:"ldap_target_urls"`
UserSearchBaseDNs []string `yaml:"user_search_base_dns"`
UserSearchFilter string `yaml:"user_search_filter"`
Expand All @@ -82,7 +92,8 @@ type UserInfoLDAPSource struct {
}

type UserInfoSouces struct {
Ldap UserInfoLDAPSource
GitDB GitDatabaseConfig
Ldap UserInfoLDAPSource
}

type Oauth2Config struct {
Expand Down Expand Up @@ -424,6 +435,17 @@ func loadVerifyConfigFile(configFilename string) (*RuntimeState, error) {

logger.Debugf(1, "End of config initialization: %+v", &runtimeState)

// UserInfo setup.
if runtimeState.Config.UserInfo.GitDB.LocalRepositoryDirectory != "" {
gitdbConfig := runtimeState.Config.UserInfo.GitDB
runtimeState.gitDB, err = gitdb.New(gitdbConfig.RepositoryURL,
gitdbConfig.Branch, gitdbConfig.LocalRepositoryDirectory,
gitdbConfig.CheckInterval, logger)
if err != nil {
return nil, err
}
logger.Println("loaded UserInfo GitDB")
}
// DB initialization
err = initDB(&runtimeState)
if err != nil {
Expand Down
65 changes: 31 additions & 34 deletions cmd/keymasterd/idp_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,19 @@ package main

import (
"bytes"
//"crypto"
//"crypto/sha256"
"encoding/json"
"errors"
"fmt"
//"io/ioutil"
"log"
"net/http"
"net/url"
"regexp"
"strings"
"time"

//"golang.org/x/net/context"
"github.com/mendsley/gojwk"
//"gopkg.in/dgrijalva/jwt-go.v2"
"github.com/Cloud-Foundations/keymaster/lib/authutil"
"github.com/Cloud-Foundations/keymaster/lib/instrumentedwriter"
//"golang.org/x/crypto/ssh"
"github.com/mendsley/gojwk"
"gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
)
Expand Down Expand Up @@ -451,11 +445,20 @@ func (state *RuntimeState) idpOpenIDCTokenHandler(w http.ResponseWriter, r *http

}

func (state *RuntimeState) getUserAttributes(username string, attributes []string) (map[string][]string, error) {
func (state *RuntimeState) getUserAttributes(username string,
attributes []string) (map[string][]string, error) {
if state.gitDB != nil {
groups, err := state.gitDB.GetUserGroups(username)
if err != nil {
return nil, err
}
return map[string][]string{"groups": prependGroups(
groups, state.Config.UserInfo.GitDB.GroupPrepend)},
nil
}
ldapConfig := state.Config.UserInfo.Ldap
var timeoutSecs uint
timeoutSecs = 2
//for _, ldapUrl := range ldapConfig.LDAPTargetURLs {
for _, ldapUrl := range strings.Split(ldapConfig.LDAPTargetURLs, ",") {
if len(ldapUrl) < 1 {
continue
Expand Down Expand Up @@ -505,15 +508,15 @@ type openidConnectUserInfo struct {
Groups []string `json:"groups,omitempty"`
}

func (state *RuntimeState) idpOpenIDCUserinfoHandler(w http.ResponseWriter, r *http.Request) {

func (state *RuntimeState) idpOpenIDCUserinfoHandler(w http.ResponseWriter,
r *http.Request) {
if !(r.Method == "GET" || r.Method == "POST") {
logger.Printf("Invalid Method for Userinfo Handler")
state.writeFailureResponse(w, r, http.StatusBadRequest, "Invalid Method for Userinfo Handler")
state.writeFailureResponse(w, r, http.StatusBadRequest,
"Invalid Method for Userinfo Handler")
return
}
logger.Debugf(2, "userinfo request=%+v", r)

var accessToken string
authHeader := r.Header.Get("Authorization")
if authHeader != "" {
Expand All @@ -526,7 +529,6 @@ func (state *RuntimeState) idpOpenIDCUserinfoHandler(w http.ResponseWriter, r *h
}
}
if accessToken == "" {
//logger.Printf("")
err := r.ParseForm()
if err != nil {
state.writeFailureResponse(w, r, http.StatusInternalServerError, "")
Expand All @@ -535,51 +537,49 @@ func (state *RuntimeState) idpOpenIDCUserinfoHandler(w http.ResponseWriter, r *h
accessToken = r.Form.Get("access_token")
}
logger.Debugf(1, "access_token='%s'", accessToken)

if accessToken == "" {
logger.Printf("access_token='%s'", accessToken)
state.writeFailureResponse(w, r, http.StatusBadRequest, "Missing access token")
state.writeFailureResponse(w, r, http.StatusBadRequest,
"Missing access token")
return
}

tok, err := jwt.ParseSigned(accessToken)
if err != nil {
logger.Printf("err=%s", err)
state.writeFailureResponse(w, r, http.StatusBadRequest, "bad access token")
state.writeFailureResponse(w, r, http.StatusBadRequest,
"bad access token")
return
}
logger.Debugf(1, "tok=%+v", tok)

parsedAccessToken := userInfoToken{}
//if err := tok.Claims(state.Signer.Public(), &parsedAccessToken); err != nil {
if err := state.JWTClaims(tok, &parsedAccessToken); err != nil {
logger.Printf("err=%s", err)
state.writeFailureResponse(w, r, http.StatusBadRequest, "bad code")
return
}
logger.Debugf(1, "out=%+v", parsedAccessToken)

//now we check for validity
// Now we check for validity.
if parsedAccessToken.Expiration < time.Now().Unix() {
logger.Printf("expired token attempted to be used for bearer")
state.writeFailureResponse(w, r, http.StatusUnauthorized, "")
return
}
//now we check for validity
if parsedAccessToken.Type != "bearer" {
state.writeFailureResponse(w, r, http.StatusUnauthorized, "")
return
}

//Get email from ldap if available
// Get email from LDAP if available.
defaultEmailDomain := state.HostIdentity
if len(state.Config.OpenIDConnectIDP.DefaultEmailDomain) > 3 {
defaultEmailDomain = state.Config.OpenIDConnectIDP.DefaultEmailDomain
}
email := fmt.Sprintf("%s@%s", parsedAccessToken.Username, defaultEmailDomain)
userAttributeMap, err := state.getUserAttributes(parsedAccessToken.Username, []string{"mail"})
email := fmt.Sprintf("%s@%s", parsedAccessToken.Username,
defaultEmailDomain)
userAttributeMap, err := state.getUserAttributes(parsedAccessToken.Username,
[]string{"mail"})
if err != nil {
logger.Printf("warn: failed to get user attributes for %s, %s", parsedAccessToken.Username, err)
logger.Printf("warn: failed to get user attributes for %s, %s",
parsedAccessToken.Username, err)
}
var userGroups []string
if userAttributeMap != nil {
Expand All @@ -593,7 +593,6 @@ func (state *RuntimeState) idpOpenIDCUserinfoHandler(w http.ResponseWriter, r *h
userGroups = groupList
}
}

userInfo := openidConnectUserInfo{
Subject: parsedAccessToken.Username,
Username: parsedAccessToken.Username,
Expand All @@ -602,21 +601,19 @@ func (state *RuntimeState) idpOpenIDCUserinfoHandler(w http.ResponseWriter, r *h
Login: parsedAccessToken.Username,
Groups: userGroups,
}

// and write the json output
// Write the json output.
b, err := json.Marshal(userInfo)
if err != nil {
log.Printf("error marshaling in idpOpenIDUserinfonHandler: %s", err)
state.writeFailureResponse(w, r, http.StatusInternalServerError, "Internal Error")
state.writeFailureResponse(w, r, http.StatusInternalServerError,
"Internal Error")
return
}
logger.Debugf(1, "userinfo=%+v\n b=%s", userInfo, b)

var out bytes.Buffer
json.Indent(&out, b, "", "\t")
w.Header().Set("Content-Type", "application/json")
out.WriteTo(w)

logger.Printf("200 Successful userinfo request")
logger.Debugf(0, " Userinfo response = %s", b)
}

0 comments on commit 67596f2

Please sign in to comment.