Skip to content
This repository has been archived by the owner on Apr 21, 2023. It is now read-only.

Spot-128 Make every pipeline to return results to main program #47

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 31 additions & 7 deletions spot-ml/src/main/scala/org/apache/spot/SuspiciousConnects.scala
Expand Up @@ -18,13 +18,14 @@
package org.apache.spot

import org.apache.log4j.{Level, LogManager, Logger}
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spot.SuspiciousConnectsArgumentParser.SuspiciousConnectsConfig
import org.apache.spot.dns.DNSSuspiciousConnectsAnalysis
import org.apache.spot.netflow.FlowSuspiciousConnectsAnalysis
import org.apache.spot.proxy.ProxySuspiciousConnectsAnalysis
import org.apache.spot.utilities.data.InputOutputDataHandler
import org.apache.spot.utilities.data.validation.InvalidDataHandler


/**
Expand Down Expand Up @@ -68,14 +69,30 @@ object SuspiciousConnects {
System.exit(0)
}

analysis match {
case "flow" => FlowSuspiciousConnectsAnalysis.run(config, sparkContext, sqlContext, logger, inputDataFrame)
case "dns" => DNSSuspiciousConnectsAnalysis.run(config, sparkContext, sqlContext, logger, inputDataFrame)
case "proxy" => ProxySuspiciousConnectsAnalysis.run(config, sparkContext, sqlContext, logger, inputDataFrame)
case _ => logger.error("Unsupported (or misspelled) analysis: " + analysis)
val results: Option[SuspiciousConnectsAnalysisResults] = analysis match {
case "flow" => Some(FlowSuspiciousConnectsAnalysis.run(config, sparkContext, sqlContext, logger,
inputDataFrame))
case "dns" => Some(DNSSuspiciousConnectsAnalysis.run(config, sparkContext, sqlContext, logger,
inputDataFrame))
case "proxy" => Some(ProxySuspiciousConnectsAnalysis.run(config, sparkContext, sqlContext, logger,
inputDataFrame))
case _ => None
}

InputOutputDataHandler.mergeResultsFiles(sparkContext, config.hdfsScoredConnect, analysis, logger)
results match {
case Some(SuspiciousConnectsAnalysisResults(resultRecords, invalidRecords)) => {

logger.info(s"$analysis suspicious connects analysis completed.")
logger.info("Saving results to : " + config.hdfsScoredConnect)
resultRecords.map(_.mkString(config.outputDelimiter)).saveAsTextFile(config.hdfsScoredConnect)

InputOutputDataHandler.mergeResultsFiles(sparkContext, config.hdfsScoredConnect, analysis, logger)

InvalidDataHandler.showAndSaveInvalidRecords(invalidRecords, config.hdfsScoredConnect, logger)
}

case None => logger.error("Unsupported (or misspelled) analysis: " + analysis)
}

sparkContext.stop()

Expand All @@ -85,5 +102,12 @@ object SuspiciousConnects {
System.exit(0)
}

/**
*
* @param suspiciousConnects
* @param invalidRecords
*/
case class SuspiciousConnectsAnalysisResults(val suspiciousConnects: DataFrame, val invalidRecords: DataFrame)


}
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spot.SuspiciousConnects.SuspiciousConnectsAnalysisResults
import org.apache.spot.SuspiciousConnectsArgumentParser.SuspiciousConnectsConfig
import org.apache.spot.dns.DNSSchema._
import org.apache.spot.dns.model.DNSSuspiciousConnectsModel
Expand All @@ -48,58 +49,38 @@ object DNSSuspiciousConnectsAnalysis {
* @param logger
*/
def run(config: SuspiciousConnectsConfig, sparkContext: SparkContext, sqlContext: SQLContext, logger: Logger,
inputDNSRecords: DataFrame) = {
inputDNSRecords: DataFrame): SuspiciousConnectsAnalysisResults = {


logger.info("Starting DNS suspicious connects analysis.")

val cleanDNSRecords = filterAndSelectCleanDNSRecords(inputDNSRecords)
val dnsRecords = filterRecords(inputDNSRecords)
.select(InSchema: _*)
.na.fill(DefaultQueryClass, Seq(QueryClass))
.na.fill(DefaultQueryType, Seq(QueryType))
.na.fill(DefaultQueryResponseCode, Seq(QueryResponseCode))

val scoredDNSRecords = scoreDNSRecords(cleanDNSRecords, config, sparkContext, sqlContext, logger)
logger.info("Fitting probabilistic model to data")
val model =
DNSSuspiciousConnectsModel.trainModel(sparkContext, sqlContext, logger, config, dnsRecords)

val filteredDNSRecords = filterScoredDNSRecords(scoredDNSRecords, config.threshold)
logger.info("Identifying outliers")
val scoredDNSRecords = model.score(sparkContext, sqlContext, dnsRecords, config.userDomain)

val orderedDNSRecords = filteredDNSRecords.orderBy(Score)
val filteredScored = filterScoredRecords(scoredDNSRecords, config.threshold)

val mostSuspiciousDNSRecords = if(config.maxResults > 0) orderedDNSRecords.limit(config.maxResults) else orderedDNSRecords
val orderedDNSRecords = filteredScored.orderBy(Score)

val outputDNSRecords = mostSuspiciousDNSRecords.select(OutSchema:_*).sort(Score)
val mostSuspiciousDNSRecords =
if (config.maxResults > 0) orderedDNSRecords.limit(config.maxResults)
else orderedDNSRecords

logger.info("DNS suspicious connects analysis completed.")
logger.info("Saving results to : " + config.hdfsScoredConnect)
val outputDNSRecords = mostSuspiciousDNSRecords.select(OutSchema: _*)

outputDNSRecords.map(_.mkString(config.outputDelimiter)).saveAsTextFile(config.hdfsScoredConnect)
val invalidDNSRecords = filterInvalidRecords(inputDNSRecords).select(InSchema: _*)

val invalidDNSRecords = filterAndSelectInvalidDNSRecords(inputDNSRecords)
dataValidation.showAndSaveInvalidRecords(invalidDNSRecords, config.hdfsScoredConnect, logger)
SuspiciousConnectsAnalysisResults(outputDNSRecords, invalidDNSRecords)

val corruptDNSRecords = filterAndSelectCorruptDNSRecords(scoredDNSRecords)
dataValidation.showAndSaveCorruptRecords(corruptDNSRecords, config.hdfsScoredConnect, logger)
}


/**
* Identify anomalous DNS log entries in in the provided data frame.
*
* @param data Data frame of DNS entries
* @param config
* @param sparkContext
* @param sqlContext
* @param logger
* @return
*/

def scoreDNSRecords(data: DataFrame, config: SuspiciousConnectsConfig,
sparkContext: SparkContext,
sqlContext: SQLContext,
logger: Logger) : DataFrame = {

logger.info("Fitting probabilistic model to data")
val model =
DNSSuspiciousConnectsModel.trainNewModel(sparkContext, sqlContext, logger, config, data, config.topicCount)

logger.info("Identifying outliers")
model.score(sparkContext, sqlContext, data, config.userDomain)
}


Expand All @@ -108,13 +89,13 @@ object DNSSuspiciousConnectsAnalysis {
* @param inputDNSRecords raw DNS records.
* @return
*/
def filterAndSelectCleanDNSRecords(inputDNSRecords: DataFrame): DataFrame ={
def filterRecords(inputDNSRecords: DataFrame): DataFrame = {

val cleanDNSRecordsFilter = inputDNSRecords(Timestamp).isNotNull &&
inputDNSRecords(Timestamp).notEqual("") &&
inputDNSRecords(Timestamp).notEqual("-") &&
inputDNSRecords(UnixTimestamp).isNotNull &&
inputDNSRecords(FrameLength).isNotNull &&
inputDNSRecords(UnixTimestamp).geq(0) &&
inputDNSRecords(FrameLength).geq(0) &&
inputDNSRecords(QueryName).isNotNull &&
inputDNSRecords(QueryName).notEqual("") &&
inputDNSRecords(QueryName).notEqual("-") &&
Expand All @@ -126,14 +107,10 @@ object DNSSuspiciousConnectsAnalysis {
inputDNSRecords(QueryClass).notEqual("") &&
inputDNSRecords(QueryClass).notEqual("-")) ||
inputDNSRecords(QueryType).isNotNull ||
inputDNSRecords(QueryResponseCode).isNotNull)
inputDNSRecords(QueryResponseCode).geq(0))

inputDNSRecords
.filter(cleanDNSRecordsFilter)
.select(InSchema: _*)
.na.fill(DefaultQueryClass, Seq(QueryClass))
.na.fill(DefaultQueryType, Seq(QueryType))
.na.fill(DefaultQueryResponseCode, Seq(QueryResponseCode))
}


Expand All @@ -142,7 +119,7 @@ object DNSSuspiciousConnectsAnalysis {
* @param inputDNSRecords raw DNS records.
* @return
*/
def filterAndSelectInvalidDNSRecords(inputDNSRecords: DataFrame): DataFrame ={
def filterInvalidRecords(inputDNSRecords: DataFrame): DataFrame = {

val invalidDNSRecordsFilter = inputDNSRecords(Timestamp).isNull ||
inputDNSRecords(Timestamp).equalTo("") ||
Expand All @@ -164,7 +141,6 @@ object DNSSuspiciousConnectsAnalysis {

inputDNSRecords
.filter(invalidDNSRecordsFilter)
.select(InSchema: _*)
}


Expand All @@ -174,7 +150,7 @@ object DNSSuspiciousConnectsAnalysis {
* @param threshold score tolerance.
* @return
*/
def filterScoredDNSRecords(scoredDNSRecords: DataFrame, threshold: Double): DataFrame ={
def filterScoredRecords(scoredDNSRecords: DataFrame, threshold: Double): DataFrame = {


val filteredDNSRecordsFilter = scoredDNSRecords(Score).leq(threshold) &&
Expand All @@ -183,22 +159,6 @@ object DNSSuspiciousConnectsAnalysis {
scoredDNSRecords.filter(filteredDNSRecordsFilter)
}

/**
*
* @param scoredDNSRecords scored DNS records.
* @return
*/
def filterAndSelectCorruptDNSRecords(scoredDNSRecords: DataFrame): DataFrame = {

val corruptDNSRecordsFilter = scoredDNSRecords(Score).equalTo(dataValidation.ScoreError)

scoredDNSRecords
.filter(corruptDNSRecordsFilter)
.select(OutSchema: _*)

}


val DefaultQueryClass = "unknown"
val DefaultQueryType = -1
val DefaultQueryResponseCode = -1
Expand Down
Expand Up @@ -158,15 +158,13 @@ object DNSSuspiciousConnectsModel {
* @param config Analysis configuration object containing CLI parameters.
* Contains the path to the feedback file in config.scoresFile
* @param inputRecords Data used to train the model.
* @param topicCount Number of topics (traffic profiles) used to build the model.
* @return A new [[DNSSuspiciousConnectsModel]] instance trained on the dataframe and feedback file.
*/
def trainNewModel(sparkContext: SparkContext,
sqlContext: SQLContext,
logger: Logger,
config: SuspiciousConnectsConfig,
inputRecords: DataFrame,
topicCount: Int): DNSSuspiciousConnectsModel = {
def trainModel(sparkContext: SparkContext,
sqlContext: SQLContext,
logger: Logger,
config: SuspiciousConnectsConfig,
inputRecords: DataFrame): DNSSuspiciousConnectsModel = {

logger.info("Training DNS suspicious connects model from " + config.inputPath)

Expand Down Expand Up @@ -297,7 +295,7 @@ object DNSSuspiciousConnectsModel {
.toMap


new DNSSuspiciousConnectsModel(topicCount,
new DNSSuspiciousConnectsModel(config.topicCount,
ipToTopicMix,
wordToPerTopicProb,
timeCuts,
Expand Down