Skip to content

Commit

Permalink
Fixes SQL injection vulnerability (#1101)
Browse files Browse the repository at this point in the history
  • Loading branch information
wivern authored and pavgra committed May 22, 2019
1 parent 382fe40 commit d7b12b2
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.math.RoundingMode;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.StringJoiner;
import java.util.regex.Pattern;
Expand All @@ -21,6 +22,8 @@
import org.apache.commons.lang3.StringUtils;
import org.json.JSONObject;
import org.ohdsi.circe.helper.ResourceHelper;
import org.ohdsi.webapi.util.PreparedStatementRenderer;
import org.ohdsi.webapi.util.QuoteUtils;
import org.ohdsi.webapi.util.SourceUtils;
import org.springframework.stereotype.Component;
import org.ohdsi.featureExtraction.FeatureExtraction;
Expand Down Expand Up @@ -111,20 +114,22 @@ private List<String> buildCriteriaClauses(String searchTerm, List<String> analys
ArrayList<String> clauses = new ArrayList<>();

if (searchTerm != null && searchTerm.length() > 0) {
clauses.add(String.format("lower(fr.covariate_name) like '%%%s%%'", searchTerm));
clauses.add(String.format("lower(fr.covariate_name) like '%%%s%%'", QuoteUtils.escapeSql(searchTerm)));
}

if (analysisIds != null && analysisIds.size() > 0) {
ArrayList<String> ids = new ArrayList<>();
ArrayList<Integer> ids = new ArrayList<>();
ArrayList<String> ranges = new ArrayList<>();

analysisIds.stream().map((analysisIdExpr) -> analysisIdExpr.split(":")).forEachOrdered((parsedIds) -> {
if (parsedIds.length > 1) {
ranges.add(String.format("(ar.analysis_id >= %s and ar.analysis_id <= %s)", parsedIds[0], parsedIds[1]));
} else {
ids.add(parsedIds[0]);
}
});
analysisIds.stream().map((analysisIdExpr) -> analysisIdExpr.split(":"))
.map(strArray -> Arrays.stream(strArray).map(Integer::parseInt).toArray(Integer[]::new))
.forEachOrdered((parsedIds) -> {
if (parsedIds.length > 1) {
ranges.add(String.format("(ar.analysis_id >= %s and ar.analysis_id <= %s)", parsedIds[0], parsedIds[1]));
} else {
ids.add(parsedIds[0]);
}
});

String idClause = "";
if (ids.size() > 0) {
Expand All @@ -141,7 +146,7 @@ private List<String> buildCriteriaClauses(String searchTerm, List<String> analys
if (timeWindows != null && timeWindows.size() > 0) {
ArrayList<String> timeWindowClauses = new ArrayList<>();
timeWindows.forEach((timeWindow) -> {
timeWindowClauses.add(String.format("ar.analysis_name like '%%%s'", timeWindow));
timeWindowClauses.add(String.format("ar.analysis_name like '%%%s'", QuoteUtils.escapeSql(timeWindow)));
});
clauses.add("(" + StringUtils.join(timeWindowClauses, " OR ") + ")");
}
Expand All @@ -152,7 +157,7 @@ private List<String> buildCriteriaClauses(String searchTerm, List<String> analys
if (domain.toLowerCase().equals("null")) {
domainClauses.add("ar.domain_id is null");
} else {
domainClauses.add(String.format("lower(ar.domain_id) = lower('%s')", domain));
domainClauses.add(String.format("lower(ar.domain_id) = lower('%s')", QuoteUtils.escapeSql(domain)));
}
});
clauses.add("(" + StringUtils.join(domainClauses, " OR ") + ")");
Expand Down Expand Up @@ -223,7 +228,7 @@ public List<PrevalenceStat> getCohortFeaturePrevalenceStats(
String resultsSchema = SourceUtils.getResultsQualifier(source);
String cdmSchema = SourceUtils.getCdmQualifier(source);
String tempSchema = SourceUtils.getTempQualifier(source);

String categoricalQuery = SqlRender.renderSql(
QUERY_COVARIATE_STATS,
new String[]{"cdm_database_schema", "cdm_results_schema", "cohort_definition_id", "criteria_clauses"},
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/org/ohdsi/webapi/util/QuoteUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,9 @@ public static String dequote(String val) {

return Objects.nonNull(val) ? val.replaceAll("(^\"|\"$|^'|'$)", "") : val;
}

public static String escapeSql(String val) {

return Objects.nonNull(val) ? val.replaceAll("'", "''") : val;
}
}

0 comments on commit d7b12b2

Please sign in to comment.