-
Notifications
You must be signed in to change notification settings - Fork 0
/
auth_switch_response.go
130 lines (121 loc) · 3.25 KB
/
auth_switch_response.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
package server
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/sha256"
"crypto/tls"
"fmt"
. "github.com/892294101/go-mysql/mysql"
"github.com/pingcap/errors"
)
func (c *Conn) handleAuthSwitchResponse() error {
authData, err := c.readAuthSwitchRequestResponse()
if err != nil {
return err
}
switch c.authPluginName {
case AUTH_NATIVE_PASSWORD:
if err := c.acquirePassword(); err != nil {
return err
}
return c.compareNativePasswordAuthData(authData, c.password)
case AUTH_CACHING_SHA2_PASSWORD:
if !c.cachingSha2FullAuth {
// Switched auth method but no MoreData packet send yet
if err := c.compareCacheSha2PasswordAuthData(authData); err != nil {
return err
} else {
if c.cachingSha2FullAuth {
return c.handleAuthSwitchResponse()
}
return nil
}
}
// AuthMoreData packet already sent, do full auth
if err := c.handleCachingSha2PasswordFullAuth(authData); err != nil {
return err
}
c.writeCachingSha2Cache()
return nil
case AUTH_SHA256_PASSWORD:
cont, err := c.handlePublicKeyRetrieval(authData)
if err != nil {
return err
}
if !cont {
return nil
}
if err := c.acquirePassword(); err != nil {
return err
}
return c.compareSha256PasswordAuthData(authData, c.password)
default:
return errors.Errorf("unknown authentication plugin name '%s'", c.authPluginName)
}
}
func (c *Conn) handleCachingSha2PasswordFullAuth(authData []byte) error {
if err := c.acquirePassword(); err != nil {
return err
}
if tlsConn, ok := c.Conn.Conn.(*tls.Conn); ok {
if !tlsConn.ConnectionState().HandshakeComplete {
return errors.New("incomplete TSL handshake")
}
// connection is SSL/TLS, client should send plain password
// deal with the trailing \NUL added for plain text password received
if l := len(authData); l != 0 && authData[l-1] == 0x00 {
authData = authData[:l-1]
}
if bytes.Equal(authData, []byte(c.password)) {
return nil
}
return errAccessDenied(c.password)
} else {
// client either request for the public key or send the encrypted password
if len(authData) == 1 && authData[0] == 0x02 {
// send the public key
if err := c.writeAuthMoreDataPubkey(); err != nil {
return err
}
// read the encrypted password
var err error
if authData, err = c.readAuthSwitchRequestResponse(); err != nil {
return err
}
}
// the encrypted password
// decrypt
dbytes, err := rsa.DecryptOAEP(sha1.New(), rand.Reader, (c.serverConf.tlsConfig.Certificates[0].PrivateKey).(*rsa.PrivateKey), authData, nil)
if err != nil {
return err
}
plain := make([]byte, len(c.password)+1)
copy(plain, c.password)
for i := range plain {
j := i % len(c.salt)
plain[i] ^= c.salt[j]
}
if bytes.Equal(plain, dbytes) {
return nil
}
return errAccessDenied(c.password)
}
}
func (c *Conn) writeCachingSha2Cache() {
// write cache
if c.password == "" {
return
}
// SHA256(PASSWORD)
crypt := sha256.New()
crypt.Write([]byte(c.password))
m1 := crypt.Sum(nil)
// SHA256(SHA256(PASSWORD))
crypt.Reset()
crypt.Write(m1)
m2 := crypt.Sum(nil)
// caching_sha2_password will maintain an in-memory hash of `user`@`host` => SHA256(SHA256(PASSWORD))
c.serverConf.cacheShaPassword.Store(fmt.Sprintf("%s@%s", c.user, c.Conn.LocalAddr()), m2)
}