Skip to content

Commit

Permalink
Better tests and added ability to load/save database
Browse files Browse the repository at this point in the history
  • Loading branch information
abates committed Feb 26, 2018
1 parent f4b54c7 commit ec60633
Show file tree
Hide file tree
Showing 13 changed files with 651 additions and 288 deletions.
85 changes: 59 additions & 26 deletions database.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,66 +2,99 @@ package disgo

import (
"errors"
"image"
"io"
)

var (
ErrNotFound = errors.New("Image not found")
ErrNotFound = errors.New("Image not found")
ErrSaveNotSupported = errors.New("Underlying index does not support saving")
ErrLoadNotSupported = errors.New("Underlying index does not support loading")
)

type Index interface {
Insert(PHash) error
Search(PHash, int) ([]PHash, error)
}

type Saveable interface {
Save(io.Writer) error
}

type Loadable interface {
Load(io.Reader) error
}

type DB struct {
index Index
paths map[PHash][]string
index Index
imageHasher func(image.Image) (PHash, error)
fileHasher func(io.Reader) (PHash, error)
}

func New() *DB {
return NewDB(NewRadixIndex())
}

func NewDB(index Index) *DB {
r := new(DB)
r.index = index
r.paths = make(map[PHash][]string)
r := &DB{
index: index,
imageHasher: Hash,
fileHasher: HashFile,
}
return r
}

