forked from cylondata/twister2
/
BTPartitionExample.java
104 lines (93 loc) · 4.27 KB
/
BTPartitionExample.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package edu.iu.dsc.tws.examples.task.batch;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Logger;
import edu.iu.dsc.tws.api.comms.messaging.types.MessageType;
import edu.iu.dsc.tws.api.comms.messaging.types.MessageTypes;
import edu.iu.dsc.tws.api.compute.TaskContext;
import edu.iu.dsc.tws.api.compute.nodes.BaseSource;
import edu.iu.dsc.tws.api.compute.nodes.ICompute;
import edu.iu.dsc.tws.api.compute.schedule.elements.TaskInstancePlan;
import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.examples.task.BenchTaskWorker;
import edu.iu.dsc.tws.examples.utils.bench.BenchmarkConstants;
import edu.iu.dsc.tws.examples.utils.bench.BenchmarkUtils;
import edu.iu.dsc.tws.examples.utils.bench.Timing;
import edu.iu.dsc.tws.examples.verification.ResultsVerifier;
import edu.iu.dsc.tws.examples.verification.comparators.IntArrayComparator;
import edu.iu.dsc.tws.examples.verification.comparators.IteratorComparator;
import edu.iu.dsc.tws.task.impl.ComputeGraphBuilder;
import edu.iu.dsc.tws.task.typed.batch.BPartitionCompute;
public class BTPartitionExample extends BenchTaskWorker {
private static final Logger LOG = Logger.getLogger(BTPartitionExample.class.getName());
@Override
public ComputeGraphBuilder buildTaskGraph() {
List<Integer> taskStages = jobParameters.getTaskStages();
int sourceParallelism = taskStages.get(0);
int sinkParallelism = taskStages.get(1);
MessageType dataType = MessageTypes.INTEGER_ARRAY;
String edge = "edge";
BaseSource g = new SourceTask(edge);
ICompute r = new PartitionSinkTask();
computeGraphBuilder.addSource(SOURCE, g, sourceParallelism);
computeConnection = computeGraphBuilder.addCompute(SINK, r, sinkParallelism);
computeConnection.partition(SOURCE)
.viaEdge(edge)
.withDataType(dataType);
return computeGraphBuilder;
}
@SuppressWarnings({"rawtypes", "unchecked"})
protected static class PartitionSinkTask extends BPartitionCompute<int[]> {
private static final long serialVersionUID = -254264903510284798L;
private ResultsVerifier<int[], Iterator<int[]>> resultsVerifier;
private boolean verified = true;
private boolean timingCondition;
@Override
public void prepare(Config cfg, TaskContext ctx) {
super.prepare(cfg, ctx);
this.timingCondition = getTimingCondition(SINK, context);
int totalSinks = ctx.getTasksByName(SINK).size();
long noOfSources = ctx.getTasksByName(SOURCE).stream().map(
TaskInstancePlan::getTaskIndex
).filter(ti -> ti % totalSinks == ctx.taskIndex()).count();
if (jobParameters.getTotalIterations() % totalSinks != 0) {
LOG.warning("Total iterations is not divisible by total sinks. "
+ "Verification won't run for this configuration.");
} else {
resultsVerifier = new ResultsVerifier<>(inputDataArray, (ints, args) -> {
List<int[]> expectedData = new ArrayList<>();
for (long i = 0; i < noOfSources * jobParameters.getTotalIterations(); i++) {
expectedData.add(ints);
}
return expectedData.iterator();
}, new IteratorComparator<>(
IntArrayComparator.getInstance()
));
}
}
@Override
public boolean partition(Iterator<int[]> content) {
Timing.mark(BenchmarkConstants.TIMING_ALL_RECV, this.timingCondition);
LOG.info(String.format("%d received partition %d", context.getWorkerId(),
context.globalTaskId()));
BenchmarkUtils.markTotalTime(resultsRecorder, this.timingCondition);
resultsRecorder.writeToCSV();
if (resultsVerifier != null) {
this.verified = verifyResults(resultsVerifier, content, null, verified);
}
return true;
}
}
}