Skip to content

Commit

Permalink
Code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
abates committed Feb 26, 2018
1 parent ec60633 commit 9da7e1b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 47 deletions.
26 changes: 13 additions & 13 deletions database.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"errors"
"image"
"io"

"github.com/disintegration/imaging"
)

var (
Expand All @@ -26,9 +28,8 @@ type Loadable interface {
}

type DB struct {
index Index
imageHasher func(image.Image) (PHash, error)
fileHasher func(io.Reader) (PHash, error)
index Index
hasher func(image.Image) (PHash, error)
}

func New() *DB {
Expand All @@ -37,25 +38,24 @@ func New() *DB {

func NewDB(index Index) *DB {
r := &DB{
index: index,
imageHasher: Hash,
fileHasher: HashFile,
index: index,
hasher: Hash,
}
return r
}

func (db *DB) Add(img image.Image) (PHash, error) {
hash, err := db.imageHasher(img)
hash, err := db.hasher(img)
if err == nil {
err = db.AddHash(hash)
}
return hash, err
}

func (db *DB) AddFile(reader io.Reader) (PHash, error) {
hash, err := db.fileHasher(reader)
func (db *DB) AddFile(reader io.Reader) (hash PHash, err error) {
img, err := imaging.Decode(reader)
if err == nil {
err = db.AddHash(hash)
hash, err = db.Add(img)
}
return hash, err
}
Expand All @@ -80,17 +80,17 @@ func (db *DB) Load(reader io.Reader) error {
}

func (db *DB) Search(img image.Image, maxDistance int) (matches []PHash, err error) {
hash, err := db.imageHasher(img)
hash, err := db.hasher(img)
if err == nil {
matches, err = db.SearchByHash(hash, maxDistance)
}
return matches, err
}

func (db *DB) SearchByFile(reader io.Reader, maxDistance int) (matches []PHash, err error) {
h, err := db.fileHasher(reader)
img, err := imaging.Decode(reader)
if err == nil {
matches, err = db.SearchByHash(h, maxDistance)
matches, err = db.Search(img, maxDistance)
}
return matches, err
}
Expand Down
30 changes: 17 additions & 13 deletions database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"bytes"
"fmt"
"image"
"io"
"image/png"
"reflect"
"testing"
)
Expand Down Expand Up @@ -32,7 +32,7 @@ func TestDBAdd(t *testing.T) {

for i, test := range tests {
db := NewDB(newTestIndex())
db.imageHasher = func(image.Image) (PHash, error) { return PHash(0), test.expectedErr }
db.hasher = func(image.Image) (PHash, error) { return PHash(0), test.expectedErr }

_, err := db.Add(test.img)
if err != test.expectedErr {
Expand All @@ -43,17 +43,19 @@ func TestDBAdd(t *testing.T) {

func TestDBAddFile(t *testing.T) {
tests := []struct {
reader io.Reader
expectedErr error
}{
{bytes.NewReader([]byte{}), nil},
{bytes.NewReader([]byte{}), fmt.Errorf("Just some test error")},
{nil},
{fmt.Errorf("Just some test error")},
}

for i, test := range tests {
db := NewDB(newTestIndex())
db.fileHasher = func(io.Reader) (PHash, error) { return PHash(0), test.expectedErr }
_, err := db.AddFile(test.reader)
db.hasher = func(image.Image) (PHash, error) { return PHash(0), test.expectedErr }
image := image.NewAlpha(image.Rect(0, 0, 1, 1))
buf := bytes.NewBuffer([]byte{})
png.Encode(buf, image)
_, err := db.AddFile(buf)
if err != test.expectedErr {
t.Errorf("tests[%d] expected %v got %v", i, test.expectedErr, err)
}
Expand All @@ -76,7 +78,7 @@ func TestDBSearch(t *testing.T) {
testIndex.err = test.expectedSearchErr
testIndex.matches = test.expectedMatches
db := NewDB(testIndex)
db.imageHasher = func(image.Image) (PHash, error) { return PHash(0), test.expectedHashErr }
db.hasher = func(image.Image) (PHash, error) { return PHash(0), test.expectedHashErr }

matches, err := db.Search(test.img, 0)
if test.expectedHashErr != nil && err != test.expectedHashErr {
Expand All @@ -93,23 +95,25 @@ func TestDBSearch(t *testing.T) {

func TestDBSearchByFile(t *testing.T) {
tests := []struct {
reader io.Reader
expectedMatches []PHash
expectedHashErr error
expectedSearchErr error
}{
{bytes.NewReader([]byte{}), []PHash{}, nil, nil},
{bytes.NewReader([]byte{}), []PHash{}, nil, fmt.Errorf("Some test error")},
{[]PHash{}, nil, nil},
{[]PHash{}, nil, fmt.Errorf("Some test error")},
}

for i, test := range tests {
testIndex := newTestIndex()
testIndex.err = test.expectedSearchErr
testIndex.matches = test.expectedMatches
db := NewDB(testIndex)
db.fileHasher = func(io.Reader) (PHash, error) { return PHash(0), test.expectedHashErr }
db.hasher = func(image.Image) (PHash, error) { return PHash(0), test.expectedHashErr }

matches, err := db.SearchByFile(test.reader, 0)
image := image.NewAlpha(image.Rect(0, 0, 1, 1))
buf := bytes.NewBuffer([]byte{})
png.Encode(buf, image)
matches, err := db.SearchByFile(buf, 0)
if test.expectedHashErr != nil && err != test.expectedHashErr {
t.Errorf("tests[%d] expected %v got %v", i, test.expectedHashErr, err)
} else if test.expectedSearchErr != nil && err != test.expectedSearchErr {
Expand Down
18 changes: 2 additions & 16 deletions hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package disgo

import (
"fmt"
"github.com/disintegration/imaging"
"image"
"io"

"github.com/disintegration/imaging"
)

type PHash uint64
Expand All @@ -19,26 +19,12 @@ func (p1 PHash) Distance(p2 PHash) (distance int) {
return
}

/*func intensity(img *image.NRGBA, row, column int) uint8 {
offset := (row-img.Rect.Min.Y)*img.Stride + (column-img.Rect.Min.X)*4
return uint8((uint16(img.Pix[offset]) + uint16(img.Pix[offset+1]) + uint16(img.Pix[offset+2])) / 3)
}*/

func intensity(img image.Image, row, column int) uint8 {
c := img.At(column, row)
r, g, b, _ := c.RGBA()
return uint8((r + g + b) / 3)
}

func HashFile(reader io.Reader) (hash PHash, err error) {
img, err := imaging.Decode(reader)
if err == nil {
hash, err = Hash(img)
}

return hash, err
}

func Hash(img image.Image) (PHash, error) {
rows := 8
columns := 9
Expand Down
6 changes: 1 addition & 5 deletions hash_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package disgo

import (
"bytes"
"image"
"image/png"
"testing"
)

Expand Down Expand Up @@ -94,9 +92,7 @@ func TestHash(t *testing.T) {
}

for i, test := range tests {
buf := bytes.NewBuffer([]byte{})
png.Encode(buf, test.img)
hash, _ := HashFile(buf)
hash, _ := Hash(test.img)
if hash != test.expectedHash {
t.Errorf("tests[%d] expected %x got %x", i, test.expectedHash, hash)
}
Expand Down

0 comments on commit 9da7e1b

Please sign in to comment.