Skip to content

Commit

Permalink
[FLINK-3338] [kafka] Use proper classloader when cloning the deserial…
Browse files Browse the repository at this point in the history
…ization schema.

This closes #1590
  • Loading branch information
StephanEwen committed Feb 5, 2016
1 parent fe0c3b5 commit 2eb2a0e
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 12 deletions.
Expand Up @@ -310,8 +310,33 @@ public static byte[] serializeObject(Object o) throws IOException {
* @throws ClassNotFoundException
*/
public static <T extends Serializable> T clone(T obj) throws IOException, ClassNotFoundException {
final byte[] serializedObject = serializeObject(obj);
return deserializeObject(serializedObject, obj.getClass().getClassLoader());
if (obj == null) {
return null;
} else {
return clone(obj, obj.getClass().getClassLoader());
}
}

/**
* Clones the given serializable object using Java serialization, using the given classloader to
* resolve the cloned classes.
*
* @param obj Object to clone
* @param classLoader The classloader to resolve the classes during deserialization.
* @param <T> Type of the object to clone
*
* @return Cloned object
*
* @throws IOException
* @throws ClassNotFoundException
*/
public static <T extends Serializable> T clone(T obj, ClassLoader classLoader) throws IOException, ClassNotFoundException {
if (obj == null) {
return null;
} else {
final byte[] serializedObject = serializeObject(obj);
return deserializeObject(serializedObject, classLoader);
}
}

// --------------------------------------------------------------------------------------------
Expand Down
Expand Up @@ -251,7 +251,8 @@ public void open(Configuration parameters) throws Exception {
}

// create fetcher
fetcher = new LegacyFetcher(this.subscribedPartitions, props, getRuntimeContext().getTaskName());
fetcher = new LegacyFetcher(this.subscribedPartitions, props,
getRuntimeContext().getTaskName(), getRuntimeContext().getUserCodeClassLoader());

// offset handling
offsetHandler = new ZookeeperOffsetHandler(props);
Expand Down
Expand Up @@ -31,11 +31,12 @@
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.connectors.kafka.FlinkKafkaConsumer08;
import org.apache.flink.streaming.util.serialization.KeyedDeserializationSchema;

import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.util.StringUtils;

import org.apache.kafka.clients.consumer.ConsumerConfig;
import org.apache.kafka.common.Node;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -49,7 +50,7 @@
import java.util.Properties;
import java.util.concurrent.atomic.AtomicReference;

import static com.google.common.base.Preconditions.checkNotNull;
import static java.util.Objects.requireNonNull;

/**
* This fetcher uses Kafka's low-level API to pull data from a specific
Expand All @@ -70,6 +71,9 @@ public class LegacyFetcher implements Fetcher {

/** The first error that occurred in a connection thread */
private final AtomicReference<Throwable> error;

/** The classloader for dynamically loaded classes */
private final ClassLoader userCodeClassloader;

/** The partitions that the fetcher should read, with their starting offsets */
private Map<KafkaTopicPartitionLeader, Long> partitionsToRead;
Expand All @@ -86,8 +90,13 @@ public class LegacyFetcher implements Fetcher {
/** Flag to shot the fetcher down */
private volatile boolean running = true;

public LegacyFetcher(List<KafkaTopicPartitionLeader> partitions, Properties props, String taskName) {
this.config = checkNotNull(props, "The config properties cannot be null");
public LegacyFetcher(
List<KafkaTopicPartitionLeader> partitions, Properties props,
String taskName, ClassLoader userCodeClassloader) {

this.config = requireNonNull(props, "The config properties cannot be null");
this.userCodeClassloader = requireNonNull(userCodeClassloader);

//this.topic = checkNotNull(topic, "The topic cannot be null");
this.partitionsToRead = new HashMap<>();
for (KafkaTopicPartitionLeader p: partitions) {
Expand Down Expand Up @@ -200,7 +209,8 @@ public <T> void run(SourceFunction.SourceContext<T> sourceContext,

FetchPartition[] partitions = partitionsList.toArray(new FetchPartition[partitionsList.size()]);

final KeyedDeserializationSchema<T> clonedDeserializer = InstantiationUtil.clone(deserializer);
final KeyedDeserializationSchema<T> clonedDeserializer =
InstantiationUtil.clone(deserializer, userCodeClassloader);

SimpleConsumerThread<T> thread = new SimpleConsumerThread<>(this, config,
broker, partitions, sourceContext, clonedDeserializer, lastOffsets);
Expand Down Expand Up @@ -344,9 +354,9 @@ public SimpleConsumerThread(LegacyFetcher owner,
this.config = config;
this.broker = broker;
this.partitions = partitions;
this.sourceContext = checkNotNull(sourceContext);
this.deserializer = checkNotNull(deserializer);
this.offsetsState = checkNotNull(offsetsState);
this.sourceContext = requireNonNull(sourceContext);
this.deserializer = requireNonNull(deserializer);
this.offsetsState = requireNonNull(offsetsState);
}

@Override
Expand Down
Expand Up @@ -98,7 +98,7 @@ public ExecutionConfig getExecutionConfig() {

@Override
public ClassLoader getUserCodeClassLoader() {
throw new UnsupportedOperationException();
return getClass().getClassLoader();
}

@Override
Expand Down

0 comments on commit 2eb2a0e

Please sign in to comment.