In [3]:
%use kotlin-statistics, krangl, lets-plot, numpy(0.1.4)

@file:Repository("https://repo1.maven.org/maven2")
@file:DependsOn("de.sldk:kotbar:0.1.0")

In [4]:
class Obj(
    val grams: List<String>,
    val y: Int // 1 - legit, 0 - spam
)

class Batch(
    val objs: List<Obj>
)

In [5]:
class ConfInfo(
        var tp: Int = 0,
        var fp: Int = 0,
        var fn: Int = 0,
        var tn: Int = 0,
        var cnt: Int = 0,
        var prec: Double = 0.0,
        var recall: Double = 0.0,
        var fSc: Double = 0.0
)

class Stat(
    val accuracy: Double, 
    val tpr: Double,
    val fpr: Double
)

fun getStat(confMatrix: Array<IntArray>): Stat {
    val k = confMatrix.size
    val infos = Array(k) { ConfInfo() }
    var all = 0
    for (i in 0 until k) {
        for (j in 0 until k) {
            val cur = confMatrix[i][j]
            all += cur
            infos[i].cnt += cur
            if (i == j) {
                infos[i].tp = cur
            } else {
                infos[i].fp += cur
                infos[j].fn += cur
            }
        }
    }
    infos.forEach {
        it.tn = all - it.fp - it.fn - it.tp
        it.recall = if (it.tp + it.fn != 0) it.tp.toDouble() / (it.tp + it.fn) else 0.0
        it.prec = if (it.tp + it.fp != 0) it.tp.toDouble() / (it.tp + it.fp) else 0.0
        it.fSc = if (it.recall + it.prec != 0.0) 2 * it.recall * it.prec / (it.recall + it.prec) else 0.0
    }
    val acc = (infos.map{it.tn}.sum() + infos.map{it.tp}.sum()).toDouble() / 
    (infos.map {it.tp + it.fn}.sum() + infos.map {it.fp + it.tn}.sum())
    val tpr = infos.map{it.tp}.sum().toDouble() / infos.map {it.tp + it.fn}.sum()
    val fpr = infos.map{it.fp}.sum().toDouble() / infos.map {it.fp + it.tn}.sum()
    return Stat(acc, tpr, fpr)
}

In [6]:
import java.io.File

fun readData(gramN: Int): List<Batch> {
    val batches = ArrayList<Batch>()
    File("res/messages").listFiles().forEach { part ->
        val objs = ArrayList<Obj>()
        part.listFiles().forEach { f ->
            val y = if (f.name.contains("legit")) 1 else 0
           // println(f.readText())
            val words = f.readText().drop(8).replace("\n", " ").trim().split(" ").filter {it.isNotEmpty()}.toList()
            val grams = ArrayList<String>()
            for (i in 0 until words.size - gramN) {
                var gram = ""
                for (j in 0 until gramN) {
                    gram += "${words[i + j]} "
                }
                grams.add(gram.trim())
            }
            objs.add(Obj(grams, y))
        }
        batches.add(Batch(objs))
    }
    return batches
}



In [7]:
println(readData(1).flatMap{it.objs}.size)

1090


In [8]:
class Bayes(
    val n: Int,
    val lambda: List<Double>,
    val alpha: Double,
    val wordsInClass: List<HashMap<String, Int>> = List(n) { HashMap<String, Int>() },
    val classCount: MutableList<Int> = MutableList(n) {0},
    val allWords: HashSet<String> = HashSet()
    
)

fun fillBayes(b: Bayes, objs: List<Obj>) {
    for (i in objs.indices) {
        for (w in objs[i].grams) {
            b.wordsInClass[objs[i].y].putIfAbsent(w, 0)
            b.wordsInClass[objs[i].y][w] = b.wordsInClass[objs[i].y][w]!! + 1
            b.allWords.add(w)
        }
        b.classCount[objs[i].y]++
    }
}

