/
myprogram.go
58 lines (47 loc) · 1.39 KB
/
myprogram.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
package main
import (
"fmt"
"log"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/evaluation"
"github.com/sjwhitworth/golearn/filters"
"github.com/sjwhitworth/golearn/naive"
)
func main() {
// Read in the loan training data set into golearn "instances".
trainingData, err := base.ParseCSVToInstances("training.csv", true)
if err != nil {
log.Fatal(err)
}
// Initialize a new Naive Bayes classifier.
nb := naive.NewBernoulliNBClassifier()
// Fit the Naive Bayes classifier.
nb.Fit(convertToBinary(trainingData))
// Read in the loan test data set into golearn "instances".
testData, err := base.ParseCSVToInstances("test.csv", true)
if err != nil {
log.Fatal(err)
}
// Make our predictions.
predictions := nb.Predict(convertToBinary(testData))
// Generate a Confusion Matrix.
cm, err := evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
log.Fatal(err)
}
// Retrieve the accuracy.
accuracy := evaluation.GetAccuracy(cm)
fmt.Printf("\nAccuracy: %0.2f\n\n", accuracy)
}
// convertToBinary utilizes built in golearn functionality to
// convert our labels to a binary label format.
func convertToBinary(src base.FixedDataGrid) base.FixedDataGrid {
b := filters.NewBinaryConvertFilter()
attrs := base.NonClassAttributes(src)
for _, a := range attrs {
b.AddAttribute(a)
}
b.Train()
ret := base.NewLazilyFilteredInstances(src, b)
return ret
}