Skip to content

Commit

Permalink
#1264 Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
To-om committed Mar 11, 2021
1 parent 68764f5 commit 6066a58
Show file tree
Hide file tree
Showing 11 changed files with 168 additions and 126 deletions.
57 changes: 27 additions & 30 deletions thehive/app/org/thp/thehive/services/CaseSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -336,52 +336,50 @@ class CaseSrv @Inject() (
cases.flatMap(_.tags).distinct
)

val allProfilesOrgas = get(cases.head)
val allProfilesOrgas: Seq[(Profile with Entity, Organisation with Entity)] = get(cases.head)
.shares
.project(_.by(_.profile).by(_.organisation))
.toSeq

for {
user <- userSrv.get(EntityIdOrName(authContext.userId)).getOrFail("User")
orga <- organisationSrv.current.getOrFail("Organisation")
richCase <- create(mergedCase, Some(user), orga, Seq(), None, Seq())
user <- userSrv.current.getOrFail("User")
currentOrga <- organisationSrv.current.getOrFail("Organisation")
richCase <- create(mergedCase, Some(user), currentOrga, Seq(), None, Seq())
// Share case with all organisations except the one who created the merged case
_ <-
allProfilesOrgas
.filterNot(_._2._id == currentOrga._id)
.toTry(profileOrg => shareSrv.shareCase(owner = false, richCase.`case`, profileOrg._2, profileOrg._1))
_ <- cases.toTry { c =>
for {
// Share case with all organisations except the one who created the merged case
_ <-
allProfilesOrgas
.filter(_._2._id != organisationSrv.currentId)
.toTry(profileOrg => shareSrv.shareCase(owner = false, richCase.`case`, profileOrg._2, profileOrg._1))

_ <- shareMergedCaseTasks(allProfilesOrgas.map(_._2), c, richCase.`case`)
_ <- shareMergedCaseObservables(allProfilesOrgas.map(_._2), c, richCase.`case`)
_ <-
get(c)
.alert
.toList
.toSeq
.toTry(alertSrv.alertCaseSrv.create(AlertCase(), _, richCase.`case`))
_ <-
get(c)
.procedure
.toList
.toSeq
.toTry(caseProcedureSrv.create(CaseProcedure(), richCase.`case`, _))
_ <-
get(c)
.richCustomFields
.toList
.toSeq
.toTry(c => createCustomField(richCase.`case`, EntityIdOrName(c.customField.name), c.value, c.order))
} yield Success(())
}
_ = cases.map(remove(_))
_ <- cases.toTry(remove(_))
} yield richCase
} else
Failure(BadRequestError("To be able to merge, cases must have same organisation / profile pair and user must be org-admin"))

private def canMerge(cases: Seq[Case with Entity])(implicit graph: Graph, authContext: AuthContext): Boolean = {
val allOrgProfiles = getByIds(cases.map(_._id): _*)
.shares
.project(_.by(_.profile.value(_.name)).by(_.organisation._id))
.fold
.flatMap(_.shares.project(_.by(_.profile.value(_.name)).by(_.organisation._id)).fold)
.toSeq
.map(_.toSet)
.distinct
Expand All @@ -390,41 +388,40 @@ class CaseSrv @Inject() (
// case organisation must match current organisation and be of org-admin profile
allOrgProfiles.size == 1 && allOrgProfiles
.head
.find(_._2 == organisationSrv.currentId)
.map(_._1)
.contains(Profile.orgAdmin.name)
.exists {
case (profile, orgId) => orgId == organisationSrv.currentId && profile == Profile.orgAdmin.name
}
}

private def shareMergedCaseTasks(orgs: Seq[Organisation with Entity], fromCase: Case with Entity, mergedCase: Case with Entity)(implicit
graph: Graph,
authContext: AuthContext
): Try[Unit] =
for {
_ <- orgs.toTry(org =>
orgs
.toTry { org =>
get(fromCase)
.share(org._id)
.tasks
.richTask
.toList
.toSeq
.toTry(shareSrv.shareTask(_, mergedCase, org._id))
)
} yield Success()
}
.map(_ => ())

private def shareMergedCaseObservables(orgs: Seq[Organisation with Entity], fromCase: Case with Entity, mergedCase: Case with Entity)(implicit
graph: Graph,
authContext: AuthContext
): Try[Unit] =
for {
_ <- orgs.toTry(org =>
orgs
.toTry { org =>
get(fromCase)
.share(org._id)
.observables
.richObservable
.toList
.toSeq
.toTry(shareSrv.shareObservable(_, mergedCase, org._id))
)
} yield Success()

}
.map(_ => ())
}

