Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support mysql-client8 for sharding-proxy #5218

Merged
merged 6 commits into from
Apr 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,14 @@

import com.google.common.base.Strings;
import io.netty.channel.ChannelHandlerContext;
import org.apache.shardingsphere.database.protocol.mysql.constant.MySQLAuthenticationMethod;
import org.apache.shardingsphere.database.protocol.mysql.constant.MySQLCapabilityFlag;
import org.apache.shardingsphere.database.protocol.mysql.constant.MySQLServerErrorCode;
import org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLErrPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLOKPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLAuthSwitchRequestPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLAuthSwitchResponsePacket;
import org.apache.shardingsphere.database.protocol.mysql.constant.MySQLConnectionPhase;
import org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLHandshakePacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLHandshakeResponse41Packet;
import org.apache.shardingsphere.database.protocol.mysql.payload.MySQLPacketPayload;
Expand All @@ -42,37 +47,65 @@ public final class MySQLAuthenticationEngine implements AuthenticationEngine {

private final MySQLAuthenticationHandler authenticationHandler = new MySQLAuthenticationHandler();

private MySQLConnectionPhase connectionPhase = MySQLConnectionPhase.INITIAL_HANDSHAKE;

private int sequenceId;

private String username;

private byte[] authResponse;

private String database;

@Override
public void handshake(final ChannelHandlerContext context, final BackendConnection backendConnection) {
int connectionId = ConnectionIdGenerator.getInstance().nextId();
backendConnection.setConnectionId(connectionId);
connectionPhase = MySQLConnectionPhase.AUTH_PHASE_FAST_PATH;
context.writeAndFlush(new MySQLHandshakePacket(connectionId, authenticationHandler.getAuthPluginData()));
}

@Override
public boolean auth(final ChannelHandlerContext context, final PacketPayload payload, final BackendConnection backendConnection) {
MySQLHandshakeResponse41Packet response41 = new MySQLHandshakeResponse41Packet((MySQLPacketPayload) payload);
if (!Strings.isNullOrEmpty(response41.getDatabase()) && !LogicSchemas.getInstance().schemaExists(response41.getDatabase())) {
context.writeAndFlush(new MySQLErrPacket(response41.getSequenceId() + 1, MySQLServerErrorCode.ER_BAD_DB_ERROR, response41.getDatabase()));
return true;
if (MySQLConnectionPhase.AUTH_PHASE_FAST_PATH == connectionPhase) {
MySQLHandshakeResponse41Packet response41 = new MySQLHandshakeResponse41Packet((MySQLPacketPayload) payload);
username = response41.getUsername();
authResponse = response41.getAuthResponse();
database = response41.getDatabase();
sequenceId = response41.getSequenceId();
if (!Strings.isNullOrEmpty(database) && !LogicSchemas.getInstance().schemaExists(database)) {
context.writeAndFlush(new MySQLErrPacket(++sequenceId, MySQLServerErrorCode.ER_BAD_DB_ERROR, database));
return false;
}
if (0 != (response41.getCapabilityFlags() & MySQLCapabilityFlag.CLIENT_PLUGIN_AUTH.getValue())
&& !MySQLAuthenticationMethod.SECURE_PASSWORD_AUTHENTICATION.getMethodName().equals(response41.getAuthPluginName())) {
connectionPhase = MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH;
context.writeAndFlush(new MySQLAuthSwitchRequestPacket(++sequenceId,
MySQLAuthenticationMethod.SECURE_PASSWORD_AUTHENTICATION.getMethodName(), authenticationHandler.getAuthPluginData()));
return false;
}
} else if (MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH == connectionPhase) {
MySQLAuthSwitchResponsePacket authSwitchResponsePacket = new MySQLAuthSwitchResponsePacket((MySQLPacketPayload) payload);
sequenceId = authSwitchResponsePacket.getSequenceId();
authResponse = authSwitchResponsePacket.getAuthPluginResponse();
}
Optional<MySQLServerErrorCode> errorCode = authenticationHandler.login(response41);
Optional<MySQLServerErrorCode> errorCode = authenticationHandler.login(username, authResponse, database);
if (errorCode.isPresent()) {
context.writeAndFlush(getMySQLErrPacket(errorCode.get(), context, response41));
context.writeAndFlush(getMySQLErrPacket(errorCode.get(), context));
} else {
backendConnection.setCurrentSchema(response41.getDatabase());
backendConnection.setUserName(response41.getUsername());
context.writeAndFlush(new MySQLOKPacket(response41.getSequenceId() + 1));
backendConnection.setCurrentSchema(database);
backendConnection.setUserName(username);
context.writeAndFlush(new MySQLOKPacket(++sequenceId));
}
return true;
}

private MySQLErrPacket getMySQLErrPacket(final MySQLServerErrorCode errorCode, final ChannelHandlerContext context, final MySQLHandshakeResponse41Packet response41) {
private MySQLErrPacket getMySQLErrPacket(final MySQLServerErrorCode errorCode, final ChannelHandlerContext context) {
if (MySQLServerErrorCode.ER_DBACCESS_DENIED_ERROR == errorCode) {
return new MySQLErrPacket(response41.getSequenceId() + 1, MySQLServerErrorCode.ER_DBACCESS_DENIED_ERROR, response41.getUsername(), getHostAddress(context), response41.getDatabase());
return new MySQLErrPacket(++sequenceId, MySQLServerErrorCode.ER_DBACCESS_DENIED_ERROR, username, getHostAddress(context), database);
} else {
return new MySQLErrPacket(response41.getSequenceId() + 1, MySQLServerErrorCode.ER_ACCESS_DENIED_ERROR, response41.getUsername(), getHostAddress(context),
0 == response41.getAuthResponse().length ? "NO" : "YES");
return new MySQLErrPacket(++sequenceId, MySQLServerErrorCode.ER_ACCESS_DENIED_ERROR, username, getHostAddress(context),
0 == authResponse.length ? "NO" : "YES");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.apache.shardingsphere.core.rule.ProxyUser;
import org.apache.shardingsphere.database.protocol.mysql.constant.MySQLServerErrorCode;
import org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLAuthPluginData;
import org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLHandshakeResponse41Packet;
import org.apache.shardingsphere.shardingproxy.context.ShardingProxyContext;

import java.util.Arrays;
Expand All @@ -45,15 +44,17 @@ public final class MySQLAuthenticationHandler {
/**
* Login.
*
* @param response41 handshake response
* @param userName user name.
* @param authResponse auth response
* @param database database
* @return login success or failure
*/
public Optional<MySQLServerErrorCode> login(final MySQLHandshakeResponse41Packet response41) {
Optional<ProxyUser> user = getUser(response41.getUsername());
if (!user.isPresent() || !isPasswordRight(user.get().getPassword(), response41.getAuthResponse())) {
public Optional<MySQLServerErrorCode> login(final String userName, final byte[] authResponse, final String database) {
Optional<ProxyUser> user = getUser(userName);
if (!user.isPresent() || !isPasswordRight(user.get().getPassword(), authResponse)) {
return Optional.of(MySQLServerErrorCode.ER_ACCESS_DENIED_ERROR);
}
if (!isAuthorizedSchema(user.get().getAuthorizedSchemas(), response41.getDatabase())) {
if (!isAuthorizedSchema(user.get().getAuthorizedSchemas(), database)) {
return Optional.of(MySQLServerErrorCode.ER_DBACCESS_DENIED_ERROR);
}
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@
import lombok.SneakyThrows;
import org.apache.shardingsphere.core.rule.Authentication;
import org.apache.shardingsphere.core.rule.ProxyUser;
import org.apache.shardingsphere.database.protocol.mysql.constant.MySQLConnectionPhase;
import org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLErrPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLOKPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLHandshakePacket;
import org.apache.shardingsphere.database.protocol.mysql.payload.MySQLPacketPayload;
import org.apache.shardingsphere.shardingproxy.backend.communication.jdbc.connection.BackendConnection;
import org.apache.shardingsphere.shardingproxy.context.ShardingProxyContext;
import org.apache.shardingsphere.shardingproxy.frontend.ConnectionIdGenerator;
import org.apache.shardingsphere.shardingproxy.frontend.mysql.auth.MySQLAuthenticationEngine;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -79,6 +81,7 @@ public void assertHandshake() {

@Test
public void assertAuthWhenLoginSuccess() {
setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH);
ProxyUser proxyUser = new ProxyUser("", Collections.singleton("db1"));
setAuthentication(proxyUser);
when(payload.readStringNul()).thenReturn("root");
Expand All @@ -88,19 +91,21 @@ public void assertAuthWhenLoginSuccess() {

@Test
public void assertAuthWhenLoginFailure() {
setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH);
ProxyUser proxyUser = new ProxyUser("error", Collections.singleton("db1"));
setAuthentication(proxyUser);
when(payload.readStringNul()).thenReturn("root");
when(payload.readStringNulByBytes()).thenReturn("root".getBytes());
when(context.channel()).thenReturn(channel);
when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 3307));
when(context.channel()).thenReturn(channel);
assertTrue(mysqlFrontendEngine.getAuthEngine().auth(context, payload, mock(BackendConnection.class)));
verify(context).writeAndFlush(isA(MySQLErrPacket.class));
}

@Test
@SneakyThrows
public void assertErrorMsgWhenLoginFailure() {
setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH);
ProxyUser proxyUser = new ProxyUser("error", Collections.singleton("db1"));
setAuthentication(proxyUser);
when(payload.readStringNul()).thenReturn("root");
Expand All @@ -120,4 +125,11 @@ private void setAuthentication(final ProxyUser proxyUser) {
field.setAccessible(true);
field.set(ShardingProxyContext.getInstance(), authentication);
}

@SneakyThrows
private void setConnectionPhase(final MySQLConnectionPhase connectionPhase) {
Field field = MySQLAuthenticationEngine.class.getDeclaredField("connectionPhase");
field.setAccessible(true);
field.set(mysqlFrontendEngine.getAuthEngine(), connectionPhase);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@

import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import lombok.SneakyThrows;
import org.apache.shardingsphere.database.protocol.mysql.constant.MySQLCapabilityFlag;
import org.apache.shardingsphere.database.protocol.mysql.constant.MySQLConnectionPhase;
import org.apache.shardingsphere.database.protocol.mysql.constant.MySQLServerErrorCode;
import org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLErrPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLOKPacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLHandshakePacket;
import org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLHandshakeResponse41Packet;
import org.apache.shardingsphere.database.protocol.mysql.payload.MySQLPacketPayload;
import org.apache.shardingsphere.shardingproxy.backend.communication.jdbc.connection.BackendConnection;
import org.apache.shardingsphere.shardingproxy.backend.schema.LogicSchema;
Expand All @@ -38,8 +39,11 @@
import java.util.Map;
import java.util.Optional;

import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -72,12 +76,33 @@ public void assertHandshake() {
verify(context).writeAndFlush(any(MySQLHandshakePacket.class));
verify(backendConnection).setConnectionId(anyInt());
}

@Test
public void assertAuthenticationMethodMismatch() {
setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH);
MySQLPacketPayload payload = mock(MySQLPacketPayload.class);
ChannelHandlerContext channelHandlerContext = mock(ChannelHandlerContext.class);
when(payload.readInt4()).thenReturn(MySQLCapabilityFlag.CLIENT_PLUGIN_AUTH.getValue());
authenticationEngine.auth(channelHandlerContext, payload, mock(BackendConnection.class));
assertThat(getConnectionPhase(), is(MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH));
}

@Test
public void assertAuthSwitchResponse() {
setConnectionPhase(MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH);
MySQLPacketPayload payload = mock(MySQLPacketPayload.class);
ChannelHandlerContext channelHandlerContext = mock(ChannelHandlerContext.class);
when(payload.readStringEOFByBytes()).thenReturn(authResponse);
authenticationEngine.auth(channelHandlerContext, payload, mock(BackendConnection.class));
assertThat(getAuthResponse(), is(authResponse));
}

@Test
public void assertAuthWithLoginFail() throws NoSuchFieldException, IllegalAccessException {
setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH);
ChannelHandlerContext context = getContext();
setLogicSchemas(Collections.singletonMap("sharding_db", mock(LogicSchema.class)));
when(authenticationHandler.login(any(MySQLHandshakeResponse41Packet.class))).thenReturn(Optional.of(MySQLServerErrorCode.ER_ACCESS_DENIED_ERROR));
when(authenticationHandler.login(anyString(), any(), anyString())).thenReturn(Optional.of(MySQLServerErrorCode.ER_ACCESS_DENIED_ERROR));
authenticationEngine.auth(context, getPayload("root", "sharding_db", authResponse), mock(BackendConnection.class));
verify(context).writeAndFlush(any(MySQLErrPacket.class));
}
Expand All @@ -86,19 +111,21 @@ public void assertAuthWithLoginFail() throws NoSuchFieldException, IllegalAccess
public void assertAuthWithAbsentDatabase() throws NoSuchFieldException, IllegalAccessException {
ChannelHandlerContext context = getContext();
setLogicSchemas(Collections.singletonMap("sharding_db", mock(LogicSchema.class)));
setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH);
authenticationEngine.auth(context, getPayload("root", "ABSENT DATABASE", authResponse), mock(BackendConnection.class));
verify(context).writeAndFlush(any(MySQLErrPacket.class));
}

@Test
public void assertAuth() throws NoSuchFieldException, IllegalAccessException {
setConnectionPhase(MySQLConnectionPhase.AUTH_PHASE_FAST_PATH);
ChannelHandlerContext context = getContext();
when(authenticationHandler.login(any(MySQLHandshakeResponse41Packet.class))).thenReturn(Optional.empty());
when(authenticationHandler.login(anyString(), any(), anyString())).thenReturn(Optional.empty());
setLogicSchemas(Collections.singletonMap("sharding_db", mock(LogicSchema.class)));
authenticationEngine.auth(context, getPayload("root", "sharding_db", authResponse), mock(BackendConnection.class));
verify(context).writeAndFlush(any(MySQLOKPacket.class));
}

private void setLogicSchemas(final Map<String, LogicSchema> logicSchemas) throws NoSuchFieldException, IllegalAccessException {
Field field = LogicSchemas.class.getDeclaredField("logicSchemas");
field.setAccessible(true);
Expand Down Expand Up @@ -130,4 +157,25 @@ private SocketAddress getRemoteAddress() {
when(result.toString()).thenReturn("127.0.0.1");
return result;
}

@SneakyThrows
private void setConnectionPhase(final MySQLConnectionPhase connectionPhase) {
Field field = MySQLAuthenticationEngine.class.getDeclaredField("connectionPhase");
field.setAccessible(true);
field.set(authenticationEngine, connectionPhase);
}

@SneakyThrows
private MySQLConnectionPhase getConnectionPhase() {
Field field = MySQLAuthenticationEngine.class.getDeclaredField("connectionPhase");
field.setAccessible(true);
return (MySQLConnectionPhase) field.get(authenticationEngine);
}

@SneakyThrows
private byte[] getAuthResponse() {
Field field = MySQLAuthenticationEngine.class.getDeclaredField("authResponse");
field.setAccessible(true);
return (byte[]) field.get(authenticationEngine);
}
}
Loading