/
portcache.go
138 lines (114 loc) · 3.14 KB
/
portcache.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
package portcache
import (
"fmt"
"sync"
"github.com/aporeto-inc/trireme-lib/utils/cache"
"github.com/aporeto-inc/trireme-lib/utils/portspec"
)
// PortCache is a generic cache of port pairs or exact ports. It can store
// and do lookups of ports on exact matches or ranges. It returns the stored
// values
type PortCache struct {
ports cache.DataStore
ranges []*portspec.PortSpec
sync.Mutex
}
// NewPortCache creates a new port cache
func NewPortCache(name string) *PortCache {
return &PortCache{
ports: cache.NewCache(name),
ranges: []*portspec.PortSpec{},
}
}
// AddPortSpec adds a port spec into the cache
func (p *PortCache) AddPortSpec(s *portspec.PortSpec) {
if s.Min == s.Max {
p.ports.AddOrUpdate(s.Min, s)
} else {
p.Lock()
p.ranges = append(p.ranges, s)
p.Unlock()
}
}
// AddUnique adds a port spec into the cache and makes sure its unique
func (p *PortCache) AddUnique(s *portspec.PortSpec) error {
p.Lock()
defer p.Unlock()
if s.Min == s.Max {
if err, _ := p.ports.Get(s.Min); err != nil {
return fmt.Errorf("Port already exists: %s", err)
}
}
for _, r := range p.ranges {
if r.Max <= s.Min || r.Min >= s.Max {
continue
}
return fmt.Errorf("Overlap detected: %d %d", r.Max, r.Min)
}
if s.Min == s.Max {
return p.ports.Add(s.Min, s)
}
p.ranges = append(p.ranges, s)
return nil
}
// GetSpecValueFromPort searches the cache for a match based on a port
// It will return the first match found on exact ports or on the ranges
// of ports. If there are multiple intervals that match it will randomly
// return one of them.
func (p *PortCache) GetSpecValueFromPort(port uint16) (interface{}, error) {
if spec, err := p.ports.Get(port); err == nil {
return spec.(*portspec.PortSpec).Value(), nil
}
p.Lock()
defer p.Unlock()
for _, s := range p.ranges {
if s.Min <= port && port < s.Max {
return s.Value(), nil
}
}
return nil, fmt.Errorf("No match for port %d", port)
}
// GetAllSpecValueFromPort will return all the specs that potentially match. This
// will allow for overlapping ranges
func (p *PortCache) GetAllSpecValueFromPort(port uint16) ([]interface{}, error) {
var allMatches []interface{}
if spec, err := p.ports.Get(port); err == nil {
allMatches = append(allMatches, spec.(*portspec.PortSpec).Value())
}
p.Lock()
defer p.Unlock()
for _, s := range p.ranges {
if s.Min <= port && port < s.Max {
allMatches = append(allMatches, s.Value())
}
}
if len(allMatches) == 0 {
return nil, fmt.Errorf("No match for port %d", port)
}
return allMatches, nil
}
// Remove will remove a port from the cache
func (p *PortCache) Remove(s *portspec.PortSpec) error {
if s.Min == s.Max {
return p.ports.Remove(uint16(s.Min))
}
p.Lock()
defer p.Unlock()
for i, r := range p.ranges {
if r.Min == s.Min && r.Max == s.Max {
left := p.ranges[:i]
right := p.ranges[i+1:]
p.ranges = append(left, right...)
return nil
}
}
return fmt.Errorf("port not found")
}
// RemoveStringPorts will remove a port from the cache
func (p *PortCache) RemoveStringPorts(ports string) error {
s, err := portspec.NewPortSpecFromString(ports, nil)
if err != nil {
return err
}
return p.Remove(s)
}