From 75fc3c76e30a0771bb39f1c26183f358eec29cbe Mon Sep 17 00:00:00 2001 From: kevinJCross Date: Wed, 15 Apr 2020 08:05:32 +0100 Subject: [PATCH] Bugfix: Allow TLS to work over socks proxy. --- broker.go | 21 +- broker_test.go | 560 +++++++++++++++++++++++---------------------- client_tls_test.go | 33 +-- config.go | 13 ++ 4 files changed, 327 insertions(+), 300 deletions(-) diff --git a/broker.go b/broker.go index 4f3991af7..bd329c4ca 100644 --- a/broker.go +++ b/broker.go @@ -154,25 +154,20 @@ func (b *Broker) Open(conf *Config) error { go withRecover(func() { defer b.lock.Unlock() - dialer := net.Dialer{ - Timeout: conf.Net.DialTimeout, - KeepAlive: conf.Net.KeepAlive, - LocalAddr: conf.Net.LocalAddr, - } - - if conf.Net.TLS.Enable { - b.conn, b.connErr = tls.DialWithDialer(&dialer, "tcp", b.addr, conf.Net.TLS.Config) - } else if conf.Net.Proxy.Enable { - b.conn, b.connErr = conf.Net.Proxy.Dialer.Dial("tcp", b.addr) - } else { - b.conn, b.connErr = dialer.Dial("tcp", b.addr) - } + dialer := conf.getDialer() + b.conn, b.connErr = dialer.Dial("tcp", b.addr) if b.connErr != nil { Logger.Printf("Failed to connect to broker %s: %s\n", b.addr, b.connErr) b.conn = nil atomic.StoreInt32(&b.opened, 0) return } + + if conf.Net.TLS.Enable { + Logger.Printf("Using tls") + b.conn = tls.Client(b.conn, conf.Net.TLS.Config) + } + b.conn = newBufConn(b.conn) b.conf = conf diff --git a/broker_test.go b/broker_test.go index e2b17462c..5807a2dc9 100644 --- a/broker_test.go +++ b/broker_test.go @@ -80,38 +80,40 @@ func TestBrokerAccessors(t *testing.T) { func TestSimpleBrokerCommunication(t *testing.T) { for _, tt := range brokerTestTable { - Logger.Printf("Testing broker communication for %s", tt.name) - mb := NewMockBroker(t, 0) - mb.Returns(&mockEncoder{tt.response}) - pendingNotify := make(chan brokerMetrics) - // Register a callback to be notified about successful requests - mb.SetNotifier(func(bytesRead, bytesWritten int) { - pendingNotify <- brokerMetrics{bytesRead, bytesWritten} + t.Run(tt.name, func(t *testing.T) { + Logger.Printf("Testing broker communication for %s", tt.name) + mb := NewMockBroker(t, 0) + mb.Returns(&mockEncoder{tt.response}) + pendingNotify := make(chan brokerMetrics) + // Register a callback to be notified about successful requests + mb.SetNotifier(func(bytesRead, bytesWritten int) { + pendingNotify <- brokerMetrics{bytesRead, bytesWritten} + }) + broker := NewBroker(mb.Addr()) + // Set the broker id in order to validate local broker metrics + broker.id = 0 + conf := NewConfig() + conf.Version = tt.version + err := broker.Open(conf) + if err != nil { + t.Fatal(err) + } + tt.runner(t, broker) + // Wait up to 500 ms for the remote broker to process the request and + // notify us about the metrics + timeout := 500 * time.Millisecond + select { + case mockBrokerMetrics := <-pendingNotify: + validateBrokerMetrics(t, broker, mockBrokerMetrics) + case <-time.After(timeout): + t.Errorf("No request received for: %s after waiting for %v", tt.name, timeout) + } + mb.Close() + err = broker.Close() + if err != nil { + t.Error(err) + } }) - broker := NewBroker(mb.Addr()) - // Set the broker id in order to validate local broker metrics - broker.id = 0 - conf := NewConfig() - conf.Version = tt.version - err := broker.Open(conf) - if err != nil { - t.Fatal(err) - } - tt.runner(t, broker) - // Wait up to 500 ms for the remote broker to process the request and - // notify us about the metrics - timeout := 500 * time.Millisecond - select { - case mockBrokerMetrics := <-pendingNotify: - validateBrokerMetrics(t, broker, mockBrokerMetrics) - case <-time.After(timeout): - t.Errorf("No request received for: %s after waiting for %v", tt.name, timeout) - } - mb.Close() - err = broker.Close() - if err != nil { - t.Error(err) - } } } @@ -204,58 +206,60 @@ func TestSASLOAuthBearer(t *testing.T) { } for i, test := range testTable { - // mockBroker mocks underlying network logic and broker responses - mockBroker := NewMockBroker(t, 0) + t.Run(test.name, func(t *testing.T) { + // mockBroker mocks underlying network logic and broker responses + mockBroker := NewMockBroker(t, 0) - mockBroker.SetHandlerByMap(map[string]MockResponse{ - "SaslAuthenticateRequest": test.mockSASLAuthResponse, - "SaslHandshakeRequest": test.mockSASLHandshakeResponse, - }) + mockBroker.SetHandlerByMap(map[string]MockResponse{ + "SaslAuthenticateRequest": test.mockSASLAuthResponse, + "SaslHandshakeRequest": test.mockSASLHandshakeResponse, + }) - // broker executes SASL requests against mockBroker - broker := NewBroker(mockBroker.Addr()) - broker.requestRate = metrics.NilMeter{} - broker.outgoingByteRate = metrics.NilMeter{} - broker.incomingByteRate = metrics.NilMeter{} - broker.requestSize = metrics.NilHistogram{} - broker.responseSize = metrics.NilHistogram{} - broker.responseRate = metrics.NilMeter{} - broker.requestLatency = metrics.NilHistogram{} - broker.requestsInFlight = metrics.NilCounter{} + // broker executes SASL requests against mockBroker + broker := NewBroker(mockBroker.Addr()) + broker.requestRate = metrics.NilMeter{} + broker.outgoingByteRate = metrics.NilMeter{} + broker.incomingByteRate = metrics.NilMeter{} + broker.requestSize = metrics.NilHistogram{} + broker.responseSize = metrics.NilHistogram{} + broker.responseRate = metrics.NilMeter{} + broker.requestLatency = metrics.NilHistogram{} + broker.requestsInFlight = metrics.NilCounter{} - conf := NewConfig() - conf.Net.SASL.Mechanism = SASLTypeOAuth - conf.Net.SASL.TokenProvider = test.tokProvider + conf := NewConfig() + conf.Net.SASL.Mechanism = SASLTypeOAuth + conf.Net.SASL.TokenProvider = test.tokProvider - broker.conf = conf + broker.conf = conf - dialer := net.Dialer{ - Timeout: conf.Net.DialTimeout, - KeepAlive: conf.Net.KeepAlive, - LocalAddr: conf.Net.LocalAddr, - } + dialer := net.Dialer{ + Timeout: conf.Net.DialTimeout, + KeepAlive: conf.Net.KeepAlive, + LocalAddr: conf.Net.LocalAddr, + } - conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String()) + conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String()) - if err != nil { - t.Fatal(err) - } + if err != nil { + t.Fatal(err) + } - broker.conn = conn + broker.conn = conn - err = broker.authenticateViaSASL() + err = broker.authenticateViaSASL() - if test.expectedBrokerError != ErrNoError { - if test.expectedBrokerError != err { - t.Errorf("[%d]:[%s] Expected %s auth error, got %s\n", i, test.name, test.expectedBrokerError, err) + if test.expectedBrokerError != ErrNoError { + if test.expectedBrokerError != err { + t.Errorf("[%d]:[%s] Expected %s auth error, got %s\n", i, test.name, test.expectedBrokerError, err) + } + } else if test.expectClientErr && err == nil { + t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name) + } else if !test.expectClientErr && err != nil { + t.Errorf("[%d]:[%s] Unexpected error, got %s\n", i, test.name, err) } - } else if test.expectClientErr && err == nil { - t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name) - } else if !test.expectClientErr && err != nil { - t.Errorf("[%d]:[%s] Unexpected error, got %s\n", i, test.name, err) - } - mockBroker.Close() + mockBroker.Close() + }) } } @@ -264,7 +268,7 @@ type MockSCRAMClient struct { done bool } -func (m *MockSCRAMClient) Begin(userName, password, authzID string) (err error) { +func (m *MockSCRAMClient) Begin(_, _, _ string) (err error) { return nil } @@ -325,70 +329,72 @@ func TestSASLSCRAMSHAXXX(t *testing.T) { } for i, test := range testTable { - // mockBroker mocks underlying network logic and broker responses - mockBroker := NewMockBroker(t, 0) - broker := NewBroker(mockBroker.Addr()) - // broker executes SASL requests against mockBroker - broker.requestRate = metrics.NilMeter{} - broker.outgoingByteRate = metrics.NilMeter{} - broker.incomingByteRate = metrics.NilMeter{} - broker.requestSize = metrics.NilHistogram{} - broker.responseSize = metrics.NilHistogram{} - broker.responseRate = metrics.NilMeter{} - broker.requestLatency = metrics.NilHistogram{} - broker.requestsInFlight = metrics.NilCounter{} + t.Run(test.name, func(t *testing.T) { + // mockBroker mocks underlying network logic and broker responses + mockBroker := NewMockBroker(t, 0) + broker := NewBroker(mockBroker.Addr()) + // broker executes SASL requests against mockBroker + broker.requestRate = metrics.NilMeter{} + broker.outgoingByteRate = metrics.NilMeter{} + broker.incomingByteRate = metrics.NilMeter{} + broker.requestSize = metrics.NilHistogram{} + broker.responseSize = metrics.NilHistogram{} + broker.responseRate = metrics.NilMeter{} + broker.requestLatency = metrics.NilHistogram{} + broker.requestsInFlight = metrics.NilCounter{} - mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).SetAuthBytes([]byte(test.scramChallengeResp)) - mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).SetEnabledMechanisms([]string{SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512}) + mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).SetAuthBytes([]byte(test.scramChallengeResp)) + mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).SetEnabledMechanisms([]string{SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512}) - if test.mockSASLAuthErr != ErrNoError { - mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockSASLAuthErr) - } - if test.mockHandshakeErr != ErrNoError { - mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr) - } + if test.mockSASLAuthErr != ErrNoError { + mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockSASLAuthErr) + } + if test.mockHandshakeErr != ErrNoError { + mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr) + } - mockBroker.SetHandlerByMap(map[string]MockResponse{ - "SaslAuthenticateRequest": mockSASLAuthResponse, - "SaslHandshakeRequest": mockSASLHandshakeResponse, - }) + mockBroker.SetHandlerByMap(map[string]MockResponse{ + "SaslAuthenticateRequest": mockSASLAuthResponse, + "SaslHandshakeRequest": mockSASLHandshakeResponse, + }) - conf := NewConfig() - conf.Net.SASL.Mechanism = SASLTypeSCRAMSHA512 - conf.Net.SASL.SCRAMClientGeneratorFunc = func() SCRAMClient { return test.scramClient } + conf := NewConfig() + conf.Net.SASL.Mechanism = SASLTypeSCRAMSHA512 + conf.Net.SASL.SCRAMClientGeneratorFunc = func() SCRAMClient { return test.scramClient } - broker.conf = conf - dialer := net.Dialer{ - Timeout: conf.Net.DialTimeout, - KeepAlive: conf.Net.KeepAlive, - LocalAddr: conf.Net.LocalAddr, - } + broker.conf = conf + dialer := net.Dialer{ + Timeout: conf.Net.DialTimeout, + KeepAlive: conf.Net.KeepAlive, + LocalAddr: conf.Net.LocalAddr, + } - conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String()) + conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String()) - if err != nil { - t.Fatal(err) - } + if err != nil { + t.Fatal(err) + } - broker.conn = conn + broker.conn = conn - err = broker.authenticateViaSASL() + err = broker.authenticateViaSASL() - if test.mockSASLAuthErr != ErrNoError { - if test.mockSASLAuthErr != err { - t.Errorf("[%d]:[%s] Expected %s SASL authentication error, got %s\n", i, test.name, test.mockHandshakeErr, err) - } - } else if test.mockHandshakeErr != ErrNoError { - if test.mockHandshakeErr != err { - t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.mockHandshakeErr, err) + if test.mockSASLAuthErr != ErrNoError { + if test.mockSASLAuthErr != err { + t.Errorf("[%d]:[%s] Expected %s SASL authentication error, got %s\n", i, test.name, test.mockHandshakeErr, err) + } + } else if test.mockHandshakeErr != ErrNoError { + if test.mockHandshakeErr != err { + t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.mockHandshakeErr, err) + } + } else if test.expectClientErr && err == nil { + t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name) + } else if !test.expectClientErr && err != nil { + t.Errorf("[%d]:[%s] Unexpected error, got %s\n", i, test.name, err) } - } else if test.expectClientErr && err == nil { - t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name) - } else if !test.expectClientErr && err != nil { - t.Errorf("[%d]:[%s] Unexpected error, got %s\n", i, test.name, err) - } - mockBroker.Close() + mockBroker.Close() + }) } } @@ -424,96 +430,98 @@ func TestSASLPlainAuth(t *testing.T) { } for i, test := range testTable { - // mockBroker mocks underlying network logic and broker responses - mockBroker := NewMockBroker(t, 0) + t.Run(test.name, func(t *testing.T) { + // mockBroker mocks underlying network logic and broker responses + mockBroker := NewMockBroker(t, 0) - mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t). - SetAuthBytes([]byte(`response_payload`)) + mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t). + SetAuthBytes([]byte(`response_payload`)) - if test.mockAuthErr != ErrNoError { - mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockAuthErr) - } + if test.mockAuthErr != ErrNoError { + mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockAuthErr) + } - mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t). - SetEnabledMechanisms([]string{SASLTypePlaintext}) + mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t). + SetEnabledMechanisms([]string{SASLTypePlaintext}) - if test.mockHandshakeErr != ErrNoError { - mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr) - } + if test.mockHandshakeErr != ErrNoError { + mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr) + } - mockBroker.SetHandlerByMap(map[string]MockResponse{ - "SaslAuthenticateRequest": mockSASLAuthResponse, - "SaslHandshakeRequest": mockSASLHandshakeResponse, - }) + mockBroker.SetHandlerByMap(map[string]MockResponse{ + "SaslAuthenticateRequest": mockSASLAuthResponse, + "SaslHandshakeRequest": mockSASLHandshakeResponse, + }) - // broker executes SASL requests against mockBroker - broker := NewBroker(mockBroker.Addr()) - broker.requestRate = metrics.NilMeter{} - broker.outgoingByteRate = metrics.NilMeter{} - broker.incomingByteRate = metrics.NilMeter{} - broker.requestSize = metrics.NilHistogram{} - broker.responseSize = metrics.NilHistogram{} - broker.responseRate = metrics.NilMeter{} - broker.requestLatency = metrics.NilHistogram{} - broker.requestsInFlight = metrics.NilCounter{} + // broker executes SASL requests against mockBroker + broker := NewBroker(mockBroker.Addr()) + broker.requestRate = metrics.NilMeter{} + broker.outgoingByteRate = metrics.NilMeter{} + broker.incomingByteRate = metrics.NilMeter{} + broker.requestSize = metrics.NilHistogram{} + broker.responseSize = metrics.NilHistogram{} + broker.responseRate = metrics.NilMeter{} + broker.requestLatency = metrics.NilHistogram{} + broker.requestsInFlight = metrics.NilCounter{} - conf := NewConfig() - conf.Net.SASL.Mechanism = SASLTypePlaintext - conf.Net.SASL.AuthIdentity = test.authidentity - conf.Net.SASL.User = "token" - conf.Net.SASL.Password = "password" - conf.Net.SASL.Version = SASLHandshakeV1 + conf := NewConfig() + conf.Net.SASL.Mechanism = SASLTypePlaintext + conf.Net.SASL.AuthIdentity = test.authidentity + conf.Net.SASL.User = "token" + conf.Net.SASL.Password = "password" + conf.Net.SASL.Version = SASLHandshakeV1 - broker.conf = conf - broker.conf.Version = V1_0_0_0 - dialer := net.Dialer{ - Timeout: conf.Net.DialTimeout, - KeepAlive: conf.Net.KeepAlive, - LocalAddr: conf.Net.LocalAddr, - } - - conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String()) - - if err != nil { - t.Fatal(err) - } - - broker.conn = conn - - err = broker.authenticateViaSASL() - if err == nil { - for _, rr := range mockBroker.History() { - switch r := rr.Request.(type) { - case *SaslAuthenticateRequest: - x := bytes.SplitN(r.SaslAuthBytes, []byte("\x00"), 3) - if string(x[0]) != conf.Net.SASL.AuthIdentity { - t.Errorf("[%d]:[%s] expected %s auth identity, got %s\n", i, test.name, conf.Net.SASL.AuthIdentity, x[0]) - } - if string(x[1]) != conf.Net.SASL.User { - t.Errorf("[%d]:[%s] expected %s user, got %s\n", i, test.name, conf.Net.SASL.User, x[1]) - } - if string(x[2]) != conf.Net.SASL.Password { - t.Errorf("[%d]:[%s] expected %s password, got %s\n", i, test.name, conf.Net.SASL.Password, x[2]) + broker.conf = conf + broker.conf.Version = V1_0_0_0 + dialer := net.Dialer{ + Timeout: conf.Net.DialTimeout, + KeepAlive: conf.Net.KeepAlive, + LocalAddr: conf.Net.LocalAddr, + } + + conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String()) + + if err != nil { + t.Fatal(err) + } + + broker.conn = conn + + err = broker.authenticateViaSASL() + if err == nil { + for _, rr := range mockBroker.History() { + switch r := rr.Request.(type) { + case *SaslAuthenticateRequest: + x := bytes.SplitN(r.SaslAuthBytes, []byte("\x00"), 3) + if string(x[0]) != conf.Net.SASL.AuthIdentity { + t.Errorf("[%d]:[%s] expected %s auth identity, got %s\n", i, test.name, conf.Net.SASL.AuthIdentity, x[0]) + } + if string(x[1]) != conf.Net.SASL.User { + t.Errorf("[%d]:[%s] expected %s user, got %s\n", i, test.name, conf.Net.SASL.User, x[1]) + } + if string(x[2]) != conf.Net.SASL.Password { + t.Errorf("[%d]:[%s] expected %s password, got %s\n", i, test.name, conf.Net.SASL.Password, x[2]) + } } } } - } - if test.mockAuthErr != ErrNoError { - if test.mockAuthErr != err { - t.Errorf("[%d]:[%s] Expected %s auth error, got %s\n", i, test.name, test.mockAuthErr, err) - } - } else if test.mockHandshakeErr != ErrNoError { - if test.mockHandshakeErr != err { - t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.mockHandshakeErr, err) + if test.mockAuthErr != ErrNoError { + if test.mockAuthErr != err { + t.Errorf("[%d]:[%s] Expected %s auth error, got %s\n", i, test.name, test.mockAuthErr, err) + } + } else if test.mockHandshakeErr != ErrNoError { + if test.mockHandshakeErr != err { + t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.mockHandshakeErr, err) + } + } else if test.expectClientErr && err == nil { + t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name) + } else if !test.expectClientErr && err != nil { + t.Errorf("[%d]:[%s] Unexpected error, got %s\n", i, test.name, err) } - } else if test.expectClientErr && err == nil { - t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name) - } else if !test.expectClientErr && err != nil { - t.Errorf("[%d]:[%s] Unexpected error, got %s\n", i, test.name, err) - } - mockBroker.Close() + mockBroker.Close() + }) } } @@ -616,73 +624,75 @@ func TestGSSAPIKerberosAuth_Authorize(t *testing.T) { }, } for i, test := range testTable { - mockBroker := NewMockBroker(t, 0) - // broker executes SASL requests against mockBroker + t.Run(test.name, func(t *testing.T) { + mockBroker := NewMockBroker(t, 0) + // broker executes SASL requests against mockBroker + + mockBroker.SetGSSAPIHandler(func(bytes []byte) []byte { + return nil + }) + broker := NewBroker(mockBroker.Addr()) + broker.requestRate = metrics.NilMeter{} + broker.outgoingByteRate = metrics.NilMeter{} + broker.incomingByteRate = metrics.NilMeter{} + broker.requestSize = metrics.NilHistogram{} + broker.responseSize = metrics.NilHistogram{} + broker.responseRate = metrics.NilMeter{} + broker.requestLatency = metrics.NilHistogram{} + broker.requestsInFlight = metrics.NilCounter{} + conf := NewConfig() + conf.Net.SASL.Mechanism = SASLTypeGSSAPI + conf.Net.SASL.GSSAPI.ServiceName = "kafka" + conf.Net.SASL.GSSAPI.KerberosConfigPath = "krb5.conf" + conf.Net.SASL.GSSAPI.Realm = "EXAMPLE.COM" + conf.Net.SASL.GSSAPI.Username = "kafka" + conf.Net.SASL.GSSAPI.Password = "kafka" + conf.Net.SASL.GSSAPI.KeyTabPath = "kafka.keytab" + conf.Net.SASL.GSSAPI.AuthType = KRB5_USER_AUTH + broker.conf = conf + broker.conf.Version = V1_0_0_0 + dialer := net.Dialer{ + Timeout: conf.Net.DialTimeout, + KeepAlive: conf.Net.KeepAlive, + LocalAddr: conf.Net.LocalAddr, + } + + conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String()) - mockBroker.SetGSSAPIHandler(func(bytes []byte) []byte { - return nil - }) - broker := NewBroker(mockBroker.Addr()) - broker.requestRate = metrics.NilMeter{} - broker.outgoingByteRate = metrics.NilMeter{} - broker.incomingByteRate = metrics.NilMeter{} - broker.requestSize = metrics.NilHistogram{} - broker.responseSize = metrics.NilHistogram{} - broker.responseRate = metrics.NilMeter{} - broker.requestLatency = metrics.NilHistogram{} - broker.requestsInFlight = metrics.NilCounter{} - conf := NewConfig() - conf.Net.SASL.Mechanism = SASLTypeGSSAPI - conf.Net.SASL.GSSAPI.ServiceName = "kafka" - conf.Net.SASL.GSSAPI.KerberosConfigPath = "krb5.conf" - conf.Net.SASL.GSSAPI.Realm = "EXAMPLE.COM" - conf.Net.SASL.GSSAPI.Username = "kafka" - conf.Net.SASL.GSSAPI.Password = "kafka" - conf.Net.SASL.GSSAPI.KeyTabPath = "kafka.keytab" - conf.Net.SASL.GSSAPI.AuthType = KRB5_USER_AUTH - broker.conf = conf - broker.conf.Version = V1_0_0_0 - dialer := net.Dialer{ - Timeout: conf.Net.DialTimeout, - KeepAlive: conf.Net.KeepAlive, - LocalAddr: conf.Net.LocalAddr, - } - - conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String()) - - if err != nil { - t.Fatal(err) - } - - gssapiHandler := KafkaGSSAPIHandler{ - client: &MockKerberosClient{}, - badResponse: test.badResponse, - badKeyChecksum: test.badKeyChecksum, - } - mockBroker.SetGSSAPIHandler(gssapiHandler.MockKafkaGSSAPI) - broker.conn = conn - if test.mockKerberosClient { - broker.kerberosAuthenticator.NewKerberosClientFunc = func(config *GSSAPIConfig) (KerberosClient, error) { - return &MockKerberosClient{ - mockError: test.error, - errorStage: test.errorStage, - }, nil - } - } else { - broker.kerberosAuthenticator.NewKerberosClientFunc = nil - } - - err = broker.authenticateViaSASL() - - if err != nil && test.error != nil { - if test.error.Error() != err.Error() { + if err != nil { + t.Fatal(err) + } + + gssapiHandler := KafkaGSSAPIHandler{ + client: &MockKerberosClient{}, + badResponse: test.badResponse, + badKeyChecksum: test.badKeyChecksum, + } + mockBroker.SetGSSAPIHandler(gssapiHandler.MockKafkaGSSAPI) + broker.conn = conn + if test.mockKerberosClient { + broker.kerberosAuthenticator.NewKerberosClientFunc = func(config *GSSAPIConfig) (KerberosClient, error) { + return &MockKerberosClient{ + mockError: test.error, + errorStage: test.errorStage, + }, nil + } + } else { + broker.kerberosAuthenticator.NewKerberosClientFunc = nil + } + + err = broker.authenticateViaSASL() + + if err != nil && test.error != nil { + if test.error.Error() != err.Error() { + t.Errorf("[%d] Expected error:%s, got:%s.", i, test.error, err) + } + } else if (err == nil && test.error != nil) || (err != nil && test.error == nil) { t.Errorf("[%d] Expected error:%s, got:%s.", i, test.error, err) } - } else if (err == nil && test.error != nil) || (err != nil && test.error == nil) { - t.Errorf("[%d] Expected error:%s, got:%s.", i, test.error, err) - } - mockBroker.Close() + mockBroker.Close() + }) } } @@ -723,17 +733,19 @@ func TestBuildClientFirstMessage(t *testing.T) { } for i, test := range testTable { - actual, err := buildClientFirstMessage(test.token) - - if !reflect.DeepEqual(test.expected, actual) { - t.Errorf("Expected %s, got %s\n", test.expected, actual) - } - if test.expectError && err == nil { - t.Errorf("[%d]:[%s] Expected an error but did not get one", i, test.name) - } - if !test.expectError && err != nil { - t.Errorf("[%d]:[%s] Expected no error but got %s\n", i, test.name, err) - } + t.Run(test.name, func(t *testing.T) { + actual, err := buildClientFirstMessage(test.token) + + if !reflect.DeepEqual(test.expected, actual) { + t.Errorf("Expected %s, got %s\n", test.expected, actual) + } + if test.expectError && err == nil { + t.Errorf("[%d]:[%s] Expected an error but did not get one", i, test.name) + } + if !test.expectError && err != nil { + t.Errorf("[%d]:[%s] Expected no error but got %s\n", i, test.name, err) + } + }) } } diff --git a/client_tls_test.go b/client_tls_test.go index 0e47e17c2..e36612705 100644 --- a/client_tls_test.go +++ b/client_tls_test.go @@ -1,16 +1,15 @@ package sarama import ( - "math/big" - "net" - "testing" - "time" - "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "math/big" + "net" + "testing" + "time" ) func TestTLS(t *testing.T) { @@ -95,10 +94,12 @@ func TestTLS(t *testing.T) { } for _, tc := range []struct { + name string Succeed bool Server, Client *tls.Config }{ - { // Verify client fails if wrong CA cert pool is specified + { + name: "Verify client fails if wrong CA cert pool is specified", Succeed: false, Server: serverTLSConfig, Client: &tls.Config{ @@ -109,7 +110,8 @@ func TestTLS(t *testing.T) { }}, }, }, - { // Verify client fails if wrong key is specified + { + name: "Verify client fails if wrong key is specified", Succeed: false, Server: serverTLSConfig, Client: &tls.Config{ @@ -120,7 +122,8 @@ func TestTLS(t *testing.T) { }}, }, }, - { // Verify client fails if wrong cert is specified + { + name: "Verify client fails if wrong cert is specified", Succeed: false, Server: serverTLSConfig, Client: &tls.Config{ @@ -131,7 +134,8 @@ func TestTLS(t *testing.T) { }}, }, }, - { // Verify client fails if no CAs are specified + { + name: "Verify client fails if no CAs are specified", Succeed: false, Server: serverTLSConfig, Client: &tls.Config{ @@ -141,18 +145,21 @@ func TestTLS(t *testing.T) { }}, }, }, - { // Verify client fails if no keys are specified + { + name: "Verify client fails if no keys are specified", Succeed: false, Server: serverTLSConfig, Client: &tls.Config{ RootCAs: pool, }, }, - { // Finally, verify it all works happily with client and server cert in place + { + name: "Finally, verify it all works happily with client and server cert in place", Succeed: true, Server: serverTLSConfig, Client: &tls.Config{ - RootCAs: pool, + RootCAs: pool, + ServerName: "127.0.0.1", Certificates: []tls.Certificate{{ Certificate: [][]byte{clientDer}, PrivateKey: clientkey, @@ -160,7 +167,7 @@ func TestTLS(t *testing.T) { }, }, } { - doListenerTLSTest(t, tc.Succeed, tc.Server, tc.Client) + t.Run(tc.name, func(t *testing.T) { doListenerTLSTest(t, tc.Succeed, tc.Server, tc.Client) }) } } diff --git a/config.go b/config.go index e899820cb..0ce308f80 100644 --- a/config.go +++ b/config.go @@ -734,3 +734,16 @@ func (c *Config) Validate() error { return nil } + +func (c *Config) getDialer() proxy.Dialer { + if c.Net.Proxy.Enable { + Logger.Printf("using proxy %s", c.Net.Proxy.Dialer) + return c.Net.Proxy.Dialer + } else { + return &net.Dialer{ + Timeout: c.Net.DialTimeout, + KeepAlive: c.Net.KeepAlive, + LocalAddr: c.Net.LocalAddr, + } + } +}