Skip to content

Commit

Permalink
add spark for support this parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
liugddx committed Oct 17, 2022
1 parent 2b2dc2b commit 092adb1
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ public final class Constants {

public static final String SOURCE_SERIALIZATION = "source.serialization";

public static final String SOURCE_PARALLELISM = "parallelism";

public static final String HDFS_ROOT = "hdfs.root";

public static final String HDFS_USER = "hdfs.user";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.seatunnel.api.common.JobContext;
import org.apache.seatunnel.api.sink.SeaTunnelSink;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.common.constants.CollectionConstants;
import org.apache.seatunnel.core.starter.exception.TaskExecuteException;
import org.apache.seatunnel.plugin.discovery.PluginIdentifier;
import org.apache.seatunnel.plugin.discovery.seatunnel.SeaTunnelSinkPluginDiscovery;
Expand Down Expand Up @@ -71,6 +72,13 @@ public List<Dataset<Row>> execute(List<Dataset<Row>> upstreamDataStreams) throws
Config sinkConfig = pluginConfigs.get(i);
SeaTunnelSink<?, ?, ?, ?> seaTunnelSink = plugins.get(i);
Dataset<Row> dataset = fromSourceTable(sinkConfig, sparkEnvironment).orElse(input);
int parallelism;
if (sinkConfig.hasPath(CollectionConstants.PARALLELISM)) {
parallelism = sinkConfig.getInt(CollectionConstants.PARALLELISM);
} else {
parallelism = sparkEnvironment.getSparkConf().getInt(CollectionConstants.PARALLELISM, 1);
}
dataset.sparkSession().read().option(CollectionConstants.PARALLELISM, parallelism);
// TODO modify checkpoint location
seaTunnelSink.setTypeInfo((SeaTunnelRowType) TypeConverterUtils.convert(dataset.schema()));
SparkSinkInjector.inject(dataset.write(), seaTunnelSink).option("checkpointLocation", "/tmp").save();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.seatunnel.api.common.JobContext;
import org.apache.seatunnel.api.source.SeaTunnelSource;
import org.apache.seatunnel.common.Constants;
import org.apache.seatunnel.common.constants.CollectionConstants;
import org.apache.seatunnel.common.utils.SerializationUtils;
import org.apache.seatunnel.plugin.discovery.PluginIdentifier;
import org.apache.seatunnel.plugin.discovery.seatunnel.SeaTunnelSourcePluginDiscovery;
Expand Down Expand Up @@ -56,15 +57,15 @@ public List<Dataset<Row>> execute(List<Dataset<Row>> upstreamDataStreams) {
SeaTunnelSource<?, ?, ?> source = plugins.get(i);
Config pluginConfig = pluginConfigs.get(i);
int parallelism;
if (pluginConfig.hasPath(Constants.SOURCE_PARALLELISM)) {
parallelism = pluginConfig.getInt(Constants.SOURCE_PARALLELISM);
if (pluginConfig.hasPath(CollectionConstants.PARALLELISM)) {
parallelism = pluginConfig.getInt(CollectionConstants.PARALLELISM);
} else {
parallelism = sparkEnvironment.getSparkConf().getInt(Constants.SOURCE_PARALLELISM, 1);
parallelism = sparkEnvironment.getSparkConf().getInt(CollectionConstants.PARALLELISM, 1);
}
Dataset<Row> dataset = sparkEnvironment.getSparkSession()
.read()
.format(SeaTunnelSource.class.getSimpleName())
.option(Constants.SOURCE_PARALLELISM, parallelism)
.option(CollectionConstants.PARALLELISM, parallelism)
.option(Constants.SOURCE_SERIALIZATION, SerializationUtils.objectToString(source))
.schema((StructType) TypeConverterUtils.convert(source.getProducedType())).load();
sources.add(dataset);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.seatunnel.api.source.SeaTunnelSource;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.common.Constants;
import org.apache.seatunnel.common.constants.CollectionConstants;
import org.apache.seatunnel.common.utils.SerializationUtils;
import org.apache.seatunnel.translation.spark.source.batch.BatchSourceReader;
import org.apache.seatunnel.translation.spark.source.micro.MicroBatchSourceReader;
Expand Down Expand Up @@ -59,14 +60,14 @@ public DataSourceReader createReader(StructType rowType, DataSourceOptions optio
@Override
public DataSourceReader createReader(DataSourceOptions options) {
SeaTunnelSource<SeaTunnelRow, ?, ?> seaTunnelSource = getSeaTunnelSource(options);
int parallelism = options.getInt(Constants.SOURCE_PARALLELISM, 1);
int parallelism = options.getInt(CollectionConstants.PARALLELISM, 1);
return new BatchSourceReader(seaTunnelSource, parallelism);
}

@Override
public MicroBatchReader createMicroBatchReader(Optional<StructType> rowTypeOptional, String checkpointLocation, DataSourceOptions options) {
SeaTunnelSource<SeaTunnelRow, ?, ?> seaTunnelSource = getSeaTunnelSource(options);
Integer parallelism = options.getInt(Constants.SOURCE_PARALLELISM, 1);
Integer parallelism = options.getInt(CollectionConstants.PARALLELISM, 1);
Integer checkpointInterval = options.getInt(Constants.CHECKPOINT_INTERVAL, CHECKPOINT_INTERVAL_DEFAULT);
String checkpointPath = StringUtils.replacePattern(checkpointLocation, "sources/\\d+", "sources-state");
Configuration configuration = SparkSession.getActiveSession().get().sparkContext().hadoopConfiguration();
Expand Down

0 comments on commit 092adb1

Please sign in to comment.