Skip to content
Permalink
Browse files
Serializing / deserialing task params using bean methods. Supporting …
…thread local parameters to avoid thread contention in parallel job execution
  • Loading branch information
DImuthuUpe committed May 21, 2021
1 parent a49c3d4 commit 97198401dac3149e69f7e875cbc0b3ad3b8c1e0b
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 39 deletions.
@@ -41,7 +41,7 @@
<dependency>
<groupId>org.apache.helix</groupId>
<artifactId>helix-core</artifactId>
<version>1.0.1</version>
<version>0.9.7</version>
<exclusions>
<exclusion>
<groupId>org.slf4j</groupId>
@@ -68,5 +68,10 @@
<artifactId>snakeyaml</artifactId>
<version>${yaml.version}</version>
</dependency>
<dependency>
<groupId>commons-beanutils</groupId>
<artifactId>commons-beanutils</artifactId>
<version>${commons.beanutils.version}</version>
</dependency>
</dependencies>
</project>
@@ -87,6 +87,7 @@ public void run(String... args) throws Exception {
InstanceConfig instanceConfig = new InstanceConfig(participantName);
instanceConfig.setHostName("localhost");
instanceConfig.setInstanceEnabled(true);
instanceConfig.setMaxConcurrentTask(30);
zkHelixAdmin.addInstance(clusterName, instanceConfig);
logger.info("Participant: " + participantName + " has been added to cluster: " + clusterName);

@@ -57,13 +57,24 @@ public void run(String... args) throws Exception {
ExampleBlockingTask bt2 = new ExampleBlockingTask();
bt2.setTaskId("bt2-" + UUID.randomUUID());

ExampleBlockingTask bt3 = new ExampleBlockingTask();
bt3.setTaskId("bt3-" + UUID.randomUUID());

ExampleBlockingTask bt4 = new ExampleBlockingTask();
bt4.setTaskId("bt4-" + UUID.randomUUID());

// Setting dependency
bt1.setOutPort(new OutPort().setNextTaskId(bt2.getTaskId()));
bt1.setOutPort(new OutPort().setNextTaskId(bt3.getTaskId()));
bt2.setOutPort(new OutPort().setNextTaskId(bt3.getTaskId()));
bt4.setOutPort(new OutPort().setNextTaskId(bt3.getTaskId()));

Map<String, AbstractTask> taskMap = new HashMap<>();
taskMap.put(bt1.getTaskId(), bt1);
taskMap.put(bt2.getTaskId(), bt2);
String workflowId = workflowOperator.buildAndRunWorkflow(taskMap, bt1.getTaskId());
taskMap.put(bt3.getTaskId(), bt3);
taskMap.put(bt4.getTaskId(), bt4);
String[] startTaskIds = {bt1.getTaskId(), bt2.getTaskId(), bt4.getTaskId()};
String workflowId = workflowOperator.buildAndRunWorkflow(taskMap, startTaskIds);
logger.info("Launched workflow {}", workflowId);
}

@@ -23,18 +23,25 @@
import org.apache.airavata.datalake.orchestrator.workflow.engine.task.annotation.BlockingTaskDef;
import org.apache.airavata.datalake.orchestrator.workflow.engine.task.annotation.TaskOutPort;
import org.apache.airavata.datalake.orchestrator.workflow.engine.task.annotation.TaskParam;
import org.apache.commons.beanutils.PropertyUtils;
import org.apache.helix.HelixManager;
import org.apache.helix.HelixManagerFactory;
import org.apache.helix.InstanceType;
import org.apache.helix.task.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.util.*;

public class WorkflowOperator {

private static final long WORKFLOW_EXPIRY_TIME = 1 * 1000;
private static final long TASK_EXPIRY_TIME = 24 * 60 * 60 * 1000;
private static final int PARALLEL_JOBS_PER_WORKFLOW = 20;

private final static Logger logger = LoggerFactory.getLogger(WorkflowOperator.class);

private TaskDriver taskDriver;
private HelixManager helixManager;
@@ -64,17 +71,23 @@ public void destroy() {
}
}

public String buildAndRunWorkflow(Map<String, AbstractTask> taskMap, String startTaskId) throws Exception {
public String buildAndRunWorkflow(Map<String, AbstractTask> taskMap, String[] startTaskIds) throws Exception {

if (taskDriver == null) {
throw new Exception("Workflow operator needs to be initialized");
}

String workflowName = UUID.randomUUID().toString();
Workflow.Builder workflowBuilder = new Workflow.Builder(workflowName).setExpiry(0);
buildWorkflowRecursively(workflowBuilder, startTaskId, taskMap);

WorkflowConfig.Builder config = new WorkflowConfig.Builder().setFailureThreshold(0);
for (String startTaskId: startTaskIds) {
buildWorkflowRecursively(workflowBuilder, startTaskId, taskMap);
}

WorkflowConfig.Builder config = new WorkflowConfig.Builder()
.setFailureThreshold(0)
.setAllowOverlapJobAssignment(true);

workflowBuilder.setWorkflowConfig(config.build());
workflowBuilder.setExpiry(WORKFLOW_EXPIRY_TIME);
Workflow workflow = workflowBuilder.build();
@@ -112,6 +125,7 @@ private void buildWorkflowRecursively(Workflow.Builder workflowBuilder, String n
for (OutPort outPort : outPorts) {
if (outPort != null) {
workflowBuilder.addParentChildDependency(currentTask.getTaskId(), outPort.getNextTaskId());
logger.info("Parent to child dependency {} -> {}", currentTask.getTaskId(), outPort.getNextTaskId());
buildWorkflowRecursively(workflowBuilder, outPort.getNextTaskId(), taskMap);
}
}
@@ -135,19 +149,19 @@ public void deleteWorkflow(String workflowName) {
taskDriver.delete(workflowName);
}

private <T extends AbstractTask> Map<String, String> serializeTaskData(T data) throws IllegalAccessException {
private <T extends AbstractTask> Map<String, String> serializeTaskData(T data) throws IllegalAccessException, InvocationTargetException, NoSuchMethodException {

Map<String, String> result = new HashMap<>();
for (Class<?> c = data.getClass(); c != null; c = c.getSuperclass()) {
Field[] fields = c.getDeclaredFields();
for (Field classField : fields) {
TaskParam parm = classField.getAnnotation(TaskParam.class);
if (parm != null) {
classField.setAccessible(true);
if (classField.get(data) instanceof TaskParamType) {
result.put(parm.name(), TaskParamType.class.cast(classField.get(data)).serialize());
Object propertyValue = PropertyUtils.getProperty(data, parm.name());
if (propertyValue instanceof TaskParamType) {
result.put(parm.name(), TaskParamType.class.cast(propertyValue).serialize());
} else {
result.put(parm.name(), classField.get(data).toString());
result.put(parm.name(), propertyValue.toString());
}
}

@@ -19,34 +19,40 @@

import org.apache.airavata.datalake.orchestrator.workflow.engine.task.annotation.TaskOutPort;
import org.apache.airavata.datalake.orchestrator.workflow.engine.task.annotation.TaskParam;
import org.apache.commons.beanutils.PropertyUtils;
import org.apache.helix.task.Task;
import org.apache.helix.task.TaskCallbackContext;
import org.apache.helix.task.TaskResult;
import org.apache.helix.task.UserContentStore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.beans.PropertyDescriptor;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;

public abstract class AbstractTask extends UserContentStore implements Task {

private final static Logger logger = LoggerFactory.getLogger(AbstractTask.class);

private TaskCallbackContext callbackContext;
private ThreadLocal<TaskCallbackContext> callbackContext = new ThreadLocal<>();
private BlockingQueue<TaskCallbackContext> callbackContextQueue = new LinkedBlockingQueue<>();

@TaskOutPort(name = "nextTask")
private OutPort outPort;

@TaskParam(name = "taskId")
private String taskId;
private ThreadLocal<String> taskId = new ThreadLocal<>();

@TaskParam(name = "retryCount")
private int retryCount = 3;
private ThreadLocal<Integer> retryCount = ThreadLocal.withInitial(()-> 3);

public AbstractTask() {

@@ -55,9 +61,17 @@ public AbstractTask() {
@Override
public TaskResult run() {
try {
String helixTaskId = this.callbackContext.getTaskConfig().getId();
TaskCallbackContext cbc = callbackContextQueue.poll();

if (cbc == null) {
logger.error("No callback context available");
throw new Exception("No callback context available");
}

this.callbackContext.set(cbc);
String helixTaskId = getCallbackContext().getTaskConfig().getId();
logger.info("Running task {}", helixTaskId);
deserializeTaskData(this, this.callbackContext.getTaskConfig().getConfigMap());
deserializeTaskData(this, getCallbackContext().getTaskConfig().getConfigMap());
} catch (Exception e) {
logger.error("Failed at deserializing task data", e);
return new TaskResult(TaskResult.Status.FAILED, "Failed in deserializing task data");
@@ -83,27 +97,32 @@ public void setOutPort(OutPort outPort) {
}

public int getRetryCount() {
return retryCount;
return retryCount.get();
}

public void setRetryCount(int retryCount) {
this.retryCount = retryCount;
this.retryCount.set(retryCount);
}

public TaskCallbackContext getCallbackContext() {
return callbackContext;
return callbackContext.get();
}

public String getTaskId() {
return taskId;
return taskId.get();
}

public void setTaskId(String taskId) {
this.taskId = taskId;
this.taskId.set(taskId);
}

public void setCallbackContext(TaskCallbackContext callbackContext) {
this.callbackContext = callbackContext;
logger.info("Setting callback context {}", callbackContext.getJobConfig().getId());
try {
this.callbackContextQueue.put(callbackContext);
} catch (InterruptedException e) {
logger.error("Failed to put callback context to the queue", e);
}
}

private <T extends AbstractTask> void deserializeTaskData(T instance, Map<String, String> params) throws IllegalAccessException, NoSuchMethodException, InvocationTargetException, InstantiationException {
@@ -124,23 +143,27 @@ private <T extends AbstractTask> void deserializeTaskData(T instance, Map<String
if (param != null) {
if (params.containsKey(param.name())) {
classField.setAccessible(true);
if (classField.getType().isAssignableFrom(String.class)) {
classField.set(instance, params.get(param.name()));
} else if (classField.getType().isAssignableFrom(Integer.class) ||
classField.getType().isAssignableFrom(Integer.TYPE)) {
classField.set(instance, Integer.parseInt(params.get(param.name())));
} else if (classField.getType().isAssignableFrom(Long.class) ||
classField.getType().isAssignableFrom(Long.TYPE)) {
classField.set(instance, Long.parseLong(params.get(param.name())));
} else if (classField.getType().isAssignableFrom(Boolean.class) ||
classField.getType().isAssignableFrom(Boolean.TYPE)) {
classField.set(instance, Boolean.parseBoolean(params.get(param.name())));
} else if (TaskParamType.class.isAssignableFrom(classField.getType())) {
Class<?> clazz = classField.getType();
Constructor<?> ctor = clazz.getConstructor();
PropertyDescriptor propertyDescriptor = PropertyUtils.getPropertyDescriptor(this, classField.getName());
Method writeMethod = PropertyUtils.getWriteMethod(propertyDescriptor);
Class<?>[] methodParamType = writeMethod.getParameterTypes();
Class<?> writeParameterType = methodParamType[0];

if (writeParameterType.isAssignableFrom(String.class)) {
writeMethod.invoke(instance, params.get(param.name()));
} else if (writeParameterType.isAssignableFrom(Integer.class) ||
writeParameterType.isAssignableFrom(Integer.TYPE)) {
writeMethod.invoke(instance, Integer.parseInt(params.get(param.name())));
} else if (writeParameterType.isAssignableFrom(Long.class) ||
writeParameterType.isAssignableFrom(Long.TYPE)) {
writeMethod.invoke(instance, Long.parseLong(params.get(param.name())));
} else if (writeParameterType.isAssignableFrom(Boolean.class) ||
writeParameterType.isAssignableFrom(Boolean.TYPE)) {
writeMethod.invoke(instance, Boolean.parseBoolean(params.get(param.name())));
} else if (TaskParamType.class.isAssignableFrom(writeParameterType)) {
Constructor<?> ctor = writeParameterType.getConstructor();
Object obj = ctor.newInstance();
((TaskParamType)obj).deserialize(params.get(param.name()));
classField.set(instance, obj);
writeMethod.invoke(instance, obj);
}
}
}
@@ -30,7 +30,22 @@ public class ExampleBlockingTask extends BlockingTask {

@Override
public TaskResult runBlockingCode() {
logger.info("Running example blocking task {}", getTaskId());
logger.info("Starting task {}", getTaskId());
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
if (getTaskId().startsWith("bt1")) {
try {
logger.info("Task {} is sleeping", getTaskId());
Thread.sleep(10000);
//return new TaskResult(TaskResult.Status.FAILED, "Fail");
} catch (InterruptedException e) {
e.printStackTrace();
}
}
logger.info("Ending task {}", getTaskId());
return new TaskResult(TaskResult.Status.COMPLETED, "Success");
}
}
@@ -149,7 +149,7 @@
<spring-security.version>5.3.4.RELEASE</spring-security.version>
<yaml.version>1.15</yaml.version>
<spring.boot.version>2.2.1.RELEASE</spring.boot.version>

<commons.beanutils.version>1.9.4</commons.beanutils.version>
</properties>

</project>

0 comments on commit 9719840

Please sign in to comment.