Skip to content

Commit

Permalink
rewrite early return function to functional style
Browse files Browse the repository at this point in the history
  • Loading branch information
mbaechler committed Apr 29, 2024
1 parent c5098e0 commit 613dcfe
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 167 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

package com.normation.rudder.domain.policies

import cats.implicits.*
import com.normation.inventory.domain.NodeId
import com.normation.rudder.domain.nodes.NodeGroup
import com.normation.rudder.domain.nodes.NodeGroupId
Expand All @@ -47,6 +48,7 @@ import net.liftweb.json.JsonDSL.*
import scala.collection.MapView
import scala.util.matching.Regex
import zio.Chunk
import zio.interop.catz.chunkStdInstances

/**
* A target is either
Expand Down Expand Up @@ -247,52 +249,49 @@ object RuleTarget extends Loggable {
allNodesAreThere: Boolean = true // if we are working on a subset of node, set to false
): Set[NodeId] = {

targets.foldLeft(Set[NodeId]()) {
case (nodes, target) =>
target match {
case AllTarget => return allNodes.keySet.toSet
case AllTargetExceptPolicyServers => nodes ++ allNodes.collect { case (k, isPolicyServer) if (!isPolicyServer) => k }
case AllPolicyServers => nodes ++ allNodes.collect { case (k, isPolicyServer) if (isPolicyServer) => k }
case PolicyServerTarget(nodeId) =>
if (allNodesAreThere) {
nodes + nodeId
} else {
// nodeId may not be in allNodes
allNodes.keySet.contains(nodeId) match {
case true => nodes + nodeId
case _ => nodes
}
}
targets.toList.traverse {
case AllTarget => Left(allNodes.keySet.toSet)
case AllTargetExceptPolicyServers => Right(allNodes.collect { case (k, isPolicyServer) if (!isPolicyServer) => k }.toSet)
case AllPolicyServers => Right(allNodes.collect { case (k, isPolicyServer) if (isPolicyServer) => k }.toSet)
case PolicyServerTarget(nodeId) =>
if (allNodesAreThere) {
Right(Set(nodeId))
} else {
// nodeId may not be in allNodes
allNodes.keySet.contains(nodeId) match {
case true => Right(Set(nodeId))
case _ => Right(Set.empty)
}
}

// here, if we don't find the group, we consider it's an error in the
// target recording, but don't fail, just log it.
case GroupTarget(groupId) =>
nodes ++ groups.getOrElse(groupId, Set())
// here, if we don't find the group, we consider it's an error in the
// target recording, but don't fail, just log it.
case GroupTarget(groupId) =>
Right(groups.getOrElse(groupId, Set.empty))

case TargetIntersection(targets) =>
val nodeSets = targets.map(t => getNodeIds(Set(t), allNodes, groups, allNodesAreThere))
// Compute the intersection of the sets of Nodes
val intersection = nodeSets.foldLeft(allNodes.keySet) {
case (currentIntersection, nodes) => currentIntersection.intersect(nodes)
}
nodes ++ intersection

case TargetUnion(targets) =>
val nodeSets = targets.map(t => getNodeIds(Set(t), allNodes, groups, allNodesAreThere))
// Compute the union of the sets of Nodes
val union = nodeSets.foldLeft(Set[NodeId]()) { case (currentUnion, nodes) => currentUnion.union(nodes) }
nodes ++ union

case TargetExclusion(included, excluded) =>
// Compute the included Nodes
val includedNodes = getNodeIds(Set(included), allNodes, groups, allNodesAreThere)
// Compute the excluded Nodes
val excludedNodes = getNodeIds(Set(excluded), allNodes, groups, allNodesAreThere)
// Remove excluded nodes from included nodes
val result = includedNodes -- excludedNodes
nodes ++ result
case TargetIntersection(targets) =>
val nodeSets = targets.map(t => getNodeIds(Set(t), allNodes, groups, allNodesAreThere))
// Compute the intersection of the sets of Nodes
val intersection = nodeSets.foldLeft(allNodes.keySet.toSet) {
case (currentIntersection, nodes) => currentIntersection.intersect(nodes)
}
}
Right(intersection)

case TargetUnion(targets) =>
val nodeSets = targets.map(t => getNodeIds(Set(t), allNodes, groups, allNodesAreThere))
// Compute the union of the sets of Nodes
val union = nodeSets.foldLeft(Set[NodeId]()) { case (currentUnion, nodes) => currentUnion.union(nodes) }
Right(union)

case TargetExclusion(included, excluded) =>
// Compute the included Nodes
val includedNodes = getNodeIds(Set(included), allNodes, groups, allNodesAreThere)
// Compute the excluded Nodes
val excludedNodes = getNodeIds(Set(excluded), allNodes, groups, allNodesAreThere)
// Remove excluded nodes from included nodes
val result = includedNodes -- excludedNodes
Right(result)
}.map(_.toSet.flatten).merge
}

/**
Expand All @@ -319,48 +318,40 @@ object RuleTarget extends Loggable {
groups: Map[NodeGroupId, Chunk[NodeId]]
): Chunk[NodeId] = {

targets.foldLeft(Chunk[NodeId]()) {
case (nodes, target) =>
target match {
case AllTarget => return Chunk.fromIterable(allNodes.keys)
case AllTargetExceptPolicyServers => nodes ++ allNodes.collect { case (k, isPolicyServer) if (!isPolicyServer) => k }
case AllPolicyServers => nodes ++ allNodes.collect { case (k, isPolicyServer) if (isPolicyServer) => k }
case PolicyServerTarget(nodeId) =>
nodes :+ nodeId

// here, if we don't find the group, we consider it's an error in the
// target recording, but don't fail, just log it.
case GroupTarget(groupId) =>
val groupNodes = groups.getOrElse(groupId, Chunk.empty)
val filtered = {
groupNodes
}
nodes ++ filtered
targets.traverse {
case AllTarget => Left(Chunk.fromIterable(allNodes.keys))
case AllTargetExceptPolicyServers => Right(allNodes.collect { case (k, isPolicyServer) if (!isPolicyServer) => k })
case AllPolicyServers => Right(allNodes.collect { case (k, isPolicyServer) if (isPolicyServer) => k })
case PolicyServerTarget(nodeId) => Right(Set(nodeId))

case TargetIntersection(targets) =>
val nodeSets = targets.map(t => getNodeIdsChunkRec(Chunk(t), allNodes, groups))
// Compute the intersection of the sets of Nodes
val intersection = nodeSets.foldLeft(Chunk.fromIterable(allNodes.keys)) {
case (currentIntersection, nodes) => currentIntersection.intersect(nodes)
}
nodes ++ intersection

case TargetUnion(targets) =>
val nodeSets = targets.map(t => getNodeIdsChunkRec(Chunk(t), allNodes, groups))
// Compute the union of the sets of Nodes
val union = nodeSets.foldLeft(Chunk[NodeId]()) { case (currentUnion, nodes) => currentUnion.concat(nodes) }
nodes ++ union

case TargetExclusion(included, excluded) =>
// Compute the included Nodes
val includedNodes = getNodeIdsChunkRec(Chunk(included), allNodes, groups)
// Compute the excluded Nodes
val excludedNodes = getNodeIdsChunkRec(Chunk(excluded), allNodes, groups)
// Remove excluded nodes from included nodes
val result = includedNodes.filterNot(id => excludedNodes.contains(id))
nodes ++ result

// here, if we don't find the group, we consider it's an error in the
// target recording, but don't fail, just log it.
case GroupTarget(groupId) => Right(groups.getOrElse(groupId, Chunk.empty))

case TargetIntersection(targets) =>
val nodeSets = targets.map(t => getNodeIdsChunkRec(Chunk(t), allNodes, groups))
// Compute the intersection of the sets of Nodes
val intersection = nodeSets.foldLeft(Chunk.fromIterable(allNodes.keys)) {
case (currentIntersection, nodes) => currentIntersection.intersect(nodes)
}
}
Right(intersection)

case TargetUnion(targets) =>
val nodeSets = targets.map(t => getNodeIdsChunkRec(Chunk(t), allNodes, groups))
// Compute the union of the sets of Nodes
val union = nodeSets.foldLeft(Chunk[NodeId]()) { case (currentUnion, nodes) => currentUnion.concat(nodes) }
Right(union)

case TargetExclusion(included, excluded) =>
// Compute the included Nodes
val includedNodes = getNodeIdsChunkRec(Chunk(included), allNodes, groups)
// Compute the excluded Nodes
val excludedNodes = getNodeIdsChunkRec(Chunk(excluded), allNodes, groups)
// Remove excluded nodes from included nodes
val result = includedNodes.filterNot(id => excludedNodes.contains(id))
Right(result)
}.map(_.flatten).merge
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,8 @@ object NodePropertyMatcherUtils {
sealed trait ComparatorList {

def comparators: Seq[CriterionComparator]
def comparatorForString(s: String): Option[CriterionComparator] = {
for (comp <- comparators) {
if (s.equalsIgnoreCase(comp.id)) return Some(comp)
}
None
}
def comparatorForString(s: String): Option[CriterionComparator] =
comparators.find(comp => s.equalsIgnoreCase(comp.id))
}

object BaseComparators extends ComparatorList {
Expand Down Expand Up @@ -880,12 +876,8 @@ case class ObjectCriterion(val objectType: String, val criteria: Seq[Criterion])
require(criteria.nonEmpty, "You must at least have one criterion for the line")

// optionally retrieve the criterion from a "string" attribute
def criterionForName(name: String): (Option[Criterion]) = {
for (c <- criteria) {
if (name.equalsIgnoreCase(c.name)) return Some(c)
}
None
}
def criterionForName(name: String): (Option[Criterion]) =
criteria.find(c => name.equalsIgnoreCase(c.name))

def criterionComparatorForName(name: String, comparator: String): (Option[Criterion], Option[CriterionComparator]) = {
criterionForName(name) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ package com.normation.rudder.services.eventlog
import com.normation.box.*
import com.normation.rudder.domain.eventlog.InventoryEventLog
import com.normation.rudder.repository.EventLogRepository
import com.normation.utils.Control
import doobie.*
import net.liftweb.common.*

Expand All @@ -56,14 +57,10 @@ class InventoryEventLogServiceImpl(
.getEventLogByCriteria(Some(Fragment.const(" eventType in ('AcceptNode', 'RefuseNode', 'DeleteNode') ")))
.toBox match {
case Full(seq) =>
val result = scala.collection.mutable.Buffer[InventoryEventLog]()
for (log <- seq) {
log match {
case inventoryLog: InventoryEventLog => result += inventoryLog
case _ => return Failure("Wrong event log type, not an inventory")
}
Control.traverse(seq) {
case inventoryLog: InventoryEventLog => Full(inventoryLog)
case _ => Failure("Wrong event log type, not an inventory")
}
Full(result.toSeq)
case Empty => Empty
case _ => Failure("Could not retrieve eventLogs")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,38 +127,35 @@ trait DefaultStringQueryParser extends StringQueryParser {

def parseLine(line: StringCriterionLine): Box[CriterionLine] = {

val objectType = criterionObjects.getOrElse(
line.objectType,
return Failure(
s"The object type '${line.objectType}' is unknown in line 'line'. Possible object types: [${criterionObjects.keySet.toList.sorted
.mkString(",")}] ".format(line)
)
)

val criterion = objectType.criterionForName(line.attribute).getOrElse {
return Failure(
s"The attribute '${line.attribute}' is unknown for type '${line.objectType}' in line '${line}'. Possible attributes: [${objectType.criteria.map(_.name).sorted.mkString(", ")}]"
)
}

val comparator = criterion.cType.comparatorForString(line.comparator).getOrElse {
return Failure(
s"The comparator '${line.comparator}' is unknown for attribute '${line.attribute}' in line '${line}'. Possible comparators:: [${criterion.cType.comparators.map(_.id).sorted.mkString(", ")}]"
)
}
(for {
objectType <-
criterionObjects
.get(line.objectType)
.toRight(
s"The object type '${line.objectType}' is unknown in line 'line'. Possible object types: [${criterionObjects.keySet.toList.sorted
.mkString(",")}] ".format(line)
)
criterion <- objectType.criterionForName(line.attribute).toRight {
s"The attribute '${line.attribute}' is unknown for type '${line.objectType}' in line '${line}'. Possible attributes: [${objectType.criteria.map(_.name).sorted.mkString(", ")}]"
}

comparator <- criterion.cType.comparatorForString(line.comparator).toRight {
s"The comparator '${line.comparator}' is unknown for attribute '${line.attribute}' in line '${line}'. Possible comparators:: [${criterion.cType.comparators.map(_.id).sorted.mkString(", ")}]"
}

/*
* Only validate the fact that if the comparator requires a value, then a value is provided.
* Providing an error when none is required is not an error
*/
value <- line.value match {
case Some(x) => Right(x)
case None =>
if (comparator.hasValue)
Left("Missing required value for comparator '%s' in line '%s'".format(line.comparator, line))
else Right("")
}
} yield CriterionLine(objectType, criterion, comparator, value)).toBox

