Skip to content

Commit

Permalink
Improve S5773: Support secondary issue location for partial methods (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Tim-Pohlmann authored Jul 24, 2023
1 parent 753ea36 commit be9ca30
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,14 @@ public override void VisitObjectCreationExpression(ObjectCreationExpressionSynta

protected override bool IsBindToTypeMethod(SyntaxNode methodDeclaration) =>
methodDeclaration is MethodDeclarationSyntax { Identifier.Text: nameof(SerializationBinder.BindToType), ParameterList.Parameters.Count: 2 } syntax
&& (syntax.Body is not null || syntax.ArrowExpressionBody() is not null)
&& syntax.EnsureCorrectSemanticModelOrDefault(SemanticModel) is { } semanticModel
&& syntax.ParameterList.Parameters[0].Type.IsKnownType(KnownType.System_String, semanticModel)
&& syntax.ParameterList.Parameters[1].Type.IsKnownType(KnownType.System_String, semanticModel);

protected override bool IsResolveTypeMethod(SyntaxNode methodDeclaration) =>
methodDeclaration is MethodDeclarationSyntax { Identifier.Text: "ResolveType", ParameterList.Parameters.Count: 1 } syntax
&& (syntax.Body is not null || syntax.ArrowExpressionBody() is not null)
&& syntax.EnsureCorrectSemanticModelOrDefault(SemanticModel) is { } semanticModel
&& syntax.ParameterList.Parameters[0].Type.IsKnownType(KnownType.System_String, semanticModel);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,18 @@ public abstract class RestrictDeserializedTypesBase : SymbolicRuleCheck
private const string RestrictTypesMessage = "Restrict types of objects allowed to be deserialized.";
private const string VerifyMacMessage = "Serialized data signature (MAC) should be verified.";

private static readonly KnownType[] FormattersWithBinder = new[]
private static readonly KnownType[] FormattersWithBinderProperty = new[]
{
KnownType.System_Runtime_Serialization_Formatters_Binary_BinaryFormatter,
KnownType.System_Runtime_Serialization_NetDataContractSerializer,
KnownType.System_Runtime_Serialization_Formatters_Soap_SoapFormatter
};
private static readonly KnownType JavaScriptSerializer = KnownType.System_Web_Script_Serialization_JavaScriptSerializer;
private static readonly KnownType LosFormatter = KnownType.System_Web_UI_LosFormatter;
private static readonly KnownType[] TypesWithDeserializeMethod = FormattersWithBinder.Append(JavaScriptSerializer).ToArray();
private static readonly KnownType[] TypesWithDeserializeMethod = FormattersWithBinderProperty.Append(JavaScriptSerializer).ToArray();

private readonly Dictionary<ISymbol, SyntaxNode> additionalLocationsForSymbols = new();
private readonly Dictionary<IOperation, SyntaxNode> additionalLocationsForOperations = new();
private readonly Dictionary<ISymbol, SyntaxNode> unsafeMethodsForSymbols = new();
private readonly Dictionary<IOperation, SyntaxNode> unsafeMethodsForOperations = new();

protected abstract bool IsBindToTypeMethod(SyntaxNode methodDeclaration);
protected abstract bool IsResolveTypeMethod(SyntaxNode methodDeclaration);
Expand All @@ -54,7 +54,7 @@ protected override ProgramState PreProcessSimple(SymbolicContext context)
var operation = context.Operation.Instance;
if (operation.Kind == OperationKindEx.ObjectCreation)
{
return operation.Type.IsAny(FormattersWithBinder)
return operation.Type.IsAny(FormattersWithBinderProperty)
? state.SetOperationConstraint(operation, SerializationConstraint.Unsafe)
: ProcessOtherSerializerCreations(state, operation.ToObjectCreation());
}
Expand All @@ -64,15 +64,19 @@ protected override ProgramState PreProcessSimple(SymbolicContext context)
{
return binderProcessedState;
}
else if (AdditionalLocation(state, assignment.Value) is { } methodDeclaration
else if (UnsafeMethodDeclaration(state, assignment.Value) is { } methodDeclaration
&& assignment.Target.TrackedSymbol() is { } symbol)
{
additionalLocationsForSymbols[symbol] = methodDeclaration;
// Assignments propagate constraints. The same needs to be done for method declarations.
// This is especially relevant, when the property is set in an object initializer:
/// var formatter = new BinaryFormatter { Binder = binder };
// The constraint will be learned on a FlowCaptureReference and propagated via the assignment.
unsafeMethodsForSymbols[symbol] = methodDeclaration;
}
}
else if (UnsafeDeserialization(state, operation) is { } invocation)
{
var methodDeclaration = AdditionalLocation(state, invocation.Instance);
var methodDeclaration = UnsafeMethodDeclaration(state, invocation.Instance);
var additionalLocations = methodDeclaration is not null
? new[] { GetIdentifier(methodDeclaration).GetLocation() }
: Array.Empty<Location>();
Expand All @@ -85,7 +89,7 @@ private ProgramState ProcessOtherSerializerCreations(ProgramState state, IObject
{
if (UnsafeJavaScriptSerializer(state, objectCreation, out var resolveTypeDeclaration))
{
additionalLocationsForOperations[objectCreation.WrappedOperation] = resolveTypeDeclaration;
unsafeMethodsForOperations[objectCreation.WrappedOperation] = resolveTypeDeclaration;
return state.SetOperationConstraint(objectCreation.WrappedOperation, SerializationConstraint.Unsafe);
}
else if (objectCreation.Type.Is(LosFormatter) && !EnableMacIsTrue(state, objectCreation))
Expand Down Expand Up @@ -136,15 +140,15 @@ private ProgramState ProcessBinderAssignment(ProgramState state, IAssignmentOper
: SerializationConstraint.Unsafe;
if (constraint == SerializationConstraint.Unsafe)
{
additionalLocationsForOperations[instance] = bindToTypeDeclaration;
unsafeMethodsForOperations[instance] = bindToTypeDeclaration;
}
state = state.SetOperationConstraint(instance, constraint);

if (instance.TrackedSymbol() is { } symbol)
{
if (constraint == SerializationConstraint.Unsafe)
{
additionalLocationsForSymbols[symbol] = bindToTypeDeclaration;
unsafeMethodsForSymbols[symbol] = bindToTypeDeclaration;
}
state = state.SetSymbolConstraint(symbol, constraint);
}
Expand All @@ -155,7 +159,7 @@ private ProgramState ProcessBinderAssignment(ProgramState state, IAssignmentOper

private static IOperation BinderAssignmentInstance(ProgramState state, IAssignmentOperationWrapper assignment) =>
state.ResolveCaptureAndUnwrapConversion(assignment.Target).AsPropertyReference() is { Property.Name: nameof(IFormatter.Binder), Instance: { } propertyInstance }
&& propertyInstance.Type.IsAny(FormattersWithBinder)
&& propertyInstance.Type.IsAny(FormattersWithBinderProperty)
? state.ResolveCaptureAndUnwrapConversion(propertyInstance)
: null;

Expand All @@ -176,11 +180,11 @@ private bool BinderIsSafe(ProgramState state, IAssignmentOperationWrapper assign
private static IEnumerable<SyntaxNode> DeclarationCandidates(IOperation operation) =>
operation.Type?.DeclaringSyntaxReferences.SelectMany(x => x.GetSyntax().ChildNodes());

private SyntaxNode AdditionalLocation(ProgramState state, IOperation operation)
private SyntaxNode UnsafeMethodDeclaration(ProgramState state, IOperation operation)
{
operation = state.ResolveCaptureAndUnwrapConversion(operation);
return additionalLocationsForOperations.TryGetValue(operation, out var methodDeclaration)
|| (operation.TrackedSymbol() is { } symbol && additionalLocationsForSymbols.TryGetValue(symbol, out methodDeclaration))
return unsafeMethodsForOperations.TryGetValue(operation, out var methodDeclaration)
|| (operation.TrackedSymbol() is { } symbol && unsafeMethodsForSymbols.TryGetValue(symbol, out methodDeclaration))
? methodDeclaration
: null;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.IO;
using System.Reflection;
using System.Runtime.Serialization;
using System.Runtime.Serialization.Formatters.Binary;

Expand Down Expand Up @@ -81,8 +82,20 @@ public partial void Method(MemoryStream ms)
new BinaryFormatter().Deserialize(ms); // Noncompliant

var formatter = new BinaryFormatter();
formatter.Binder = new SafeBinderPartial();
formatter.Deserialize(ms); // Noncompliant FP: safe binder was used
formatter.Binder = new SafeBinderPartial1();
formatter.Deserialize(ms); // Compliant: safe binder was used

formatter = new BinaryFormatter();
formatter.Binder = new SafeBinderPartial2();
formatter.Deserialize(ms); // Compliant: safe binder was usedant: safe binder was used

formatter = new BinaryFormatter();
formatter.Binder = new UnsafeBinderPartial1();
formatter.Deserialize(ms); // Noncompliant: unsafe binder was used

formatter = new BinaryFormatter();
formatter.Binder = new UnsafeBinderPartial2();
formatter.Deserialize(ms); // Noncompliant: unsafe binder was used
}
}

Expand All @@ -98,15 +111,48 @@ public override Type BindToType(string assemblyName, string typeName) =>
typeName is string and "TypeT" ? typeof(TypeT) : null;
}

internal sealed partial class SafeBinderPartial : SerializationBinder
internal sealed partial class SafeBinderPartial1 : SerializationBinder
{
public override partial Type BindToType(string assemblyName, string typeName);
}

internal sealed partial class SafeBinderPartial1 : SerializationBinder
{
public override partial Type BindToType(string assemblyName, string typeName); // Secondary FP
public override partial Type BindToType(string assemblyName, string typeName) =>
typeName == "TypeT" ? typeof(TypeT) : null;
}

internal sealed partial class SafeBinderPartial : SerializationBinder
internal sealed partial class SafeBinderPartial2 : SerializationBinder
{
public override partial Type BindToType(string assemblyName, string typeName) =>
typeName == "TypeT" ? typeof(TypeT) : null;
}

internal sealed partial class SafeBinderPartial2 : SerializationBinder
{
public override partial Type BindToType(string assemblyName, string typeName);
}

internal sealed partial class UnsafeBinderPartial1 : SerializationBinder
{
public override partial Type BindToType(string assemblyName, string typeName);
}

internal sealed partial class UnsafeBinderPartial1 : SerializationBinder
{
public override partial Type BindToType(string assemblyName, string typeName) => // Secondary
Assembly.Load(assemblyName).GetType(typeName);
}

internal sealed partial class UnsafeBinderPartial2 : SerializationBinder
{
public override partial Type BindToType(string assemblyName, string typeName) => // Secondary
Assembly.Load(assemblyName).GetType(typeName);
}

internal sealed partial class UnsafeBinderPartial2 : SerializationBinder
{
public override partial Type BindToType(string assemblyName, string typeName);
}

public class TypeT { }

0 comments on commit be9ca30

Please sign in to comment.