Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require explicit intention for empty password. #126

Merged
merged 1 commit into from
Aug 24, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
80 changes: 37 additions & 43 deletions bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package ldap
import (
"errors"

"gopkg.in/asn1-ber.v1"
ber "gopkg.in/asn1-ber.v1"
)

// SimpleBindRequest represents a username/password bind operation
Expand All @@ -18,6 +18,9 @@ type SimpleBindRequest struct {
Password string
// Controls are optional controls to send with the bind request
Controls []Control
// AllowEmptyPassword sets whether the client allows binding with an empty password
// (normally used for unauthenticated bind).
AllowEmptyPassword bool
}

// SimpleBindResult contains the response from the server
Expand All @@ -28,9 +31,10 @@ type SimpleBindResult struct {
// NewSimpleBindRequest returns a bind request
func NewSimpleBindRequest(username string, password string, controls []Control) *SimpleBindRequest {
return &SimpleBindRequest{
Username: username,
Password: password,
Controls: controls,
Username: username,
Password: password,
Controls: controls,
AllowEmptyPassword: false,
}
}

Expand All @@ -47,6 +51,10 @@ func (bindRequest *SimpleBindRequest) encode() *ber.Packet {

// SimpleBind performs the simple bind operation defined in the given request
func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResult, error) {
if simpleBindRequest.Password == "" && !simpleBindRequest.AllowEmptyPassword {
return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client"))
}

packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
encodedBindRequest := simpleBindRequest.encode()
Expand Down Expand Up @@ -97,47 +105,33 @@ func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResu
return result, nil
}

// Bind performs a bind with the given username and password
// Bind performs a bind with the given username and password.
//
// It does not allow unauthenticated bind (i.e. empty password). Use the UnauthenticatedBind method
// for that.
func (l *Conn) Bind(username, password string) error {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
bindRequest.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, username, "User Name"))
bindRequest.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, password, "Password"))
packet.AppendChild(bindRequest)

if l.Debug {
ber.PrintPacket(packet)
}

msgCtx, err := l.sendMessage(packet)
if err != nil {
return err
}
defer l.finishMessage(msgCtx)

packetResponse, ok := <-msgCtx.responses
if !ok {
return NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
}
packet, err = packetResponse.ReadPacket()
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if err != nil {
return err
}

if l.Debug {
if err := addLDAPDescriptions(packet); err != nil {
return err
}
ber.PrintPacket(packet)
req := &SimpleBindRequest{
Username: username,
Password: password,
AllowEmptyPassword: false,
}
_, err := l.SimpleBind(req)
return err
}

