Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
*/
package org.sonar.python.checks.hotspots;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.sonar.check.Rule;
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
Expand All @@ -42,6 +44,7 @@
import org.sonar.plugins.python.api.tree.KeyValuePair;
import org.sonar.plugins.python.api.tree.ListLiteral;
import org.sonar.plugins.python.api.tree.Name;
import org.sonar.plugins.python.api.tree.QualifiedExpression;
import org.sonar.plugins.python.api.tree.RegularArgument;
import org.sonar.plugins.python.api.tree.StringLiteral;
import org.sonar.plugins.python.api.tree.SubscriptionExpression;
Expand All @@ -54,8 +57,7 @@
@Rule(key = "S4502")
public class CsrfDisabledCheck extends PythonSubscriptionCheck {

private static final String DISABLING_CSRF_MESSAGE = "Make sure disabling CSRF protection is safe here.";
private static final String CSRFPROTECT_MISSING_MESSAGE = "Make sure not using CSRFProtect is safe here.";
private static final String MESSAGE = "Make sure disabling CSRF protection is safe here.";

@Override
public void initialize(Context context) {
Expand Down Expand Up @@ -87,9 +89,7 @@ private static void djangoMiddlewareArrayCheck(SubscriptionContext subscriptionC
.test(asgn.assignedValue());

if (!containsCsrfViewMiddleware) {
subscriptionContext.addIssue(
asgn.lastToken(),
"Make sure not using CSRF protection (" + CSRF_VIEW_MIDDLEWARE + ") is safe here.");
subscriptionContext.addIssue(asgn.lastToken(), MESSAGE);
}
}
}
Expand Down Expand Up @@ -128,7 +128,7 @@ private static void decoratorCsrfExemptCheck(SubscriptionContext subscriptionCon
boolean isDangerous = names.stream().anyMatch(s -> s.toLowerCase(Locale.US).contains("csrf")) &&
names.stream().anyMatch(s -> s.toLowerCase(Locale.US).contains("exempt"));
if (isDangerous) {
subscriptionContext.addIssue(decorator.lastToken(), DISABLING_CSRF_MESSAGE);
subscriptionContext.addIssue(decorator.lastToken(), MESSAGE);
}
}

Expand All @@ -138,7 +138,7 @@ private static void functionCsrfExemptCheck(SubscriptionContext subscriptionCont
Optional.ofNullable(callExpr.calleeSymbol())
.map(Symbol::fullyQualifiedName)
.filter(DANGEROUS_DECORATORS::contains)
.ifPresent(fqn -> subscriptionContext.addIssue(callExpr.callee().lastToken(), DISABLING_CSRF_MESSAGE));
.ifPresent(fqn -> subscriptionContext.addIssue(callExpr.callee().lastToken(), MESSAGE));
}

/** Checks that <code>'WTF_CSRF_ENABLED'</code> setting is not switched off. */
Expand All @@ -154,7 +154,7 @@ private static void flaskWtfCsrfEnabledFalseCheck(SubscriptionContext subscripti
.flatMap(s -> ((SubscriptionExpression) s).subscripts().expressions().stream())
.anyMatch(isStringSatisfying(s -> "WTF_CSRF_ENABLED".equals(s) || "WTF_CSRF_CHECK_DEFAULT".equals(s)));
if (isWtfCsrfEnabledSubscription && Expressions.isFalsy(asgn.assignedValue())) {
subscriptionContext.addIssue(asgn.assignedValue(), DISABLING_CSRF_MESSAGE);
subscriptionContext.addIssue(asgn.assignedValue(), MESSAGE);
}
}

Expand Down Expand Up @@ -182,7 +182,7 @@ private static void metaCheck(SubscriptionContext subscriptionContext) {
if (stmt.is(Tree.Kind.ASSIGNMENT_STMT)) {
AssignmentStatement asgn = (AssignmentStatement) stmt;
if (isLhsCalled("csrf").test(asgn) && Expressions.isFalsy(asgn.assignedValue())) {
subscriptionContext.addIssue(asgn.assignedValue(), DISABLING_CSRF_MESSAGE);
subscriptionContext.addIssue(asgn.assignedValue(), MESSAGE);
}
}
});
Expand All @@ -204,7 +204,7 @@ private static void formInstantiationCheck(SubscriptionContext subscriptionConte
if (arg instanceof RegularArgument) {
RegularArgument regArg = (RegularArgument) arg;
searchForProblemsInFormInitializationArguments(regArg)
.ifPresent(badExpr -> subscriptionContext.addIssue(badExpr, DISABLING_CSRF_MESSAGE));
.ifPresent(badExpr -> subscriptionContext.addIssue(badExpr, MESSAGE));
}
});
}
Expand Down Expand Up @@ -252,7 +252,7 @@ private static void improperlyConfiguredFlaskApp(SubscriptionContext subscriptio
.flatMap(usages -> usages.stream().filter(CsrfDisabledCheck::isWithinCsrfEnablingStatement).findFirst()))
.isPresent();
if (!isCsrfEnabledInThisFile) {
subscriptionContext.addIssue(asgn.assignedValue(), CSRFPROTECT_MISSING_MESSAGE);
subscriptionContext.addIssue(asgn.assignedValue(), MESSAGE);
}
}
}
Expand All @@ -266,19 +266,76 @@ private static boolean isFlaskAppInstantiation(Expression expr) {
return false;
}



/** Attempts to extract a list of name fragments from a nested qualified expressions. */
private static Optional<ArrayList<String>> extractQualifiedNameComponents(Expression expr) {
if (expr.is(Tree.Kind.NAME)) {
ArrayList<String> res = new ArrayList<>();
res.add(((Name) expr).name());
return Optional.of(res);
} else if (expr.is(Tree.Kind.QUALIFIED_EXPR)){
QualifiedExpression qe = (QualifiedExpression) expr;
return extractQualifiedNameComponents(qe.qualifier()).map(list -> { list.add(qe.name().name()); return list; });
} else {
return Optional.empty();
}
}

