Skip to content
This repository has been archived by the owner on Dec 13, 2023. It is now read-only.

Commit

Permalink
reset workflow in primary datastore upon restart
Browse files Browse the repository at this point in the history
  • Loading branch information
apanicker-nflx committed May 17, 2019
1 parent a8faa2b commit 451eb97
Show file tree
Hide file tree
Showing 12 changed files with 51 additions and 85 deletions.
Expand Up @@ -167,10 +167,6 @@ public List<Task> createTasks(List<Task> tasks) {
@Override
public void updateTask(Task task) {
try {
task.setUpdateTime(System.currentTimeMillis());
if (task.getStatus().isTerminal() && task.getEndTime() == 0) {
task.setEndTime(System.currentTimeMillis());
}
// TODO: calculate the shard number the task belongs to
String taskPayload = toJson(task);
recordCassandraDaoRequests("updateTask", task.getTaskType(), task.getWorkflowType());
Expand Down Expand Up @@ -202,11 +198,6 @@ public boolean exceedsRateLimitPerFrequency(Task task) {
throw new UnsupportedOperationException("This method is not implemented in CassandraExecutionDAO. Please use ExecutionDAOFacade instead.");
}

@Override
public void updateTasks(List<Task> tasks) {
tasks.forEach(this::updateTask);
}

@Override
public boolean removeTask(String taskId) {
Task task = getTask(taskId);
Expand Down Expand Up @@ -273,7 +264,6 @@ public List<Task> getTasksForWorkflow(String workflowId) {
@Override
public String createWorkflow(Workflow workflow) {
try {
workflow.setCreateTime(System.currentTimeMillis());
List<Task> tasks = workflow.getTasks();
workflow.setTasks(new LinkedList<>());
String payload = toJson(workflow);
Expand All @@ -295,10 +285,6 @@ public String createWorkflow(Workflow workflow) {
@Override
public String updateWorkflow(Workflow workflow) {
try {
workflow.setUpdateTime(System.currentTimeMillis());
if (workflow.getStatus().isTerminal()) {
workflow.setEndTime(System.currentTimeMillis());
}
List<Task> tasks = workflow.getTasks();
workflow.setTasks(new LinkedList<>());
String payload = toJson(workflow);
Expand Down
Expand Up @@ -98,6 +98,7 @@ public void testWorkflowCRUD() {
workflow.setWorkflowId(workflowId);
workflow.setInput(new HashMap<>());
workflow.setStatus(Workflow.WorkflowStatus.RUNNING);
workflow.setCreateTime(System.currentTimeMillis());

// create a new workflow in the datastore
String id = executionDAO.createWorkflow(workflow);
Expand Down Expand Up @@ -130,6 +131,7 @@ public void testTasksCRUD() {
workflow.setWorkflowId(workflowId);
workflow.setInput(new HashMap<>());
workflow.setStatus(Workflow.WorkflowStatus.RUNNING);
workflow.setCreateTime(System.currentTimeMillis());

// add it to the datastore
executionDAO.createWorkflow(workflow);
Expand Down Expand Up @@ -200,18 +202,19 @@ public void testTasksCRUD() {
assertEquals(task2, found.getTaskByRefName("task2"));
assertEquals(task3, found.getTaskByRefName("task3"));

// update a task
// update tasks
task1.setStatus(Task.Status.IN_PROGRESS);
executionDAO.updateTask(task1);
task = executionDAO.getTask(task1Id);
assertEquals(task1, task);

// update multiple tasks
task2.setStatus(Task.Status.COMPLETED);
task3.setStatus(Task.Status.FAILED);
executionDAO.updateTasks(Arrays.asList(task2, task3));
executionDAO.updateTask(task2);
task = executionDAO.getTask(task2Id);
assertEquals(task2, task);

task3.setStatus(Task.Status.FAILED);
executionDAO.updateTask(task3);
task = executionDAO.getTask(task3Id);
assertEquals(task3, task);

Expand Down
Expand Up @@ -391,8 +391,8 @@ public void rewind(String workflowId, boolean useLatestDefinitions) {
throw new ApplicationException(CONFLICT, String.format("Workflow: %s is non-restartable", workflow));
}

// Remove all the tasks...
workflow.getTasks().forEach(task -> executionDAOFacade.removeTask(task.getTaskId()));
// Remove the workflow from the primary datastore (archive in indexer) and re-create it
executionDAOFacade.removeWorkflow(workflowId, true);
workflow.getTasks().clear();
workflow.setReasonForIncompletion(null);
workflow.setStartTime(System.currentTimeMillis());
Expand All @@ -401,7 +401,7 @@ public void rewind(String workflowId, boolean useLatestDefinitions) {
workflow.setStatus(WorkflowStatus.RUNNING);
workflow.setOutput(null);
workflow.setExternalOutputPayloadStoragePath(null);
executionDAOFacade.updateWorkflow(workflow);
executionDAOFacade.createWorkflow(workflow);
decide(workflowId);
}

Expand Down
Expand Up @@ -142,6 +142,7 @@ public long getPendingWorkflowCount(String workflowName) {
* @return the id of the created workflow
*/
public String createWorkflow(Workflow workflow) {
workflow.setCreateTime(System.currentTimeMillis());
executionDAO.createWorkflow(workflow);
indexDAO.indexWorkflow(workflow);
return workflow.getWorkflowId();
Expand All @@ -154,6 +155,10 @@ public String createWorkflow(Workflow workflow) {
* @return the id of the updated workflow
*/
public String updateWorkflow(Workflow workflow) {
workflow.setUpdateTime(System.currentTimeMillis());
if (workflow.getStatus().isTerminal()) {
workflow.setEndTime(System.currentTimeMillis());
}
executionDAO.updateWorkflow(workflow);
indexDAO.indexWorkflow(workflow);
return workflow.getWorkflowId();
Expand Down Expand Up @@ -231,6 +236,14 @@ public long getInProgressTaskCount(String taskDefName) {
*/
public void updateTask(Task task) {
try {
if (task.getStatus() != null) {
if (!task.getStatus().isTerminal() || (task.getStatus().isTerminal() && task.getUpdateTime() == 0)) {
task.setUpdateTime(System.currentTimeMillis());
}
if (task.getStatus().isTerminal() && task.getEndTime() == 0) {
task.setEndTime(System.currentTimeMillis());
}
}
executionDAO.updateTask(task);
indexDAO.indexTask(task);
} catch (Exception e) {
Expand Down
Expand Up @@ -85,13 +85,6 @@ public interface ExecutionDAO {
* false: If the {@link Task} is not rateLimited
*/
boolean exceedsRateLimitPerFrequency(Task task);

/**
*
* @param tasks Multiple tasks to be updated
*
*/
void updateTasks(List<Task> tasks);

/**
*
Expand Down
Expand Up @@ -386,9 +386,9 @@ public void testRestartWorkflow() {
verify(metadataDAO, never()).getLatest(any());

ArgumentCaptor<Workflow> argumentCaptor = ArgumentCaptor.forClass(Workflow.class);
verify(executionDAOFacade, times(2)).updateWorkflow(argumentCaptor.capture());
assertEquals(workflow.getWorkflowId(), argumentCaptor.getAllValues().get(1).getWorkflowId());
assertEquals(workflow.getWorkflowDefinition(), argumentCaptor.getAllValues().get(1).getWorkflowDefinition());
verify(executionDAOFacade, times(1)).createWorkflow(argumentCaptor.capture());
assertEquals(workflow.getWorkflowId(), argumentCaptor.getAllValues().get(0).getWorkflowId());
assertEquals(workflow.getWorkflowDefinition(), argumentCaptor.getAllValues().get(0).getWorkflowDefinition());

// add a new version of the workflow definition and restart with latest
workflow.setStatus(Workflow.WorkflowStatus.COMPLETED);
Expand All @@ -404,9 +404,9 @@ public void testRestartWorkflow() {
verify(metadataDAO, times(1)).getLatest(anyString());

argumentCaptor = ArgumentCaptor.forClass(Workflow.class);
verify(executionDAOFacade, times(4)).updateWorkflow(argumentCaptor.capture());
assertEquals(workflow.getWorkflowId(), argumentCaptor.getAllValues().get(3).getWorkflowId());
assertEquals(workflowDef, argumentCaptor.getAllValues().get(3).getWorkflowDefinition());
verify(executionDAOFacade, times(2)).createWorkflow(argumentCaptor.capture());
assertEquals(workflow.getWorkflowId(), argumentCaptor.getAllValues().get(1).getWorkflowId());
assertEquals(workflowDef, argumentCaptor.getAllValues().get(1).getWorkflowDefinition());
}


Expand Down
Expand Up @@ -243,15 +243,13 @@ public void testTaskOps() {
Task matching = pending.stream().filter(task -> task.getTaskId().equals(tasks.get(0).getTaskId())).findAny().get();
assertTrue(EqualsBuilder.reflectionEquals(matching, tasks.get(0)));

List<Task> update = new LinkedList<>();
for (int i = 0; i < 3; i++) {
Task found = getExecutionDAO().getTask(workflowId + "_t" + i);
assertNotNull(found);
found.getOutputData().put("updated", true);
found.setStatus(Task.Status.COMPLETED);
update.add(found);
getExecutionDAO().updateTask(found);
}
getExecutionDAO().updateTasks(update);

List<String> taskIds = tasks.stream().map(Task::getTaskId).collect(Collectors.toList());
List<Task> found = getExecutionDAO().getTasks(taskIds);
Expand Down
Expand Up @@ -345,7 +345,7 @@ public void addTaskExecutionLogs(List<TaskExecLog> taskExecLogs) {
List<String> taskIds = taskExecLogs.stream()
.map(TaskExecLog::getTaskId)
.collect(Collectors.toList());
logger.error("Failed to index task execution logs for tasks: ", taskIds, e);
logger.error("Failed to index task execution logs for tasks: {}", taskIds, e);
}
}

Expand Down
Expand Up @@ -189,11 +189,6 @@ public boolean exceedsInProgressLimit(Task task) {
return rateLimited;
}

@Override
public void updateTasks(List<Task> tasks) {
withTransaction(connection -> tasks.forEach(task -> updateTask(connection, task)));
}

@Override
public boolean removeTask(String taskId) {
Task task = getTask(taskId);
Expand Down Expand Up @@ -251,13 +246,11 @@ public List<Task> getTasksForWorkflow(String workflowId) {

@Override
public String createWorkflow(Workflow workflow) {
workflow.setCreateTime(System.currentTimeMillis());
return insertOrUpdateWorkflow(workflow, false);
}

@Override
public String updateWorkflow(Workflow workflow) {
workflow.setUpdateTime(System.currentTimeMillis());
return insertOrUpdateWorkflow(workflow, true);
}

Expand Down Expand Up @@ -477,10 +470,6 @@ private String insertOrUpdateWorkflow(Workflow workflow, boolean update) {

boolean terminal = workflow.getStatus().isTerminal();

if (terminal) {
workflow.setEndTime(System.currentTimeMillis());
}

List<Task> tasks = workflow.getTasks();
workflow.setTasks(Lists.newLinkedList());

Expand All @@ -504,11 +493,6 @@ private String insertOrUpdateWorkflow(Workflow workflow, boolean update) {
}

private void updateTask(Connection connection, Task task) {
task.setUpdateTime(System.currentTimeMillis());
if (task.getStatus() != null && task.getStatus().isTerminal() && task.getEndTime() == 0) {
task.setEndTime(System.currentTimeMillis());
}

Optional<TaskDef> taskDefinition = task.getTaskDefinition();

if (taskDefinition.isPresent() && taskDefinition.get().concurrencyLimit() > 0) {
Expand Down
Expand Up @@ -155,32 +155,20 @@ public List<Task> createTasks(List<Task> tasks) {

}

@Override
public void updateTasks(List<Task> tasks) {
for (Task task : tasks) {
updateTask(task);
}
}

@Override
public void updateTask(Task task) {
task.setUpdateTime(System.currentTimeMillis());
if (task.getStatus() != null && task.getStatus().isTerminal() && task.getEndTime() == 0) {
task.setEndTime(System.currentTimeMillis());
}

Optional<TaskDef> taskDefinition = task.getTaskDefinition();

if(taskDefinition.isPresent() && taskDefinition.get().concurrencyLimit() > 0) {

if(task.getStatus() != null && task.getStatus().equals(Status.IN_PROGRESS)) {
dynoClient.sadd(nsKey(TASKS_IN_PROGRESS_STATUS, task.getTaskDefName()), task.getTaskId());
logger.debug("Workflow Task added to TASKS_IN_PROGRESS_STATUS with tasksInProgressKey: {}, workflowId: {}, taskId: {}, taskType: {}, taskStatus: {} during updateTask",
nsKey(TASKS_IN_PROGRESS_STATUS, task.getTaskDefName(), task.getWorkflowInstanceId(), task.getTaskId(), task.getTaskType(), task.getStatus().name()));
nsKey(TASKS_IN_PROGRESS_STATUS, task.getTaskDefName(), task.getTaskId()), task.getWorkflowInstanceId(), task.getTaskId(), task.getTaskType(), task.getStatus().name());
}else {
dynoClient.srem(nsKey(TASKS_IN_PROGRESS_STATUS, task.getTaskDefName()), task.getTaskId());
logger.debug("Workflow Task removed from TASKS_IN_PROGRESS_STATUS with tasksInProgressKey: {}, workflowId: {}, taskId: {}, taskType: {}, taskStatus: {} during updateTask",
nsKey(TASKS_IN_PROGRESS_STATUS, task.getTaskDefName(), task.getWorkflowInstanceId(), task.getTaskId(), task.getTaskType(), task.getStatus().name()));
nsKey(TASKS_IN_PROGRESS_STATUS, task.getTaskDefName(), task.getTaskId()), task.getWorkflowInstanceId(), task.getTaskId(), task.getTaskType(), task.getStatus().name());
String key = nsKey(TASK_LIMIT_BUCKET, task.getTaskDefName());
dynoClient.zrem(key, task.getTaskId());
logger.debug("Workflow Task removed from TASK_LIMIT_BUCKET with taskLimitBucketKey: {}, workflowId: {}, taskId: {}, taskType: {}, taskStatus: {} during updateTask",
Expand Down Expand Up @@ -364,13 +352,11 @@ public List<Task> getPendingTasksForTaskType(String taskName) {

@Override
public String createWorkflow(Workflow workflow) {
workflow.setCreateTime(System.currentTimeMillis());
return insertOrUpdateWorkflow(workflow, false);
}

@Override
public String updateWorkflow(Workflow workflow) {
workflow.setUpdateTime(System.currentTimeMillis());
return insertOrUpdateWorkflow(workflow, true);
}

Expand Down Expand Up @@ -496,9 +482,6 @@ public boolean canSearchAcrossWorkflows() {
private String insertOrUpdateWorkflow(Workflow workflow, boolean update) {
Preconditions.checkNotNull(workflow, "workflow object cannot be null");

if (workflow.getStatus().isTerminal()) {
workflow.setEndTime(System.currentTimeMillis());
}
List<Task> tasks = workflow.getTasks();
workflow.setTasks(new LinkedList<>());

Expand Down
Expand Up @@ -3528,8 +3528,8 @@ public void testTaskSkipping() {
// Check the tasks, at this time there should be 3 task
assertEquals(2, es.getTasks().size());

assertEquals(SCHEDULED, es.getTasks().stream().filter( task -> "t1".equals(task.getReferenceTaskName())).findFirst().orElse(null).getStatus());
assertEquals(Status.SKIPPED, es.getTasks().stream().filter( task -> "t2".equals(task.getReferenceTaskName())).findFirst().orElse(null).getStatus());
assertEquals(SCHEDULED, es.getTasks().stream().filter(task -> "t1".equals(task.getReferenceTaskName())).findFirst().orElse(null).getStatus());
assertEquals(Status.SKIPPED, es.getTasks().stream().filter(task -> "t2".equals(task.getReferenceTaskName())).findFirst().orElse(null).getStatus());

Task task = workflowExecutionService.poll("junit_task_1", "task1.junit.worker");
assertNotNull(task);
Expand Down Expand Up @@ -3714,7 +3714,7 @@ public void testSubWorkflow() {
assertNotNull("Output: " + task.getOutputData().toString() + ", status: " + task.getStatus(), task.getOutputData().get("subWorkflowId"));
assertNotNull(task.getInputData());
assertTrue(task.getInputData().containsKey("workflowInput"));
assertEquals(42, ((Map<String, Object>)task.getInputData().get("workflowInput")).get("param2"));
assertEquals(42, ((Map<String, Object>) task.getInputData().get("workflowInput")).get("param2"));
String subWorkflowId = task.getOutputData().get("subWorkflowId").toString();

es = workflowExecutionService.getExecutionStatus(subWorkflowId, true);
Expand Down Expand Up @@ -4036,8 +4036,8 @@ public void testLambda() {
assertNotNull(workflowDef);
metadataService.registerWorkflowDef(workflowDef);

Map<String, Object> inputs = new HashMap<>();
inputs.put("a",1);
Map<String, Object> inputs = new HashMap<>();
inputs.put("a", 1);
String workflowId = startOrLoadWorkflowExecution(workflowDef.getName(), workflowDef.getVersion(), "", inputs, null, null);
Workflow workflow = workflowExecutor.getWorkflow(workflowId, true);

Expand Down Expand Up @@ -4143,6 +4143,7 @@ public void testTerminateTaskWithFailedStatus() {
metadataService.registerWorkflowDef(workflowDef);

Map wfInput = Collections.singletonMap("a", 1);
//noinspection unchecked
String workflowId = startOrLoadWorkflowExecution(workflowDef.getName(), workflowDef.getVersion(), "", wfInput, null, null);
Workflow workflow = workflowExecutor.getWorkflow(workflowId, true);

Expand Down Expand Up @@ -4360,11 +4361,11 @@ public void testWorkflowUsingExternalPayloadStorage() {
}

@Test
public void testExecutionTimes(){
public void testExecutionTimes() {

String taskName = "junit_task_1";
TaskDef taskDef = notFoundSafeGetTaskDef(taskName);
taskDef.setTimeoutSeconds(1);
taskDef.setTimeoutSeconds(10);
metadataService.updateTaskDef(taskDef);

metadataService.registerTaskDef(Collections.singletonList(taskDef));
Expand Down Expand Up @@ -4419,6 +4420,7 @@ public void testExecutionTimes(){
metadataService.registerWorkflowDef(workflowDef);

Map workflowInput = Collections.emptyMap();
//noinspection unchecked
String workflowId = startOrLoadWorkflowExecution(workflowDef.getName(), workflowDef.getVersion(), "test", workflowInput, null, null);
Workflow workflow = workflowExecutor.getWorkflow(workflowId, true);

Expand All @@ -4440,7 +4442,7 @@ public void testExecutionTimes(){
assertNotNull(workflow);
assertEquals(WorkflowStatus.COMPLETED, workflow.getStatus());

workflow.getTasks().forEach( workflowTask -> {
workflow.getTasks().forEach(workflowTask -> {
assertTrue(workflowTask.getScheduledTime() <= workflowTask.getStartTime());
assertTrue(workflowTask.getStartTime() < workflowTask.getEndTime());
});
Expand Down Expand Up @@ -4768,7 +4770,7 @@ public void testSubWorkflowTaskToDomain() {
assertNotNull("Output: " + task.getOutputData().toString() + ", status: " + task.getStatus(), task.getOutputData().get("subWorkflowId"));
assertNotNull(task.getInputData());
assertTrue(task.getInputData().containsKey("workflowInput"));
assertEquals(42, ((Map<String, Object>)task.getInputData().get("workflowInput")).get("param2"));
assertEquals(42, ((Map<String, Object>) task.getInputData().get("workflowInput")).get("param2"));
String subWorkflowId = task.getOutputData().get("subWorkflowId").toString();

es = workflowExecutionService.getExecutionStatus(subWorkflowId, true);
Expand Down Expand Up @@ -4841,7 +4843,7 @@ public void testSubWorkflowTaskToDomainWildcard() {
assertNotNull("Output: " + task.getOutputData().toString() + ", status: " + task.getStatus(), task.getOutputData().get("subWorkflowId"));
assertNotNull(task.getInputData());
assertTrue(task.getInputData().containsKey("workflowInput"));
assertEquals(42, ((Map<String, Object>)task.getInputData().get("workflowInput")).get("param2"));
assertEquals(42, ((Map<String, Object>) task.getInputData().get("workflowInput")).get("param2"));
String subWorkflowId = task.getOutputData().get("subWorkflowId").toString();

es = workflowExecutionService.getExecutionStatus(subWorkflowId, true);
Expand Down

0 comments on commit 451eb97

Please sign in to comment.