# Initialisation

In [None]:
import $ivy.`org.apache.spark::spark-sql:3.0.1`
import $ivy.`org.apache.spark::spark-mllib:3.0.1`
import $ivy.`org.plotly-scala::plotly-almond:0.7.6`

In [None]:
val currentDirectory = new java.io.File(".").getCanonicalPath
val path = java.nio.file.FileSystems.getDefault().getPath(s"$currentDirectory/lib/TDM-assembly-0.3.0.jar")
val x = ammonite.ops.Path(path)
interp.load.cp(x)

In [None]:
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import org.apache.spark.sql.functions._
import tdm._
import tdm.core._
import tdm.core.decomposition.Norm

In [None]:
implicit val spark = {
    SparkSession.builder()
        .appName("PrimarySchoolTDM")
        .master("local[*]")
        .getOrCreate()
}

spark.sparkContext.setLogLevel("ERROR")
import spark.implicits._

# Load data

In [None]:
var df = spark.read.format("csv").option("header","false")
        .option("sep", "\t").load("datasets/primaryschool.csv")
        .toDF("time", "student1", "student2", "class1", "class2")
        .withColumn("time", floor(col("time") / (3600 / 12)).cast("integer"))
        .withColumn("student1", floor(col("student1") - 1425))
        .withColumn("student2", floor(col("student2") - 1425))
        .groupBy("time", "student1", "student2", "class1", "class2").count
        .withColumnRenamed("count", "val")
        .withColumn("val", lit(1))
df = df.union(
    df.withColumnRenamed("student1", "studenttmp")
        .withColumnRenamed("student2", "student1")
        .withColumnRenamed("studenttmp", "student2")
        .withColumnRenamed("class1", "classtmp")
        .withColumnRenamed("class2", "class1")
        .withColumnRenamed("classtmp", "class2")
        .select("time", "student1", "student2", "class1", "class2", "val")
).dropDuplicates

In [None]:
df = df.withColumn("student1", concat(col("class1"), lit("-"), col("student1")))
        .withColumn("student2", concat(col("class2"), lit("-"), col("student2")))
        .drop("class1")
        .drop("class2")

# Create tensor

In [None]:
object Student1 extends TensorDimension[String]
object Student2 extends TensorDimension[String]
object Time extends TensorDimension[Int]

In [None]:
val tensor = TensorBuilder[Int](df)
                .addDimension(Student1, "student1")
                .addDimension(Student2, "student2")
                .addDimension(Time, "time")
                .build("val")

# Decomposition

In [None]:
val kruskal = tensor.canonicalPolyadicDecomposition(13, norm = Norm.L2, computeCorcondia = true, minFms = 0.999)

In [None]:
val student1Tensor = kruskal.extract(Student1)
val student2Tensor = kruskal.extract(Student2)
val timeTensor = kruskal.extract(Time)

# Visualisation

In [None]:
import plotly._
import plotly.element._
import plotly.layout._
import plotly.Almond._

init(offline=true)

In [None]:
val students = {
    val collectedStudents = student1Tensor.projection(Rank)(0).collect
    (for (i <- collectedStudents.indices) yield {
        collectedStudents(Student1, i)
    }).toList.sorted.zipWithIndex
}
val studentsIdToName = students.map(v => v._2 -> v._1).toMap
val studentsNameToId = students.map(v => v._1 -> v._2).toMap

In [None]:
def orderClasses(classesToOrder: Array[Array[Double]]): Array[Array[Double]] = {
    val levels = List("1A", "1B", "2A", "2B", "3A", "3B", "4A", "4B", "5A", "5B")
    val result = Array.ofDim[Double](classesToOrder.length, classesToOrder(0).length)
    var alreadyOrdered = List[Int]()
    var currentIndex = classesToOrder.length - 1
    // Put the communities that are classes
    for (level <- levels) {
        for (rank <- 0 until classesToOrder.length if !alreadyOrdered.contains(rank)) {
            val sum = classesToOrder(rank).sum
            var levelSum = 0.0
            for (id <- classesToOrder(rank).indices) {
                if (studentsIdToName(id).split("-")(0) == level) {
                    levelSum += classesToOrder(rank)(id)
                }
            }
            if (levelSum > (sum / 2)) {
                result(currentIndex) = classesToOrder(rank)
                currentIndex -= 1
                alreadyOrdered :+= rank
            }
        }
    }
    // Put the communities that are not classes
    for (rank <- 0 until classesToOrder.length if !alreadyOrdered.contains(rank)) {
        result(currentIndex) = classesToOrder(rank)
        currentIndex -= 1
        alreadyOrdered :+= rank
    }
    
    result
}

