Skip to content

Commit

Permalink
Add guardrail for SELECT IN terms and their cartesian product
Browse files Browse the repository at this point in the history
patch by Andrés de la Peña; reviewed by Ekaterina Dimitrova for CASSANDRA-17187

Co-authored-by: Aleksandr Sorokoumov <aleksandr.sorokoumov@gmail.com>
Co-authored-by: Andrés de la Peña <a.penya.garcia@gmail.com>
  • Loading branch information
adelapena and Gerrrr committed Mar 8, 2022
1 parent ac34f28 commit 3233c82
Show file tree
Hide file tree
Showing 30 changed files with 496 additions and 119 deletions.
1 change: 1 addition & 0 deletions CHANGES.txt
@@ -1,4 +1,5 @@
4.1
* Add guardrail for SELECT IN terms and their cartesian product (CASSANDRA-17187)
* remove unused imports in cqlsh.py and cqlshlib (CASSANDRA-17413)
* deprecate property windows_timer_interval (CASSANDRA-17404)
* Expose streaming as a vtable (CASSANDRA-17390)
Expand Down
5 changes: 5 additions & 0 deletions conf/cassandra.yaml
Expand Up @@ -1620,6 +1620,11 @@ drop_compact_storage_enabled: false
# The two thresholds default to -1 to disable.
# partition_keys_in_select_warn_threshold: -1
# partition_keys_in_select_fail_threshold: -1
# Guardrail to warn or fail when an IN query creates a cartesian product with a size exceeding threshold,
# eg. "a in (1,2,...10) and b in (1,2...10)" results in cartesian product of 100.
# The two thresholds default to -1 to disable.
# in_select_cartesian_product_warn_threshold: -1
# in_select_cartesian_product_fail_threshold: -1

# Startup Checks are executed as part of Cassandra startup process, not all of them
# are configurable (so you can disable them) but these which are enumerated bellow.
Expand Down
3 changes: 3 additions & 0 deletions src/java/org/apache/cassandra/config/Config.java
Expand Up @@ -767,10 +767,13 @@ public static void setClientMode(boolean clientMode)
public volatile int page_size_fail_threshold = DISABLED_GUARDRAIL;
public volatile int partition_keys_in_select_warn_threshold = DISABLED_GUARDRAIL;
public volatile int partition_keys_in_select_fail_threshold = DISABLED_GUARDRAIL;
public volatile int in_select_cartesian_product_warn_threshold = DISABLED_GUARDRAIL;
public volatile int in_select_cartesian_product_fail_threshold = DISABLED_GUARDRAIL;
public volatile Set<String> table_properties_ignored = Collections.emptySet();
public volatile Set<String> table_properties_disallowed = Collections.emptySet();
public volatile boolean user_timestamps_enabled = true;
public volatile boolean read_before_write_list_operations_enabled = true;

public volatile DurationSpec streaming_state_expires = DurationSpec.inDays(3);
public volatile DataStorageSpec streaming_state_size = DataStorageSpec.inMebibytes(40);

Expand Down
26 changes: 26 additions & 0 deletions src/java/org/apache/cassandra/config/GuardrailsOptions.java
Expand Up @@ -69,6 +69,7 @@ public GuardrailsOptions(Config config)
validateIntThreshold(config.page_size_warn_threshold, config.page_size_fail_threshold, "page_size");
validateIntThreshold(config.partition_keys_in_select_warn_threshold,
config.partition_keys_in_select_fail_threshold, "partition_keys_in_select");
validateIntThreshold(config.in_select_cartesian_product_warn_threshold, config.in_select_cartesian_product_fail_threshold, "in_select_cartesian_product");
}

@Override
Expand Down Expand Up @@ -321,6 +322,31 @@ public void setReadBeforeWriteListOperationsEnabled(boolean enabled)
x -> config.read_before_write_list_operations_enabled = x);
}

@Override
public int getInSelectCartesianProductWarnThreshold()
{
return config.in_select_cartesian_product_warn_threshold;
}

@Override
public int getInSelectCartesianProductFailThreshold()
{
return config.in_select_cartesian_product_fail_threshold;
}

public void setInSelectCartesianProductThreshold(int warn, int fail)
{
validateIntThreshold(warn, fail, "in_select_cartesian_product");
updatePropertyWithLogging("in_select_cartesian_product_warn_threshold",
warn,
() -> config.in_select_cartesian_product_warn_threshold,
x -> config.in_select_cartesian_product_warn_threshold = x);
updatePropertyWithLogging("in_select_cartesian_product_fail_threshold",
fail,
() -> config.in_select_cartesian_product_fail_threshold,
x -> config.in_select_cartesian_product_fail_threshold = x);
}

