Skip to content

Commit

Permalink
Fix #12054 - Adds recursion guards to analyzer.
Browse files Browse the repository at this point in the history
- Adds visited node set to prevent re-visiting.
- Also added max recursion depth fail safe in case there are more cases.
  • Loading branch information
anpete committed May 29, 2018
1 parent a3783ac commit 5bf9cc3
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 44 deletions.
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Collections.Generic;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
Expand Down Expand Up @@ -65,11 +66,15 @@ var sqlArgumentExpressionSyntax
return;
}

var depth = 0;

CheckPossibleInjection(
analysisContext,
sqlArgumentExpressionSyntax,
identifierValueText,
invocationExpressionSyntax.GetLocation());
invocationExpressionSyntax.GetLocation(),
visited: new HashSet<SyntaxNode>(),
ref depth);
}
}
}
Expand Down
105 changes: 66 additions & 39 deletions src/EFCore.Analyzers/SqlInjectionDiagnosticAnalyzerBase.cs
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
Expand All @@ -16,6 +17,8 @@ public abstract class SqlInjectionDiagnosticAnalyzerBase : DiagnosticAnalyzer
= "The SQL expression passed to '{0}' embeds data that will not be parameterized."
+ " Review for potential SQL injection vulnerability. See https://go.microsoft.com/fwlink/?linkid=871170 for more information.";

public const int RecursionLimit = 30;

protected const string DefaultTitle = "Possible SQL injection vulnerability.";
protected const string Category = "Security";

Expand Down Expand Up @@ -46,7 +49,12 @@ private void AnalyzeSimpleMemberAccessExpressionSyntaxNode(SyntaxNodeAnalysisCon
MemberAccessExpressionSyntax memberAccessExpressionSyntax);

protected bool CheckPossibleInjection(
SyntaxNodeAnalysisContext analysisContext, SyntaxNode syntaxNode, string identifier, Location location)
SyntaxNodeAnalysisContext analysisContext,
SyntaxNode syntaxNode,
string identifier,
Location location,
ISet<SyntaxNode> visited,
ref int depth)
{
if (UsesUnsafeInterpolation(analysisContext, syntaxNode)
|| UsesUnsafeStringOperation(analysisContext, syntaxNode))
Expand All @@ -57,61 +65,80 @@ private void AnalyzeSimpleMemberAccessExpressionSyntaxNode(SyntaxNodeAnalysisCon
return true;
}

var rootSyntaxNode
= syntaxNode.Ancestors().First(n => n is MemberDeclarationSyntax);
if (visited.Contains(syntaxNode)
|| ++depth > RecursionLimit)
{
return false;
}

visited.Add(syntaxNode);

foreach (var identifierNameSyntax
in syntaxNode.DescendantNodesAndSelf().OfType<IdentifierNameSyntax>())
try
{
var symbol = analysisContext.GetSymbol(identifierNameSyntax);
var rootSyntaxNode
= syntaxNode.Ancestors().First(n => n is MemberDeclarationSyntax);

if (symbol is ILocalSymbol
|| symbol is IParameterSymbol)
foreach (var identifierNameSyntax
in syntaxNode.DescendantNodesAndSelf().OfType<IdentifierNameSyntax>())
{
foreach (var descendantNode in rootSyntaxNode.DescendantNodes())
{
if (descendantNode == syntaxNode)
{
break;
}
var symbol = analysisContext.GetSymbol(identifierNameSyntax);

switch (descendantNode)
if (symbol is ILocalSymbol
|| symbol is IParameterSymbol)
{
foreach (var descendantNode in rootSyntaxNode.DescendantNodes())
{
case AssignmentExpressionSyntax assignmentExpressionSyntax
when assignmentExpressionSyntax.Left is IdentifierNameSyntax
&& Equals(analysisContext.GetSymbol(assignmentExpressionSyntax.Left), symbol):
if (descendantNode == syntaxNode)
{
if (CheckPossibleInjection(
analysisContext,
assignmentExpressionSyntax.Right,
identifier,
location))
{
return true;
}

break;
}
case VariableDeclaratorSyntax variableDeclaratorSyntax
when Equals(analysisContext.SemanticModel.GetDeclaredSymbol(variableDeclaratorSyntax), symbol):

switch (descendantNode)
{
if (CheckPossibleInjection(
analysisContext,
variableDeclaratorSyntax.Initializer,
identifier,
location))
case AssignmentExpressionSyntax assignmentExpressionSyntax
when assignmentExpressionSyntax.Left is IdentifierNameSyntax
&& Equals(analysisContext.GetSymbol(assignmentExpressionSyntax.Left), symbol):
{
return true;
if (CheckPossibleInjection(
analysisContext,
assignmentExpressionSyntax.Right,
identifier,
location,
visited,
ref depth))
{
return true;
}

break;
}
case VariableDeclaratorSyntax variableDeclaratorSyntax
when Equals(analysisContext.SemanticModel.GetDeclaredSymbol(variableDeclaratorSyntax), symbol):
{
if (CheckPossibleInjection(
analysisContext,
variableDeclaratorSyntax.Initializer,
identifier,
location,
visited,
ref depth))
{
return true;
}

break;
}

break;
}
}
}
}

return false;
}
finally
{
--depth;
}

return false;
}

protected static bool UsesUnsafeInterpolation(
Expand Down
Expand Up @@ -21,6 +21,33 @@ var diagnostics

Assert.Empty(diagnostics);
}

[Fact]
public async Task Error_when_sql_expression_recursively_initialized()
{
var diagnostics
= await GetDiagnosticsAsync(
@"string q = null;
q = M2(q);
string M2(string _) { return null; }
RelationalDatabaseFacadeExtensions.ExecuteSqlCommand(null, q);");

Assert.Empty(diagnostics);
}

[Fact]
public async Task Error_when_sql_expression_recursively_initialized_multi()
{
var diagnostics
= await GetDiagnosticsAsync(
@"string q = null;
q = M2(q);
var s = q;
string M2(string _) { return null; }
RelationalDatabaseFacadeExtensions.ExecuteSqlCommand(null, s);");

Assert.Empty(diagnostics);
}

[Fact]
public async Task No_warning_when_string_literal_passed_to_execute_sql_command()
Expand Down Expand Up @@ -109,7 +136,7 @@ var diagnostics
Assert.Equal(1, diagnostic.Location.GetLineSpan().StartLinePosition.Line);
Assert.Equal(22, diagnostic.Location.GetLineSpan().StartLinePosition.Character);
Assert.Equal(string.Format(
RawSqlStringInjectionDiagnosticAnalyzer.MessageFormat, "ExecuteSqlCommandAsync"), diagnostic.GetMessage());
SqlInjectionDiagnosticAnalyzerBase.MessageFormat, "ExecuteSqlCommandAsync"), diagnostic.GetMessage());
}

