Skip to content

Commit

Permalink
[#1998]fix a bug where filter was not currently working properly in t…
Browse files Browse the repository at this point in the history
…riple protocol and fix some spelling mistake (#1999)
  • Loading branch information
Mulavar committed Jul 30, 2022
1 parent 51be359 commit 013f0b2
Show file tree
Hide file tree
Showing 18 changed files with 88 additions and 65 deletions.
4 changes: 2 additions & 2 deletions config/service_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (s *ServiceConfig) Init(rc *RootConfig) error {
s.ProtocolIDs = rc.Provider.ProtocolIDs
}
if len(s.ProtocolIDs) <= 0 {
for k, _ := range rc.Protocols {
for k := range rc.Protocols {
s.ProtocolIDs = append(s.ProtocolIDs, k)
}
}
Expand Down Expand Up @@ -400,7 +400,7 @@ func (s *ServiceConfig) Unexport() {
s.exportersLock.Lock()
defer s.exportersLock.Unlock()
for _, exporter := range s.exporters {
exporter.Unexport()
exporter.UnExport()
}
s.exporters = nil
}()
Expand Down
32 changes: 27 additions & 5 deletions filter/token/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ func init() {
extension.SetFilter(constant.TokenFilterKey, newTokenFilter)
}

const (
InValidTokenFormat = "[Token Filter]Invalid token! Forbid invoke remote service %v with method %s"
)

// tokenFilter will verify if the token is valid
type tokenFilter struct{}

Expand All @@ -60,13 +64,31 @@ func newTokenFilter() filter.Filter {
func (f *tokenFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
invokerTkn := invoker.GetURL().GetParam(constant.TokenKey, "")
if len(invokerTkn) > 0 {
attachs := invocation.Attachments()
remoteTkn, exist := attachs[constant.TokenKey]
if exist && remoteTkn != nil && strings.EqualFold(invokerTkn, remoteTkn.(string)) {
attas := invocation.Attachments()
var remoteTkn string
remoteTknIface, exist := attas[constant.TokenKey]
if !exist || remoteTknIface == nil {
return &protocol.RPCResult{Err: perrors.Errorf(InValidTokenFormat, invoker, invocation.MethodName())}
}
switch remoteTknIface.(type) {
case string:
// deal with dubbo protocol
remoteTkn = remoteTknIface.(string)
case []string:
// deal with triple protocol
remoteTkns := remoteTknIface.([]string)
if len(remoteTkns) != 1 {
return &protocol.RPCResult{Err: perrors.Errorf(InValidTokenFormat, invoker, invocation.MethodName())}
}
remoteTkn = remoteTkns[0]
default:
return &protocol.RPCResult{Err: perrors.Errorf(InValidTokenFormat, invoker, invocation.MethodName())}
}

if strings.EqualFold(invokerTkn, remoteTkn) {
return invoker.Invoke(ctx, invocation)
}
return &protocol.RPCResult{Err: perrors.Errorf("Invalid token! Forbid invoke remote service %v method %s ",
invoker, invocation.MethodName())}
return &protocol.RPCResult{Err: perrors.Errorf(InValidTokenFormat, invoker, invocation.MethodName())}
}

return invoker.Invoke(ctx, invocation)
Expand Down
6 changes: 3 additions & 3 deletions protocol/dubbo/dubbo_exporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ func NewDubboExporter(key string, invoker protocol.Invoker, exporterMap *sync.Ma
}

// Unexport unexport dubbo service exporter.
func (de *DubboExporter) Unexport() {
func (de *DubboExporter) UnExport() {
interfaceName := de.GetInvoker().GetURL().GetParam(constant.InterfaceKey, "")
de.BaseExporter.Unexport()
de.BaseExporter.UnExport()
err := common.ServiceMap.UnRegister(interfaceName, DUBBO, de.GetInvoker().GetURL().ServiceKey())
if err != nil {
logger.Errorf("[DubboExporter.Unexport] error: %v", err)
logger.Errorf("[DubboExporter.UnExport] error: %v", err)
}
}
26 changes: 13 additions & 13 deletions protocol/dubbo/dubbo_invoker.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import (
"dubbo.apache.org/dubbo-go/v3/common/constant"
"dubbo.apache.org/dubbo-go/v3/config"
"dubbo.apache.org/dubbo-go/v3/protocol"
invocation_impl "dubbo.apache.org/dubbo-go/v3/protocol/invocation"
"dubbo.apache.org/dubbo-go/v3/protocol/invocation"
"dubbo.apache.org/dubbo-go/v3/remoting"
)

Expand Down Expand Up @@ -84,7 +84,7 @@ func (di *DubboInvoker) getClient() *remoting.ExchangeClient {
}

// Invoke call remoting.
func (di *DubboInvoker) Invoke(ctx context.Context, invocation protocol.Invocation) protocol.Result {
func (di *DubboInvoker) Invoke(ctx context.Context, ivc protocol.Invocation) protocol.Result {
var (
err error
result protocol.RPCResult
Expand Down Expand Up @@ -114,7 +114,7 @@ func (di *DubboInvoker) Invoke(ctx context.Context, invocation protocol.Invocati
return &result
}

inv := invocation.(*invocation_impl.RPCInvocation)
inv := ivc.(*invocation.RPCInvocation)
// init param
inv.SetAttachment(constant.PathKey, di.GetURL().GetParam(constant.InterfaceKey, ""))
for _, k := range attachmentKey {
Expand Down Expand Up @@ -142,15 +142,15 @@ func (di *DubboInvoker) Invoke(ctx context.Context, invocation protocol.Invocati
timeout := di.getTimeout(inv)
if async {
if callBack, ok := inv.CallBack().(func(response common.CallbackResponse)); ok {
result.Err = di.client.AsyncRequest(&invocation, url, timeout, callBack, rest)
result.Err = di.client.AsyncRequest(&ivc, url, timeout, callBack, rest)
} else {
result.Err = di.client.Send(&invocation, url, timeout)
result.Err = di.client.Send(&ivc, url, timeout)
}
} else {
if inv.Reply() == nil {
result.Err = protocol.ErrNoReply
} else {
result.Err = di.client.Request(&invocation, url, timeout, rest)
result.Err = di.client.Request(&ivc, url, timeout, rest)
}
}
if result.Err == nil {
Expand All @@ -162,21 +162,21 @@ func (di *DubboInvoker) Invoke(ctx context.Context, invocation protocol.Invocati
}

// get timeout including methodConfig
func (di *DubboInvoker) getTimeout(invocation *invocation_impl.RPCInvocation) time.Duration {
methodName := invocation.MethodName()
func (di *DubboInvoker) getTimeout(ivc *invocation.RPCInvocation) time.Duration {
methodName := ivc.MethodName()
if di.GetURL().GetParamBool(constant.GenericKey, false) {
methodName = invocation.Arguments()[0].(string)
methodName = ivc.Arguments()[0].(string)
}
timeout := di.GetURL().GetParam(strings.Join([]string{constant.MethodKeys, methodName, constant.TimeoutKey}, "."), "")
if len(timeout) != 0 {
if t, err := time.ParseDuration(timeout); err == nil {
// config timeout into attachment
invocation.SetAttachment(constant.TimeoutKey, strconv.Itoa(int(t.Milliseconds())))
ivc.SetAttachment(constant.TimeoutKey, strconv.Itoa(int(t.Milliseconds())))
return t
}
}
// set timeout into invocation at method level
invocation.SetAttachment(constant.TimeoutKey, strconv.Itoa(int(di.timeout.Milliseconds())))
ivc.SetAttachment(constant.TimeoutKey, strconv.Itoa(int(di.timeout.Milliseconds())))
return di.timeout
}

Expand Down Expand Up @@ -207,11 +207,11 @@ func (di *DubboInvoker) Destroy() {

// Finally, I made the decision that I don't provide a general way to transfer the whole context
// because it could be misused. If the context contains to many key-value pairs, the performance will be much lower.
func (di *DubboInvoker) appendCtx(ctx context.Context, inv *invocation_impl.RPCInvocation) {
func (di *DubboInvoker) appendCtx(ctx context.Context, ivc *invocation.RPCInvocation) {
// inject opentracing ctx
currentSpan := opentracing.SpanFromContext(ctx)
if currentSpan != nil {
err := injectTraceCtx(currentSpan, inv)
err := injectTraceCtx(currentSpan, ivc)
if err != nil {
logger.Errorf("Could not inject the span context into attachments: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions protocol/dubbo/dubbo_protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ func TestDubboProtocol_Export(t *testing.T) {
eq2 := exporter2.GetInvoker().GetURL().URLEqual(url2)
assert.True(t, eq2)

// make sure exporterMap after 'Unexport'
// make sure exporterMap after 'UnExport'
_, ok := proto.(*DubboProtocol).ExporterMap().Load(url2.ServiceKey())
assert.True(t, ok)
exporter2.Unexport()
exporter2.UnExport()
_, ok = proto.(*DubboProtocol).ExporterMap().Load(url2.ServiceKey())
assert.False(t, ok)

Expand Down
6 changes: 3 additions & 3 deletions protocol/dubbo3/dubbo3_exporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ func NewDubboExporter(key string, invoker protocol.Invoker, exporterMap *sync.Ma
}

// Unexport unexport dubbo3 service exporter.
func (de *DubboExporter) Unexport() {
func (de *DubboExporter) UnExport() {
url := de.GetInvoker().GetURL()
interfaceName := url.GetParam(constant.InterfaceKey, "")
de.BaseExporter.Unexport()
de.BaseExporter.UnExport()
err := common.ServiceMap.UnRegister(interfaceName, tripleConstant.TRIPLE, url.ServiceKey())
if err != nil {
logger.Errorf("[DubboExporter.Unexport] error: %v", err)
logger.Errorf("[DubboExporter.UnExport] error: %v", err)
}
de.serviceMap.Delete(interfaceName)
}
7 changes: 4 additions & 3 deletions protocol/dubbo3/dubbo3_invoker.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ func NewDubboInvoker(url *common.URL) (*DubboInvoker, error) {
interfaceKey := url.GetParam(constant.InterfaceKey, "")
consumerService := config.GetConsumerServiceByInterfaceName(interfaceKey)

dubboSerializaerType := url.GetParam(constant.SerializationKey, constant.ProtobufSerialization)
triCodecType := tripleConstant.CodecType(dubboSerializaerType)
dubboSerializerType := url.GetParam(constant.SerializationKey, constant.ProtobufSerialization)
triCodecType := tripleConstant.CodecType(dubboSerializerType)
// new triple client
opts := []triConfig.OptionFunction{
triConfig.WithClientTimeout(uint32(timeout.Seconds())),
Expand Down Expand Up @@ -181,6 +181,7 @@ func (di *DubboInvoker) Invoke(ctx context.Context, invocation protocol.Invocati

// append interface id to ctx
gRPCMD := make(metadata.MD, 0)
// triple will convert attachment value to []string
for k, v := range invocation.Attachments() {
if str, ok := v.(string); ok {
gRPCMD.Set(k, str)
Expand All @@ -190,7 +191,7 @@ func (di *DubboInvoker) Invoke(ctx context.Context, invocation protocol.Invocati
gRPCMD.Set(k, str...)
continue
}
logger.Warnf("triple attachment value with key = %s is invalid, which should be string or []string", k)
logger.Warnf("[Triple Protocol]Triple attachment value with key = %s is invalid, which should be string or []string", k)
}
ctx = metadata.NewOutgoingContext(ctx, gRPCMD)
ctx = context.WithValue(ctx, tripleConstant.InterfaceKey, di.BaseInvoker.GetURL().GetParam(constant.InterfaceKey, ""))
Expand Down
6 changes: 3 additions & 3 deletions protocol/dubbo3/dubbo3_protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ func TestDubboProtocolExport(t *testing.T) {
eq := exporter.GetInvoker().GetURL().URLEqual(url)
assert.True(t, eq)

// make sure exporterMap after 'Unexport'
// make sure exporterMap after 'UnExport'
_, ok := proto.(*DubboProtocol).ExporterMap().Load(url.ServiceKey())
assert.True(t, ok)
exporter.Unexport()
exporter.UnExport()
_, ok = proto.(*DubboProtocol).ExporterMap().Load(url.ServiceKey())
assert.False(t, ok)

Expand Down Expand Up @@ -144,7 +144,7 @@ func TestDubbo3UnaryService_GetReqParamsInterfaces(t *testing.T) {

func subTest(t *testing.T, val, paramsInterfaces interface{}) {
list := paramsInterfaces.([]interface{})
for k, _ := range list {
for k := range list {
err := hessian.ReflectResponse(val, list[k])
assert.Nil(t, err)
}
Expand Down
6 changes: 3 additions & 3 deletions protocol/grpc/grpc_exporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ func NewGrpcExporter(key string, invoker protocol.Invoker, exporterMap *sync.Map
}

// Unexport and unregister gRPC service from registry and memory.
func (gg *GrpcExporter) Unexport() {
func (gg *GrpcExporter) UnExport() {
interfaceName := gg.GetInvoker().GetURL().GetParam(constant.InterfaceKey, "")
gg.BaseExporter.Unexport()
gg.BaseExporter.UnExport()
err := common.ServiceMap.UnRegister(interfaceName, GRPC, gg.GetInvoker().GetURL().ServiceKey())
if err != nil {
logger.Errorf("[GrpcExporter.Unexport] error: %v", err)
logger.Errorf("[GrpcExporter.UnExport] error: %v", err)
}
}
4 changes: 2 additions & 2 deletions protocol/grpc/grpc_protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ func TestGrpcProtocolExport(t *testing.T) {
eq := exporter.GetInvoker().GetURL().URLEqual(url)
assert.True(t, eq)

// make sure exporterMap after 'Unexport'
// make sure exporterMap after 'UnExport'
_, ok := proto.(*GrpcProtocol).ExporterMap().Load(url.ServiceKey())
assert.True(t, ok)
exporter.Unexport()
exporter.UnExport()
_, ok = proto.(*GrpcProtocol).ExporterMap().Load(url.ServiceKey())
assert.False(t, ok)

Expand Down
6 changes: 3 additions & 3 deletions protocol/jsonrpc/jsonrpc_exporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ func NewJsonrpcExporter(key string, invoker protocol.Invoker, exporterMap *sync.
}

// Unexport exported JSON RPC service.
func (je *JsonrpcExporter) Unexport() {
func (je *JsonrpcExporter) UnExport() {
interfaceName := je.GetInvoker().GetURL().GetParam(constant.InterfaceKey, "")
je.BaseExporter.Unexport()
je.BaseExporter.UnExport()
err := common.ServiceMap.UnRegister(interfaceName, JSONRPC, je.GetInvoker().GetURL().ServiceKey())
if err != nil {
logger.Errorf("[JsonrpcExporter.Unexport] error: %v", err)
logger.Errorf("[JsonrpcExporter.UnExport] error: %v", err)
}
}
4 changes: 2 additions & 2 deletions protocol/jsonrpc/jsonrpc_protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ func TestJsonrpcProtocolExport(t *testing.T) {
eq := exporter.GetInvoker().GetURL().URLEqual(url)
assert.True(t, eq)

// make sure exporterMap after 'Unexport'
// make sure exporterMap after 'UnExport'
fmt.Println(url.Path)
_, ok := proto.(*JsonrpcProtocol).ExporterMap().Load(strings.TrimPrefix(url.Path, "/"))
assert.True(t, ok)
exporter.Unexport()
exporter.UnExport()
_, ok = proto.(*JsonrpcProtocol).ExporterMap().Load(strings.TrimPrefix(url.Path, "/"))
assert.False(t, ok)

Expand Down
16 changes: 8 additions & 8 deletions protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
"dubbo.apache.org/dubbo-go/v3/common"
)

// Protocol is the interface that wraps the basic Export Refer and Destroy method.
// Protocol is the interface that wraps the basic Export, Refer and Destroy method.
//
// Export method is to export service for remote invocation
//
Expand All @@ -42,14 +42,14 @@ type Protocol interface {
Destroy()
}

// Exporter is the interface that wraps the basic GetInvoker method and Destroy Unexport.
// Exporter is the interface that wraps the basic GetInvoker method and Destroy UnExport.
//
// GetInvoker method is to get invoker.
//
// Unexport method is to unexport a exported service
// UnExport is to un export an exported service
type Exporter interface {
GetInvoker() Invoker
Unexport()
UnExport()
}

// BaseProtocol is default protocol implement.
Expand Down Expand Up @@ -105,10 +105,10 @@ func (bp *BaseProtocol) Destroy() {
}
bp.invokers = []Invoker{}

// unexport exporters
// un export exporters
bp.exporterMap.Range(func(key, exporter interface{}) bool {
if exporter != nil {
exporter.(Exporter).Unexport()
exporter.(Exporter).UnExport()
} else {
bp.exporterMap.Delete(key)
}
Expand Down Expand Up @@ -137,8 +137,8 @@ func (de *BaseExporter) GetInvoker() Invoker {
return de.invoker
}

// Unexport exported service.
func (de *BaseExporter) Unexport() {
// UnExport un export service.
func (de *BaseExporter) UnExport() {
logger.Infof("Exporter unexport.")
de.invoker.Destroy()
de.exporterMap.Delete(de.key)
Expand Down
6 changes: 3 additions & 3 deletions protocol/rest/rest_exporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ func NewRestExporter(key string, invoker protocol.Invoker, exporterMap *sync.Map
}

// Unexport unexport the RestExporter
func (re *RestExporter) Unexport() {
func (re *RestExporter) UnExport() {
interfaceName := re.GetInvoker().GetURL().GetParam(constant.InterfaceKey, "")
re.BaseExporter.Unexport()
re.BaseExporter.UnExport()
err := common.ServiceMap.UnRegister(interfaceName, REST, re.GetInvoker().GetURL().ServiceKey())
if err != nil {
logger.Errorf("[RestExporter.Unexport] error: %v", err)
logger.Errorf("[RestExporter.UnExport] error: %v", err)
}
}
4 changes: 2 additions & 2 deletions protocol/rest/rest_protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ package rest
// // make sure url
// eq := exporter.GetInvoker().GetURL().URLEqual(url)
// assert.True(t, eq)
// // make sure exporterMap after 'Unexport'
// // make sure exporterMap after 'UnExport'
// fmt.Println(url.Path)
// _, ok := proto.(*RestProtocol).ExporterMap().Load(strings.TrimPrefix(url.Path, "/"))
// assert.True(t, ok)
// exporter.Unexport()
// exporter.UnExport()
// _, ok = proto.(*RestProtocol).ExporterMap().Load(strings.TrimPrefix(url.Path, "/"))
// assert.False(t, ok)
//
Expand Down

0 comments on commit 013f0b2

Please sign in to comment.