# Initialisation

In [None]:
import $ivy.`org.apache.spark::spark-sql:3.0.1`
import $ivy.`org.apache.spark::spark-mllib:3.0.1`
import $ivy.`org.scalanlp::breeze:1.1`
import $ivy.`org.scalanlp::breeze-natives:1.1`
import $ivy.`org.postgresql:postgresql:42.2.5`
import $ivy.`org.plotly-scala::plotly-almond:0.7.6`

In [None]:
val currentDirectory = new java.io.File(".").getCanonicalPath
val path = java.nio.file.FileSystems.getDefault().getPath(s"$currentDirectory/lib/mulot_2.12-0.3.jar")
val x = ammonite.ops.Path(path)
interp.load.cp(x)

In [None]:
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import org.apache.spark.sql.functions._
import mulot.Tensor
import mulot.tensordecomposition._
import mulot.tensordecomposition.CPALS._

In [None]:
implicit val spark = {
    val MAX_MEMORY = "64g"
    SparkSession.builder()
        .config("spark.executor.memory", MAX_MEMORY)
        .config("spark.driver.memory", MAX_MEMORY)
        .appName(s"Strata")
        .master("local[40]")
        .getOrCreate()
}

spark.sparkContext.setLogLevel("ERROR")
spark.sparkContext.setCheckpointDir("Checkpoint")
import spark.implicits._

# Stratification
To compute the stratification, two functions are needed:
* One to find the best rank of the CP decomposition (according to CORCONDIA)
* One to perform the stratification: the best rank is used, the elements of each resulting rank are kept only is they are higher of the average value of their vector, and the clusters are removed from the tensor (deflation) to iterate on the method

In [None]:
case class BestDecomposition(kruskal: Kruskal, rank: Int)

/**
 * Find the best rank for this tensor, according to the core consistency (CORCONDIA).
 *
 * @param tensor the [[Tensor]] on which to perform the decomposition
 * @param norm
 * @param maxIterations
 * @param minFms
 * @param checkpoint
 * @param highRank
 * @param hintRank
 * @param spark
 * @return
 */
def findBestDecomposition(tensor: Tensor, norm: String = NORM_L1, maxIterations: Int = 25, minFms: Double = 0.99, 
                          highRank: Option[Boolean] = None, hintRank: Int = 20)
                          (implicit spark: SparkSession): BestDecomposition = {
    var currentRank = 2
    var minRank = 2
    var maxRank = Integer.MAX_VALUE
    println(s"Try rank $currentRank")
    var previousDecomposition = CPALS.compute(tensor, currentRank, norm, maxIterations, minFms, highRank, true)
    var bestDecomposition = BestDecomposition(previousDecomposition, currentRank)
    
    println(s"Try rank $hintRank")
    var currentDecomposition = CPALS.compute(tensor, hintRank, norm, maxIterations, minFms, highRank, true)
    
    if (currentDecomposition.corcondia.get >= bestDecomposition.kruskal.corcondia.get || currentDecomposition.corcondia.get >= 80) {
        bestDecomposition = BestDecomposition(currentDecomposition, hintRank)
        // The hint rank is less than the best rank: keep increasing the rank
        currentRank = hintRank
        while (!currentDecomposition.corcondia.get.isNaN &&
                (currentDecomposition.corcondia.get >= previousDecomposition.corcondia.get || currentDecomposition.corcondia.get >= 80)) {
            minRank = currentRank
            currentRank += hintRank
            previousDecomposition = currentDecomposition
            println(s"Try rank $currentRank")
            currentDecomposition = CPALS.compute(tensor, currentRank, norm, maxIterations, minFms, highRank, true)
            if (currentDecomposition.corcondia.get >= bestDecomposition.kruskal.corcondia.get) {
                bestDecomposition = BestDecomposition(currentDecomposition, currentRank)
            }
        }
        // Max rank found
        maxRank = currentRank
    } else {
        // The hint rank is more than the best rank: set the hint rank as max rank
        maxRank = hintRank
        currentRank = hintRank
    }

    // Reduce the maxRank and increase the minRank until finding the best rank
    while ((maxRank - minRank) > 1) {
        val previousRank = currentRank
        currentRank = minRank + ((maxRank - minRank) / 2)
        println(s"Try rank $currentRank")
        previousDecomposition = currentDecomposition
        currentDecomposition = CPALS.compute(tensor, currentRank, norm, maxIterations, minFms, highRank, true)
        if (currentDecomposition.corcondia.get.isNaN || currentDecomposition.corcondia.get < 0) {
            maxRank = currentRank
        } else {
            if (currentDecomposition.corcondia.get >= 80) {
                minRank = currentRank
            } else {
                if (previousDecomposition.corcondia.get.isNaN || currentDecomposition.corcondia.get >= previousDecomposition.corcondia.get) {
                    if (previousRank > currentRank) {
                        // The rank is too high
                        maxRank = currentRank
                    } else {
                        // The rank is too low
                        minRank = currentRank
                    }
                } else {
                    if (previousRank > currentRank) {
                        // The rank is too low
                        minRank = currentRank
                    } else {
                        // The rank is too high
                        maxRank = currentRank
                    }
                }
            }
        }

        if (currentDecomposition.corcondia.get >= bestDecomposition.kruskal.corcondia.get ||
            (currentDecomposition.corcondia.get >= 80 && currentRank > bestDecomposition.rank)) {
            bestDecomposition = BestDecomposition(currentDecomposition, currentRank)
        }
        println(s"Min rank: $minRank, max rank: $maxRank")
    }
    // Check if best rank is 1
    if (bestDecomposition.rank == 2 && (bestDecomposition.kruskal.corcondia.get.isNaN || bestDecomposition.kruskal.corcondia.get < 99)) {
        println("Choose rank 1")
        val rank1Decomposition = CPALS.compute(tensor, 1, norm, maxIterations, minFms, highRank, true)
        bestDecomposition = BestDecomposition(rank1Decomposition, 1)
    }

    bestDecomposition
}