fun predict(obj: Obj, b: Bayes, nu: Double? = null): Pair<Int,Double> {
    val ps = MutableList(b.n) { 0.0 }
    var maxP = -10000000000000000000000.0
    var maxClas = 10
    val classP = b.classCount.map {it.toDouble() / b.classCount.sum()}
    val counts = b.wordsInClass.map { it.values.sum()}
    for (clas in 0 until b.n) {
        var p = ln(b.lambda[clas]) + ln(classP[clas])
        for (word in obj.grams) {
//             println("adad = ${b.pWordInClass[clas][word] ?: "null"}")
            val a = ln(b.alpha.toDouble() + (b.wordsInClass[clas][word]?.toDouble() ?: 0.0))
            val b = ln(
                (b.alpha * b.allWords.size + counts[clas])
            )
            if (a != 0.0 && b != 0.0) p += a - b
//             println("ai = ${ln(
//                 b.pWordInClass[clas][word] ?: (b.alpha.toDouble() / (b.wordsCountInClass[clas] + b.alpha * b.allWords.size))
//             )}")
        }
//         println("lambda = ${b.lambda[clas]}")
//         println("a = ${b.lambda[clas].toDouble() * b.classP[clas]}")
//         println("b = ${ln(b.lambda[clas].toDouble() * b.classP[clas])}")
//         println("c = ${p}")
//         println("^^^^^^")
        ps[clas] = p
        if (maxP < p) {
            maxP = p
            maxClas = clas
        }
    }
//        println("________________________")
    //println()
     //println("-----------------------------")
//     ps.map { E.pow(it) }.let { a ->
//         for (clas in 0 until b.n) {
//             val r = a[clas] / a.sum()
// //             print(" ${r} -- ")
//             if (maxP < r) {
//                 maxP = r
//                 maxClas = clas
//             }
//         }
//         println()
       
//     }
    if (nu != null) {
        val pr = (abs(ps.sum()) - abs(ps[1])) / (-ps.sum()) 
        if (pr >- nu) {
            return Pair(1, pr)
        } else {
            return Pair(0, pr)
        }
    }
    val pr = (abs(ps.sum()) - abs(ps[1])) / (-ps.sum()) 
    return Pair(maxClas, pr)
}

In [9]:
val nParam = listOf(1,2,3)
val bestAlphaFound = listOf(1e-7)
val alphaParam = listOf(0.0000000000001, 0.000000001,0.0000001, 0.000001, 0.001, 0.01, 0.1, 0.5, 0.9, 1.0)
val fixedLambdaSpam = 1.0
val legitLambdas = listOf(1.0,1e10,1e20,1e40, 1e60, 1e80, 1e100, 1e125,
                          1e150,1e170,1e180,1e190,1e200,1e210,1e240, 1e260)
fun calc(gramN: Int, alphaFound : List<Double>, lambdas : List<Double>):HashMap<Double,Double> {
    var bestN = 0
    var bestAlpha = 0.0
    var bestLambdaLegitToAcc = HashMap<Double,Double>()
    var bestAcc = 0.0
    var lamdaLegitNoErrors = -1
    loop@ for (n in gramN..gramN) {
        val batches = readData(n)
        for (alpha in alphaFound) {
            var lambdaLegitToAcc = HashMap<Double,Double>()    
            for (legitLambda in lambdas) {
                //println()
                //println("iter = ${a++}")
                val confMatrix = Array(2) {IntArray(2)}
                var legitErrorCount = 0
                for(i in batches.indices) {
                    val trainSet = batches.filterIndexed {ind, _ -> i != ind}.flatMap {it.objs}
                    val testSet = batches[i].objs
                    val bayes = Bayes(2, listOf(fixedLambdaSpam, legitLambda), alpha)
                    fillBayes(bayes, trainSet)
                    val testRes = testSet.map {
                        predict(it, bayes).first
                    }
                    for (j in testSet.indices) {
//                         println("actual: ${testSet[j].y}, got: ${testRes[j]}")
                        if (testRes[j] != testSet[j].y && testSet[j].y == 1) legitErrorCount++
                        confMatrix[testRes[j]][testSet[j].y] += 1
                    }
                }
                val acc = getStat(confMatrix).accuracy
                lambdaLegitToAcc[legitLambda] = acc
                if (legitErrorCount == 0) {
                     if (acc > bestAcc) {
                        bestAcc = acc
                        bestN = n
                        bestAlpha = alpha
                        bestLambdaLegitToAcc = lambdaLegitToAcc 
                    }
               }
                println()
                println("legit errors: ${legitErrorCount}")
                println("n: ${n}")
                println("legitLambda: ${legitLambda}")
                println("alpha: ${alpha}")
                println("acc: ${acc}")
                
            }
        }
    }
    println(""" best params:
    accuracy = ${bestAcc}
    n param = ${bestN}
    alpha param = ${bestAlpha}
    lambda spam = ${fixedLambdaSpam}
    lambda legit no errors = ${lamdaLegitNoErrors}
    """)
    return bestLambdaLegitToAcc
    
}

