Skip to content

Commit

Permalink
Merge pull request #8090: [BEAM-6861] Add support to specify a query …
Browse files Browse the repository at this point in the history
…in CassandraIO
  • Loading branch information
iemejia committed Mar 27, 2019
2 parents cea9c32 + 1ac0e32 commit c984f7e
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ public abstract static class Read<T> extends PTransform<PBegin, PCollection<T>>
@Nullable
abstract ValueProvider<List<String>> hosts();

@Nullable
abstract ValueProvider<String> query();

@Nullable
abstract ValueProvider<Integer> port();

Expand Down Expand Up @@ -176,7 +179,7 @@ public abstract static class Read<T> extends PTransform<PBegin, PCollection<T>>
public Read<T> withHosts(List<String> hosts) {
checkArgument(hosts != null, "hosts can not be null");
checkArgument(!hosts.isEmpty(), "hosts can not be empty");
return builder().setHosts(ValueProvider.StaticValueProvider.of(hosts)).build();
return withHosts(ValueProvider.StaticValueProvider.of(hosts));
}

/** Specify the hosts of the Apache Cassandra instances. */
Expand All @@ -187,7 +190,7 @@ public Read<T> withHosts(ValueProvider<List<String>> hosts) {
/** Specify the port number of the Apache Cassandra instances. */
public Read<T> withPort(int port) {
checkArgument(port > 0, "port must be > 0, but was: %s", port);
return builder().setPort(ValueProvider.StaticValueProvider.of(port)).build();
return withPort(ValueProvider.StaticValueProvider.of(port));
}

/** Specify the port number of the Apache Cassandra instances. */
Expand All @@ -198,7 +201,7 @@ public Read<T> withPort(ValueProvider<Integer> port) {
/** Specify the Cassandra keyspace where to read data. */
public Read<T> withKeyspace(String keyspace) {
checkArgument(keyspace != null, "keyspace can not be null");
return builder().setKeyspace(ValueProvider.StaticValueProvider.of(keyspace)).build();
return withKeyspace(ValueProvider.StaticValueProvider.of(keyspace));
}

/** Specify the Cassandra keyspace where to read data. */
Expand All @@ -209,14 +212,25 @@ public Read<T> withKeyspace(ValueProvider<String> keyspace) {
/** Specify the Cassandra table where to read data. */
public Read<T> withTable(String table) {
checkArgument(table != null, "table can not be null");
return builder().setTable(ValueProvider.StaticValueProvider.of(table)).build();
return withTable(ValueProvider.StaticValueProvider.of(table));
}

/** Specify the Cassandra table where to read data. */
public Read<T> withTable(ValueProvider<String> table) {
return builder().setTable(table).build();
}

/** Specify the query to read data. */
public Read<T> withQuery(String query) {
checkArgument(query != null && query.length() > 0, "query cannot be null");
return withQuery(ValueProvider.StaticValueProvider.of(query));
}

/** Specify the query to read data. */
public Read<T> withQuery(ValueProvider<String> query) {
return builder().setQuery(query).build();
}

/**
* Specify the entity class (annotated POJO). The {@link CassandraIO} will read the data and
* convert the data as entity instances. The {@link PCollection} resulting from the read will
Expand All @@ -236,7 +250,7 @@ public Read<T> withCoder(Coder<T> coder) {
/** Specify the username for authentication. */
public Read<T> withUsername(String username) {
checkArgument(username != null, "username can not be null");
return builder().setUsername(ValueProvider.StaticValueProvider.of(username)).build();
return withUsername(ValueProvider.StaticValueProvider.of(username));
}

/** Specify the username for authentication. */
Expand All @@ -247,7 +261,7 @@ public Read<T> withUsername(ValueProvider<String> username) {
/** Specify the password for authentication. */
public Read<T> withPassword(String password) {
checkArgument(password != null, "password can not be null");
return builder().setPassword(ValueProvider.StaticValueProvider.of(password)).build();
return withPassword(ValueProvider.StaticValueProvider.of(password));
}

/** Specify the clear password for authentication. */
Expand All @@ -258,7 +272,7 @@ public Read<T> withPassword(ValueProvider<String> password) {
/** Specify the local DC used for the load balancing. */
public Read<T> withLocalDc(String localDc) {
checkArgument(localDc != null, "localDc can not be null");
return builder().setLocalDc(ValueProvider.StaticValueProvider.of(localDc)).build();
return withLocalDc(ValueProvider.StaticValueProvider.of(localDc));
}

/** Specify the local DC used for the load balancing. */
Expand All @@ -268,9 +282,7 @@ public Read<T> withLocalDc(ValueProvider<String> localDc) {

public Read<T> withConsistencyLevel(String consistencyLevel) {
checkArgument(consistencyLevel != null, "consistencyLevel can not be null");
return builder()
.setConsistencyLevel(ValueProvider.StaticValueProvider.of(consistencyLevel))
.build();
return withConsistencyLevel(ValueProvider.StaticValueProvider.of(consistencyLevel));
}

public Read<T> withConsistencyLevel(ValueProvider<String> consistencyLevel) {
Expand All @@ -292,7 +304,7 @@ public Read<T> withConsistencyLevel(ValueProvider<String> consistencyLevel) {
*/
public Read<T> withWhere(String where) {
checkArgument(where != null, "where can not be null");
return builder().setWhere(ValueProvider.StaticValueProvider.of(where)).build();
return withWhere(ValueProvider.StaticValueProvider.of(where));
}

/**
Expand Down Expand Up @@ -320,9 +332,7 @@ public Read<T> withWhere(ValueProvider<String> where) {
public Read<T> withMinNumberOfSplits(Integer minNumberOfSplits) {
checkArgument(minNumberOfSplits != null, "minNumberOfSplits can not be null");
checkArgument(minNumberOfSplits > 0, "minNumberOfSplits must be greater than 0");
return builder()
.setMinNumberOfSplits(ValueProvider.StaticValueProvider.of(minNumberOfSplits))
.build();
return withMinNumberOfSplits(ValueProvider.StaticValueProvider.of(minNumberOfSplits));
}

/**
Expand All @@ -334,6 +344,10 @@ public Read<T> withMinNumberOfSplits(ValueProvider<Integer> minNumberOfSplits) {
return builder().setMinNumberOfSplits(minNumberOfSplits).build();
}

/**
* A factory to create a specific {@link Mapper} for a given Cassandra Session. This is useful
* to provide mappers that don't rely in Cassandra annotated objects.
*/
public Read<T> withMapperFactoryFn(SerializableFunction<Session, Mapper> mapperFactory) {
checkArgument(
mapperFactory != null,
Expand All @@ -356,6 +370,8 @@ public PCollection<T> expand(PBegin input) {
abstract static class Builder<T> {
abstract Builder<T> setHosts(ValueProvider<List<String>> hosts);

abstract Builder<T> setQuery(ValueProvider<String> query);

abstract Builder<T> setPort(ValueProvider<Integer> port);

abstract Builder<T> setKeyspace(ValueProvider<String> keyspace);
Expand Down Expand Up @@ -435,18 +451,22 @@ public List<BoundedSource<T>> split(
LOG.warn(
"Only Murmur3Partitioner is supported for splitting, using an unique source for "
+ "the read");
String splitQuery =
String.format(
"SELECT * FROM %s.%s%s;",
spec.keyspace().get(),
spec.table().get(),
spec.where() != null ? "" : String.format(" WHERE %s", spec.where().get()));
return Collections.singletonList(
new CassandraIO.CassandraSource<>(spec, Collections.singletonList(splitQuery)));
new CassandraIO.CassandraSource<>(spec, Collections.singletonList(buildQuery(spec))));
}
}
}

private static String buildQuery(Read spec) {
return (spec.query() == null)
? String.format(
"SELECT * FROM %s.%s%s",
spec.keyspace().get(),
spec.table().get(),
(spec.where() == null) ? "" : " WHERE (" + spec.where().get() + ")")
: spec.query().get().toString();
}

/**
* Compute the number of splits based on the estimated size and the desired bundle size, and
* create several sources.
Expand Down Expand Up @@ -485,32 +505,11 @@ private List<BoundedSource<T>> splitWithTokenRanges(
// of
// the partitioner range, and the other from the start of the partitioner range to the
// end token of the split.
queries.add(
generateRangeQuery(
spec.keyspace(),
spec.table(),
spec.where(),
partitionKey,
range.getStart(),
null));
queries.add(generateRangeQuery(spec, partitionKey, range.getStart(), null));
// Generation of the second query of the wrapping range
queries.add(
generateRangeQuery(
spec.keyspace(),
spec.table(),
spec.where(),
partitionKey,
null,
range.getEnd()));
queries.add(generateRangeQuery(spec, partitionKey, null, range.getEnd()));
} else {
queries.add(
generateRangeQuery(
spec.keyspace(),
spec.table(),
spec.where(),
partitionKey,
range.getStart(),
range.getEnd()));
queries.add(generateRangeQuery(spec, partitionKey, range.getStart(), range.getEnd()));
}
}
sources.add(new CassandraIO.CassandraSource<>(spec, queries));
Expand All @@ -519,28 +518,22 @@ private List<BoundedSource<T>> splitWithTokenRanges(
}

private static String generateRangeQuery(
ValueProvider<String> keyspace,
ValueProvider<String> table,
ValueProvider<String> where,
String partitionKey,
BigInteger rangeStart,
BigInteger rangeEnd) {
String query =
String.format(
"SELECT * FROM %s.%s WHERE %s;",
keyspace.get(),
table.get(),
Joiner.on(" AND ")
.skipNulls()
.join(
where == null ? null : String.format("(%s)", where.get()),
rangeStart == null
? null
: String.format("(token(%s)>=%d)", partitionKey, rangeStart),
rangeEnd == null
? null
: String.format("(token(%s)<%d)", partitionKey, rangeEnd)));
LOG.debug("Cassandra generated read query : {}", query);
Read spec, String partitionKey, BigInteger rangeStart, BigInteger rangeEnd) {
final String rangeFilter =
Joiner.on(" AND ")
.skipNulls()
.join(
rangeStart == null
? null
: String.format("(token(%s) >= %d)", partitionKey, rangeStart),
rangeEnd == null
? null
: String.format("(token(%s) < %d)", partitionKey, rangeEnd));
final String query =
(spec.query() == null && spec.where() == null)
? buildQuery(spec) + " WHERE " + rangeFilter
: buildQuery(spec) + " AND " + rangeFilter;
LOG.debug("CassandraIO generated query : {}", query);
return query;
}

Expand Down Expand Up @@ -1132,7 +1125,6 @@ private static class Mutator<T> {
private List<Future<Void>> mutateFutures;
private final BiFunction<Mapper<T>, T, Future<Void>> mutator;
private final String operationName;
private final Class<T> entityClass;

Mutator(Write<T> spec, BiFunction<Mapper<T>, T, Future<Void>> mutator, String operationName) {
this.cluster =
Expand All @@ -1144,7 +1136,6 @@ private static class Mutator<T> {
spec.localDc(),
spec.consistencyLevel());
this.session = cluster.connect(spec.keyspace());
this.entityClass = spec.entity();
this.mapperFactoryFn = spec.mapperFactoryFn();
this.mutateFutures = new ArrayList<>();
this.mutator = mutator;
Expand Down

0 comments on commit c984f7e

Please sign in to comment.