In [None]:
val nbRanks = kruskal.lambdas.size
println(nbRanks)
val classes = Array.ofDim[Double](nbRanks, 242)
for (rank <- 0 until nbRanks) {
    val collectedClass1 = student1Tensor.projection(Rank)(rank).collect
    val collectedClass2 = student2Tensor.projection(Rank)(rank).collect
   
    for (i <- collectedClass1.indices) {
        classes(rank)(studentsNameToId(collectedClass1(Student1, i))) += math.abs(collectedClass1(i))
        classes(rank)(studentsNameToId(collectedClass2(Student2, i))) += math.abs(collectedClass2(i))
    }
    classes(rank) = classes(rank).map(v => v / classes(rank).max)
}
var plot = Seq(
        Heatmap(x = studentsNameToId.keys.toList.sorted.toSeq,
                y = (0 until nbRanks).toSeq,
                z = orderClasses(classes).map(_.toSeq).toSeq
        )
)
plot.plot()

In [None]:
val nbRanks = kruskal.lambdas.size
for (rank <- 0 until nbRanks) {
    val studentMap = scala.collection.mutable.Map[String, Double]()
    val student1Collected = {
        val students = student1Tensor.projection(Rank)(rank)
            .selection(v => math.abs(v) > 0.01).collect.orderByValuesDesc
        for (i <- students.indices) {
            studentMap(students(Student1, i)) = studentMap.getOrElse(students(Student1, i), 0.0) + math.abs(students(i))
        }
    }
    val student2Collected = {
        val students = student2Tensor.projection(Rank)(rank)
            .selection(v => math.abs(v) > 0.01).collect.orderByValuesDesc
        for (i <- students.indices) {
            studentMap(students(Student2, i)) = studentMap.getOrElse(students(Student2, i), 0.0) + math.abs(students(i))
        }
    }
    
    val days1 = {
            val days = timeTensor.projection(Rank)(rank)
                .restriction(Time.condition(v => v <= (62300 / (5 * 60)))).collect
            (for (i <- days.indices) yield {
                (days(Time, i), days(i))
            }).toList.sortWith((d1, d2) => d1._1 < d2._1)
        }
    val days2 = {
            val days = timeTensor.projection(Rank)(rank)
                .restriction(Time.condition(v => v >= (117240 / (5 * 60)))).collect
            (for (i <- days.indices) yield {
                (days(Time, i), days(i))
            }).toList.sortWith((d1, d2) => d1._1 < d2._1)
        }
    
    val nbMinutes = 5
    var plot = Seq(
        // Day 1
        Bar(
            days1.map(v => {
                    val hours = math.floor((v._1 * nbMinutes) / 60).toInt
                    s"${hours}h${(v._1 * nbMinutes) - hours * 60}"
                }).toSeq, 
            days1.map(v => math.abs(v._2)).toSeq,
            name = "Day 1",
            xaxis = AxisReference.X1,
            yaxis = AxisReference.Y1
        ),
        // Day 2
        Bar(
            days2.map(v => {
                    val hours = math.floor((v._1 * nbMinutes) / 60).toInt
                    s"${hours - 24}h${(v._1 * nbMinutes) - hours * 60}"
                }).toSeq, 
            days2.map(v => math.abs(v._2)).toSeq,
            name = "Day 2",
            xaxis = AxisReference.X2,
            yaxis = AxisReference.Y2
        ),
        Bar(
            (for ((k, v) <- studentMap) yield (k, v)).toList.sortWith((e1, e2) => e1._2 > e2._2).map(v => v._1),
            (for ((k, v) <- studentMap) yield (k, v)).toList.sortWith((e1, e2) => e1._2 > e2._2).map(v => (v._2 / 2)),
            name = "Students",
            xaxis = AxisReference.X3,
            yaxis = AxisReference.Y3
        )
    )

    val layout = Layout(
        title = s"Rank $rank",
        width = 1000,
        xaxis1 = Axis(anchor = AxisAnchor.Reference(AxisReference.Y1), domain = (0.0, 0.49), automargin = true),
        xaxis2 = Axis(anchor = AxisAnchor.Reference(AxisReference.Y2), domain = (0.51, 1.0), automargin = true),
        xaxis3 = Axis(anchor = AxisAnchor.Reference(AxisReference.Y3), domain = (0.0, 1.0), automargin = true),
        yaxis1 = Axis(anchor = AxisAnchor.Reference(AxisReference.X1), domain = (0.55, 1.0), automargin = true),
        yaxis2 = Axis(anchor = AxisAnchor.Reference(AxisReference.X2), domain = (0.55, 1.0), automargin = true),
        yaxis3 = Axis(anchor = AxisAnchor.Reference(AxisReference.X3), domain = (0.0, 0.45), automargin = true),
        legend = Legend(y = 1.1, x = .5, yanchor = Anchor.Top, xanchor = Anchor.Center, orientation = Orientation.Horizontal)
    )

    plot.plot(layout = layout, Config(), "")
}