object CaseOps {
Expand Down
171 changes: 84 additions & 87 deletions thehive/test/org/thp/thehive/services/CaseSrvTest.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package org.thp.thehive.services

import org.specs2.matcher.Matcher
import org.thp.scalligraph.EntityName
import org.thp.scalligraph.auth.AuthContext
import org.thp.scalligraph.controllers.FPathElem
import org.thp.scalligraph.models._
import org.thp.scalligraph.query.PropertyUpdater
import org.thp.scalligraph.traversal.{Graph, Traversal}
import org.thp.scalligraph.traversal.TraversalOps._
import org.thp.scalligraph.{BadRequestError, EntityName}
import org.thp.thehive.TestAppBuilder
import org.thp.thehive.models._
import org.thp.thehive.services.CaseOps._
import org.thp.thehive.services.ShareOps._
import play.api.libs.json.Json
import play.api.test.PlaySpecification

Expand Down Expand Up @@ -246,41 +248,41 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder {
}

"add an observable if not existing" in testApp { app =>
// app[Database].roTransaction { implicit graph =>
// val c1 = app[CaseSrv].get(EntityName("1")).getOrFail("Case").get
// val observables = app[ObservableSrv].startTraversal.richObservable.toList
//
// observables must not(beEmpty)
//
// val hfr = observables.find(_.message.contains("Some weird domain")).get
//
// app[Database].tryTransaction { implicit graph =>
//// app[CaseSrv].addObservable(c1, hfr)
// app[CaseSrv].createObservable(c1, hfr, hfr.data.get)
// }.get must throwA[CreateError]
//
// val newObs = app[Database].tryTransaction { implicit graph =>
// val organisation = app[OrganisationSrv].current.getOrFail("Organisation").get
// app[ObservableSrv].create(
// Observable(
// message = Some("if you feel lost"),
// tlp = 1,
// ioc = false,
// sighted = true,
// ignoreSimilarity = None,
// dataType = "domain",
// tags = Nil,
// organisationIds = Seq(organisation._id),
// relatedId = c1._id
// ),
// "lost.com"
// )
// }.get
//
// app[Database].tryTransaction { implicit graph =>
// app[CaseSrv].addObservable(c1, newObs)
// } must beSuccessfulTry
// }
// app[Database].roTransaction { implicit graph =>
// val c1 = app[CaseSrv].get(EntityName("1")).getOrFail("Case").get
// val observables = app[ObservableSrv].startTraversal.richObservable.toList
//
// observables must not(beEmpty)
//
// val hfr = observables.find(_.message.contains("Some weird domain")).get
//
// app[Database].tryTransaction { implicit graph =>
//// app[CaseSrv].addObservable(c1, hfr)
// app[CaseSrv].createObservable(c1, hfr, hfr.data.get)
// }.get must throwA[CreateError]
//
// val newObs = app[Database].tryTransaction { implicit graph =>
// val organisation = app[OrganisationSrv].current.getOrFail("Organisation").get
// app[ObservableSrv].create(
// Observable(
// message = Some("if you feel lost"),
// tlp = 1,
// ioc = false,
// sighted = true,
// ignoreSimilarity = None,
// dataType = "domain",
// tags = Nil,
// organisationIds = Seq(organisation._id),
// relatedId = c1._id
// ),
// "lost.com"
// )
// }.get
//
// app[Database].tryTransaction { implicit graph =>
// app[CaseSrv].addObservable(c1, newObs)
// } must beSuccessfulTry
// }
pending
}

Expand Down Expand Up @@ -442,24 +444,26 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder {
}

"show linked cases" in testApp { app =>
// app[Database].roTransaction { implicit graph =>
// app[CaseSrv].get(EntityName("1")).linkedCases must beEmpty
// val observables = app[ObservableSrv].startTraversal.richObservable.toList
// val hfr = observables.find(_.message.contains("Some weird domain")).get
//
// app[Database].tryTransaction { implicit graph =>
// app[CaseSrv].addObservable(app[CaseSrv].get(EntityName("2")).getOrFail("Case").get, hfr)
// }
//
// app[Database].roTransaction(implicit graph => app[CaseSrv].get(EntityName("1")).linkedCases must not(beEmpty))
// }
// app[Database].roTransaction { implicit graph =>
// app[CaseSrv].get(EntityName("1")).linkedCases must beEmpty
// val observables = app[ObservableSrv].startTraversal.richObservable.toList
// val hfr = observables.find(_.message.contains("Some weird domain")).get
//
// app[Database].tryTransaction { implicit graph =>
// app[CaseSrv].addObservable(app[CaseSrv].get(EntityName("2")).getOrFail("Case").get, hfr)
// }
//
// app[Database].roTransaction(implicit graph => app[CaseSrv].get(EntityName("1")).linkedCases must not(beEmpty))
// }
pending
}

"merge cases, happy path with one organisation" in testApp { app =>
app[Database].tryTransaction { implicit graph =>
def case21 = app[CaseSrv].get(EntityName("21")).clone()

def case22 = app[CaseSrv].get(EntityName("22")).clone()

def case23 = app[CaseSrv].get(EntityName("23")).clone()
// Procedures
case21.procedure.toSeq.size mustEqual 1
Expand Down Expand Up @@ -491,6 +495,7 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder {
} must beASuccessfulTry.which { richCase =>
app[Database].roTransaction { implicit graph =>
def mergedCase = app[CaseSrv].get(EntityName(richCase.number.toString)).clone()

mergedCase.procedure.toSeq.size mustEqual 3
mergedCase.customFields.toSeq.size mustEqual 2
mergedCase.tasks.toSeq.size mustEqual 3
Expand All @@ -504,58 +509,50 @@ class CaseSrvTest extends PlaySpecification with TestAppBuilder {
}
}

"refuse to merge cases with different shares" in testApp { app =>
app[Database].tryTransaction { implicit graph =>
val case21 = app[CaseSrv].getOrFail(EntityName("21")).get
val case24 = app[CaseSrv].getOrFail(EntityName("24")).get
val case26 = app[CaseSrv].getOrFail(EntityName("26")).get
app[CaseSrv].merge(Seq(case21, case24, case26))
} must beFailedTry.withThrowable[BadRequestError]
}

"merge cases, happy path with three organisations" in testApp { app =>
implicit val authContext: AuthContext =
DummyUserSrv(organisation = "soc", permissions = Profile.analyst.permissions).authContext

def getCase(number: Int)(implicit graph: Graph): Traversal.V[Case] = app[CaseSrv].getByName(number.toString)

app[Database].tryTransaction { implicit graph =>
def case21 = app[CaseSrv].get(EntityName("21")).clone()
def case24 = app[CaseSrv].get(EntityName("24")).clone()
def case26 = app[CaseSrv].get(EntityName("26")).clone()
// Tasks
case21.tasks.toSeq.size mustEqual 2
case24.tasks.toSeq.size mustEqual 0
case26.tasks.toSeq.size mustEqual 0
getCase(24).share(EntityName("cert")).tasks.getCount mustEqual 1
getCase(24).share(EntityName("soc")).tasks.getCount mustEqual 2
getCase(25).share(EntityName("cert")).tasks.getCount mustEqual 0
getCase(25).share(EntityName("soc")).tasks.getCount mustEqual 0

// Observables
case21.observables.toSeq.size mustEqual 1
case24.observables.toSeq.size mustEqual 0
case26.observables.toSeq.size mustEqual 0
getCase(24).share(EntityName("cert")).observables.getCount mustEqual 0
getCase(24).share(EntityName("soc")).observables.getCount mustEqual 0
getCase(25).share(EntityName("cert")).observables.getCount mustEqual 2
getCase(25).share(EntityName("soc")).observables.getCount mustEqual 1

for {
c21 <- case21.getOrFail("Case")
c24 <- case24.getOrFail("Case")
c26 <- case26.getOrFail("Case")
newCase <- app[CaseSrv].merge(Seq(c21, c24, c26))
c24 <- getCase(24).getOrFail("Case")
c25 <- getCase(25).getOrFail("Case")
newCase <- app[CaseSrv].merge(Seq(c24, c25))
} yield newCase
} must beASuccessfulTry.which { richCase =>
app[Database].roTransaction { implicit graph =>
def mergedCase = app[CaseSrv].get(EntityName(richCase.number.toString)).clone()
mergedCase.tasks.toSeq.size mustEqual 2
mergedCase.observables.toSeq.size mustEqual 1

app[CaseSrv].get(EntityName("21")).getOrFail("Case") must beAFailedTry
app[CaseSrv].get(EntityName("24")).getOrFail("Case") must beAFailedTry
app[CaseSrv].get(EntityName("26")).getOrFail("Case") must beAFailedTry
}
getCase(richCase.number).share(EntityName("cert")).tasks.getCount mustEqual 1
getCase(richCase.number).share(EntityName("soc")).tasks.getCount mustEqual 2
getCase(richCase.number).share(EntityName("cert")).observables.getCount mustEqual 2
getCase(richCase.number).share(EntityName("soc")).observables.getCount mustEqual 1

app[Database].roTransaction { implicit graph =>
implicit val authContext: AuthContext =
DummyUserSrv(userId = "socuser@thehive.local", organisation = "soc", permissions = Profile.analyst.permissions).authContext

def mergedCase = app[CaseSrv].get(EntityName(richCase.number.toString)).clone()
mergedCase.getOrFail("Case") must beASuccessfulTry
mergedCase.tasks.toSeq.size mustEqual 1
mergedCase.observables.toSeq.size mustEqual 1
}

app[Database].roTransaction { implicit graph =>
implicit val authContext: AuthContext =
DummyUserSrv(userId = "puguser@thehive.local", organisation = "pug", permissions = Profile.analyst.permissions).authContext

def mergedCase = app[CaseSrv].get(EntityName(richCase.number.toString)).clone()
mergedCase.getOrFail("Case") must beASuccessfulTry
mergedCase.tasks.toSeq.size mustEqual 0
mergedCase.observables.toSeq.size mustEqual 0
getCase(24).getOrFail("Case") must beAFailedTry
getCase(25).getOrFail("Case") must beAFailedTry
}
}
}

}
}
20 changes: 20 additions & 0 deletions thehive/test/resources/data/Observable.json
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,25 @@
"data": "127.0.0.1",
"sighted": true,
"relatedId": ""
},
{
"id": "mergeObs251",
"message": "merge Obs 251",
"tlp": 4,
"ioc": true,
"dataType": "ip",
"data": "127.0.0.1",
"sighted": true,
"relatedId": ""
},
{
"id": "mergeObs252",
"message": "merge Obs 252",
"tlp": 4,
"ioc": true,
"dataType": "ip",
"data": "127.0.0.1",
"sighted": true,
"relatedId": ""
}
]
1 change: 0 additions & 1 deletion thehive/test/resources/data/ObservableType.json

This file was deleted.

3 changes: 2 additions & 1 deletion thehive/test/resources/data/OrganisationShare.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
{"from": "cert", "to": "case24-merge-cert"},
{"from": "soc", "to": "case24-merge-soc"},
{"from": "cert", "to": "case25-merge-cert"},
{"from": "soc", "to": "case25-merge-soc"},
{"from": "pug", "to": "case26-merge-pug"}
]
]
Loading

0 comments on commit 6066a58

Please sign in to comment.