private static <T> void updatePropertyWithLogging(String propertyName, T newValue, Supplier<T> getter, Consumer<T> setter)
{
T oldValue = getter.get();
Expand Down
Expand Up @@ -19,6 +19,7 @@

import java.util.*;

import org.apache.cassandra.db.guardrails.Guardrails;
import org.apache.cassandra.schema.ColumnMetadata;
import org.apache.cassandra.schema.TableMetadata;
import org.apache.cassandra.cql3.QueryOptions;
Expand All @@ -27,6 +28,7 @@
import org.apache.cassandra.db.filter.RowFilter;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.index.IndexRegistry;
import org.apache.cassandra.service.ClientState;
import org.apache.cassandra.utils.btree.BTreeSet;

import static org.apache.cassandra.cql3.statements.RequestValidations.checkFalse;
Expand Down Expand Up @@ -101,12 +103,16 @@ private boolean hasMultiColumnSlice()
return false;
}

public NavigableSet<Clustering<?>> valuesAsClustering(QueryOptions options) throws InvalidRequestException
public NavigableSet<Clustering<?>> valuesAsClustering(QueryOptions options, ClientState state) throws InvalidRequestException
{
MultiCBuilder builder = MultiCBuilder.create(comparator, hasIN());
for (SingleRestriction r : restrictions)
{
r.appendTo(builder, options);

if (hasIN() && Guardrails.inSelectCartesianProduct.enabled(state))
Guardrails.inSelectCartesianProduct.guard(builder.buildSize(), "clustering key", state);

if (builder.hasMissingElements())
break;
}
Expand Down
Expand Up @@ -23,6 +23,7 @@
import org.apache.cassandra.schema.TableMetadata;
import org.apache.cassandra.cql3.QueryOptions;
import org.apache.cassandra.cql3.statements.Bound;
import org.apache.cassandra.service.ClientState;

/**
* A set of restrictions on the partition key.
Expand All @@ -32,7 +33,7 @@ interface PartitionKeyRestrictions extends Restrictions
{
public PartitionKeyRestrictions mergeWith(Restriction restriction);

public List<ByteBuffer> values(QueryOptions options);
public List<ByteBuffer> values(QueryOptions options, ClientState state);

public List<ByteBuffer> bounds(Bound b, QueryOptions options);

Expand Down
Expand Up @@ -20,6 +20,7 @@
import java.nio.ByteBuffer;
import java.util.*;

import org.apache.cassandra.db.guardrails.Guardrails;
import org.apache.cassandra.schema.TableMetadata;
import org.apache.cassandra.cql3.QueryOptions;
import org.apache.cassandra.cql3.statements.Bound;
Expand All @@ -28,6 +29,7 @@
import org.apache.cassandra.db.MultiCBuilder;
import org.apache.cassandra.db.filter.RowFilter;
import org.apache.cassandra.index.IndexRegistry;
import org.apache.cassandra.service.ClientState;

/**
* A set of single restrictions on the partition key.
Expand Down Expand Up @@ -78,12 +80,16 @@ public PartitionKeyRestrictions mergeWith(Restriction restriction)
}

@Override
public List<ByteBuffer> values(QueryOptions options)
public List<ByteBuffer> values(QueryOptions options, ClientState state)
{
MultiCBuilder builder = MultiCBuilder.create(comparator, hasIN());
for (SingleRestriction r : restrictions)
{
r.appendTo(builder, options);

if (hasIN() && Guardrails.inSelectCartesianProduct.enabled(state))
Guardrails.inSelectCartesianProduct.guard(builder.buildSize(), "partition key", state);

if (builder.hasMissingElements())
break;
}
Expand Down
Expand Up @@ -35,6 +35,7 @@
import org.apache.cassandra.index.IndexRegistry;
import org.apache.cassandra.schema.ColumnMetadata;
import org.apache.cassandra.schema.TableMetadata;
import org.apache.cassandra.service.ClientState;
import org.apache.cassandra.utils.btree.BTreeSet;

import org.apache.commons.lang3.builder.ToStringBuilder;
Expand Down Expand Up @@ -619,11 +620,12 @@ public RowFilter getRowFilter(IndexRegistry indexRegistry, QueryOptions options)
* Returns the partition keys for which the data is requested.
*
* @param options the query options
* @param state the client state
* @return the partition keys for which the data is requested.
*/
public List<ByteBuffer> getPartitionKeys(final QueryOptions options)
public List<ByteBuffer> getPartitionKeys(final QueryOptions options, ClientState state)
{
return partitionKeyRestrictions.values(options);
return partitionKeyRestrictions.values(options, state);
}

/**
Expand Down Expand Up @@ -741,17 +743,18 @@ public boolean hasClusteringColumnsRestrictions()
* Returns the requested clustering columns.
*
* @param options the query options
* @param state the client state
* @return the requested clustering columns
*/
public NavigableSet<Clustering<?>> getClusteringColumns(QueryOptions options)
public NavigableSet<Clustering<?>> getClusteringColumns(QueryOptions options, ClientState state)
{
// If this is a names command and the table is a static compact one, then as far as CQL is concerned we have
// only a single row which internally correspond to the static parts. In which case we want to return an empty
// set (since that's what ClusteringIndexNamesFilter expects).
if (table.isStaticCompactTable())
return BTreeSet.empty(table.comparator);

return clusteringColumnsRestrictions.valuesAsClustering(options);
return clusteringColumnsRestrictions.valuesAsClustering(options, state);
}

/**
Expand Down
10 changes: 6 additions & 4 deletions src/java/org/apache/cassandra/cql3/restrictions/TokenFilter.java
Expand Up @@ -35,6 +35,7 @@
import org.apache.cassandra.dht.Token;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.index.IndexRegistry;
import org.apache.cassandra.service.ClientState;

import static org.apache.cassandra.cql3.statements.Bound.END;
import static org.apache.cassandra.cql3.statements.Bound.START;
Expand Down Expand Up @@ -102,9 +103,9 @@ public TokenFilter(PartitionKeyRestrictions restrictions, TokenRestriction token
}

@Override
public List<ByteBuffer> values(QueryOptions options) throws InvalidRequestException
public List<ByteBuffer> values(QueryOptions options, ClientState state) throws InvalidRequestException
{
return filter(restrictions.values(options), options);
return filter(restrictions.values(options, state), options, state);
}

@Override
Expand Down Expand Up @@ -139,13 +140,14 @@ public List<ByteBuffer> bounds(Bound bound, QueryOptions options) throws Invalid
*
* @param values the values returned by the decorated restriction
* @param options the query options
* @param state the client state
* @return the values matching the token restriction
* @throws InvalidRequestException if the request is invalid
*/
private List<ByteBuffer> filter(List<ByteBuffer> values, QueryOptions options) throws InvalidRequestException
private List<ByteBuffer> filter(List<ByteBuffer> values, QueryOptions options, ClientState state) throws InvalidRequestException
{
RangeSet<Token> rangeSet = tokenRestriction.hasSlice() ? toRangeSet(tokenRestriction, options)
: toRangeSet(tokenRestriction.values(options));
: toRangeSet(tokenRestriction.values(options, state));

return filterWithRangeSet(rangeSet, values);
}
Expand Down
Expand Up @@ -31,6 +31,7 @@
import org.apache.cassandra.db.filter.RowFilter;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.index.IndexRegistry;
import org.apache.cassandra.service.ClientState;

import static org.apache.cassandra.cql3.statements.RequestValidations.invalidRequest;

Expand Down Expand Up @@ -205,7 +206,10 @@ protected PartitionKeyRestrictions doMergeWith(TokenRestriction otherRestriction
@Override
public List<ByteBuffer> bounds(Bound b, QueryOptions options) throws InvalidRequestException
{
return values(options);
// ClientState is used by inSelectCartesianProduct guardrail to skip non-ordinary users.
// Passing null here to avoid polluting too many methods, because in case of EQ token restriction,
// it won't generate high cartesian product.
return values(options, null);
}

@Override
Expand All @@ -221,7 +225,7 @@ public boolean isInclusive(Bound b)
}

@Override
public List<ByteBuffer> values(QueryOptions options) throws InvalidRequestException
public List<ByteBuffer> values(QueryOptions options, ClientState state) throws InvalidRequestException
{
return Collections.singletonList(value.bindAndGet(options));
}
Expand Down Expand Up @@ -254,7 +258,7 @@ public boolean hasSlice()
}