[Fact]
Expand All @@ -128,7 +155,7 @@ var diagnostics
Assert.Equal(2, diagnostic.Location.GetLineSpan().StartLinePosition.Line);
Assert.Equal(22, diagnostic.Location.GetLineSpan().StartLinePosition.Character);
Assert.Equal(string.Format(
RawSqlStringInjectionDiagnosticAnalyzer.MessageFormat, "ExecuteSqlCommandAsync"), diagnostic.GetMessage());
SqlInjectionDiagnosticAnalyzerBase.MessageFormat, "ExecuteSqlCommandAsync"), diagnostic.GetMessage());
}

[Fact]
Expand All @@ -147,7 +174,7 @@ var diagnostics
Assert.Equal(2, diagnostic.Location.GetLineSpan().StartLinePosition.Line);
Assert.Equal(22, diagnostic.Location.GetLineSpan().StartLinePosition.Character);
Assert.Equal(string.Format(
RawSqlStringInjectionDiagnosticAnalyzer.MessageFormat, "FromSql"), diagnostic.GetMessage());
SqlInjectionDiagnosticAnalyzerBase.MessageFormat, "FromSql"), diagnostic.GetMessage());
}

[Fact]
Expand All @@ -167,7 +194,7 @@ var diagnostics
Assert.Equal(3, diagnostic.Location.GetLineSpan().StartLinePosition.Line);
Assert.Equal(22, diagnostic.Location.GetLineSpan().StartLinePosition.Character);
Assert.Equal(string.Format(
RawSqlStringInjectionDiagnosticAnalyzer.MessageFormat, "ExecuteSqlCommandAsync"), diagnostic.GetMessage());
SqlInjectionDiagnosticAnalyzerBase.MessageFormat, "ExecuteSqlCommandAsync"), diagnostic.GetMessage());
}

protected override DiagnosticAnalyzer CreateDiagnosticAnalyzer()
Expand Down

0 comments on commit 5bf9cc3

Please sign in to comment.