private static final List<Pattern> CSRF_INIT_APP_CALLEE_PATTERNS = Arrays.asList(
Pattern.compile("(csrf|CSRF)"),
Pattern.compile("init_app")
);

/**
* Attempts to unpack the <code>expr</code> as nested <code>QualifiedExpression</code>s, and checks that
* every component of the name matches the corresponding regex pattern.
*/
private static boolean checkNestedQualifiedExpressions(List<Pattern> patternsToMatch, Expression expr) {
Optional<ArrayList<String>> nameFragmentsOpt = extractQualifiedNameComponents(expr);
return nameFragmentsOpt.filter(nameFragments -> {
if (nameFragments.size() == patternsToMatch.size()) {
for (int i = 0; i < nameFragments.size(); i++) {
Pattern p = patternsToMatch.get(i);
String s = nameFragments.get(i);
if (!p.matcher(s).matches()) {
return false;
}
}
return true;
} else {
return false;
}
}).isPresent();
}

/** Detects usages like <code>CSRFProtect(a)</code>. */
private static boolean isWithinCsrfEnablingStatement(Usage u) {
Tree t = u.tree();
return isWithinCall("flask_wtf.csrf.CSRFProtect", t) ||
isWithinCall("flask_wtf.csrf.CSRFProtect.init_app", t);
return isWithinCall(new HashSet<>(Arrays.asList(
"flask_wtf.csrf.CSRFProtect",
"flask_wtf.csrf.CSRFProtect.init_app",
"flask_wtf.CSRFProtect",
"flask_wtf.CSRFProtect.init_app"
)), CSRF_INIT_APP_CALLEE_PATTERNS, t);
}

/** Checks that the surroundings of <code>t</code> look like <code>expectedCalleeFqn(someExpr(t))</code>. */
private static boolean isWithinCall(String expectedCalleeFqn, Tree t) {
/**
* Checks that the surroundings of <code>t</code> look like <code>expectedCallee(someExpr(t))</code>,
* where the <code>expectedCallee</code> is either a symbol with an FQN from the specified set,
* or where at least the name of the callee matches a given regex.
*/
@SuppressWarnings("SameParameterValue")
private static boolean isWithinCall(Set<String> expectedCalleeFqns, List<Pattern> fallbackCalleeRegexes, Tree t) {
Tree callExprTree = TreeUtils.firstAncestorOfKind(t, Tree.Kind.CALL_EXPR);
if (callExprTree != null) {
Symbol callExprSymb = ((CallExpression) callExprTree).calleeSymbol();
return callExprSymb != null && expectedCalleeFqn.equals(callExprSymb.fullyQualifiedName());
if (callExprSymb != null && expectedCalleeFqns.contains(callExprSymb.fullyQualifiedName())) {
return true;
} else {
Expression callee = ((CallExpression) callExprTree).callee();
return checkNestedQualifiedExpressions(fallbackCalleeRegexes, callee);
}
}
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,11 @@ public void testExemptDecorators() {
public void testExemptAsFunction() {
testFile("flask/exemptAsFunction.py");
}

@Test
public void fixupTestsMoreRobustCSRFProtect() { testFile("flask/fixupTestsMoreRobustCSRFProtect.py"); }

@Test
public void fixupCsrfInGlobalScope() { testFile("flask/fixupCsrfInGlobalScope.py"); }

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
'django.middleware.security.SecurityMiddleware',
# 'django.middleware.csrf.CsrfViewMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
] # Noncompliant {{Make sure not using CSRF protection (django.middleware.csrf.CsrfViewMiddleware) is safe here.}}
] # Noncompliant {{Make sure disabling CSRF protection is safe here.}}

MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from flask import Flask
from flask_wtf import CSRFProtect
csrf = CSRFProtect()
def create_app():
app = Flask(__name__) # Compliant
csrf.init_app(app)

def create_app_noncompliant():
app2 = Flask(__name__) # Noncompliant

Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
def csrfInitAppShouldWorkEvenIfCsrfSymbolIsNotFound():
from flask import Flask
app = Flask(__name__) # Compliant
csrf.init_app(app) # Unknown symbol, but looks similar enough to CSRFProtect.

def csrfInitAppShouldWorkEvenIfCsrfSymbolIsNotFoundUppercase():
from flask import Flask
app = Flask(__name__) # Compliant
CSRF.init_app(app) # Unknown symbol, but looks similar enough to CSRFProtect.

def csrfInitAppShouldCheckTheQualifier():
from flask import Flask
app = Flask(__name__) # Noncompliant {{Make sure disabling CSRF protection is safe here.}}
# ^^^^^^^^^^^^^^^
somethingUnrelated.init_app(app) # insufficient
tooLong.csrf.init_app(app) # insufficient
csrf.do_something_else(app) # insufficient

def csrfProtectCanBeImportedFromFlaskWtfDirectly():
from flask import Flask
from flask_wtf import CSRFProtect
app = Flask(__name__) # Compliant
CSRFProtect(app)

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ def misconfiguredFlaskExamples():
from flask_wtf.csrf import CSRFProtect
# from flask_wtf import csrf

app1 = Flask(__name__) # Noncompliant {{Make sure not using CSRFProtect is safe here.}}
app1 = Flask(__name__) # Noncompliant {{Make sure disabling CSRF protection is safe here.}}
# ^^^^^^^^^^^^^^^

app2 = Flask(__name__)
Expand Down