Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.cassandra;

import com.datastax.driver.core.Cluster;
import com.datastax.driver.core.Session;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.beam.sdk.io.cassandra.CassandraIO.Read;
import org.apache.beam.sdk.options.ValueProvider;

@SuppressWarnings({
"nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
})
public class ConnectionManager {

private static final ConcurrentHashMap<String, Cluster> clusterMap =
new ConcurrentHashMap<String, Cluster>();
private static final ConcurrentHashMap<String, Session> sessionMap =
new ConcurrentHashMap<String, Session>();

static {
Runtime.getRuntime()
.addShutdownHook(
new Thread(
() -> {
for (Session session : sessionMap.values()) {
if (!session.isClosed()) {
session.close();
}
}
}));
}

private static String readToClusterHash(Read<?> read) {
return Objects.requireNonNull(read.hosts()).get().stream().reduce(",", (a, b) -> a + b)
+ Objects.requireNonNull(read.port()).get()
+ safeVPGet(read.localDc())
+ safeVPGet(read.consistencyLevel());
}

private static String readToSessionHash(Read<?> read) {
return readToClusterHash(read) + read.keyspace().get();
}

static Session getSession(Read<?> read) {
Cluster cluster =
clusterMap.computeIfAbsent(
readToClusterHash(read),
k ->
CassandraIO.getCluster(
Objects.requireNonNull(read.hosts()),
Objects.requireNonNull(read.port()),
read.username(),
read.password(),
read.localDc(),
read.consistencyLevel(),
read.connectTimeout(),
read.readTimeout()));
return sessionMap.computeIfAbsent(
readToSessionHash(read),
k -> cluster.connect(Objects.requireNonNull(read.keyspace()).get()));
}

private static String safeVPGet(ValueProvider<String> s) {
return s != null ? s.get() : "";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
})
class DefaultObjectMapper<T> implements Mapper<T>, Serializable {

private transient com.datastax.driver.mapping.Mapper<T> mapper;
private final transient com.datastax.driver.mapping.Mapper<T> mapper;

DefaultObjectMapper(com.datastax.driver.mapping.Mapper mapper) {
this.mapper = mapper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
class DefaultObjectMapperFactory<T> implements SerializableFunction<Session, Mapper> {

private transient MappingManager mappingManager;
Class<T> entity;
final Class<T> entity;

DefaultObjectMapperFactory(Class<T> entity) {
this.entity = entity;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.cassandra;

import com.datastax.driver.core.Cluster;
import com.datastax.driver.core.ColumnMetadata;
import com.datastax.driver.core.PreparedStatement;
import com.datastax.driver.core.ResultSet;
import com.datastax.driver.core.Session;
import com.datastax.driver.core.Token;
import java.util.Collections;
import java.util.Iterator;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.beam.sdk.io.cassandra.CassandraIO.Read;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Joiner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@SuppressWarnings({
"rawtypes", // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
"nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
})
class ReadFn<T> extends DoFn<Read<T>, T> {

private static final Logger LOG = LoggerFactory.getLogger(ReadFn.class);

@ProcessElement
public void processElement(@Element Read<T> read, OutputReceiver<T> receiver) {
try {
Session session = ConnectionManager.getSession(read);
Mapper<T> mapper = read.mapperFactoryFn().apply(session);
String partitionKey =
session.getCluster().getMetadata().getKeyspace(read.keyspace().get())
.getTable(read.table().get()).getPartitionKey().stream()
.map(ColumnMetadata::getName)
.collect(Collectors.joining(","));

String query = generateRangeQuery(read, partitionKey, read.ringRanges() != null);
PreparedStatement preparedStatement = session.prepare(query);
Set<RingRange> ringRanges =
read.ringRanges() == null ? Collections.emptySet() : read.ringRanges().get();

for (RingRange rr : ringRanges) {
Token startToken = session.getCluster().getMetadata().newToken(rr.getStart().toString());
Token endToken = session.getCluster().getMetadata().newToken(rr.getEnd().toString());
ResultSet rs =
session.execute(preparedStatement.bind().setToken(0, startToken).setToken(1, endToken));
Iterator<T> iter = mapper.map(rs);
while (iter.hasNext()) {
T n = iter.next();
receiver.output(n);
}
}

if (read.ringRanges() == null) {
ResultSet rs = session.execute(preparedStatement.bind());
Iterator<T> iter = mapper.map(rs);
while (iter.hasNext()) {
receiver.output(iter.next());
}
}
} catch (Exception ex) {
LOG.error("error", ex);
}
}

private Session getSession(Read<T> read) {
Cluster cluster =
CassandraIO.getCluster(
read.hosts(),
read.port(),
read.username(),
read.password(),
read.localDc(),
read.consistencyLevel(),
read.connectTimeout(),
read.readTimeout());

return cluster.connect(read.keyspace().get());
}

private static String generateRangeQuery(
Read<?> spec, String partitionKey, Boolean hasRingRange) {
final String rangeFilter =
(hasRingRange)
? Joiner.on(" AND ")
.skipNulls()
.join(
String.format("(token(%s) >= ?)", partitionKey),
String.format("(token(%s) < ?)", partitionKey))
: "";
final String combinedQuery = buildInitialQuery(spec, hasRingRange) + rangeFilter;
LOG.debug("CassandraIO generated query : {}", combinedQuery);
return combinedQuery;
}

private static String buildInitialQuery(Read<?> spec, Boolean hasRingRange) {
return (spec.query() == null)
? String.format("SELECT * FROM %s.%s", spec.keyspace().get(), spec.table().get())
+ " WHERE "
: spec.query().get() + (hasRingRange ? " AND " : "");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,31 @@
*/
package org.apache.beam.sdk.io.cassandra;

import java.io.Serializable;
import java.math.BigInteger;
import javax.annotation.Nullable;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.annotations.Experimental.Kind;

/** Models a Cassandra token range. */
final class RingRange {
@Experimental(Kind.SOURCE_SINK)
@SuppressWarnings({
"nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
})
public final class RingRange implements Serializable {
private final BigInteger start;
private final BigInteger end;

RingRange(BigInteger start, BigInteger end) {
private RingRange(BigInteger start, BigInteger end) {
this.start = start;
this.end = end;
}

BigInteger getStart() {
public BigInteger getStart() {
return start;
}

BigInteger getEnd() {
public BigInteger getEnd() {
return end;
}

Expand All @@ -55,4 +63,34 @@ public boolean isWrapping() {
public String toString() {
return String.format("(%s,%s]", start.toString(), end.toString());
}

public static RingRange of(BigInteger start, BigInteger end) {
return new RingRange(start, end);
}

@Override
public boolean equals(@Nullable Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}

RingRange ringRange = (RingRange) o;

if (getStart() != null
? !getStart().equals(ringRange.getStart())
: ringRange.getStart() != null) {
return false;
}
return getEnd() != null ? getEnd().equals(ringRange.getEnd()) : ringRange.getEnd() == null;
}

@Override
public int hashCode() {
int result = getStart() != null ? getStart().hashCode() : 0;
result = 31 * result + (getEnd() != null ? getEnd().hashCode() : 0);
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,22 @@ final class SplitGenerator {
this.partitioner = partitioner;
}

private static BigInteger getRangeMin(String partitioner) {
static BigInteger getRangeMin(String partitioner) {
if (partitioner.endsWith("RandomPartitioner")) {
return BigInteger.ZERO;
} else if (partitioner.endsWith("Murmur3Partitioner")) {
return new BigInteger("2").pow(63).negate();
return BigInteger.valueOf(2).pow(63).negate();
} else {
throw new UnsupportedOperationException(
"Unsupported partitioner. " + "Only Random and Murmur3 are supported");
}
}

private static BigInteger getRangeMax(String partitioner) {
static BigInteger getRangeMax(String partitioner) {
if (partitioner.endsWith("RandomPartitioner")) {
return new BigInteger("2").pow(127).subtract(BigInteger.ONE);
return BigInteger.valueOf(2).pow(127).subtract(BigInteger.ONE);
} else if (partitioner.endsWith("Murmur3Partitioner")) {
return new BigInteger("2").pow(63).subtract(BigInteger.ONE);
return BigInteger.valueOf(2).pow(63).subtract(BigInteger.ONE);
} else {
throw new UnsupportedOperationException(
"Unsupported partitioner. " + "Only Random and Murmur3 are supported");
Expand Down Expand Up @@ -84,7 +84,7 @@ List<List<RingRange>> generateSplits(long totalSplitCount, List<BigInteger> ring
BigInteger start = ringTokens.get(i);
BigInteger stop = ringTokens.get((i + 1) % tokenRangeCount);

if (!inRange(start) || !inRange(stop)) {
if (!isInRange(start) || !isInRange(stop)) {
throw new RuntimeException(
String.format("Tokens (%s,%s) not in range of %s", start, stop, partitioner));
}
Expand Down Expand Up @@ -127,7 +127,7 @@ List<List<RingRange>> generateSplits(long totalSplitCount, List<BigInteger> ring

// Append the splits between the endpoints
for (int j = 0; j < splitCount; j++) {
splits.add(new RingRange(endpointTokens.get(j), endpointTokens.get(j + 1)));
splits.add(RingRange.of(endpointTokens.get(j), endpointTokens.get(j + 1)));
LOG.debug("Split #{}: [{},{})", j + 1, endpointTokens.get(j), endpointTokens.get(j + 1));
}
}
Expand All @@ -144,7 +144,7 @@ List<List<RingRange>> generateSplits(long totalSplitCount, List<BigInteger> ring
return coalesceSplits(getTargetSplitSize(totalSplitCount), splits);
}

private boolean inRange(BigInteger token) {
private boolean isInRange(BigInteger token) {
return !(token.compareTo(rangeMin) < 0 || token.compareTo(rangeMax) > 0);
}

Expand Down
Loading