-
Notifications
You must be signed in to change notification settings - Fork 5
/
mux.go
138 lines (116 loc) · 3.2 KB
/
mux.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
package srpc
import "sync"
// Mux contains a set of <service, method> handlers.
type Mux interface {
// Invoker invokes the methods.
Invoker
// Register registers a new RPC method handler (service).
Register(handler Handler) error
// HasService checks if the service ID exists in the handlers.
HasService(serviceID string) bool
// HasServiceMethod checks if <service-id, method-id> exists in the handlers.
HasServiceMethod(serviceID, methodID string) bool
}
// muxMethods is a mapping from method id to handler.
type muxMethods map[string]Handler
// mux is the default implementation of Mux.
type mux struct {
// fallback is the list of fallback invokers
// if the mux doesn't match the service, calls the invokers.
fallback []Invoker
// rmtx guards below fields
rmtx sync.RWMutex
// services contains a mapping from services to handlers.
services map[string]muxMethods
}
// NewMux constructs a new Mux.
//
// fallbackInvokers is the list of fallback Invokers to call in the case that
// the service/method is not found on this mux.
func NewMux(fallbackInvokers ...Invoker) Mux {
return &mux{
fallback: fallbackInvokers,
services: make(map[string]muxMethods),
}
}
// Register registers a new RPC method handler (service).
func (m *mux) Register(handler Handler) error {
serviceID := handler.GetServiceID()
methodIDs := handler.GetMethodIDs()
if serviceID == "" {
return ErrEmptyServiceID
}
m.rmtx.Lock()
defer m.rmtx.Unlock()
serviceMethods := m.services[serviceID]
if serviceMethods == nil {
serviceMethods = make(muxMethods)
m.services[serviceID] = serviceMethods
}
for _, methodID := range methodIDs {
if methodID != "" {
serviceMethods[methodID] = handler
}
}
return nil
}
// HasService checks if the service ID exists in the handlers.
func (m *mux) HasService(serviceID string) bool {
if serviceID == "" {
return false
}
m.rmtx.Lock()
defer m.rmtx.Unlock()
return len(m.services[serviceID]) != 0
}
// HasServiceMethod checks if <service-id, method-id> exists in the handlers.
func (m *mux) HasServiceMethod(serviceID, methodID string) bool {
if serviceID == "" || methodID == "" {
return false
}
m.rmtx.Lock()
defer m.rmtx.Unlock()
handlers := m.services[serviceID]
for _, mh := range handlers {
for _, mhMethodID := range mh.GetMethodIDs() {
if mhMethodID == methodID {
return true
}
}
}
return false
}
// InvokeMethod invokes the method matching the service & method ID.
// Returns false, nil if not found.
// If service string is empty, ignore it.
func (m *mux) InvokeMethod(serviceID, methodID string, strm Stream) (bool, error) {
var handler Handler
m.rmtx.RLock()
if serviceID == "" {
for _, svc := range m.services {
if handler = svc[methodID]; handler != nil {
break
}
}
} else {
svcMethods := m.services[serviceID]
if svcMethods != nil {
handler = svcMethods[methodID]
}
}
m.rmtx.RUnlock()
if handler != nil {
return handler.InvokeMethod(serviceID, methodID, strm)
}
for _, invoker := range m.fallback {
if invoker != nil {
handled, err := invoker.InvokeMethod(serviceID, methodID, strm)
if err != nil || handled {
return handled, err
}
}
}
return false, nil
}
// _ is a type assertion
var _ Mux = ((*mux)(nil))