/*
* Only validate the fact that if the comparator require a value, then a value is provided.
* Providing an error when none is required is not an error
*/
val value = line.value match {
case Some(x) => x
case None =>
if (comparator.hasValue)
return Failure("Missing required value for comparator '%s' in line '%s'".format(line.comparator, line))
else ""
}
Full(CriterionLine(objectType, criterion, comparator, value))
}

}
Expand Down Expand Up @@ -261,14 +258,6 @@ trait JsonQueryLexer extends QueryLexer {
}

def parseCriterion(json: Any): Box[StringCriterionLine] = {
def failureMissing(param: String, line: Map[String, String]) = Failure(
"Missing expected '%s' query parameter in criterion '%s'".format(OBJECT, line)
)
def failureEmpty(param: String, line: Map[String, String]) = Failure(
"Parameter '%s' must be non empty in criterion '%s'".format(OBJECT, line)
)
def failureBadParam(param: String, line: Map[String, String], x: Any) =
Failure("Bad query format for '%s' parameter in line '%s'. Expecting a string, found '%s'".format(OBJECT, line, x))

json match {
case l: Map[?, ?] =>
Expand All @@ -277,35 +266,22 @@ trait JsonQueryLexer extends QueryLexer {
val line = l.asInstanceOf[Map[String, String]] // is map always homogenous ?
// First, parse the line. Then, try to bind name with object

// mandatory object type, attribute, comparator ; optionnal value
// object type
val objectType = line.get(OBJECT) match {
case None => return failureMissing(OBJECT, line)
case Some(x: String) => if (x.nonEmpty) x else return failureEmpty(OBJECT, line)
case Some(x) => return failureBadParam(OBJECT, line, x)
def getMandatoryAttribute(attribute: String) = {
line.get(attribute) match {
case None => Failure("Missing expected '%s' query parameter in criterion '%s'".format(attribute, line))
case Some(x) if x.nonEmpty => Full(x)
case Some(x) => Failure("Parameter '%s' must be non empty in criterion '%s'".format(attribute, line))
}
}

// attribute
val attribute = line.get(ATTRIBUTE) match {
case None => return failureMissing(ATTRIBUTE, line)
case Some(x: String) => if (x.nonEmpty) x else return failureEmpty(ATTRIBUTE, line)
case Some(x) => return failureBadParam(ATTRIBUTE, line, x)
}
def getOptionalAttribute(attribute: String) = line.get(attribute).filter(_.nonEmpty)

// comparator
val comparator = line.get(COMPARATOR) match {
case None => return failureMissing(COMPARATOR, line)
case Some(x: String) => if (x.nonEmpty) x else return failureEmpty(COMPARATOR, line)
case Some(x) => return failureBadParam(COMPARATOR, line, x)
}

// value
val value = line.get(VALUE) match {
case None => None
case Some(x: String) => if (x.nonEmpty) Some(x) else None
case Some(x) => return failureBadParam(VALUE, line, x)
}
Full(StringCriterionLine(objectType, attribute, comparator, value))
for {
objectType <- getMandatoryAttribute(OBJECT)
attribute <- getMandatoryAttribute(ATTRIBUTE)
comparator <- getMandatoryAttribute(COMPARATOR)
value = getOptionalAttribute(VALUE)
} yield StringCriterionLine(objectType, attribute, comparator, value)

case _ => Failure("Bad query format for criterion line. Expecting an (string,string), found '%s'".format(l.head))
}
Expand Down

0 comments on commit 613dcfe

Please sign in to comment.