Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(open-mysql-db): refactor mock code #3831

Merged
merged 3 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,21 @@ public class MySqlListener implements AutoCloseable {

public static final String VERSION = "8.0.29";
public static final String VERSION_COMMENT = "";
public static final String CHARACTER_SET_UTF8MB4 = "utf8mb4";
public static final String COLLATION_UTF8MB4_0900_AI_CI = "utf8mb4_0900_ai_ci";
public static final String SETTINGS_LOWER_CASE_TABLE_NAMES = "2";
public static final String SETTINGS_INTERACTIVE_TIMEOUT = "28800";
public static final String SETTINGS_WAIT_TIMEOUT = "28800";
private static final Pattern SETTINGS_PATTERN =
Pattern.compile("@@([\\w.]+)(?:\\sAS\\s)?(\\w+)?");
private static final Pattern USE_DB_PATTERN = Pattern.compile("(?i)use (.+)");
private final SqlEngine sqlEngine;
private final int port;
private final Channel channel;
private final io.netty.channel.EventLoopGroup parentGroup;
private final EventLoopGroup childGroup;
private final EventExecutorGroup eventExecutorGroup;

public MySqlListener(int port, int executorGroupSize, SqlEngine sqlEngine) {
this.port = port;
this.sqlEngine = sqlEngine;

parentGroup = new NioEventLoopGroup();
Expand Down Expand Up @@ -87,7 +90,7 @@ public MySqlListener(int port, int executorGroupSize, SqlEngine sqlEngine) {
.childHandler(
new ChannelInitializer<NioSocketChannel>() {
@Override
protected void initChannel(NioSocketChannel ch) throws Exception {
protected void initChannel(NioSocketChannel ch) {
System.out.println("[mysql-protocol] Initializing child channel");
final ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast(new MysqlServerPacketEncoder());
Expand Down Expand Up @@ -160,14 +163,21 @@ private void handleHandshakeResponse(
Throwable cause = e.getCause();
int errorCode;
byte[] sqlState;
String errMsg =
Utils.getLocalDateTimeNow()
+ " "
+ Objects.requireNonNullElse(cause.getMessage(), e.getMessage());
if (cause instanceof IllegalAccessException) {
errorCode = 1045;
sqlState = "#28000".getBytes(StandardCharsets.US_ASCII);
String errMsg;
if (cause != null) {
errMsg =
Utils.getLocalDateTimeNow()
+ " "
+ Objects.requireNonNullElse(cause.getMessage(), e.getMessage());
if (cause instanceof IllegalAccessException) {
errorCode = 1045;
sqlState = "#28000".getBytes(StandardCharsets.US_ASCII);
} else {
errorCode = 1105;
sqlState = "#HY000".getBytes(StandardCharsets.US_ASCII);
}
} else {
errMsg = Utils.getLocalDateTimeNow() + " " + Objects.requireNonNullElse(e.getMessage(), "");
errorCode = 1105;
sqlState = "#HY000".getBytes(StandardCharsets.US_ASCII);
}
Expand Down Expand Up @@ -197,17 +207,14 @@ private void handleQuery(
+ userName
+ ", scramble411: "
+ scramble411.length);
Matcher useDbMatcher =
USE_DB_PATTERN.matcher(queryString.replaceAll("/\\*.*\\*/", "").toLowerCase().trim());
String queryStringWithoutComment =
queryString.replaceAll("/\\*.*\\*/", "").toLowerCase().trim();
Matcher useDbMatcher = USE_DB_PATTERN.matcher(queryStringWithoutComment);

if (isServerSettingsQuery(queryString)) {
sendSettingsResponse(ctx, query, remoteAddr);
} else if (queryString.replaceAll("/\\*.*\\*/", "").toLowerCase().trim().startsWith("set ")
&& !queryString
.replaceAll("/\\*.*\\*/", "")
.toLowerCase()
.trim()
.startsWith("set @@execute_mode=")) {
} else if (queryStringWithoutComment.startsWith("set ")
&& !queryStringWithoutComment.startsWith("set @@execute_mode=")) {
// ignore SET command
ctx.writeAndFlush(OkResponse.builder().sequenceId(query.getSequenceId() + 1).build());
} else if (useDbMatcher.matches()) {
Expand All @@ -218,12 +225,9 @@ private void handleQuery(
} else {
// Generic response
int[] sequenceId = new int[] {query.getSequenceId()};

boolean[] columnsWritten = new boolean[1];

ResultSetWriter resultSetWriter =
new ResultSetWriter() {

@Override
public void writeColumns(List<QueryResultColumn> columns) {
ctx.write(new ColumnCount(++sequenceId[0], columns.size()));
Expand Down Expand Up @@ -272,9 +276,7 @@ public void writeColumns(List<QueryResultColumn> columns) {
.build());
}
ctx.write(new EofResponse(++sequenceId[0], 0));

System.out.println("[mysql-protocol] Columns done");

columnsWritten[0] = !columns.isEmpty();
}

Expand All @@ -290,7 +292,6 @@ public void writeRow(List<String> row) {
@Override
public void finish() {
ctx.writeAndFlush(new EofResponse(++sequenceId[0], 0));

System.out.println("[mysql-protocol] All done");
}
};
Expand All @@ -311,22 +312,29 @@ public void finish() {
Throwable cause = e.getCause();
int errorCode;
byte[] sqlState;
String errMsg =
Utils.getLocalDateTimeNow()
+ " "
+ Objects.requireNonNullElse(cause.getMessage(), e.getMessage());
if (cause instanceof IllegalAccessException) {
errorCode = 1045;
sqlState = "#28000".getBytes(StandardCharsets.US_ASCII);
} else if (cause instanceof IllegalArgumentException) {
errorCode = 1064;
sqlState = "#42000".getBytes(StandardCharsets.US_ASCII);
} else if (e.getMessage()
.equalsIgnoreCase(
"java.sql.SQLException: executeSQL fail: [2000] please enter database first")) {
errorCode = 1046;
sqlState = "#3D000".getBytes(StandardCharsets.US_ASCII);
String errMsg;
if (cause != null) {
errMsg =
Utils.getLocalDateTimeNow()
+ " "
+ Objects.requireNonNullElse(cause.getMessage(), e.getMessage());
if (cause instanceof IllegalAccessException) {
errorCode = 1045;
sqlState = "#28000".getBytes(StandardCharsets.US_ASCII);
} else if (cause instanceof IllegalArgumentException) {
errorCode = 1064;
sqlState = "#42000".getBytes(StandardCharsets.US_ASCII);
} else if (e.getMessage()
.equalsIgnoreCase(
"java.sql.SQLException: executeSQL fail: [2000] please enter database first")) {
errorCode = 1046;
sqlState = "#3D000".getBytes(StandardCharsets.US_ASCII);
} else {
errorCode = 1105;
sqlState = "#HY000".getBytes(StandardCharsets.US_ASCII);
}
} else {
errMsg = Utils.getLocalDateTimeNow() + " " + e.getMessage();
errorCode = 1105;
sqlState = "#HY000".getBytes(StandardCharsets.US_ASCII);
}
Expand Down Expand Up @@ -435,16 +443,18 @@ private void sendSettingsResponse(
columnDefinitions.add(
newColumnDefinition(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_VAR_STRING, 12));
values.add("utf8mb4");
values.add(CHARACTER_SET_UTF8MB4);
break;
case "collation_server":
case "GLOBAL.collation_server":
case "collation_connection":
columnDefinitions.add(
newColumnDefinition(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_VAR_STRING, 63));
values.add("utf8mb4_0900_ai_ci");
values.add(COLLATION_UTF8MB4_0900_AI_CI);
break;
case "init_connect":
case "language":
columnDefinitions.add(
newColumnDefinition(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_VAR_STRING, 0));
Expand All @@ -454,13 +464,7 @@ private void sendSettingsResponse(
columnDefinitions.add(
newColumnDefinition(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_VAR_STRING, 21));
values.add("28800");
break;
case "language":
columnDefinitions.add(
newColumnDefinition(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_VAR_STRING, 0));
values.add("");
values.add(SETTINGS_INTERACTIVE_TIMEOUT);
break;
case "license":
columnDefinitions.add(
Expand All @@ -472,7 +476,7 @@ private void sendSettingsResponse(
columnDefinitions.add(
newColumnDefinition(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_LONGLONG, 63));
values.add("2");
values.add(SETTINGS_LOWER_CASE_TABLE_NAMES);
break;
case "max_allowed_packet":
case "global.max_allowed_packet":
Expand Down Expand Up @@ -528,7 +532,7 @@ private void sendSettingsResponse(
columnDefinitions.add(
newColumnDefinition(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_LONGLONG, 12));
values.add("28800");
values.add(SETTINGS_WAIT_TIMEOUT);
break;
case "query_cache_type":
columnDefinitions.add(
Expand All @@ -542,12 +546,6 @@ private void sendSettingsResponse(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_VAR_STRING, 0));
values.add(VERSION_COMMENT);
break;
case "collation_connection":
columnDefinitions.add(
newColumnDefinition(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_VAR_STRING, 63));
values.add("utf8mb4_0900_ai_ci");
break;
case "query_cache_size":
columnDefinitions.add(
newColumnDefinition(
Expand Down Expand Up @@ -578,14 +576,12 @@ private void sendSettingsResponse(
newColumnDefinition(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_VAR_STRING, 63));
values.add("REPEATABLE-READ");
// values.add("READ-UNCOMMITTED");
break;
case "session.transaction_read_only":
columnDefinitions.add(
newColumnDefinition(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_TINY, 1));
values.add("0");
// values.add("READ-UNCOMMITTED");
break;
default:
System.err.println("[mysql-protocol] Unknown system variable: " + systemVariable);
Expand Down Expand Up @@ -632,7 +628,7 @@ public ServerHandler() {
}

@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
public void channelActive(ChannelHandlerContext ctx) {
// todo may java.lang.NullPointerException
this.remoteAddr =
((InetSocketAddress) ctx.channel().remoteAddress()).getAddress().getHostAddress();
Expand All @@ -650,7 +646,7 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception {
}

@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
public void channelInactive(ChannelHandlerContext ctx) {
System.out.println("[mysql-protocol] Server channel inactive: " + new Date());
sqlEngine.close(getConnectionId(ctx));
}
Expand Down Expand Up @@ -682,6 +678,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
} else if (command.equals(Command.COM_PING)) {
ctx.writeAndFlush(OkResponse.builder().sequenceId(sequenceId + 1).build());
} else if (command.equals(Command.COM_FIELD_LIST)) {
// ToDo:
// https://dev.mysql.com/doc/dev/mysql-server/8.0.34/page_protocol_com_field_list.html
ctx.writeAndFlush(new EofResponse(sequenceId + 1, 0));
} else if (command.equals(Command.COM_STATISTICS)) {
String statString =
Expand All @@ -696,10 +694,10 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
cause.printStackTrace();
ctx.close();
sqlEngine.close(getConnectionId(ctx));
ctx.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,22 @@

/** An interface to callback events received from the MySQL server. */
public interface SqlEngine {
/**
* Execute query use database
*
* @param connectionId Connection id
* @param database Database name
* @throws IOException Thrown with SQLTimeoutException as the inner cause if when the driver has
* determined that the timeout value that was specified by the setQueryTimeout method has been
* exceeded and has at least attempted to cancel the currently running Statement, or
* SQLException as the inner cause if a database access error occurs.
*/
void useDatabase(int connectionId, String database) throws IOException;

/**
* Authenticating the user and password.
*
* @param connectionId Connection id
* @param database Database name
* @param userName User name
* @param scramble411 Encoded password
Expand All @@ -40,6 +51,7 @@ void authenticate(
/**
* Querying the SQL.
*
* @param connectionId Connection id
* @param resultSetWriter Response writer
* @param database Database name
* @param userName User name
Expand All @@ -59,5 +71,10 @@ void query(
String sql)
throws IOException;

void close(int connectionId) throws IOException;
/**
* Close resources of connection
*
* @param connectionId Connection id
*/
void close(int connectionId);
}
Loading
Loading