In [None]:
import breeze.linalg.min
import breeze.numerics.abs
import breeze.stats.mean
import mulot.Tensor
import mulot.tensordecomposition.CPALS
import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{DoubleType, LongType, StructField, StructType}

object Strata {
    case class Community(data: Map[String, DataFrame], size: Map[String, Int])
    case class Stratum(communities: List[Community], depth: Int)

    def compute(_tensor: Tensor, nbStrata: Int = 2, staticDimensions: List[String] = List[String](),
                norm: String = CPALS.NORM_L1)
               (implicit spark: SparkSession): List[Stratum] = {
        var tensor = _tensor
        val _valueColumnName = tensor.valueColumnName
        var strata = List[Stratum]()
        var numStratum = 0

        while (numStratum < nbStrata) {
            val begin = System.currentTimeMillis()

            // Choose the best rank for the decomposition with CORCONDIA
            val decomposition = findBestDecomposition(tensor, norm = norm)
            var communities = List[Community]()
            var conditions = List[Column]()

            for (rank <- 0 until decomposition.rank) {

                val communityData = new Array[List[(Int, Double)]](tensor.order)
                val communitySize = new Array[Int](tensor.order)
                val factorVectors = decomposition.kruskal.A.map(m => abs(m.toSparseBreeze().toDense(::, rank)))
                val averageValues = for (vector <- factorVectors) yield { mean(vector) - min(vector) }
                println(averageValues.mkString(" "))
                for (i <- factorVectors.indices) {
                    if (staticDimensions.contains(tensor.dimensionsName(i))) {
                        // If the dimension is static, we get all the elements...
                        communityData(i) = factorVectors(i).mapPairs((i, v) => (i, v)).toArray.toList
                    } else {
                        // ...if not, we keep only the elements with a value above the threshold
                        communityData(i) = factorVectors(i).findAll(_ >= averageValues(i)).map(index => (index, factorVectors(i)(index))).toList
                    }
                    
                    communitySize(i) = communityData(i).length
                }

                // Add the community to the other communities
                val communityDf = {
                    (for (i <- tensor.dimensionsName.indices) yield {
                        var df: DataFrame = spark.createDataFrame(
                            spark.sparkContext.parallelize(communityData(i).map(e => Row(e._1.toLong, e._2))),
                            StructType(Array(StructField("dimIndex", LongType, nullable = true),
                                StructField("value", DoubleType, nullable = true)))
                        )

                        if (tensor.dimensionsIndex.isDefined) {
                            df = df.join(tensor.dimensionsIndex.get(i), "dimIndex").drop("dimIndex")
                            df = df.withColumnRenamed("dimValue", tensor.dimensionsName(i))
                        } else {
                            df = df.withColumnRenamed("dimIndex", tensor.dimensionsName(i))
                        }
                        tensor.dimensionsName(i) -> df
                    }).toMap
                }

                communities :+= Community(communityDf, Map[String, Int]((for (i <- communitySize.indices) yield tensor.dimensionsName(i) -> communitySize(i)):_*))

                // Remove the community from the tensor
                var conditionIndex = 0
                while (staticDimensions.contains(tensor.dimensionsName(conditionIndex))) {
                    conditionIndex += 1
                }
                var condition = col(s"row_$conditionIndex").isInCollection(communityData(conditionIndex).map(v => v._1))
                conditionIndex += 1
                for (i <- conditionIndex until communityData.length if !staticDimensions.contains(tensor.dimensionsName(i))) {
                    condition = condition and col(s"row_$i").isInCollection(communityData(i).map(v => v._1))
                }
                conditions :+= condition
            }

            for (c <- communities) {
                println(s"Community found: ${for (i <- c.size.keys) yield s"$i: ${c.size(i)}"}")
            }

            // Add the stratum of communities
            strata :+= Stratum(communities, numStratum)
            numStratum += 1

            // Deflate the tensor
            var tensorData = tensor.data.filter(!conditions(0))
            for (condition <- conditions.tail) {
                tensorData = tensorData.filter(!condition)
            }

            // Keep the same dimensions name
            for (i <- tensor.dimensionsName.indices) {
                if (tensor.dimensionsIndex.isDefined) {
                    tensorData = tensorData.withColumnRenamed(s"row_$i", "dimIndex").join(tensor.dimensionsIndex.get(i), "dimIndex").drop("dimIndex")
                    tensorData = tensorData.withColumnRenamed("dimValue", tensor.dimensionsName(i))
                } else {
                    tensorData = tensorData.withColumnRenamed(s"row_$i", tensor.dimensionsName(i))
                }
            }
            tensor = Tensor(tensorData.localCheckpoint().cache(), _valueColumnName)

            println(s"Stratum found in ${(System.currentTimeMillis() - begin).toDouble / 1000.0}s")
            }
        strata
    }

}

