Skip to content

Commit

Permalink
[DataStore] Use prepared statement for query (#126)
Browse files Browse the repository at this point in the history
* Re-factor query to use prepared statement

* Re-factor SQL predicate parser
  • Loading branch information
raphkim committed Dec 2, 2019
1 parent a8c9d91 commit 3b66c9b
Show file tree
Hide file tree
Showing 5 changed files with 393 additions and 166 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,22 @@ public void saveModelWithInvalidForeignKey() throws ParseException {
assertTrue(actualError.getCause().getMessage().contains(expectedError));
}

/**
* Test save with SQL injection.
*/
@Test
public void saveModelWithMaliciousInputs() {
final Person person = Person.builder()
.firstName("Jane'); DROP TABLE Person; --")
.lastName("Doe")
.build();
saveModel(person);

Iterator<Person> result = queryModel(Person.class);
assertTrue(result.hasNext());
assertEquals(person, result.next());
}

/**
* Test querying the saved item in the SQLite database.
*
Expand Down Expand Up @@ -406,22 +422,22 @@ public void querySavedDataWithNumericalPredicates() throws ParseException {
.or(Person.AGE.eq(1).and(Person.AGE.ne(7)));
Iterator<Person> result = queryModel(Person.class, predicate);

Set<Person> expectedPersons = new HashSet<>();
expectedPersons.add(savedModels.get(1));
expectedPersons.add(savedModels.get(4));
expectedPersons.add(savedModels.get(5));
expectedPersons.add(savedModels.get(6));
Set<Person> expectedPeople = new HashSet<>();
expectedPeople.add(savedModels.get(1));
expectedPeople.add(savedModels.get(4));
expectedPeople.add(savedModels.get(5));
expectedPeople.add(savedModels.get(6));

Set<Person> actualPersons = new HashSet<>();
Set<Person> actualPeople = new HashSet<>();
while (result.hasNext()) {
final Person person = result.next();
assertNotNull(person);
assertTrue("Unable to find expected item in the storage adapter.",
savedModels.contains(person));
actualPersons.add(person);
actualPeople.add(person);
}

assertEquals(expectedPersons, actualPersons);
assertEquals(expectedPeople, actualPeople);
}

/**
Expand All @@ -433,7 +449,7 @@ public void querySavedDataWithNumericalPredicates() throws ParseException {
@SuppressWarnings("magicnumber")
@Test
public void querySavedDataWithStringPredicates() throws ParseException {
final Set<Person> savedModels = new HashSet<>();
final List<Person> savedModels = new ArrayList<>();
final int numModels = 10;
for (int counter = 0; counter < numModels; counter++) {
final Person person = Person.builder()
Expand All @@ -453,17 +469,63 @@ public void querySavedDataWithStringPredicates() throws ParseException {
.or(Person.LAST_NAME.beginsWith("9"))
.and(not(Person.AGE.gt(8)));
Iterator<Person> result = queryModel(Person.class, predicate);
Set<Integer> ages = new HashSet<>();

Set<Person> expectedPeople = new HashSet<>();
expectedPeople.add(savedModels.get(4));
expectedPeople.add(savedModels.get(7));

Set<Person> actualPeople = new HashSet<>();
while (result.hasNext()) {
final Person person = result.next();
assertNotNull(person);
assertTrue("Unable to find expected item in the storage adapter.",
savedModels.contains(person));
ages.add(person.getAge());
actualPeople.add(person);
}
assertEquals(2, ages.size());
assertTrue(ages.contains(4));
assertTrue(ages.contains(7));
assertEquals(expectedPeople, actualPeople);
}

/**
* Test querying with predicate condition on connected model.
*/
@Test
public void querySavedDataWithPredicatesOnForeignKey() {
final Person person = Person.builder()
.firstName("Jane")
.lastName("Doe")
.build();
saveModel(person);

final Car car = Car.builder()
.vehicleModel("Toyota Prius")
.owner(person)
.build();
saveModel(car);

QueryPredicate predicate = Person.FIRST_NAME.eq("Jane");
Iterator<Car> result = queryModel(Car.class, predicate);
assertTrue(result.hasNext());
assertEquals(car, result.next());
}

/**
* Test query with SQL injection.
*/
@Test
public void queryWithMaliciousPredicates() {
final Person jane = Person.builder()
.firstName("Jane")
.lastName("Doe")
.build();
saveModel(jane);

QueryPredicate predicate = Person.FIRST_NAME.eq("Jane; DROP TABLE Person; --");
Iterator<Person> resultOfMaliciousQuery = queryModel(Person.class, predicate);
assertFalse(resultOfMaliciousQuery.hasNext());

Iterator<Person> resultAfterMaliciousQuery = queryModel(Person.class);
assertTrue(resultAfterMaliciousQuery.hasNext());
assertEquals(jane, resultAfterMaliciousQuery.next());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,8 @@
import com.amplifyframework.core.model.ModelSchema;
import com.amplifyframework.core.model.ModelSchemaRegistry;
import com.amplifyframework.core.model.PrimaryKey;
import com.amplifyframework.core.model.query.predicate.BeginsWithQueryOperator;
import com.amplifyframework.core.model.query.predicate.BetweenQueryOperator;
import com.amplifyframework.core.model.query.predicate.ContainsQueryOperator;
import com.amplifyframework.core.model.query.predicate.EqualQueryOperator;
import com.amplifyframework.core.model.query.predicate.GreaterOrEqualQueryOperator;
import com.amplifyframework.core.model.query.predicate.GreaterThanQueryOperator;
import com.amplifyframework.core.model.query.predicate.LessOrEqualQueryOperator;
import com.amplifyframework.core.model.query.predicate.LessThanQueryOperator;
import com.amplifyframework.core.model.query.predicate.NotEqualQueryOperator;
import com.amplifyframework.core.model.query.predicate.QueryOperator;
import com.amplifyframework.core.model.query.predicate.QueryPredicate;
import com.amplifyframework.core.model.query.predicate.QueryPredicateGroup;
import com.amplifyframework.core.model.query.predicate.QueryPredicateOperation;
import com.amplifyframework.datastore.storage.sqlite.adapter.SQLPredicate;
import com.amplifyframework.datastore.storage.sqlite.adapter.SQLiteColumn;
import com.amplifyframework.datastore.storage.sqlite.adapter.SQLiteTable;
import com.amplifyframework.util.CollectionUtils;
Expand Down Expand Up @@ -153,6 +142,7 @@ public SqlCommand queryFor(@NonNull ModelSchema modelSchema,
StringBuilder rawQuery = new StringBuilder();
StringBuilder selectColumns = new StringBuilder();
StringBuilder joinStatement = new StringBuilder();
List<String> selectionArgs = null;

// Track the list of columns to return
List<SQLiteColumn> columns = new LinkedList<>(table.getSortedColumns());
Expand Down Expand Up @@ -228,15 +218,17 @@ public SqlCommand queryFor(@NonNull ModelSchema modelSchema,
// Append predicates.
// WHERE condition
if (predicate != null) {
final SQLPredicate sqlPredicate = new SQLPredicate(predicate);
selectionArgs = sqlPredicate.getSelectionArgs();
rawQuery.append(SqlKeyword.DELIMITER)
.append(SqlKeyword.WHERE)
.append(SqlKeyword.DELIMITER)
.append(parsePredicate(predicate));
.append(sqlPredicate);
}

rawQuery.append(";");
final String queryString = rawQuery.toString();
return new SqlCommand(table.getName(), queryString);
return new SqlCommand(table.getName(), queryString, selectionArgs);
}

/**
Expand Down Expand Up @@ -395,134 +387,4 @@ private StringBuilder parseForeignKeys(SQLiteTable table) {
}
return builder;
}

// Utility method to recursively parse a given predicate.
private StringBuilder parsePredicate(QueryPredicate queryPredicate) {
if (queryPredicate instanceof QueryPredicateOperation) {
QueryPredicateOperation qpo = (QueryPredicateOperation) queryPredicate;
return parsePredicateOperation(qpo);
} else if (queryPredicate instanceof QueryPredicateGroup) {
QueryPredicateGroup qpg = (QueryPredicateGroup) queryPredicate;
return parsePredicateGroup(qpg);
} else {
throw new UnsupportedTypeException(
"Tried to parse an unsupported QueryPredicate",
null,
"Try changing to one of the supported values: " +
"QueryPredicateOperation, QueryPredicateGroup.",
false
);
}
}

// Utility method to recursively parse a given predicate operation.
private StringBuilder parsePredicateOperation(QueryPredicateOperation operation) {
final StringBuilder builder = new StringBuilder();
final String field = operation.field();
final QueryOperator op = operation.operator();
switch (op.type()) {
case BETWEEN:
BetweenQueryOperator betweenOp = (BetweenQueryOperator) op;
Object start = betweenOp.start();
Object end = betweenOp.end();
QueryPredicateOperation gt = new QueryPredicateOperation(field,
new GreaterThanQueryOperator(start));
QueryPredicateOperation lt = new QueryPredicateOperation(field,
new LessThanQueryOperator(end));
return parsePredicate(gt.and(lt));
case CONTAINS:
ContainsQueryOperator containsOp = (ContainsQueryOperator) op;
return builder.append(containsOp.value())
.append(SqlKeyword.DELIMITER)
.append(SqlKeyword.IN)
.append(SqlKeyword.DELIMITER)
.append(field);
case BEGINS_WITH:
BeginsWithQueryOperator beginsWithOp = (BeginsWithQueryOperator) op;
return builder.append(field)
.append(SqlKeyword.DELIMITER)
.append(SqlKeyword.LIKE)
.append(SqlKeyword.DELIMITER)
.append("\'")
.append(beginsWithOp.value().toString() + "%")
.append("\'");
case EQUAL:
case NOT_EQUAL:
case LESS_THAN:
case GREATER_THAN:
case LESS_OR_EQUAL:
case GREATER_OR_EQUAL:
return builder.append(field)
.append(SqlKeyword.DELIMITER)
.append(SqlKeyword.fromQueryOperator(op.type()))
.append(SqlKeyword.DELIMITER)
.append(getOperatorValue(op));
default:
throw new UnsupportedTypeException(
"Tried to parse an unsupported QueryPredicateOperation",
null,
"Try changing to one of the supported values from " +
"QueryPredicateOperation.Type enum.",
false
);
}
}

// Utility method to recursively parse a given predicate group.
private StringBuilder parsePredicateGroup(QueryPredicateGroup group) {
final StringBuilder builder = new StringBuilder();
switch (group.type()) {
case NOT:
return builder.append(SqlKeyword.fromQueryPredicateGroup(group.type()))
.append(SqlKeyword.DELIMITER)
.append(parsePredicate(group.predicates().get(0)));
case OR:
case AND:
builder.append("(");
Iterator<QueryPredicate> predicateIterator = group.predicates().iterator();
while (predicateIterator.hasNext()) {
builder.append(parsePredicate(predicateIterator.next()));
if (predicateIterator.hasNext()) {
builder.append(SqlKeyword.DELIMITER)
.append(SqlKeyword.fromQueryPredicateGroup(group.type()))
.append(SqlKeyword.DELIMITER);
}
}
return builder.append(")");
default:
throw new UnsupportedTypeException(
"Tried to parse an unsupported QueryPredicateGroup",
null,
"Try changing to one of the supported values from " +
"QueryPredicateGroup.Type enum.",
false
);
}
}

// Utility method to extract the parameter value from a given operator.
private Object getOperatorValue(QueryOperator qOp) throws UnsupportedTypeException {
switch (qOp.type()) {
case NOT_EQUAL:
return ((NotEqualQueryOperator) qOp).value();
case EQUAL:
return ((EqualQueryOperator) qOp).value();
case LESS_OR_EQUAL:
return ((LessOrEqualQueryOperator) qOp).value();
case LESS_THAN:
return ((LessThanQueryOperator) qOp).value();
case GREATER_OR_EQUAL:
return ((GreaterOrEqualQueryOperator) qOp).value();
case GREATER_THAN:
return ((GreaterThanQueryOperator) qOp).value();
default:
throw new UnsupportedTypeException(
"Tried to parse an unsupported QueryOperator type",
null,
"Check if a new QueryOperator.Type enum has been created which is not supported" +
"in the AppSyncGraphQLRequestFactory.",
false
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,11 @@ private <T extends Model> Map<String, Object> buildMapForModel(
}

final int columnIndex = cursor.getColumnIndexOrThrow(columnName);
// This check is necessary, because primitive values will return 0 even when null
if (cursor.isNull(columnIndex)) {
mapForModel.put(fieldName, null);
continue;
}

final String stringValueFromCursor;
switch (fieldJavaType) {
Expand Down Expand Up @@ -695,7 +700,9 @@ Cursor getQueryAllCursor(@NonNull String tableName,
@Nullable QueryPredicate predicate) {
final ModelSchema schema = ModelSchemaRegistry.singleton()
.getModelSchemaForModelClass(tableName);
final String rawQuery = sqlCommandFactory.queryFor(schema, predicate).sqlStatement();
return this.databaseConnectionHandle.rawQuery(rawQuery, null);
final SqlCommand sqlCommand = sqlCommandFactory.queryFor(schema, predicate);
final String rawQuery = sqlCommand.sqlStatement();
final String[] selectionArgs = sqlCommand.getSelectionArgsAsArray();
return this.databaseConnectionHandle.rawQuery(rawQuery, selectionArgs);
}
}

0 comments on commit 3b66c9b

Please sign in to comment.