Skip to content

Commit

Permalink
Support numeric promotion when values in a binary expression of a WHE…
Browse files Browse the repository at this point in the history
…RE clause have different data types.
  • Loading branch information
DaveRMaltby committed Mar 27, 2024
1 parent 4b86c00 commit afa3b7c
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 9 deletions.
79 changes: 74 additions & 5 deletions src/Core/LogicalEntities/SqlBinaryExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,15 @@ private Expression GetRegexIsMatchExpression(Expression column, Expression likeP

private BinaryExpression GetBinaryExpression(Expression left, Expression right, Func<Expression, Expression, BinaryExpression> binaryOperatorExpressionFunc)
{
var leftNullable = left.Type.IsGenericType && left.Type.GetGenericTypeDefinition() == typeof(Nullable<>);
var rightNullable = right.Type.IsGenericType && right.Type.GetGenericTypeDefinition() == typeof(Nullable<>);
var commonType = GetCommonType(left.Type, right.Type);
var leftCasted = CastExpression(left, commonType);
var rightCasted = CastExpression(right, commonType);

var leftNullable = IsNullable(left.Type);
var rightNullable = IsNullable(right.Type);

if (!leftNullable && !rightNullable)
return binaryOperatorExpressionFunc(left, right);
return binaryOperatorExpressionFunc(leftCasted, rightCasted);

var leftHasValue = leftNullable ? Expression.Property(left, "HasValue") : null;
var leftValue = leftNullable ? Expression.Property(left, "Value") : null;
Expand All @@ -91,7 +95,7 @@ private BinaryExpression GetBinaryExpression(Expression left, Expression right,
if (leftNullable && !rightNullable)
{
var ternary = Expression.Condition(Expression.Not(leftHasValue), leftNull, Expression.Convert(leftValue, left.Type));
var argument = Expression.Convert(right, left.Type);
var argument = Expression.Convert(rightCasted, left.Type);
var binaryOperatorExpression = binaryOperatorExpressionFunc(ternary, argument);
return Expression.And(leftHasValue, binaryOperatorExpression);
}
Expand All @@ -100,7 +104,7 @@ private BinaryExpression GetBinaryExpression(Expression left, Expression right,
if (rightNullable && !leftNullable)
{
var ternary = Expression.Condition(Expression.Not(rightHasValue), rightNull, Expression.Convert(rightValue, right.Type));
var argument = Expression.Convert(left, right.Type);
var argument = Expression.Convert(leftCasted, right.Type);
var binaryOperatorExpression = binaryOperatorExpressionFunc(argument, ternary);
return Expression.And(rightHasValue, binaryOperatorExpression);
}
Expand All @@ -114,6 +118,71 @@ private BinaryExpression GetBinaryExpression(Expression left, Expression right,
return resultExpression;
}

private Type GetCommonType(Type type1, Type type2)
{
// If both types are same, return that type
if (type1 == type2)
return type1;

// If one type is assignable from the other, return the assignable type
if (type1.IsAssignableFrom(type2))
return type1;
if (type2.IsAssignableFrom(type1))
return type2;

// If one of the types is nullable, get the underlying type and try again
if (IsNullable(type1))
return GetCommonType(Nullable.GetUnderlyingType(type1), type2);
if (IsNullable(type2))
return GetCommonType(type1, Nullable.GetUnderlyingType(type2));

// Example of handling some numeric promotions explicitly
Dictionary<Type, int> typePrecedence = new Dictionary<Type, int>
{
{ typeof(byte), 1 },
{ typeof(short), 2 },
{ typeof(int), 3 },
{ typeof(long), 4 },
{ typeof(float), 5 },
{ typeof(double), 6 },
{ typeof(decimal), 7 }
};

if (typePrecedence.TryGetValue(type1, out int type1Precedence) &&
typePrecedence.TryGetValue(type2, out int type2Precedence))
{
Type higherPrecedenceType = type1Precedence > type2Precedence ? type1 : type2;

// You might want to handle nullable types here as well
return higherPrecedenceType;
}

// If no common type found, throw an exception
throw new Exception($"No common type found for {type1} and {type2}");
}

private Expression CastExpression(Expression expression, Type targetType)
{
// If expression type matches target type, no need for casting
if (expression.Type == targetType)
return expression;

// If expression is nullable and target type is non-nullable, perform null check and cast
if (IsNullable(expression.Type) && !IsNullable(targetType))
{
var hasValue = Expression.Property(expression, "HasValue");
var value = Expression.Property(expression, "Value");
return Expression.Condition(Expression.Not(hasValue), Expression.Constant(null, targetType), Expression.Convert(value, targetType));
}

// Perform standard conversion
return Expression.Convert(expression, targetType);
}

private bool IsNullable(Type type)
{
return type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>);
}

public string ToExpressionString() => $"{Left.ToExpressionString()} {Expr.CreateOperator(Operator)} {Right.ToExpressionString()}";
public override string ToString() => $"{Left} { Expr.CreateOperator(Operator)} {Right}";
Expand Down
23 changes: 23 additions & 0 deletions tests/Grammars/MySQL.Tests/QueryEngineTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,27 @@ public void Results_ColumnAlias()
Assert.Equal("Variable_name", results[0].Table.Columns[0].ColumnName);
Assert.Equal("Value", results[0].Table.Columns[1].ColumnName);
}

[Fact]
public void Results_Convert_WhereClauseIntegerLiteral_ToLong()
{
SqlGrammarMySQL grammar = new();
var node = GrammarParser.Parse(grammar,
"SELECT * FROM performance_schema.events_stages_history_long WHERE THREAD_ID = 1");

FakeDatabaseConnectionProvider databaseConnectionProvider = new();
FakeTableDataProvider tableDataProvider = new();
SqlSelectDefinition selectDefinition = grammar.Create(node, databaseConnectionProvider, tableDataProvider, null);

Assert.False(selectDefinition.InvalidReferences);

AllTableDataProvider allTableDataProvider = new(new ITableDataProvider[] { tableDataProvider });
var queryEngine = new QueryEngine(allTableDataProvider, selectDefinition);

var queryResults = queryEngine.Query();

var results = queryResults.Results.ToList();

Assert.Equal(0, results.Count);

Check warning on line 59 in tests/Grammars/MySQL.Tests/QueryEngineTests.cs

View workflow job for this annotation

GitHub Actions / build

Do not use Assert.Equal() to check for collection size. Use Assert.Empty instead. (https://xunit.net/xunit.analyzers/rules/xUnit2013)

Check warning on line 59 in tests/Grammars/MySQL.Tests/QueryEngineTests.cs

View workflow job for this annotation

GitHub Actions / build

Do not use Assert.Equal() to check for collection size. Use Assert.Empty instead. (https://xunit.net/xunit.analyzers/rules/xUnit2013)
}
}
33 changes: 29 additions & 4 deletions tests/Grammars/MySQL.Tests/Utils/FakeTableDataProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,45 @@ public IEnumerable<DataColumn> GetColumns(SqlTable table)
yield break;
}

throw new NotImplementedException();
if (table.TableName == "events_stages_history_long")
{
yield return new DataColumn("THREAD_ID", typeof(int));
yield return new DataColumn("EVENT_NAME", typeof(string));
yield return new DataColumn("NESTING_EVENT_ID", typeof(long));
yield break;
}

throw new KeyNotFoundException();
}

public record variable(string VARIABLE_NAME, string VARIABLE_VALUE);

Check warning on line 30 in tests/Grammars/MySQL.Tests/Utils/FakeTableDataProvider.cs

View workflow job for this annotation

GitHub Actions / build

The type name 'variable' only contains lower-cased ascii characters. Such names may become reserved for the language.

Check warning on line 30 in tests/Grammars/MySQL.Tests/Utils/FakeTableDataProvider.cs

View workflow job for this annotation

GitHub Actions / build

The type name 'variable' only contains lower-cased ascii characters. Such names may become reserved for the language.
public record events_stage_history(long THREAD_ID, string EVENT_NAME, long NESTING_EVENT_ID);

private IEnumerable<variable> GetEnumerable()
private IEnumerable<variable> Variable_GetEnumerable()
{
yield return new("activate_all_roles_on_login", "OFF");
yield return new("sql_mode", "STRICT_TRANS_TABLES,NO_ENGINE_SUBSTITUTION");
}

private IEnumerable<events_stage_history> EventsStageHistory_GetEnumerable()
{
yield break;
}


public IQueryable GetTableData(SqlTable table)
{
var result = GetEnumerable().AsQueryable();
return result;
if (table.TableName == "session_variables")
{
return Variable_GetEnumerable().AsQueryable();
}

if (table.TableName == "events_stages_history_long")
{
return EventsStageHistory_GetEnumerable().AsQueryable();
}

throw new KeyNotFoundException();
}

public (bool DatabaseServiced, IEnumerable<SqlTableInfo> Tables) GetTables(string database)
Expand Down

0 comments on commit afa3b7c

Please sign in to comment.