# Application of the stratification
The tensor User-Hashtag-Time is loaded from the CSV file, and the stratification method is used on it.

In [None]:
val dfUserHashtagTime = spark.read.options(Map("inferSchema"->"true","header"->"true")).csv("UHT.csv")
dfUserHashtagTime.count()

In [None]:
val tensorUserHashtagTime = Tensor(dfUserHashtagTime)
tensorUserHashtagTime.dimensionsSize.foreach(println)

In [None]:
val communitiesUHT = Strata.compute(tensorUserHashtagTime, 3, staticDimensions = List[String]("time"))

# Visualisation

In [None]:
import plotly._
import plotly.element._
import plotly.layout._
import plotly.Almond._
import org.apache.spark.sql.functions.{abs => sparkAbs}

def plot3D(df: Map[String, DataFrame], rank: Int, d1: String, d2: String, 
           time: String = "time", nbDays: Int = 1, lambda: Double = 0.0, valueColumn: String = "value") = {
    var plot = Seq(
        Scatter(
            df(time)
                .sort(col(time)).select(to_date(from_unixtime((col(time) * 3600 * 24 * nbDays)))).collect.map(_.get(0).toString).toSeq, 
            df(time)
                .sort(col(time)).select(sparkAbs(col(valueColumn))).collect.map(_.getDouble(0)).toSeq,
            name = "Time",
            xaxis = AxisReference.X1,
            yaxis = AxisReference.Y1
        ),
        Bar(
            df(d1)
                .sort(sparkAbs(col(valueColumn)).desc).limit(20).select(d1).collect.map("u" + _.getInt(0)).toSeq,
            df(d1)
                .sort(sparkAbs(col(valueColumn)).desc).limit(20).select(sparkAbs(col(valueColumn))).collect.map(_.getDouble(0)).toSeq,
            name = d1.capitalize,
            xaxis = AxisReference.X2,
            yaxis = AxisReference.Y2
        ),
        Bar(
            df(d2)
                .sort(sparkAbs(col(valueColumn)).desc).limit(20).select(d2).collect.map(_.getString(0)).toSeq,
            df(d2)
                .sort(sparkAbs(col(valueColumn)).desc).limit(20).select(sparkAbs(col(valueColumn))).collect.map(_.getDouble(0)).toSeq,
            name = d2.capitalize,
            xaxis = AxisReference.X3,
            yaxis = AxisReference.Y3
        )
    )
    
    val lambdaText = if (lambda > 0.0) s"lambda = $lambda" else ""
    
    val layout = Layout(
        title = s"Rank $rank $lambdaText",
        width = 1000,
        xaxis1 = Axis(anchor = AxisAnchor.Reference(AxisReference.Y1), domain = (0.0, 1.0), automargin = true),
        xaxis2 = Axis(anchor = AxisAnchor.Reference(AxisReference.Y2), domain = (0.0, 0.49), automargin = true),
        xaxis3 = Axis(anchor = AxisAnchor.Reference(AxisReference.Y3), domain = (0.51, 1.0), automargin = true),
        yaxis1 = Axis(anchor = AxisAnchor.Reference(AxisReference.X1), domain = (0.55, 1.0), automargin = true),
        yaxis2 = Axis(anchor = AxisAnchor.Reference(AxisReference.X2), domain = (0.0, 0.45), automargin = true),
        yaxis3 = Axis(anchor = AxisAnchor.Reference(AxisReference.X3), domain = (0.0, 0.45), automargin = true),
        legend = Legend(y = 1.1, x = .5, yanchor = Anchor.Top, xanchor = Anchor.Center, orientation = Orientation.Horizontal)
    )
    
    plot.plot(layout = layout, Config(), "")
}

In [None]:
val stratum = 0
for (rank <- communitiesUHT(stratum).communities.indices) {
    plot3D(communitiesUHT(stratum).communities(rank).data, rank, "user", "hashtag")
}