func (db *DB) Add(path string, hash PHash) error {
if db.paths[hash] == nil {
db.paths[hash] = []string{path}
} else {
db.paths[hash] = append(db.paths[hash], path)
func (db *DB) Add(img image.Image) (PHash, error) {
hash, err := db.imageHasher(img)
if err == nil {
err = db.AddHash(hash)
}
return hash, err
}

func (db *DB) AddFile(reader io.Reader) (PHash, error) {
hash, err := db.fileHasher(reader)
if err == nil {
err = db.AddHash(hash)
}
return hash, err
}

func (db *DB) AddHash(hash PHash) error {
db.index.Insert(hash)
return nil
}

func (db *DB) Find(hash PHash) (paths []string, err error) {
if db.paths[hash] == nil {
return nil, ErrNotFound
func (db *DB) Save(writer io.Writer) error {
if saver, ok := db.index.(Saveable); ok {
return saver.Save(writer)
}
return ErrSaveNotSupported
}

func (db *DB) Load(reader io.Reader) error {
if loader, ok := db.index.(Loadable); ok {
return loader.Load(reader)
}
return db.paths[hash], nil
return ErrLoadNotSupported
}

func (db *DB) SearchByFile(reader io.Reader, maxDistance int) (matches []string, err error) {
h, err := HashFile(reader)
func (db *DB) Search(img image.Image, maxDistance int) (matches []PHash, err error) {
hash, err := db.imageHasher(img)
if err == nil {
matches, err = db.Search(h, maxDistance)
matches, err = db.SearchByHash(hash, maxDistance)
}
return matches, err
}

func (db *DB) Search(hash PHash, maxDistance int) ([]string, error) {
var results []string
//fmt.Printf("Search: %v\n", hash)
hashes, _ := db.index.Search(hash, maxDistance)

for _, hash := range hashes {
results = append(results, db.paths[hash]...)
func (db *DB) SearchByFile(reader io.Reader, maxDistance int) (matches []PHash, err error) {
h, err := db.fileHasher(reader)
if err == nil {
matches, err = db.SearchByHash(h, maxDistance)
}
return results, nil
return matches, err
}

func (db *DB) SearchByHash(hash PHash, maxDistance int) ([]PHash, error) {
return db.index.Search(hash, maxDistance)
}
182 changes: 91 additions & 91 deletions database_test.go
Original file line number Diff line number Diff line change
@@ -1,129 +1,128 @@
package disgo

import (
"math/rand"
"os"
"sort"
"bytes"
"fmt"
"image"
"io"
"reflect"
"testing"
)

func addFile(db *DB, path string) PHash {
file, _ := os.Open(path)
phash, _ := HashFile(file)
db.Add(path, phash)
return phash
type testIndex struct {
err error
matches []PHash
}

func TestFind(t *testing.T) {
db := New()
hash1 := addFile(db, "images/gopher1.png")
hash2 := addFile(db, "images/gopher2.png")
func (ti *testIndex) Insert(PHash) error { return ti.err }
func (ti *testIndex) Search(PHash, int) ([]PHash, error) { return ti.matches, ti.err }

entries, _ := db.Find(hash1)
if len(entries) != 2 {
t.Logf("Expected to find two entries for hash %v but only got %d", hash1, len(entries))
t.Fail()
}
func newTestIndex() *testIndex {
return &testIndex{}
}

entries, _ = db.Find(hash2)
if len(entries) != 2 {
t.Logf("Expected to find two entries for hash %v but only got %d", hash2, len(entries))
t.Fail()
func TestDBAdd(t *testing.T) {
tests := []struct {
img image.Image
expectedErr error
}{
{image.NewRGBA(image.Rect(0, 0, 10, 10)), nil},
{image.NewRGBA(image.Rect(0, 0, 10, 10)), fmt.Errorf("Just some test error")},
}

file, _ := os.Open("images/ascendingGradient.png")
hash3, _ := HashFile(file)
entries, err := db.Find(hash3)
if len(entries) > 0 {
t.Logf("Expected not to find hash %v but found %d entries", hash3, len(entries))
t.Fail()
}
for i, test := range tests {
db := NewDB(newTestIndex())
db.imageHasher = func(image.Image) (PHash, error) { return PHash(0), test.expectedErr }

if err != ErrNotFound {
t.Logf("Expected %v but got %v", ErrNotFound, err)
t.Fail()
_, err := db.Add(test.img)
if err != test.expectedErr {
t.Errorf("tests[%d] expected %v got %v", i, test.expectedErr, err)
}
}
}

func TestSearchByFile(t *testing.T) {
for _, i := range []Index{NewRadixIndex(), NewLinearIndex()} {
db := NewDB(i)
addFile(db, "images/ascendingGradient.png")
addFile(db, "images/descendingGradient.png")
addFile(db, "images/alternatingGradient.png")
addFile(db, "images/gopher1.png")
addFile(db, "images/gopher2.png")

file, _ := os.Open("images/gopher3.png")
paths, err := db.SearchByFile(file, 5)
if err != nil {
t.Logf("Expected no error while searching by file. Got: %v", err)
t.Fail()
}

if len(paths) != 2 {
t.Logf("Expected exactly two matching images. Got: %d", len(paths))
t.FailNow()
}

sort.Strings(paths)
if paths[0] != "images/gopher1.png" {
t.Logf("Expected to match images/gopher1.png but got %s instead", paths[0])
t.Fail()
}
func TestDBAddFile(t *testing.T) {
tests := []struct {
reader io.Reader
expectedErr error
}{
{bytes.NewReader([]byte{}), nil},
{bytes.NewReader([]byte{}), fmt.Errorf("Just some test error")},
}

if paths[1] != "images/gopher2.png" {
t.Logf("Expected to match images/gopher2.png but got %s instead", paths[1])
t.Fail()
for i, test := range tests {
db := NewDB(newTestIndex())
db.fileHasher = func(io.Reader) (PHash, error) { return PHash(0), test.expectedErr }
_, err := db.AddFile(test.reader)
if err != test.expectedErr {
t.Errorf("tests[%d] expected %v got %v", i, test.expectedErr, err)
}
}
}

var benchmarkHashes []PHash

func getHashes(numHashes int) []PHash {
if benchmarkHashes == nil {
benchmarkHashes = make([]PHash, 0)
func TestDBSearch(t *testing.T) {
tests := []struct {
img image.Image
expectedMatches []PHash
expectedHashErr error
expectedSearchErr error
}{
{image.NewRGBA(image.Rect(0, 0, 10, 10)), []PHash{}, nil, nil},
{image.NewRGBA(image.Rect(0, 0, 10, 10)), []PHash{}, nil, fmt.Errorf("Some test error")},
}

for len(benchmarkHashes) < numHashes {
randomNumber := PHash(rand.Int63())
if rand.NormFloat64() >= 0 {
randomNumber = randomNumber | 0x8000000000000000
for i, test := range tests {
testIndex := newTestIndex()
testIndex.err = test.expectedSearchErr
testIndex.matches = test.expectedMatches
db := NewDB(testIndex)
db.imageHasher = func(image.Image) (PHash, error) { return PHash(0), test.expectedHashErr }

matches, err := db.Search(test.img, 0)
if test.expectedHashErr != nil && err != test.expectedHashErr {
t.Errorf("tests[%d] expected %v got %v", i, test.expectedHashErr, err)
} else if test.expectedSearchErr != nil && err != test.expectedSearchErr {
t.Errorf("tests[%d] expected %v got %v", i, test.expectedSearchErr, err)
}
benchmarkHashes = append(benchmarkHashes, randomNumber)
}
return benchmarkHashes
}

func benchmarkAdd(b *testing.B, index Index, numToAdd int) {
hashes := getHashes(numToAdd)
b.ResetTimer()

for n := 0; n < b.N; n++ {
db := NewDB(index)
for i := 0; i < numToAdd; i++ {
db.Add("filename", hashes[i])
if !reflect.DeepEqual(test.expectedMatches, matches) {
t.Errorf("tests[%d] expected %v got %v", i, test.expectedMatches, matches)
}
}
}

func benchmarkSearch(b *testing.B, index Index, numToSearch int) {
hashes := getHashes(numToSearch)

db := NewDB(index)
for _, hash := range hashes {
db.Add("filename", hash)
func TestDBSearchByFile(t *testing.T) {
tests := []struct {
reader io.Reader
expectedMatches []PHash
expectedHashErr error
expectedSearchErr error
}{
{bytes.NewReader([]byte{}), []PHash{}, nil, nil},
{bytes.NewReader([]byte{}), []PHash{}, nil, fmt.Errorf("Some test error")},
}

b.ResetTimer()
for n := 0; n < b.N; n++ {
for i := 0; i < numToSearch; i++ {
db.Search(hashes[i], 5)
for i, test := range tests {
testIndex := newTestIndex()
testIndex.err = test.expectedSearchErr
testIndex.matches = test.expectedMatches
db := NewDB(testIndex)
db.fileHasher = func(io.Reader) (PHash, error) { return PHash(0), test.expectedHashErr }

matches, err := db.SearchByFile(test.reader, 0)
if test.expectedHashErr != nil && err != test.expectedHashErr {
t.Errorf("tests[%d] expected %v got %v", i, test.expectedHashErr, err)
} else if test.expectedSearchErr != nil && err != test.expectedSearchErr {
t.Errorf("tests[%d] expected %v got %v", i, test.expectedSearchErr, err)
}

if !reflect.DeepEqual(test.expectedMatches, matches) {
t.Errorf("tests[%d] expected %v got %v", i, test.expectedMatches, matches)
}
}
}

/*
func BenchmarkLinearIndexAdd10(b *testing.B) { benchmarkAdd(b, NewLinearIndex(), 10) }
func BenchmarkLinearIndexAdd100(b *testing.B) { benchmarkAdd(b, NewLinearIndex(), 100) }
func BenchmarkLinearIndexAdd1000(b *testing.B) { benchmarkAdd(b, NewLinearIndex(), 1000) }
Expand All @@ -143,3 +142,4 @@ func BenchmarkRadixIndexSearch10(b *testing.B) { benchmarkSearch(b, NewRadixI
func BenchmarkRadixIndexSearch100(b *testing.B) { benchmarkSearch(b, NewRadixIndex(), 100) }
func BenchmarkRadixIndexSearch1000(b *testing.B) { benchmarkSearch(b, NewRadixIndex(), 1000) }
func BenchmarkRadixIndexSearch10000(b *testing.B) { benchmarkSearch(b, NewRadixIndex(), 10000) }
*/
16 changes: 11 additions & 5 deletions hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,15 @@ func (p1 PHash) Distance(p2 PHash) (distance int) {
return
}

func intensity(img *image.NRGBA, row, column int) uint8 {
/*func intensity(img *image.NRGBA, row, column int) uint8 {
offset := (row-img.Rect.Min.Y)*img.Stride + (column-img.Rect.Min.X)*4
return uint8((uint16(img.Pix[offset]) + uint16(img.Pix[offset+1]) + uint16(img.Pix[offset+2])) / 3)
}*/

func intensity(img image.Image, row, column int) uint8 {
c := img.At(column, row)
r, g, b, _ := c.RGBA()
return uint8((r + g + b) / 3)
}

func HashFile(reader io.Reader) (hash PHash, err error) {
Expand All @@ -38,13 +44,13 @@ func Hash(img image.Image) (PHash, error) {
columns := 9
var hash PHash

grayscale := imaging.Grayscale(img)
grayscale = imaging.Resize(grayscale, columns, rows, imaging.Box)
img = imaging.Grayscale(img)
img = imaging.Resize(img, columns, rows, imaging.Box)

for row := 0; row < rows; row++ {
for column := 0; column < columns-1; column++ {
avg1 := intensity(grayscale, row, column)
avg2 := intensity(grayscale, row, column+1)
avg1 := intensity(img, row, column)
avg2 := intensity(img, row, column+1)
hash = hash << 1
if avg1 > avg2 {
hash = hash | 0x01
Expand Down
Loading

0 comments on commit ec60633

Please sign in to comment.