@Override
public List<ByteBuffer> values(QueryOptions options) throws InvalidRequestException
public List<ByteBuffer> values(QueryOptions options, ClientState state) throws InvalidRequestException
{
throw new UnsupportedOperationException();
}
Expand Down
Expand Up @@ -282,7 +282,7 @@ public List<? extends IMutation> getMutations(ClientState state,
ModificationStatement stmt = statements.get(i);
if (metadata != null && !stmt.metadata.id.equals(metadata.id))
metadata = null;
List<ByteBuffer> stmtPartitionKeys = stmt.buildPartitionKeyNames(options.forStatement(i));
List<ByteBuffer> stmtPartitionKeys = stmt.buildPartitionKeyNames(options.forStatement(i), state);
partitionKeys.add(stmtPartitionKeys);
HashMultiset<ByteBuffer> perKeyCountsForTable = partitionCounts.computeIfAbsent(stmt.metadata.id, k -> HashMultiset.create());
for (int stmtIdx = 0, stmtSize = stmtPartitionKeys.size(); stmtIdx < stmtSize; stmtIdx++)
Expand Down Expand Up @@ -489,7 +489,7 @@ private Pair<CQL3CasRequest,Set<ColumnMetadata>> makeCasRequest(BatchQueryOption
ModificationStatement statement = statements.get(i);
QueryOptions statementOptions = options.forStatement(i);
long timestamp = attrs.getTimestamp(batchTimestamp, statementOptions);
List<ByteBuffer> pks = statement.buildPartitionKeyNames(statementOptions);
List<ByteBuffer> pks = statement.buildPartitionKeyNames(statementOptions, state.getClientState());
if (statement.getRestrictions().keyIsInRelation())
throw new IllegalArgumentException("Batch with conditions cannot span multiple partitions (you cannot use IN on the partition key)");
if (key == null)
Expand Down Expand Up @@ -524,7 +524,7 @@ else if (!key.getKey().equals(pks.get(0)))
}
else
{
Clustering<?> clustering = Iterables.getOnlyElement(statement.createClustering(statementOptions));
Clustering<?> clustering = Iterables.getOnlyElement(statement.createClustering(statementOptions, state.getClientState()));
if (statement.hasConditions())
{
statement.addConditions(clustering, casRequest, statementOptions);
Expand Down
Expand Up @@ -330,23 +330,23 @@ public boolean hasIfExistCondition()
return conditions.isIfExists();
}

public List<ByteBuffer> buildPartitionKeyNames(QueryOptions options)
public List<ByteBuffer> buildPartitionKeyNames(QueryOptions options, ClientState state)
throws InvalidRequestException
{
List<ByteBuffer> partitionKeys = restrictions.getPartitionKeys(options);
List<ByteBuffer> partitionKeys = restrictions.getPartitionKeys(options, state);
for (ByteBuffer key : partitionKeys)
QueryProcessor.validateKey(key);

return partitionKeys;
}

public NavigableSet<Clustering<?>> createClustering(QueryOptions options)
public NavigableSet<Clustering<?>> createClustering(QueryOptions options, ClientState state)
throws InvalidRequestException
{
if (appliesOnlyToStaticColumns() && !restrictions.hasClusteringColumnsRestrictions())
return FBUtilities.singleton(CBuilder.STATIC_BUILDER.build(), metadata().comparator);

return restrictions.getClusteringColumns(options);
return restrictions.getClusteringColumns(options, state);
}

/**
Expand Down Expand Up @@ -508,7 +508,8 @@ private ResultMessage executeWithCondition(QueryState queryState, QueryOptions o

private CQL3CasRequest makeCasRequest(QueryState queryState, QueryOptions options)
{
List<ByteBuffer> keys = buildPartitionKeyNames(options);
ClientState clientState = queryState.getClientState();
List<ByteBuffer> keys = buildPartitionKeyNames(options, clientState);
// We don't support IN for CAS operation so far
checkFalse(restrictions.keyIsInRelation(),
"IN on the partition key is not supported with conditional %s",
Expand All @@ -522,7 +523,7 @@ private CQL3CasRequest makeCasRequest(QueryState queryState, QueryOptions option
"IN on the clustering key columns is not supported with conditional %s",
type.isUpdate()? "updates" : "deletions");

Clustering<?> clustering = Iterables.getOnlyElement(createClustering(options));
Clustering<?> clustering = Iterables.getOnlyElement(createClustering(options, clientState));
CQL3CasRequest request = new CQL3CasRequest(metadata(), key, conditionColumns(), updatesRegularRows(), updatesStaticRow());

addConditions(clustering, request, options);
Expand Down Expand Up @@ -695,7 +696,7 @@ private List<? extends IMutation> getMutations(ClientState state,
int nowInSeconds,
long queryStartNanoTime)
{
List<ByteBuffer> keys = buildPartitionKeyNames(options);
List<ByteBuffer> keys = buildPartitionKeyNames(options, state);
HashMultiset<ByteBuffer> perPartitionKeyCounts = HashMultiset.create(keys);
SingleTableUpdatesCollector collector = new SingleTableUpdatesCollector(metadata, updatedColumns, perPartitionKeyCounts);
addUpdates(collector, keys, state, options, local, timestamp, nowInSeconds, queryStartNanoTime);
Expand Down Expand Up @@ -741,7 +742,7 @@ final void addUpdates(UpdatesCollector collector,
}
else
{
NavigableSet<Clustering<?>> clusterings = createClustering(options);
NavigableSet<Clustering<?>> clusterings = createClustering(options, state);

// If some of the restrictions were unspecified (e.g. empty IN restrictions) we do not need to do anything.
if (restrictions.hasClusteringColumnsRestrictions() && clusterings.isEmpty())
Expand Down

0 comments on commit 3233c82

Please sign in to comment.