Skip to content
Browse files

initial commit

  • Loading branch information...
0 parents commit 7f640775ada6a60e802482fa27b3935c2cfc0657 @ant0ine committed Oct 5, 2012
Showing with 302 additions and 0 deletions.
  1. 0 .gitignore
  2. +52 −0 README.txt
  3. +132 −0 mahalanobis.go
  4. +118 −0 mahalanobis_test.go
0 .gitignore
No changes.
52 README.txt
@@ -0,0 +1,52 @@
+PACKAGE
+
+package mahalanobis
+ import "github.com/ant0ine/go.mahalanobis"
+
+ Naive implementation of the Mahalanobis distance using go.matrix
+ (https://en.wikipedia.org/wiki/Mahalanobis_distance)
+
+ This is me learning Go, it's probably broken, don't use it.
+
+ Example:
+
+ package main
+ import (
+ "fmt"
+ "github.com/skelterjohn/go.matrix"
+ "github.com/ant0ine/go.mahalanobis"
+ )
+ func main() {
+ points, err := matrix.ParseMatlab("[1 4 3 4;4 2 3 4]")
+ if err != nil {
+ panic(err)
+ }
+ fmt.Println("4 points:\n", points)
+ target, err := matrix.ParseMatlab("[3;4]")
+ if err != nil {
+ panic(err)
+ }
+ fmt.Println("the target point:\n", target)
+ distance, err := mahalanobis.Distance(points, target)
+ if err != nil {
+ panic(err)
+ }
+ fmt.Println("Mahalanobis distance=", distance)
+ }
+
+FUNCTIONS
+
+func CovarianceMatrix(points *matrix.DenseMatrix) *matrix.DenseMatrix
+ Return the covariance matrix for this set of points (sample covariance
+ is used)
+
+func Distance(points, target *matrix.DenseMatrix) (float64, error)
+ Return the Mahalanobis distance
+
+func DistanceSquare(points, target *matrix.DenseMatrix) (float64, error)
+ Return the square of the Mahalanobis distance
+
+func MeanVector(points *matrix.DenseMatrix) *matrix.DenseMatrix
+ Given a set a points, return the mean vector.
+
+
132 mahalanobis.go
@@ -0,0 +1,132 @@
+// Naive implementation of the Mahalanobis distance using go.matrix
+// (https://en.wikipedia.org/wiki/Mahalanobis_distance)
+//
+// This is me learning Go, it's probably broken, don't use it.
+//
+// Example:
+//
+// package main
+//
+// import (
+// "fmt"
+// "github.com/skelterjohn/go.matrix"
+// "github.com/ant0ine/go.mahalanobis"
+// )
+//
+// func main() {
+//
+// points, err := matrix.ParseMatlab("[1 4 3 4;4 2 3 4]")
+// if err != nil {
+// panic(err)
+// }
+// fmt.Println("4 points:\n", points)
+//
+// target, err := matrix.ParseMatlab("[3;4]")
+// if err != nil {
+// panic(err)
+// }
+// fmt.Println("the target point:\n", target)
+//
+// distance, err := mahalanobis.Distance(points, target)
+// if err != nil {
+// panic(err)
+// }
+// fmt.Println("Mahalanobis distance=", distance)
+// }
+package mahalanobis
+
+import (
+// "fmt"
+ "math"
+ "github.com/skelterjohn/go.matrix"
+)
+
+// Given a set a points, return the mean vector.
+func MeanVector(points *matrix.DenseMatrix) *matrix.DenseMatrix {
+ mean := matrix.Zeros(points.Rows(), 1)
+ for i := 0; i < points.Rows(); i++ {
+ sum := 0.0
+ for j := 0; j < points.Cols(); j++ {
+ sum += points.Get(i, j)
+ }
+ mean.Set(i, 0, sum / float64(points.Cols()))
+ }
+ return mean
+}
+
+func sample_covariance_matrix(points, mean *matrix.DenseMatrix) *matrix.DenseMatrix {
+ dim := points.Rows()
+ cov := matrix.Zeros(dim, dim)
+ for i := 0; i < dim; i++ {
+ for j := 0; j < dim; j++ {
+ if i > j {
+ // symetric matrix
+ continue
+ }
+ // TODO in go routines ?
+ sum := 0.0
+ for k := 0; k < points.Cols(); k++ {
+ sum += (points.Get(i, k) - mean.Get(i, 0)) * (points.Get(j, k) - mean.Get(j, 0))
+ }
+
+ // this is the sample covariance, divide by (N - 1)
+ covariance := sum / ( float64(points.Cols() - 1))
+
+ cov.Set(i, j, covariance)
+ // symetric matrix
+ cov.Set(j, i, covariance)
+
+ }
+ }
+ return cov
+}
+
+// Return the covariance matrix for this set of points (sample covariance is used)
+func CovarianceMatrix(points *matrix.DenseMatrix) *matrix.DenseMatrix {
+ mean := MeanVector(points)
+ return sample_covariance_matrix(points, mean)
+}
+
+// Return the square of the Mahalanobis distance
+func DistanceSquare(points, target *matrix.DenseMatrix) (float64, error) {
+
+ // TODO check the dimensions
+
+ mean := MeanVector(points)
+ //fmt.Println("mean:\n", mean)
+
+ delta := target.Copy()
+ delta.SubtractDense(mean)
+ //fmt.Println("delta:\n", delta)
+
+ cov := sample_covariance_matrix(points, mean)
+ //fmt.Println("covariance:\n", cov)
+
+ inv, err := cov.Inverse()
+ if err != nil {
+ return 0, err // XXX
+ }
+ //fmt.Println("inverse covariance:\n", inv)
+
+ product1, err := inv.TimesDense(delta)
+ if err != nil {
+ return 0, err // XXX
+ }
+ delta_t := delta.Transpose()
+ product2, err := delta_t.TimesDense(product1)
+ if err != nil {
+ return 0, err // XXX
+ }
+
+ return product2.Get(0,0), nil
+}
+
+// Return the Mahalanobis distance
+func Distance(points, target *matrix.DenseMatrix) (float64, error) {
+ square, err := DistanceSquare(points, target)
+ if err != nil {
+ return 0, err // XXX
+ }
+ return math.Sqrt(square), nil
+
+}
118 mahalanobis_test.go
@@ -0,0 +1,118 @@
+package mahalanobis
+
+import (
+// "fmt"
+ "math"
+ "testing"
+ "github.com/skelterjohn/go.matrix"
+)
+
+func TestMeanVector(t *testing.T) {
+
+ points, err := matrix.ParseMatlab("[1 1 1;1 1 1]")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ expected, err := matrix.ParseMatlab("[1;1]")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ result := MeanVector(points)
+ if !matrix.Equals(result, expected) {
+ t.Error()
+ }
+
+ points, err = matrix.ParseMatlab("[0 1 2;0 2 4]")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ expected, err = matrix.ParseMatlab("[1;2]")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ result = MeanVector(points)
+ if !matrix.Equals(result, expected) {
+ t.Error()
+ }
+}
+
+func TestCovarianceMatrix(t *testing.T) {
+
+ // no (co)variance
+ // R: var(cbind(c(1, 1), c(1, 1)))
+ points, err := matrix.ParseMatlab("[1 1;1 1]")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ expected, err := matrix.ParseMatlab("[0 0;0 0]")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ result := CovarianceMatrix(points)
+ //fmt.Println("covariance:\n", result)
+ if !matrix.Equals(result, expected) {
+ t.Error()
+ }
+
+ // diagonale case
+ // R: var(cbind(c(0, 4, 2, 2), c(2, 2, 0, 4)))
+ points, err = matrix.ParseMatlab("[0 4 2 2;2 2 0 4]")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // R: var(cbind(c(0, 4, 2, 2), c(2, 2, 0, 4)))
+ expected = matrix.MakeDenseMatrix([]float64{2.66,0,0,2.66}, 2, 2)
+
+ result = CovarianceMatrix(points)
+ //fmt.Println("covariance:\n", result)
+ if !matrix.ApproxEquals(result, expected, 0.01) {
+ t.Error()
+ }
+
+ // another case
+ // R: var(cbind(c(9, 3, 5), c(3, 4, 1)))
+ points, err = matrix.ParseMatlab("[9 3 5;3 4 1]")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ expected = matrix.MakeDenseMatrix([]float64{9.33,-0.66,-0.66,2.33}, 2, 2)
+
+ result = CovarianceMatrix(points)
+ //fmt.Println("covariance:\n", result)
+ if !matrix.ApproxEquals(result, expected, 0.01) {
+ t.Error()
+ }
+}
+
+func TestDistance(t *testing.T) {
+
+ // R:
+ // x = cbind(c(9, 3, 5), c(3, 4, 1))
+ // mahalanobis(c(1,1), colMeans(x), var(x))
+ points, err := matrix.ParseMatlab("[9 3 5;3 4 1]")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ target, err := matrix.ParseMatlab("[1;1]")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ square, err := DistanceSquare(points, target)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if math.Abs(square - 4.08) > 0.01 {
+ t.Error()
+ }
+}

0 comments on commit 7f64077

Please sign in to comment.
Something went wrong with that request. Please try again.