Skip to content

Commit

Permalink
Proper handling of lifecycle
Browse files Browse the repository at this point in the history
  • Loading branch information
mproch committed Feb 5, 2021
1 parent ee15ce2 commit 424dcd7
Show file tree
Hide file tree
Showing 31 changed files with 391 additions and 417 deletions.
Expand Up @@ -31,9 +31,7 @@ abstract class EagerService extends Service {
final override def close(): Unit = {}
}

//TODO: replace ServiceInvoker with this trait?
//TODO: handle lifecycle methods...
trait EagerServiceInvoker extends Lifecycle {
trait ServiceInvoker extends Lifecycle {

def invokeService(params: Map[String, Any])(implicit ec: ExecutionContext,
collector: InvocationCollectors.ServiceInvocationCollector,
Expand Down
Expand Up @@ -2,7 +2,7 @@ package pl.touk.nussknacker.engine.api.deployment

import java.nio.charset.StandardCharsets

import pl.touk.nussknacker.engine.api.Context
import pl.touk.nussknacker.engine.api.{Context, ContextId}
import pl.touk.nussknacker.engine.api.exception.EspExceptionInfo

object TestProcess {
Expand All @@ -27,8 +27,8 @@ object TestProcess {
copy(invocationResults = invocationResults + (nodeId -> addResults(invocationResult, invocationResults.getOrElse(nodeId, List()))))
}

def updateMockedResult(nodeId: String, context: Context, name: String, result: Any) = {
val mockedResult = MockedResult(context.id, name, variableEncoder(result))
def updateMockedResult(nodeId: String, contextId: ContextId, name: String, result: Any) = {
val mockedResult = MockedResult(contextId.value, name, variableEncoder(result))
copy(mockedResults = mockedResults + (nodeId -> (mockedResults.getOrElse(nodeId, List()) :+ mockedResult)))
}

Expand Down
@@ -1,7 +1,8 @@
package pl.touk.nussknacker.engine.api.test

import pl.touk.nussknacker.engine.api.context.ProcessCompilationError.NodeId
import pl.touk.nussknacker.engine.api.test.InvocationCollectors.CollectorMode.CollectorMode
import pl.touk.nussknacker.engine.api.{Context, InterpretationResult}
import pl.touk.nussknacker.engine.api.{ContextId, InterpretationResult}

import scala.concurrent.{ExecutionContext, Future}

Expand All @@ -24,6 +25,8 @@ object InvocationCollectors {
val Test, Query, Production = Value
}

type ServiceInvocationCollectorForContext = ContextId => ServiceInvocationCollector

trait ServiceInvocationCollector {

def enable(runId: TestRunId): ServiceInvocationCollector
Expand Down Expand Up @@ -91,20 +94,22 @@ object InvocationCollectors {

}

case class TestServiceInvocationCollector private(runIdOpt: Option[TestRunId], nodeContext: NodeContext) extends ServiceInvocationCollector {
case class TestServiceInvocationCollector private(runIdOpt: Option[TestRunId],
contextId: ContextId,
nodeId: NodeId, serviceRef: String) extends ServiceInvocationCollector {
def enable(runId: TestRunId) = this.copy(runIdOpt = Some(runId))
override protected def collectorMode: CollectorMode = CollectorMode.Test

override protected def updateResult(runId: TestRunId, testInvocation: Any, name: String): Unit = {
ResultsCollectingListenerHolder.updateResults(
runId, _.updateMockedResult(nodeContext.nodeId, Context(nodeContext.contextId), nodeContext.ref, testInvocation)
runId, _.updateMockedResult(nodeId.id, contextId, serviceRef, testInvocation)
)
}
}

object TestServiceInvocationCollector {
def apply(nodeContext: NodeContext): TestServiceInvocationCollector = {
TestServiceInvocationCollector(runIdOpt = None, nodeContext = nodeContext)
def apply(contextId: ContextId, nodeId: NodeId, serviceRef: String): TestServiceInvocationCollector = {
TestServiceInvocationCollector(runIdOpt = None, contextId, nodeId, serviceRef)
}
}

Expand All @@ -118,7 +123,7 @@ object InvocationCollectors {

def collect(result: InterpretationResult): Unit = {
val mockedResult = outputPreparer(result.output)
ResultsCollectingListenerHolder.updateResults(runId, _.updateMockedResult(nodeId, result.finalContext, ref, mockedResult))
ResultsCollectingListenerHolder.updateResults(runId, _.updateMockedResult(nodeId, ContextId(result.finalContext.id), ref, mockedResult))
}
}

Expand Down
Expand Up @@ -30,7 +30,7 @@ class InterpreterSetup[T:ClassTag] {
listeners: Seq[ProcessListener]): (Context, ExecutionContext) => F[Either[List[InterpretationResult], EspExceptionInfo[_ <: Throwable]]] = {
val compiledProcess = compile(services, process, listeners)
val interpreter = compiledProcess.interpreter
val parts = failOnErrors(compiledProcess.compile())
val parts = failOnErrors(compiledProcess.compile().result)

def compileNode(part: ProcessPart) =
failOnErrors(compiledProcess.subPartCompiler.compile(part.node, part.validationContext)(process.metaData).result)
Expand Down
Expand Up @@ -2,6 +2,8 @@ package pl.touk.nussknacker.engine.process

import org.apache.flink.api.common.functions.RichFunction
import org.apache.flink.configuration.Configuration
import pl.touk.nussknacker.engine.api.context.ValidationContext
import pl.touk.nussknacker.engine.compiledgraph.node.Node
import pl.touk.nussknacker.engine.flink.api.exception.FlinkEspExceptionHandler
import pl.touk.nussknacker.engine.graph.node.NodeData
import pl.touk.nussknacker.engine.process.compiler.FlinkProcessCompilerData
Expand All @@ -13,18 +15,22 @@ trait ProcessPartFunction extends ExceptionHandlerFunction {

protected def node: SplittedNode[_<:NodeData]

protected def validationContext: ValidationContext

protected lazy val (compiledNode, services) = compiledProcessWithDeps.compileSubPart(node, validationContext)

private val nodesUsed = SplittedNodesCollector.collectNodes(node).map(_.data)

override def close(): Unit = {
super.close()
if (compiledProcessWithDeps != null) {
compiledProcessWithDeps.close(nodesUsed)
compiledProcessWithDeps.close(services)
}
}

override def open(parameters: Configuration): Unit = {
super.open(parameters)
compiledProcessWithDeps.open(getRuntimeContext, nodesUsed)
compiledProcessWithDeps.open(getRuntimeContext, services)
}

}
Expand Down
Expand Up @@ -7,15 +7,15 @@ import org.apache.flink.api.common.restartstrategy.RestartStrategies
import pl.touk.nussknacker.engine.Interpreter
import pl.touk.nussknacker.engine.api.context.{ProcessCompilationError, ValidationContext}
import pl.touk.nussknacker.engine.api.process.AsyncExecutionContextPreparer
import pl.touk.nussknacker.engine.api.{JobData, MetaData}
import pl.touk.nussknacker.engine.api.{JobData, Lifecycle, MetaData}
import pl.touk.nussknacker.engine.compile.ProcessCompilerData
import pl.touk.nussknacker.engine.compiledgraph.CompiledProcessParts
import pl.touk.nussknacker.engine.compiledgraph.node.Node
import pl.touk.nussknacker.engine.compiledgraph.service.ServiceRef
import pl.touk.nussknacker.engine.definition.LazyInterpreterDependencies
import pl.touk.nussknacker.engine.flink.api.RuntimeContextLifecycle
import pl.touk.nussknacker.engine.flink.api.exception.FlinkEspExceptionHandler
import pl.touk.nussknacker.engine.flink.api.process.FlinkProcessSignalSenderProvider
import pl.touk.nussknacker.engine.graph.node.NodeData
import pl.touk.nussknacker.engine.splittedgraph.splittednode.SplittedNode

import scala.concurrent.duration.FiniteDuration
Expand All @@ -36,21 +36,21 @@ class FlinkProcessCompilerData(compiledProcess: ProcessCompilerData,
val processTimeout: FiniteDuration
) {

def open(runtimeContext: RuntimeContext, nodesToUse: List[_<:NodeData]) : Unit = {
val lifecycle = compiledProcess.lifecycle(nodesToUse)
lifecycle.foreach {_.open(jobData)}
lifecycle.collect{
case s:RuntimeContextLifecycle =>
def open(runtimeContext: RuntimeContext, nodesToUse: List[ServiceRef]) : Unit = {
compiledProcess.lifecycle(nodesToUse).open(jobData) {
case s:RuntimeContextLifecycle with Lifecycle =>
s.open(jobData)
s.open(runtimeContext)
}
}

def close(nodesToUse: List[_<:NodeData]) : Unit = {
compiledProcess.lifecycle(nodesToUse).foreach(_.close())
def close(nodesToUse: List[ServiceRef]) : Unit = {
compiledProcess.lifecycle(nodesToUse).close()
}

def compileSubPart(node: SplittedNode[_], validationContext: ValidationContext): Node = {
validateOrFail(compiledProcess.subPartCompiler.compile(node, validationContext)(compiledProcess.metaData).result)
def compileSubPart(node: SplittedNode[_], validationContext: ValidationContext): (Node, List[ServiceRef]) = {
val compilation = compiledProcess.subPartCompiler.compile(node, validationContext)(compiledProcess.metaData)
(validateOrFail(compilation.result), compilation.services)
}

private def validateOrFail[T](validated: ValidatedNel[ProcessCompilationError, T]): T = validated match {
Expand All @@ -64,7 +64,7 @@ class FlinkProcessCompilerData(compiledProcess: ProcessCompilerData,

val lazyInterpreterDeps: LazyInterpreterDependencies = compiledProcess.lazyInterpreterDeps

def compileProcess(): CompiledProcessParts = validateOrFail(compiledProcess.compile())
def compileProcess(): CompiledProcessParts = validateOrFail(compiledProcess.compile().result)

def restartStrategy: RestartStrategies.RestartStrategyConfiguration = exceptionHandler.restartStrategy

Expand Down
Expand Up @@ -22,12 +22,12 @@ import scala.util.{Failure, Success}
import scala.util.control.NonFatal

private[registrar] class AsyncInterpretationFunction(val compiledProcessWithDepsProvider: ClassLoader => FlinkProcessCompilerData,
val node: SplittedNode[_<:NodeData], validationContext: ValidationContext,
asyncExecutionContextPreparer: AsyncExecutionContextPreparer, useIOMonad: Boolean)
val node: SplittedNode[_<:NodeData],
val validationContext: ValidationContext,
asyncExecutionContextPreparer: AsyncExecutionContextPreparer,
useIOMonad: Boolean)
extends RichAsyncFunction[Context, InterpretationResult] with LazyLogging with ProcessPartFunction {

private lazy val compiledNode = compiledProcessWithDeps.compileSubPart(node, validationContext)

import compiledProcessWithDeps._

private var executionContext: ExecutionContext = _
Expand Down
Expand Up @@ -17,11 +17,11 @@ import scala.util.control.NonFatal

private[registrar] class SyncInterpretationFunction(val compiledProcessWithDepsProvider: ClassLoader => FlinkProcessCompilerData,
val node: SplittedNode[_<:NodeData],
validationContext: ValidationContext, useIOMonad: Boolean)
val validationContext: ValidationContext,
useIOMonad: Boolean)
extends RichFlatMapFunction[Context, InterpretationResult] with ProcessPartFunction {

private lazy implicit val ec: ExecutionContext = SynchronousExecutionContext.ctx
private lazy val compiledNode = compiledProcessWithDeps.compileSubPart(node, validationContext)

import compiledProcessWithDeps._

Expand Down
Expand Up @@ -2,11 +2,15 @@ package pl.touk.nussknacker.engine.util.service.query

import com.typesafe.config.ConfigFactory
import org.scalatest.{FunSuite, Matchers}
import pl.touk.nussknacker.engine.ModelData
import pl.touk.nussknacker.engine.api._
import pl.touk.nussknacker.engine.api.process.{ProcessObjectDependencies, WithCategories}
import pl.touk.nussknacker.engine.api.test.InvocationCollectors.ServiceInvocationCollector
import pl.touk.nussknacker.engine.flink.util.service.TimeMeasuringService
import pl.touk.nussknacker.engine.graph.expression.Expression
import pl.touk.nussknacker.engine.spel.Implicits._
import pl.touk.nussknacker.engine.testing.LocalModelData
import pl.touk.nussknacker.engine.util.SynchronousExecutionContext
import pl.touk.nussknacker.engine.util.process.EmptyProcessConfigCreator
import pl.touk.nussknacker.engine.util.service.GenericTimeMeasuringService
import pl.touk.nussknacker.test.PatientScalaFutures
Expand All @@ -20,14 +24,12 @@ class ServiceQueryOpenCloseSpec

import ServiceQueryOpenCloseSpec._

private implicit val executionContext: ExecutionContext = ExecutionContext.Implicits.global
private implicit val executionContext: ExecutionContext = SynchronousExecutionContext.ctx

test("open and close service") {
val service = createService
service.wasOpen shouldBe false
whenReady(invokeService(4, service)) { r =>
r.result shouldBe 4
}
invokeService(4, service) shouldBe 4
service.wasOpen shouldBe true
eventually {
service.wasClose shouldBe true
Expand All @@ -40,24 +42,24 @@ class ServiceQueryOpenCloseSpec
super.services(processObjectDependencies) ++ Map("cast" -> WithCategories(createService))
})

whenReady(new ServiceQuery(modelData).invoke("cast", "integer" -> 4)) {
_.result shouldBe 4
}
whenReady(new ServiceQuery(modelData).invoke("cast", "integer" -> 5)) {
_.result shouldBe 5
}
invokeService(4, modelData) shouldBe 4
invokeService(5, modelData) shouldBe 5
}

private def createService = {
new CastIntToLongService with TimeMeasuringService
}

private def invokeService(arg: Int, service: Service) = {
new ServiceQuery(LocalModelData(ConfigFactory.empty, new EmptyProcessConfigCreator {
private def invokeService(arg: Int, modelData: ModelData): Any = {
new ServiceQuery(modelData).invoke("cast", "integer" -> (arg.toString: Expression))
.futureValue.result
}

private def invokeService(arg: Int, service: Service): Any = {
invokeService(arg, LocalModelData(ConfigFactory.empty, new EmptyProcessConfigCreator {
override def services(processObjectDependencies: ProcessObjectDependencies): Map[String, WithCategories[Service]] =
super.services(processObjectDependencies) ++ Map("cast" -> WithCategories(service))
}))
.invoke("cast", "integer" -> arg)
}
}

Expand Down
Expand Up @@ -8,7 +8,6 @@ import pl.touk.nussknacker.engine.api._
import pl.touk.nussknacker.engine.api.context.ProcessCompilationError.NodeId
import pl.touk.nussknacker.engine.api.exception.EspExceptionInfo
import pl.touk.nussknacker.engine.api.expression.Expression
import pl.touk.nussknacker.engine.api.test.InvocationCollectors.NodeContext
import pl.touk.nussknacker.engine.compiledgraph.node.{Sink, Source, _}
import pl.touk.nussknacker.engine.compiledgraph.service._
import pl.touk.nussknacker.engine.compiledgraph.variable._
Expand Down Expand Up @@ -77,20 +76,20 @@ private class InterpreterInternal[F[_]](listeners: Seq[ProcessListener],
}.getOrElse(parentContext)
interpretNext(next, newParentContext)
case Processor(_, ref, next, false) =>
invoke(ref, None, ctx).flatMap {
invoke(ref, ctx).flatMap {
case ValueWithContext(_, newCtx) => interpretNext(next, newCtx)
}
case Processor(_, _, next, true) => interpretNext(next, ctx)
case EndingProcessor(id, ref, false) =>
invoke(ref, None, ctx).map {
invoke(ref, ctx).map {
case ValueWithContext(output, newCtx) =>
List(InterpretationResult(EndReference(id), output, newCtx))
}
case EndingProcessor(id, _, true) =>
//FIXME: null??
monad.pure(List(InterpretationResult(EndReference(id), null, ctx)))
case Enricher(_, ref, outName, next) =>
invoke(ref, Some(outName), ctx).flatMap {
invoke(ref, ctx).flatMap {
case ValueWithContext(out, newCtx) =>
interpretNext(next, newCtx.withVariable(outName, out))
}
Expand Down Expand Up @@ -178,16 +177,13 @@ private class InterpreterInternal[F[_]](listeners: Seq[ProcessListener],
}
}

private def invoke(ref: ServiceRef, outputVariableNameOpt: Option[String], ctx: Context)(implicit node: Node): F[ValueWithContext[Any]] = {
val (newCtx, preparedParams) = expressionEvaluator.evaluateParameters(ref.parameters, ctx)
val resultFuture = ref.invoker
.invoke(preparedParams, NodeContext(ctx.id, node.id, ref.id, outputVariableNameOpt))

private def invoke(ref: ServiceRef, ctx: Context)(implicit node: Node): F[ValueWithContext[AnyRef]] = {
val (preparedParams, resultFuture) = ref.invoke(ctx, expressionEvaluator)
resultFuture.onComplete { result =>
//TODO: what about implicit??
listeners.foreach(_.serviceInvoked(node.id, ref.id, ctx, metaData, preparedParams, result))
}
interpreterShape.fromFuture(resultFuture.map(ValueWithContext(_, newCtx))(SynchronousExecutionContext.ctx))
interpreterShape.fromFuture(resultFuture.map(ValueWithContext(_, ctx))(SynchronousExecutionContext.ctx))
}

private def evaluateExpression[R](expr: Expression, ctx: Context, name: String)
Expand Down
Expand Up @@ -6,25 +6,28 @@ import cats.instances.map._
import cats.kernel.Semigroup
import cats.{Applicative, Traverse}
import com.typesafe.scalalogging.LazyLogging
import pl.touk.nussknacker.engine.api.Lifecycle
import pl.touk.nussknacker.engine.api.context.{ProcessCompilationError, ProcessUncanonizationError, ValidationContext}
import pl.touk.nussknacker.engine.api.definition.Parameter
import pl.touk.nussknacker.engine.api.expression.ExpressionTypingInfo
import pl.touk.nussknacker.engine.api.typed.typing.TypingResult
import pl.touk.nussknacker.engine.canonize.{MaybeArtificial, MaybeArtificialExtractor}
import pl.touk.nussknacker.engine.compiledgraph.service.ServiceRef

import scala.language.{higherKinds, reflectiveCalls}

case class CompilationResult[+Result](typing: Map[String, NodeTypingInfo],
result: ValidatedNel[ProcessCompilationError, Result]) {
services: List[ServiceRef],
result: ValidatedNel[ProcessCompilationError, Result]) {

import CompilationResult._

def andThen[B](f: Result => CompilationResult[B]): CompilationResult[B] =
result match {
case Valid(a) =>
val newResult = f(a)
newResult.copy(typing = Semigroup.combine(typing, newResult.typing))
case i @ Invalid(_) => CompilationResult(typing, i)
newResult.copy(typing = Semigroup.combine(typing, newResult.typing), services = services ++ newResult.services)
case i @ Invalid(_) => CompilationResult(typing, services, i)
}

def map[T](action: Result => T) : CompilationResult[T] = copy(result = result.map(action))
Expand All @@ -47,12 +50,12 @@ case class CompilationResult[+Result](typing: Map[String, NodeTypingInfo],
//in fact, I'm not quite sure it's really, formally Applicative - but for our purposes it should be ok...
object CompilationResult extends Applicative[CompilationResult] {

def apply[Result](validatedProcess: ValidatedNel[ProcessCompilationError, Result]) : CompilationResult[Result] = CompilationResult(Map(), validatedProcess)
def apply[Result](validatedProcess: ValidatedNel[ProcessCompilationError, Result]) : CompilationResult[Result] = CompilationResult(Map(), Nil, validatedProcess)

override def pure[A](x: A): CompilationResult[A] = CompilationResult(Map(), Valid(x))
override def pure[A](x: A): CompilationResult[A] = CompilationResult(Map(), Nil, Valid(x))

override def ap[A, B](ff: CompilationResult[A => B])(fa: CompilationResult[A]): CompilationResult[B] =
CompilationResult(Semigroup.combine(fa.typing, ff.typing), fa.result.ap(ff.result))
CompilationResult(Semigroup.combine(fa.typing, ff.typing), fa.services ++ ff.services, fa.result.ap(ff.result))

implicit class CompilationResultTraverseOps[T[_]: Traverse, B](traverse: T[CompilationResult[B]]) {
def sequence: CompilationResult[T[B]] = {
Expand Down

0 comments on commit 424dcd7

Please sign in to comment.