Skip to content

Commit

Permalink
add DTO for partition information
Browse files Browse the repository at this point in the history
  • Loading branch information
pitabwire committed Feb 22, 2024
1 parent 86a1d59 commit 25e4f40
Showing 1 changed file with 28 additions and 17 deletions.
45 changes: 28 additions & 17 deletions go/common/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@ import (
"google.golang.org/grpc/metadata"
)

const ctxKeyTenantId = CtxServiceKey("tenantIdKey")
const ctxKeyPartitionId = CtxServiceKey("partitionIdKey")
const ctxKeyAccessId = CtxServiceKey("accessIdKey")
const ctxKeyProfileId = CtxServiceKey("profileIdKey")
const ctxKeyPartitionInfo = CtxServiceKey("partitionInfoKey")

type PartitionInfo struct {
TenantId string
PartitionId string
AccessId string
ProfileId string
}

type GrpcClientBase struct {
// gRPC connection to the service.
Expand Down Expand Up @@ -72,8 +76,8 @@ func (gbc *GrpcClientBase) GetInfo() metadata.MD {
return gbc.xMetadata
}

func (gbc *GrpcClientBase) ToContext(ctx context.Context, key CtxServiceKey, val string) context.Context {
return context.WithValue(ctx, key, val)
func (gbc *GrpcClientBase) SetPartitionInfo(ctx context.Context, partitionInfo *PartitionInfo) context.Context {
return context.WithValue(ctx, ctxKeyPartitionInfo, partitionInfo)
}

func NewClientBase(ctx context.Context, opts ...ClientOption) (*GrpcClientBase, error) {
Expand All @@ -97,21 +101,25 @@ type JWTInterceptor struct {
mu sync.Mutex
}

func (jwt *JWTInterceptor) fromContext(ctx context.Context, key CtxServiceKey) string {
val, ok := ctx.Value(key).(string)
func (jwt *JWTInterceptor) fromContext(ctx context.Context, key CtxServiceKey) *PartitionInfo {
val, ok := ctx.Value(key).(*PartitionInfo)
if !ok {
return ""
return nil
}

return val
}

func (jwt *JWTInterceptor) setupPartitionData(ctx context.Context) context.Context {
finalCtx := metadata.AppendToOutgoingContext(ctx, "tenant_id", jwt.fromContext(ctx, ctxKeyTenantId))
finalCtx = metadata.AppendToOutgoingContext(finalCtx, "partition_id", jwt.fromContext(ctx, ctxKeyPartitionId))
finalCtx = metadata.AppendToOutgoingContext(finalCtx, "access_id", jwt.fromContext(ctx, ctxKeyAccessId))
finalCtx = metadata.AppendToOutgoingContext(finalCtx, "profile_id", jwt.fromContext(ctx, ctxKeyProfileId))
return finalCtx
partitionInfo := jwt.fromContext(ctx, ctxKeyPartitionInfo)

if partitionInfo == nil {
return ctx
}
finalCtx := metadata.AppendToOutgoingContext(ctx, "tenant_id", partitionInfo.TenantId)
finalCtx = metadata.AppendToOutgoingContext(finalCtx, "partition_id", partitionInfo.PartitionId)
finalCtx = metadata.AppendToOutgoingContext(finalCtx, "access_id", partitionInfo.AccessId)
return metadata.AppendToOutgoingContext(finalCtx, "profile_id", partitionInfo.ProfileId)
}

func (jwt *JWTInterceptor) getTokenStr(ctx context.Context) (string, error) {
Expand Down Expand Up @@ -153,10 +161,12 @@ func (jwt *JWTInterceptor) UnaryClientInterceptor(
return err
}

finalCtx := ctx
var finalCtx context.Context
if tokenStr != "" {
// Create a new context with the token and make the first request
finalCtx = metadata.AppendToOutgoingContext(ctx, "Authorization", "Bearer "+jwt.token.AccessToken)
} else {
finalCtx = ctx
}

finalCtx = jwt.setupPartitionData(finalCtx)
Expand All @@ -177,11 +187,12 @@ func (jwt *JWTInterceptor) StreamClientInterceptor(
return nil, err
}

finalCtx := ctx

var finalCtx context.Context
if tokenStr != "" {
// Create a new context with the token and make the first request
finalCtx = metadata.AppendToOutgoingContext(ctx, "Authorization", "Bearer "+jwt.token.AccessToken)
} else {
finalCtx = ctx
}

finalCtx = jwt.setupPartitionData(finalCtx)
Expand Down

0 comments on commit 25e4f40

Please sign in to comment.