fun plot(lToA: HashMap<Double,Double>) {
    val dat = lToA.toList().sortedBy { it.first }
    val plotData = mapOf<String, Any>(
    "lambda" to dat.map {it.first}.map{log10(it)},
    "accuracy" to dat.map{it.second}
    )
    var p1 =  lets_plot(plotData){x = "lambda"; y = "accuracy"}
    p1+= geom_path() + geom_point() + scale_x_log10()
    p1.show()
}

fun buildROC(gramN: Int, alpha: Double, lambdaL: Double) {
    val batches = readData(gramN)
    val dataSet = batches.flatMap {it.objs}
    val yStep = 1.0 / dataSet.filter {it.y == 1}.count()
    val xStep = 1.0 / dataSet.filter {it.y != 1}.count()
    val bayes = Bayes(2, listOf(fixedLambdaSpam, lambdaL), alpha)
    fillBayes(bayes, dataSet)
    val testRes = dataSet.map {
        Pair(it, predict(it, bayes))
    }.sortedBy { it.second.second}
    //print(testRes.map{it.first.y})
    val xs = ArrayList<Double>()
    val ys = ArrayList<Double>()
    xs.add(0.0)
    ys.add(0.0)
    for (i in testRes.indices.reversed()) {
        if (testRes[i].first.y == 1) {
            xs.add(xs.last())
            ys.add(ys.last() + yStep)
        } else {
             xs.add(xs.last() + xStep)
            ys.add(ys.last())
        }
    }
    val plotData = mapOf<String, Any>(
        "x" to xs.toList(),
        "y" to ys.toList()
    )
    var p1 =  lets_plot(plotData){x = "x"; y = "y"}
    p1+= geom_path() + geom_point()
    p1.show()
}

In [8]:
val lambdas1 = listOf(1.0,1e10,1e20,1e40, 1e60, 1e80, 1e100, 1e125,
                          1e150,1e170)
val lambdaToAcc1 = calc(1, listOf(1e-7), lambdas1)


legit errors: 10
n: 1
legitLambda: 1.0
alpha: 1.0E-7
acc: 0.9779816513761468

legit errors: 9
n: 1
legitLambda: 1.0E10
alpha: 1.0E-7
acc: 0.9724770642201835

legit errors: 5
n: 1
legitLambda: 1.0E20
alpha: 1.0E-7
acc: 0.9587155963302753

legit errors: 3
n: 1
legitLambda: 1.0E40
alpha: 1.0E-7
acc: 0.926605504587156

legit errors: 2
n: 1
legitLambda: 1.0E60
alpha: 1.0E-7
acc: 0.8972477064220183

legit errors: 1
n: 1
legitLambda: 1.0E80
alpha: 1.0E-7
acc: 0.8770642201834863

legit errors: 1
n: 1
legitLambda: 1.0E100
alpha: 1.0E-7
acc: 0.8532110091743119

legit errors: 1
n: 1
legitLambda: 1.0E125
alpha: 1.0E-7
acc: 0.8311926605504587

legit errors: 1
n: 1
legitLambda: 1.0E150
alpha: 1.0E-7
acc: 0.8036697247706422

