-
Notifications
You must be signed in to change notification settings - Fork 0
/
search.go
117 lines (105 loc) · 2.2 KB
/
search.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
package search
import (
"bytes"
_ "embed"
"encoding/gob"
"slices"
"github.com/PiotrKozimor/krkstops/pkg/search/merged"
"github.com/PiotrKozimor/krkstops/pkg/trie"
)
var (
//go:embed score/stops/stops.gob
stopsB []byte
//go:embed score/score.gob
scoreB []byte
)
type Search struct {
t trie.Trie
score map[uint]uint
stops merged.Stops
}
type Stop struct {
Name string
Id uint
Score uint
Tram, Bus bool
}
type scoredStop struct {
id, score uint
}
func New() (*Search, error) {
stops, err := merged.Read(stopsB)
if err != nil {
return nil, err
}
s := Search{
t: trie.New(),
stops: stops,
}
buf := bytes.NewBuffer(scoreB)
err = gob.NewDecoder(buf).Decode(&s.score)
if err != nil {
return nil, err
}
entries := make([]trie.Entry, 0, len(stops))
for id, stop := range stops {
entries = append(entries, trie.Entry{
Id: id,
Word: stop.Name,
})
}
s.t.InsertWords(entries...)
return &s, nil
}
func (s *Search) Search(term string, limit int) []Stop {
r := s.t.SearchExact(term)
stops := s.sort(r, limit)
if len(stops) >= limit {
return stops
}
extra := s.t.SearchInDistance(term, 1)
extraStops := s.sort(extra, limit-len(stops))
stops = append(stops, extraStops...)
return stops
}
func (s *Search) searchWithinDistance(term string, limit int) []Stop {
r := s.t.SearchWithinDistance(term, 1)
return s.sort(r, limit)
}
func (s *Search) SearchExact(term string, limit int) []Stop {
r := s.t.SearchExact(term)
return s.sort(r, limit)
}
func (s *Search) sort(results []uint, limit int) []Stop {
scored := make([]scoredStop, len(results))
for i := range results {
scored[i] = scoredStop{id: results[i], score: s.score[results[i]]}
}
slices.SortFunc(scored, func(a, b scoredStop) int {
return int(a.score - b.score)
})
res := make([]Stop, 0, limit)
for i, sc := range scored {
if i >= limit {
break
}
res = append(res, Stop{
Name: s.stops[sc.id].Name,
Bus: s.stops[sc.id].Bus,
Tram: s.stops[sc.id].Tram,
Score: sc.score,
Id: sc.id,
})
}
return res
}
func (s *Search) Get(id uint) Stop {
stop := s.stops[id]
return Stop{
Name: stop.Name,
Bus: stop.Bus,
Tram: stop.Tram,
Score: s.score[id],
Id: id,
}
}