resultCode, resultDescription := getLDAPResultCode(packet)
if resultCode != 0 {
return NewError(resultCode, errors.New(resultDescription))
// UnauthenticatedBind performs an unauthenticated bind.
//
// A username may be provided for trace (e.g. logging) purpose only, but it is normally not
// authenticated or otherwise validated by the LDAP server.
//
// See https://tools.ietf.org/html/rfc4513#section-5.1.2 .
// See https://tools.ietf.org/html/rfc4513#section-6.3.1 .
func (l *Conn) UnauthenticatedBind(username string) error {
req := &SimpleBindRequest{
Username: username,
Password: "",
AllowEmptyPassword: true,
}

return nil
_, err := l.SimpleBind(req)
return err
}
2 changes: 2 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ const (
ErrorDebugging = 203
ErrorUnexpectedMessage = 204
ErrorUnexpectedResponse = 205
ErrorEmptyPassword = 206
)

// LDAPResultCodeMap contains string descriptions for LDAP error codes
Expand Down Expand Up @@ -104,6 +105,7 @@ var LDAPResultCodeMap = map[uint8]string{
ErrorDebugging: "Debugging Error",
ErrorUnexpectedMessage: "Unexpected Message",
ErrorUnexpectedResponse: "Unexpected Response",
ErrorEmptyPassword: "Empty password not allowed by the client",
}

func getLDAPResultCode(packet *ber.Packet) (code uint8, description string) {
Expand Down
64 changes: 31 additions & 33 deletions ldap_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package ldap_test
package ldap

import (
"crypto/tls"
"fmt"
"testing"

"gopkg.in/ldap.v2"
)

var ldapServer = "ldap.itd.umich.edu"
Expand All @@ -23,7 +21,7 @@ var attributes = []string{

func TestDial(t *testing.T) {
fmt.Printf("TestDial: starting...\n")
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
t.Errorf(err.Error())
return
Expand All @@ -34,7 +32,7 @@ func TestDial(t *testing.T) {

func TestDialTLS(t *testing.T) {
fmt.Printf("TestDialTLS: starting...\n")
l, err := ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapTLSPort), &tls.Config{InsecureSkipVerify: true})
l, err := DialTLS("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapTLSPort), &tls.Config{InsecureSkipVerify: true})
if err != nil {
t.Errorf(err.Error())
return
Expand All @@ -45,7 +43,7 @@ func TestDialTLS(t *testing.T) {

func TestStartTLS(t *testing.T) {
fmt.Printf("TestStartTLS: starting...\n")
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
t.Errorf(err.Error())
return
Expand All @@ -60,16 +58,16 @@ func TestStartTLS(t *testing.T) {

func TestSearch(t *testing.T) {
fmt.Printf("TestSearch: starting...\n")
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
t.Errorf(err.Error())
return
}
defer l.Close()

searchRequest := ldap.NewSearchRequest(
searchRequest := NewSearchRequest(
baseDN,
ldap.ScopeWholeSubtree, ldap.DerefAlways, 0, 0, false,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[0],
attributes,
nil)
Expand All @@ -85,16 +83,16 @@ func TestSearch(t *testing.T) {

func TestSearchStartTLS(t *testing.T) {
fmt.Printf("TestSearchStartTLS: starting...\n")
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
t.Errorf(err.Error())
return
}
defer l.Close()

searchRequest := ldap.NewSearchRequest(
searchRequest := NewSearchRequest(
baseDN,
ldap.ScopeWholeSubtree, ldap.DerefAlways, 0, 0, false,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[0],
attributes,
nil)
Expand Down Expand Up @@ -125,22 +123,22 @@ func TestSearchStartTLS(t *testing.T) {

func TestSearchWithPaging(t *testing.T) {
fmt.Printf("TestSearchWithPaging: starting...\n")
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
t.Errorf(err.Error())
return
}
defer l.Close()

err = l.Bind("", "")
err = l.UnauthenticatedBind("")
if err != nil {
t.Errorf(err.Error())
return
}

searchRequest := ldap.NewSearchRequest(
searchRequest := NewSearchRequest(
baseDN,
ldap.ScopeWholeSubtree, ldap.DerefAlways, 0, 0, false,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[2],
attributes,
nil)
Expand All @@ -152,12 +150,12 @@ func TestSearchWithPaging(t *testing.T) {

fmt.Printf("TestSearchWithPaging: %s -> num of entries = %d\n", searchRequest.Filter, len(sr.Entries))

searchRequest = ldap.NewSearchRequest(
searchRequest = NewSearchRequest(
baseDN,
ldap.ScopeWholeSubtree, ldap.DerefAlways, 0, 0, false,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[2],
attributes,
[]ldap.Control{ldap.NewControlPaging(5)})
[]Control{NewControlPaging(5)})
sr, err = l.SearchWithPaging(searchRequest, 5)
if err != nil {
t.Errorf(err.Error())
Expand All @@ -166,23 +164,23 @@ func TestSearchWithPaging(t *testing.T) {

fmt.Printf("TestSearchWithPaging: %s -> num of entries = %d\n", searchRequest.Filter, len(sr.Entries))

searchRequest = ldap.NewSearchRequest(
searchRequest = NewSearchRequest(
baseDN,
ldap.ScopeWholeSubtree, ldap.DerefAlways, 0, 0, false,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[2],
attributes,
[]ldap.Control{ldap.NewControlPaging(500)})
[]Control{NewControlPaging(500)})
sr, err = l.SearchWithPaging(searchRequest, 5)
if err == nil {
t.Errorf("expected an error when paging size in control in search request doesn't match size given in call, got none")
return
}
}

func searchGoroutine(t *testing.T, l *ldap.Conn, results chan *ldap.SearchResult, i int) {
searchRequest := ldap.NewSearchRequest(
func searchGoroutine(t *testing.T, l *Conn, results chan *SearchResult, i int) {
searchRequest := NewSearchRequest(
baseDN,
ldap.ScopeWholeSubtree, ldap.DerefAlways, 0, 0, false,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[i],
attributes,
nil)
Expand All @@ -197,17 +195,17 @@ func searchGoroutine(t *testing.T, l *ldap.Conn, results chan *ldap.SearchResult

func testMultiGoroutineSearch(t *testing.T, TLS bool, startTLS bool) {
fmt.Printf("TestMultiGoroutineSearch: starting...\n")
var l *ldap.Conn
var l *Conn
var err error
if TLS {
l, err = ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapTLSPort), &tls.Config{InsecureSkipVerify: true})
l, err = DialTLS("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapTLSPort), &tls.Config{InsecureSkipVerify: true})
if err != nil {
t.Errorf(err.Error())
return
}
defer l.Close()
} else {
l, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
l, err = Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
t.Errorf(err.Error())
return
Expand All @@ -223,9 +221,9 @@ func testMultiGoroutineSearch(t *testing.T, TLS bool, startTLS bool) {
}
}

results := make([]chan *ldap.SearchResult, len(filter))
results := make([]chan *SearchResult, len(filter))
for i := range filter {
results[i] = make(chan *ldap.SearchResult)
results[i] = make(chan *SearchResult)
go searchGoroutine(t, l, results[i], i)
}
for i := range filter {
Expand All @@ -245,17 +243,17 @@ func TestMultiGoroutineSearch(t *testing.T) {
}

func TestEscapeFilter(t *testing.T) {
if got, want := ldap.EscapeFilter("a\x00b(c)d*e\\f"), `a\00b\28c\29d\2ae\5cf`; got != want {
if got, want := EscapeFilter("a\x00b(c)d*e\\f"), `a\00b\28c\29d\2ae\5cf`; got != want {
t.Errorf("Got %s, expected %s", want, got)
}
if got, want := ldap.EscapeFilter("Lučić"), `Lu\c4\8di\c4\87`; got != want {
if got, want := EscapeFilter("Lučić"), `Lu\c4\8di\c4\87`; got != want {
t.Errorf("Got %s, expected %s", want, got)
}
}

func TestCompare(t *testing.T) {
fmt.Printf("TestCompare: starting...\n")
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
t.Fatal(err.Error())
}
Expand Down