legit errors: 0
n: 1
legitLambda: 1.0E170
alpha: 1.0E-7
acc: 0.7844036697247706
 best params:
    accuracy = 0.7844036697247706
    n param = 1
    alpha param = 1.0E-7
    lambda spam = 1.0
    lambda legit no errors = -1
    


In [9]:
plot(lambdaToAcc1)

In [10]:
buildROC(1, 1e-7, 1e170)

In [11]:
val lambdas2 = listOf(1.0,1e20,1e40, 1e60, 1e80, 1e100, 1e125,
                          1e150,1e170,1e190,1e210, 1e230)
val lambdaToAcc2 = calc(2, listOf(1e-3),lambdas2)


legit errors: 16
n: 2
legitLambda: 1.0
alpha: 0.001
acc: 0.9798165137614679

legit errors: 7
n: 2
legitLambda: 1.0E20
alpha: 0.001
acc: 0.9761467889908257

legit errors: 3
n: 2
legitLambda: 1.0E40
alpha: 0.001
acc: 0.9642201834862385

legit errors: 3
n: 2
legitLambda: 1.0E60
alpha: 0.001
acc: 0.9440366972477064

legit errors: 2
n: 2
legitLambda: 1.0E80
alpha: 0.001
acc: 0.9220183486238532

legit errors: 2
n: 2
legitLambda: 1.0E100
alpha: 0.001
acc: 0.9045871559633027

legit errors: 2
n: 2
legitLambda: 1.0E125
alpha: 0.001
acc: 0.8834862385321101

legit errors: 2
n: 2
legitLambda: 1.0E150
alpha: 0.001
acc: 0.8660550458715597

legit errors: 2
n: 2
legitLambda: 1.0E170
alpha: 0.001
acc: 0.8513761467889909

legit errors: 1
n: 2
legitLambda: 1.0E190
alpha: 0.001
acc: 0.8376146788990826

legit errors: 1
n: 2
legitLambda: 1.0E210
alpha: 0.001
acc: 0.826605504587156

legit errors: 0
n: 2
legitLambda: 1.0E230
alpha: 0.001
acc: 0.8192660550458716
 best params:
    accuracy = 0.8192660550458716


In [12]:
plot(lambdaToAcc2)

In [13]:
buildROC(2, 1e-3, 1e230)

In [21]:
val lambdas3 = listOf(1.0,1e30, 1e60, 1e90, 1e120, 1e160, 1e200,
                          1e240,1e270, 1e300)
val lambdaToAcc3 = calc(3, listOf(0.0005), lambdas3)


legit errors: 14
n: 3
legitLambda: 1.0
alpha: 5.0E-4
acc: 0.9807339449541285

legit errors: 7
n: 3
legitLambda: 1.0E30
alpha: 5.0E-4
acc: 0.9605504587155963

legit errors: 5
n: 3
legitLambda: 1.0E60
alpha: 5.0E-4
acc: 0.9238532110091743

legit errors: 4
n: 3
legitLambda: 1.0E90
alpha: 5.0E-4
acc: 0.8899082568807339

legit errors: 4
n: 3
legitLambda: 1.0E120
alpha: 5.0E-4
acc: 0.8761467889908257

legit errors: 3
n: 3
legitLambda: 1.0E160
alpha: 5.0E-4
acc: 0.8532110091743119

legit errors: 3
n: 3
legitLambda: 1.0E200
alpha: 5.0E-4
acc: 0.828440366972477

legit errors: 2
n: 3
legitLambda: 1.0E240
alpha: 5.0E-4
acc: 0.8128440366972477

legit errors: 2
n: 3
legitLambda: 1.0E270
alpha: 5.0E-4
acc: 0.8064220183486238

legit errors: 0
n: 3
legitLambda: 1.0E300
alpha: 5.0E-4
acc: 0.7889908256880734
 best params:
    accuracy = 0.7889908256880734
    n param = 3
    alpha param = 5.0E-4
    lambda spam = 1.0
    lambda legit no errors = -1
    


In [22]:
plot(lambdaToAcc3)

In [23]:
buildROC(3, 0.0005, 1e300)