Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cancel tasks API #3036

Merged
merged 3 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package com.sksamuel.elastic4s.api

import com.sksamuel.elastic4s.requests.task.{CancelTasksRequest, GetTask, ListTasks, PendingClusterTasksRequest}
import com.sksamuel.elastic4s.requests.task.{CancelTaskByIdRequest, CancelTasksRequest, GetTask, ListTasks, PendingClusterTasksRequest}

trait TaskApi {

def cancelTasks(): CancelTasksRequest = cancelTasks(Nil)
def cancelTasks(first: String, rest: String*): CancelTasksRequest = cancelTasks(first +: rest)
def cancelTasks(nodeIds: Seq[String]): CancelTasksRequest = CancelTasksRequest(nodeIds)
def cancelTaskById(nodeId: String, taskId: String): CancelTaskByIdRequest = CancelTaskByIdRequest(nodeId, taskId)

def pendingClusterTasks(local: Boolean): PendingClusterTasksRequest = PendingClusterTasksRequest(local)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ case class CancelTasksRequest(nodeIds: Seq[String],
def actions(first: String, rest: String*): CancelTasksRequest = actions(first +: rest)
def actions(actions: Iterable[String]): CancelTasksRequest = copy(actions = actions.toSeq)
}

case class CancelTaskByIdRequest(nodeId: String, taskId: String)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ case class Task(node: String,
description: String,
@JsonProperty("start_time_in_millis") private val start_time_in_millis: Long,
@JsonProperty("running_time_in_nanos") private val running_time_in_nanos: Long,
cancellable: Boolean) {
cancellable: Boolean,
cancelled: Option[Boolean]) {
def startTimeInMillis: Long = start_time_in_millis
def runningTime: FiniteDuration = running_time_in_nanos.nanos
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package com.sksamuel.elastic4s.handlers.task

import com.sksamuel.elastic4s.requests.task.{CancelTasksRequest, GetTask, GetTaskResponse, ListTaskResponse, ListTasks}
import com.sksamuel.elastic4s.{ElasticRequest, Handler, HttpResponse, ResponseHandler}
import com.sksamuel.elastic4s.ext.OptionImplicits.RichOption
import com.sksamuel.elastic4s.handlers.ElasticErrorParser
import com.sksamuel.elastic4s.requests.task.{CancelTaskByIdRequest, CancelTasksRequest, GetTask, GetTaskResponse, ListTaskResponse, ListTasks, Node, Task}
import com.sksamuel.elastic4s.{ElasticError, ElasticRequest, Handler, HttpResponse, ResponseHandler}

import scala.util.Try

trait TaskHandlers {

Expand Down Expand Up @@ -30,25 +34,49 @@ trait TaskHandlers {
}
}

implicit object CancelTaskHandler extends Handler[CancelTasksRequest, Boolean] {
abstract class AbstractCancelTaskHandler[U] extends Handler[U, Boolean] {

override def responseHandler: ResponseHandler[Boolean] = new ResponseHandler[Boolean] {
override def handle(response: HttpResponse) = Right(response.statusCode >= 200 && response.statusCode < 300)
override def handle(response: HttpResponse): Either[ElasticError, Boolean] = response.statusCode match {
case 200 | 201 | 202 | 203 | 204 => {
val entity = response.entity.getOrError("No entity defined")
// It can fail on a 200 by returning a response containing node_failures
if (entity.content.contains("node_failures")) Right[ElasticError, Boolean](false)
else {
Try(ResponseHandler.fromEntity[ListTaskResponse](entity)).map { (list: ListTaskResponse) =>
// Check that all the tasks on all the nodes are cancelled
list.nodes.forall { case (_, node: Node) =>
node.tasks.forall { case (_, task: Task) =>
task.cancelled.getOrElse(false)
}
}
}.toEither.fold((_: Throwable) => Right[ElasticError, Boolean](false), b => Right[ElasticError, Boolean](b))
}
}
case _ =>
Left[ElasticError, Boolean](ElasticErrorParser.parse(response))
}
}

override def build(request: CancelTasksRequest): ElasticRequest = {
}

implicit object CancelTaskHandler extends AbstractCancelTaskHandler[CancelTasksRequest] {

val endpoint =
if (request.nodeIds.isEmpty) s"/_tasks/cancel"
else s"/_tasks/task_id:${request.nodeIds.mkString(",")}/_cancel"
override def build(request: CancelTasksRequest): ElasticRequest = {

val params = scala.collection.mutable.Map.empty[String, String]
if (request.nodeIds.nonEmpty)
params.put("nodes", request.nodeIds.mkString(","))
if (request.actions.nonEmpty)
params.put("actions", request.actions.mkString(","))

ElasticRequest("POST", endpoint, params.toMap)
ElasticRequest("POST", "/_tasks/_cancel", params.toMap)
}
}

implicit object CancelTaskByIdHandler extends AbstractCancelTaskHandler[CancelTaskByIdRequest] {

override def build(request: CancelTaskByIdRequest): ElasticRequest =
ElasticRequest("POST", s"/_tasks/${request.nodeId}:${request.taskId}/_cancel")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package com.sksamuel.elastic4s.tasks

import com.sksamuel.elastic4s.{RequestFailure, Response}
import com.sksamuel.elastic4s.requests.common.RefreshPolicy
import com.sksamuel.elastic4s.requests.task.Retries
import com.sksamuel.elastic4s.testkit.DockerTests
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec

import scala.concurrent.duration.DurationInt

class CancelTaskTest extends AnyWordSpec with Matchers with DockerTests {

cleanIndex("cancel_task_a")
cleanIndex("cancel_task_b")
cleanIndex("cancel_task_c")

client.execute {
bulk(
indexInto("cancel_task_a").fields(Map("foo" -> "far")),
indexInto("cancel_task_a").fields(Map("moo" -> "mar")),
indexInto("cancel_task_a").fields(Map("moo" -> "mar")),
indexInto("cancel_task_a").fields(Map("goo" -> "gar"))
).refresh(RefreshPolicy.Immediate)
}.await

"cancel task" should {
"cancel task by id" in {
// kick off a task
val resp = client.execute {
reindex("cancel_task_a", "cancel_task_b").waitForCompletion(false)
}.await.result.right.get

// use the task id from the above task
val response = client.execute {
cancelTaskById(resp.nodeId, resp.taskId)
}.await
response.result should be(true)
}

"cancel task by node and action" in {
// kick off a task
val resp = client.execute {
reindex("cancel_task_a", "cancel_task_c").waitForCompletion(false)
}.await.result.right.get

// use the task id from the above task
val response = client.execute {
cancelTasks(resp.nodeId).actions("*reindex")
}.await

response.result should be(true)
}
}
}


Loading